#!/usr/bin/env python3

import os
import sys
import time
import struct
import serial
import pylink
sys.path.insert(0, '.')
from test_common import (
    SaleaeTimeMeasurements,
    parse_nist_aead_test_vectors,
    DeviceUnderTestAeadUARTP,
    run_nist_aead_test,
)


def eprint(*args, **kargs):
    print(*args, file=sys.stderr, **kargs)


def get_serial():
    import serial.tools.list_ports
    ports = serial.tools.list_ports.comports()
    for port in ports:
        print(port.serial_number)
    devices = [
        p.device
        for p in ports
        if p.serial_number == 'FT2XA9MY'
    ]
    devices.sort()
    return serial.Serial(
        devices[0],
        baudrate=115200,
        timeout=5)


class F7(DeviceUnderTestAeadUARTP):
    RAM_SIZE = 0x50000
    def __init__(self):
        DeviceUnderTestAeadUARTP.__init__(self, get_serial())
        self.jlink = pylink.JLink()
        self.jlink.open(779340002)


    def flash(self):
        jlink = self.jlink
        jlink.connect('STM32F746ZG')
        jlink.flash_file('build/f7.bin', 0x8000000)
        jlink.flash_file('ram_pattern.bin', 0x20000000)
        jlink.reset()
        jlink.restart()

    def dump_ram(self):
        jlink = self.jlink
        return bytes(jlink.memory_read8(0x20000000, F7.RAM_SIZE))


def main(argv):
    if len(argv) < 2:
        print("Usage: test LWC_AEAD_KAT.txt")

    kat = list(parse_nist_aead_test_vectors(argv[1]))

    eprint(argv[0])
    script_dir = os.path.split(argv[0])[0]
    if len(script_dir) > 0:
        os.chdir(script_dir)

    dut = F7()

    dut.flash()
    eprint("Flashed")
    dut.prepare()
    eprint("Board initialized properly")
    sys.stdout.write("Hello, World!\n")
    sys.stdout.flush()

    try:
        run_nist_aead_test(dut, kat)
        return 0

    except Exception as ex:
        print("TEST FAILED")
        raise ex

    finally:
        sys.stdout.flush()
        sys.stderr.flush()


if __name__ == "__main__":
    sys.exit(main(sys.argv))