#!/usr/bin/env python3

import os
import sys
import time
import struct
import serial
import subprocess


RAM_SIZE = 0x50000

def eprint(*args, **kargs):
    print(*args, file=sys.stderr, **kargs)


def popen_jlink():
    pipe = subprocess.PIPE
    cmd = ['JLinkExe']
    cmd.extend(['-autoconnect', '1'])
    cmd.extend(['-device', 'STM32F746ZG'])
    cmd.extend(['-if', 'swd'])
    cmd.extend(['-speed', '4000'])
    return subprocess.Popen(cmd, stdout=sys.stderr, stdin=pipe)


def flash():
    p = popen_jlink()
    return p.communicate(("""
loadbin build/f7.bin 0x8000000
r
g
exit
    """).encode('ascii'))


def fill_ram():
    p = popen_jlink()
    return p.communicate(("""
h
loadbin ram_pattern.bin 0x20000000
savebin ram_copy.bin 0x20000000 0x%x
r
g
exit
    """ % RAM_SIZE).encode('ascii'))


def dump_ram():
    p = popen_jlink()
    return p.communicate(("""
h
savebin ram_dump.bin 0x20000000 0x%x
exit
    """ % RAM_SIZE).encode('ascii'))


def get_serial():
    import serial.tools.list_ports
    ports = serial.tools.list_ports.comports()
    devices = [ p.device for p in ports ]
    devices.sort()
    return devices[-1]


class UARTP:
    def __init__(self, ser):
        UARTP.SYN = 0xf9
        UARTP.FIN = 0xf3
        self.ser = ser

    def uart_read(self):
        r = self.ser.read(1)
        if len(r) != 1:
            raise Exception("Serial read error")
        return r[0]

    def uart_write(self, c):
        b = struct.pack("B", c)
        r = self.ser.write(b)
        if r != len(b):
            raise Exception("Serial write error")
        return r

    def send(self, buf):
        self.uart_write(UARTP.SYN)
        len_ind_0 = 0xff & len(buf)
        len_ind_1 = 0xff & (len(buf) >> 7)
        if len(buf) < 128:
            self.uart_write(len_ind_0)
        else:
            self.uart_write(len_ind_0 | 0x80)
            self.uart_write(len_ind_1)
        fcs = 0
        for i in range(len(buf)):
            info = buf[i]
            fcs = (fcs + info) & 0xff
            self.uart_write(buf[i])
        fcs = (0xff - fcs) & 0xff
        self.uart_write(fcs)
        self.uart_write(UARTP.FIN)
        eprint("sent frame '%s'" % buf.hex())

    def recv(self):
        tag_old = UARTP.FIN
        while 1:
            tag = tag_old
            while 1:
                if tag_old == UARTP.FIN:
                    if tag == UARTP.SYN:
                        break
                tag_old = tag
                tag = self.uart_read()
            tag_old = tag
            
            l = self.uart_read()
            if l & 0x80:
                l &= 0x7f
                l |= self.uart_read() << 7

            fcs = 0
            buf = []
            for i in range(l):
                info = self.uart_read()
                buf.append(info)
                fcs = (fcs + info) & 0xff
            fcs = (fcs + self.uart_read()) & 0xff

            tag = self.uart_read()
            if fcs == 0xff:
                if tag == UARTP.FIN:
                    buf = bytes(buf)
                    eprint("rcvd frame '%s'" % buf.hex())
                    if len(buf) >= 1 and buf[0] == 0xde:
                        sys.stderr.buffer.write(buf[1:])
                        sys.stderr.flush()
                    else:
                        return buf

def main(argv):
    eprint(argv[0])
    script_dir = os.path.split(argv[0])[0]
    if len(script_dir) > 0:
        os.chdir(script_dir)

    dev = get_serial()
    ser = serial.Serial(dev, baudrate=115200, timeout=5)
    uartp = UARTP(ser)

    flash()
    fill_ram()
    eprint("Flashed")
    time.sleep(0.1)

    def stdin_read(n):
        b = sys.stdin.buffer.read(n)
        if len(b) != n:
            sys.exit(1)
        return b

    def stdin_readvar():
        l = stdin_read(4)
        (l, ) = struct.unpack("<I", l)
        v = stdin_read(l)
        return v

    exp_hello = b"Hello, World!"
    hello = ser.read(ser.in_waiting)
    if hello[-13:] != exp_hello:
        eprint("Improper board initialization message: %s" % hello)
        return 1
    eprint("Board initialized properly")
    sys.stdout.write("Hello, World!\n")
    sys.stdout.flush()

    while 1:
        action = stdin_read(1)[0]
        eprint("Command %c from stdin" % action)

        if action in b"ackmps":
            v = stdin_readvar()
            uartp.send(struct.pack("B", action) + v)
            ack = uartp.recv()
            if len(ack) != 1 or ack[0] != action:
                raise Exception("Unacknowledged variable transfer")
            eprint("Var %c successfully sent to board" % action)

        elif action in b"ACKMPS":
            c = struct.pack("B", action)
            uartp.send(c)
            v = uartp.recv()
            if len(v) < 1 or v[0] != action:
                raise Exception("Could not obtain variable from board")
            v = v[1:]
            eprint("Var %c received from board: %s" % (action, v.hex()))
            l = struct.pack("<I", len(v))
            sys.stdout.buffer.write(l + v)
            sys.stdout.flush()

        elif action in b"ed":
            c = struct.pack("B", action)
            uartp.send(c)
            ack = uartp.recv()
            if len(ack) != 1 or ack[0] != action:
                raise Exception("Unacknowledged variable transfer")
            eprint("Operation %c completed successfully" % action)

        elif action in b"u":
            dump_ram()
            with open("ram_copy.bin", 'rb') as dump:
                dump_a = dump.read()
            with open("ram_dump.bin", 'rb') as dump:
                dump_b = dump.read()
            
            if len(dump_a) != RAM_SIZE or len(dump_b) != RAM_SIZE:
                raise Exception("Wrong dump sizes: 0x%x, 0x%x" % (len(dump_a), len(dump_b)))
           
            streaks = []
            streak_beg = 0
            streak_end = 0
            for i in range(len(dump_a)):
                if dump_a[i] == dump_b[i]:
                    streak_end = i
                else:
                    if streak_end != streak_beg:
                        streaks.append((streak_beg, streak_end))
                    streak_beg = i
                    streak_end = i

            for b, e in streaks:
                eprint("equal bytes from 0x%x to 0x%x (length: %d)" % (b, e, e-b))

            b, e = max(streaks, key=lambda a: a[1]-a[0])
            eprint("longest equal bytes streak from 0x%x to 0x%x (length: %d)" % (b, e, e-b))
            v = struct.pack("<II", 4, e-b)
            sys.stdout.buffer.write(v)
            sys.stdout.flush()

        else:
            raise Exception("Unknown action %c" % action)
        
    
    return 0


if __name__ == "__main__":
    sys.exit(main(sys.argv))