diff --git a/compile_all.py b/compile_all.py index a0be5aa..e310851 100755 --- a/compile_all.py +++ b/compile_all.py @@ -5,7 +5,6 @@ import sys import stat import argparse import shutil -import random import subprocess import shutil @@ -15,7 +14,7 @@ def build(algo_dir, template_dir, build_dir): return None print("Building in %s" % build_dir) - + # copy all the files from the submitted algorithm into the build directory shutil.copytree(algo_dir, build_dir) @@ -68,7 +67,6 @@ def build(algo_dir, template_dir, build_dir): p.wait() assert p.returncode == 0 - finally: sys.stdout.flush() sys.stderr.flush() @@ -122,7 +120,8 @@ def main(argv): # get all the submissions by looking for files named "api.h" subfiles = [] for submission in subs: - implementations_dir = os.path.join(submissions_dir, submission, "Implementations", "crypto_aead") + implementations_dir = os.path.join( + submissions_dir, submission, "Implementations", "crypto_aead") if not os.path.isdir(implementations_dir): continue @@ -164,19 +163,17 @@ def main(argv): pieces = f.split(os.sep) n = pieces[1] + "." + ".".join(pieces[4:-1]) print(n) - + # if include_list was provided, skip elements not in the list if include_list is not None: - if not n in include_list: + if n not in include_list: continue # Put all in a tuple and count files.append((t, d, n)) - - # For testing, we only do the first 1 - #files = files[:1] + # files = files[:1] print("%d algorithms will be compiled" % len(files)) if not os.path.isdir(build_root_dir): @@ -198,8 +195,10 @@ def main(argv): b = build(d, template_dir, build_dir) if b is None: continue - test_script.write("\n\necho \"TEST NUMBER %03d: TESTING %s\"\n" % (i, d)) - test_script.write("python3 -u ./test.py %s %s 2> %s | tee %s\n" % ( + test_script.write( + "\n\necho \"TEST NUMBER %03d: TESTING %s\"\n" % (i, d)) + test_script.write( + "python3 -u ./test.py %s %s 2> %s | tee %s\n" % ( t, os.path.join(b, 'test'), os.path.join(b, 'test_stderr.log'), diff --git a/templates/bluepill/openocd.cfg b/templates/bluepill/openocd.cfg index d70df06..d0732b8 100644 --- a/templates/bluepill/openocd.cfg +++ b/templates/bluepill/openocd.cfg @@ -21,6 +21,6 @@ source [find target/stm32f1x.cfg] #tpiu config internal swodump.stm32f103-generic.log uart off 72000000 -#reset_config srst_only srst_push_pull srst_nogate connect_assert_srst -reset_config none srst_push_pull srst_nogate +reset_config srst_only srst_push_pull srst_nogate connect_assert_srst +#reset_config none srst_push_pull srst_nogate diff --git a/templates/bluepill/test b/templates/bluepill/test index 7550a6e..9114b6d 100755 --- a/templates/bluepill/test +++ b/templates/bluepill/test @@ -2,197 +2,126 @@ import os import sys -import time -import struct -import serial import subprocess +import serial.tools.list_ports +from test_common import ( + LogicMultiplexerTimeMeasurements, + parse_nist_aead_test_vectors, + DeviceUnderTestAeadUARTP, + compare_dumps, + eprint, + run_nist_aead_test_line, +) +def get_serial(): + ports = serial.tools.list_ports.comports() + devices = [ + p.device + for p in ports + if p.serial_number == 'FT2XCRZ1' + ] + devices.sort() + return serial.Serial( + devices[0], + baudrate=115200, + timeout=5) + + +class BluePill(DeviceUnderTestAeadUARTP): + RAM_SIZE = 0x5000 + + def __init__(self, build_dir): + DeviceUnderTestAeadUARTP.__init__(self, get_serial()) + + self.firmware_path = os.path.join( + build_dir, '.pio/build/bluepill_f103c8/firmware.elf') + self.ram_pattern_path = os.path.join( + build_dir, 'empty_ram.bin') + self.ram_dump_path = os.path.join( + build_dir, 'ram_dump.bin') + self.openocd_cfg_path = os.path.join( + build_dir, 'openocd.cfg') + + def flash(self): + # pipe = subprocess.PIPE + cmd = [ + 'openocd', '-f', 'openocd.cfg', '-c', + 'program %s verify reset exit' % self.firmware_path] + p = subprocess.Popen( + cmd, stdout=sys.stderr, stdin=sys.stdout) + stdout, stderr = p.communicate("") + eprint("Firmware flashed.") + + cmd = [ + 'openocd', '-f', self.openocd_cfg_path, '-c', + 'program %s reset exit 0x20000000' % self.ram_pattern_path] + p = subprocess.Popen( + cmd, stdout=sys.stderr, stdin=sys.stdout) + stdout, stderr = p.communicate("") + eprint("RAM flashed.") + + def dump_ram(self): + cmd = [ + 'openocd', '-f', self.openocd_cfg_path, + '-c', 'init', + '-c', 'halt', + '-c', 'dump_image %s 0x20000000 0x%x' % ( + self.ram_dump_path, BluePill.RAM_SIZE), + '-c', 'resume', + '-c', 'exit'] + p = subprocess.Popen( + cmd, stdout=sys.stderr, stdin=sys.stdout) + stdout, stderr = p.communicate("") + eprint("RAM dumped.") + with open(self.ram_dump_path, 'rb') as ram: + ram = ram.read() + if len(ram) != BluePill.RAM_SIZE: + raise Exception( + "RAM dump was %d bytes instead of %d" % + (len(ram), BluePill.RAM_SIZE)) + return ram -def eprint(*args, **kargs): - print(*args, file=sys.stderr, **kargs) +def main(argv): + if len(argv) != 3: + print("Usage: test LWC_AEAD_KAT.txt build_dir") + return 1 -def flash(): - pipe = subprocess.PIPE - cmd = ['openocd', '-f', 'openocd.cfg', '-c' 'program ' + - '.pio/build/bluepill_f103c8/firmware.elf verify reset exit'] - p = subprocess.Popen(cmd, - stdout=sys.stderr, stdin=sys.stdout) - stdout, stderr = p.communicate("") + kat = list(parse_nist_aead_test_vectors(argv[1])) + build_dir = argv[2] + dut = BluePill(build_dir) -def fill_ram(): - pipe = subprocess.PIPE - cmd = ['openocd', '-f', 'openocd.cfg', '-c' 'program ' + - 'empty_ram.bin reset exit 0x20000000'] - p = subprocess.Popen(cmd, - stdout=sys.stderr, stdin=sys.stdout) - stdout, stderr = p.communicate("") + try: + tool = LogicMultiplexerTimeMeasurements(0x0003) + tool.begin_measurement() + dut.flash() + dut.prepare() + sys.stdout.write("Board prepared\n") + sys.stdout.flush() -def get_serial(): - import serial.tools.list_ports - ports = serial.tools.list_ports.comports() - devices = [ p.device for p in ports ] - devices.sort() - return devices[-1] - - -class UARTP: - def __init__(self, ser): - UARTP.SYN = 0xf9 - UARTP.FIN = 0xf3 - self.ser = ser - - def uart_read(self): - r = self.ser.read(1) - if len(r) != 1: - raise Exception("Serial read error") - return r[0] - - def uart_write(self, c): - b = struct.pack("B", c) - r = self.ser.write(b) - if r != len(b): - raise Exception("Serial write error") - return r - - def send(self, buf): - self.uart_write(UARTP.SYN) - len_ind_0 = 0xff & len(buf) - len_ind_1 = 0xff & (len(buf) >> 7) - if len(buf) < 128: - self.uart_write(len_ind_0) - else: - self.uart_write(len_ind_0 | 0x80) - self.uart_write(len_ind_1) - fcs = 0 - for i in range(len(buf)): - info = buf[i] - fcs = (fcs + info) & 0xff - self.uart_write(buf[i]) - fcs = (0xff - fcs) & 0xff - self.uart_write(fcs) - self.uart_write(UARTP.FIN) - eprint("sent frame '%s'" % buf.hex()) - - def recv(self): - tag_old = UARTP.FIN - while 1: - tag = tag_old - while 1: - if tag_old == UARTP.FIN: - if tag == UARTP.SYN: - break - tag_old = tag - tag = self.uart_read() - tag_old = tag - - l = self.uart_read() - if l & 0x80: - l &= 0x7f - l |= self.uart_read() << 7 - - fcs = 0 - buf = [] - for i in range(l): - info = self.uart_read() - buf.append(info) - fcs = (fcs + info) & 0xff - fcs = (fcs + self.uart_read()) & 0xff - - tag = self.uart_read() - if fcs == 0xff: - if tag == UARTP.FIN: - buf = bytes(buf) - eprint("rcvd frame '%s'" % buf.hex()) - if len(buf) >= 1 and buf[0] == 0xde: - sys.stderr.buffer.write(buf[1:]) - sys.stderr.flush() - else: - return buf + dump_a = dut.dump_ram() -def main(argv): - eprint(argv[0]) - script_dir = os.path.split(argv[0])[0] - if len(script_dir) > 0: - os.chdir(script_dir) - - dev = get_serial() - ser = serial.Serial(dev, baudrate=115200, timeout=5) - uartp = UARTP(ser) - - flash() - fill_ram() - eprint("Flashed") - time.sleep(0.1) - - ser.setDTR(False) # IO0=HIGH - ser.setRTS(True) # EN=LOW, chip in reset - time.sleep(0.1) - ser.setDTR(False) # IO0=HIGH - ser.setRTS(False) # EN=HIGH, chip out of reset - time.sleep(1) - - def stdin_read(n): - b = sys.stdin.buffer.read(n) - if len(b) != n: - sys.exit(1) - return b - - def stdin_readvar(): - l = stdin_read(4) - (l, ) = struct.unpack("> 7) - if len(buf) < 128: - self.uart_write(len_ind_0) - else: - self.uart_write(len_ind_0 | 0x80) - self.uart_write(len_ind_1) - fcs = 0 - for i in range(len(buf)): - info = buf[i] - fcs = (fcs + info) & 0xff - self.uart_write(buf[i]) - fcs = (0xff - fcs) & 0xff - self.uart_write(fcs) - self.uart_write(UARTP.FIN) - eprint("sent frame '%s'" % buf.hex()) - - def recv(self): - tag_old = UARTP.FIN - while 1: - tag = tag_old - while 1: - if tag_old == UARTP.FIN: - if tag == UARTP.SYN: - break - tag_old = tag - tag = self.uart_read() - tag_old = tag - - l = self.uart_read() - if l & 0x80: - l &= 0x7f - l |= self.uart_read() << 7 - - fcs = 0 - buf = [] - for i in range(l): - info = self.uart_read() - buf.append(info) - fcs = (fcs + info) & 0xff - fcs = (fcs + self.uart_read()) & 0xff - - tag = self.uart_read() - if fcs == 0xff: - if tag == UARTP.FIN: - buf = bytes(buf) - eprint("rcvd frame '%s'" % buf.hex()) - if len(buf) >= 1 and buf[0] == 0xde: - sys.stderr.buffer.write(buf[1:]) - sys.stderr.flush() - else: - return buf def main(argv): - eprint(argv[0]) - script_dir = os.path.split(argv[0])[0] - if len(script_dir) > 0: - os.chdir(script_dir) - - dev = get_serial() - flash(dev) - eprint("Flashed") - time.sleep(0.1) - - ser = serial.Serial(dev, baudrate=500000, timeout=5) - uartp = UARTP(ser) - - ser.setDTR(False) # IO0=HIGH - ser.setRTS(True) # EN=LOW, chip in reset - time.sleep(0.1) - ser.setDTR(False) # IO0=HIGH - ser.setRTS(False) # EN=HIGH, chip out of reset - time.sleep(1) - - def stdin_read(n): - b = sys.stdin.buffer.read(n) - if len(b) != n: - sys.exit(1) - return b - - def stdin_readvar(): - l = stdin_read(4) - (l, ) = struct.unpack("> 7) - if len(buf) < 128: - self.uart_write(len_ind_0) - else: - self.uart_write(len_ind_0 | 0x80) - self.uart_write(len_ind_1) - fcs = 0 - for i in range(len(buf)): - info = buf[i] - fcs = (fcs + info) & 0xff - self.uart_write(buf[i]) - fcs = (0xff - fcs) & 0xff - self.uart_write(fcs) - self.uart_write(UARTP.FIN) - eprint("sent frame '%s'" % buf.hex()) - - def recv(self): - tag_old = UARTP.FIN - while 1: - tag = tag_old - while 1: - if tag_old == UARTP.FIN: - if tag == UARTP.SYN: - break - tag_old = tag - tag = self.uart_read() - tag_old = tag - - l = self.uart_read() - if l & 0x80: - l &= 0x7f - l |= self.uart_read() << 7 - - fcs = 0 - buf = [] - for i in range(l): - info = self.uart_read() - buf.append(info) - fcs = (fcs + info) & 0xff - fcs = (fcs + self.uart_read()) & 0xff - - tag = self.uart_read() - if fcs == 0xff: - if tag == UARTP.FIN: - buf = bytes(buf) - eprint("rcvd frame '%s'" % buf.hex()) - if len(buf) >= 1 and buf[0] == 0xde: - sys.stderr.buffer.write(buf[1:]) - sys.stderr.flush() - else: - return buf + dut.flash() + dut.prepare() + sys.stdout.write("Board prepared\n") + sys.stdout.flush() -def main(argv): - eprint(argv[0]) - script_dir = os.path.split(argv[0])[0] - if len(script_dir) > 0: - os.chdir(script_dir) - - dev = get_serial() - ser = serial.Serial(dev, baudrate=115200, timeout=5) - uartp = UARTP(ser) - - flash() - fill_ram() - eprint("Flashed") - time.sleep(0.1) - - def stdin_read(n): - b = sys.stdin.buffer.read(n) - if len(b) != n: - sys.exit(1) - return b - - def stdin_readvar(): - l = stdin_read(4) - (l, ) = struct.unpack("> 7) - if len(buf) < 128: - self.uart_write(len_ind_0) - else: - self.uart_write(len_ind_0 | 0x80) - self.uart_write(len_ind_1) - fcs = 0 - for i in range(len(buf)): - info = buf[i] - fcs = (fcs + info) & 0xff - self.uart_write(buf[i]) - fcs = (0xff - fcs) & 0xff - self.uart_write(fcs) - self.uart_write(UARTP.FIN) - eprint("sent frame '%s'" % buf.hex()) - - def recv(self): - tag_old = UARTP.FIN - while 1: - tag = tag_old - while 1: - if tag_old == UARTP.FIN: - if tag == UARTP.SYN: - break - tag_old = tag - tag = self.uart_read() - tag_old = tag - - l = self.uart_read() - if l & 0x80: - l &= 0x7f - l |= self.uart_read() << 7 - - fcs = 0 - buf = [] - for i in range(l): - info = self.uart_read() - buf.append(info) - fcs = (fcs + info) & 0xff - fcs = (fcs + self.uart_read()) & 0xff - - tag = self.uart_read() - if fcs == 0xff: - if tag == UARTP.FIN: - buf = bytes(buf) - eprint("rcvd frame '%s'" % buf.hex()) - if len(buf) >= 1 and buf[0] == 0xde: - sys.stderr.buffer.write(buf[1:]) - sys.stderr.flush() - else: - return buf - - -def stdin_read(n): - b = sys.stdin.buffer.read(n) - if len(b) != n: - sys.exit(1) - return b - - -def stdin_readvar(): - l = stdin_read(4) - (l, ) = struct.unpack(" 0: - os.chdir(script_dir) - - - dev = get_serial() - flash(dev) - eprint("Flashed") - time.sleep(0.1) - ser = serial.Serial(dev, baudrate=1500000, timeout=5) - uartp = UARTP(ser) - - ser.setRTS(True) - time.sleep(0.1) - ser.setRTS(False) - time.sleep(0.1) - ser.setRTS(True) - time.sleep(1) - - exp_hello = b"Hello, World!" - hello = ser.read(len(exp_hello)) - if hello != exp_hello: - eprint("Improper board initialization message: ") + if len(argv) != 3: + print("Usage: test LWC_AEAD_KAT.txt build_dir") return 1 - eprint("Board initialized properly") - sys.stdout.write("Hello, World!\n") - sys.stdout.flush() - - while 1: - action = stdin_read(1)[0] - eprint("Command %c from stdin" % action) - - if action in b"ackmps": - v = stdin_readvar() - uartp.send(struct.pack("B", action) + v) - ack = uartp.recv() - if len(ack) != 1 or ack[0] != action: - raise Exception("Unacknowledged variable transfer") - eprint("Var %c successfully sent to board" % action) - - elif action in b"ACKMPS": - c = struct.pack("B", action) - uartp.send(c) - v = uartp.recv() - if len(v) < 1 or v[0] != action: - raise Exception("Could not obtain variable from board") - v = v[1:] - eprint("Var %c received from board: %s" % (action, v.hex())) - l = struct.pack("> 7) - if len(buf) < 128: - self.uart_write(len_ind_0) - else: - self.uart_write(len_ind_0 | 0x80) - self.uart_write(len_ind_1) - fcs = 0 - for i in range(len(buf)): - info = buf[i] - fcs = (fcs + info) & 0xff - self.uart_write(buf[i]) - fcs = (0xff - fcs) & 0xff - self.uart_write(fcs) - self.uart_write(UARTP.FIN) - eprint("sent frame '%s'" % buf.hex()) - - def recv(self): - tag_old = UARTP.FIN - while 1: - tag = tag_old - while 1: - if tag_old == UARTP.FIN: - if tag == UARTP.SYN: - break - tag_old = tag - tag = self.uart_read() - tag_old = tag - - l = self.uart_read() - if l & 0x80: - l &= 0x7f - l |= self.uart_read() << 7 - - fcs = 0 - buf = [] - for i in range(l): - info = self.uart_read() - buf.append(info) - fcs = (fcs + info) & 0xff - fcs = (fcs + self.uart_read()) & 0xff - - tag = self.uart_read() - if fcs == 0xff: - if tag == UARTP.FIN: - buf = bytes(buf) - eprint("rcvd frame '%s'" % buf.hex()) - if len(buf) >= 1 and buf[0] == 0xde: - sys.stderr.buffer.write(buf[1:]) - sys.stderr.flush() - else: - return buf - - -def stdin_read(n): - b = sys.stdin.buffer.read(n) - if len(b) != n: - sys.exit(1) - return b - - -def stdin_readvar(): - l = stdin_read(4) - (l, ) = struct.unpack(" 0: - os.chdir(script_dir) - - dev = get_serial() - flash(dev) - eprint("Flashed") - time.sleep(0.1) - - ser = serial.Serial(dev, baudrate=115200, timeout=5) - uartp = UARTP(ser) - - ser.setDTR(True) - time.sleep(0.01) - ser.setDTR(False) - time.sleep(1) - - exp_hello = b"Hello, World!" - hello = ser.read(len(exp_hello)) - if hello != exp_hello: - eprint("Improper board initialization message: ") + if len(argv) != 3: + print("Usage: test LWC_AEAD_KAT.txt build_dir") return 1 - eprint("Board initialized properly") - sys.stdout.write("Hello, World!\n") - sys.stdout.flush() - - while 1: - action = stdin_read(1)[0] - eprint("Command %c from stdin" % action) - - if action in b"ackmps": - v = stdin_readvar() - uartp.send(struct.pack("B", action) + v) - ack = uartp.recv() - if len(ack) != 1 or ack[0] != action: - raise Exception("Unacknowledged variable transfer") - eprint("Var %c successfully sent to board" % action) - - elif action in b"ACKMPS": - c = struct.pack("B", action) - uartp.send(c) - v = uartp.recv() - if len(v) < 1 or v[0] != action: - raise Exception("Could not obtain variable from board") - v = v[1:] - eprint("Var %c received from board: %s" % (action, v.hex())) - l = struct.pack(" %s\n" % (' '.join(sys.argv), outfile)) - mdbfile.close() - return 0 - raise Exception("Capture didn't complete successfully") - -if __name__ == "__main__": - sys.exit(main(sys.argv)) diff --git a/test_common.py b/test_common.py new file mode 100644 index 0000000..384c276 --- /dev/null +++ b/test_common.py @@ -0,0 +1,414 @@ +#!/usr/bin/env python3 + +import os +import re +import sys +import time +import struct +import serial + + +def eprint(*args, **kargs): + print(*args, file=sys.stderr, **kargs) + + +class DeviceUnderTest: + def __init__(self): + pass + + def flash(self): + """ + This method should be overridden to flash the DUT. + """ + pass + + def dump_ram(self): + """ + This should be overridden to return the RAM dump as a bytes string. + """ + return None + + +class DeviceUnderTestAeadUARTP(DeviceUnderTest): + def __init__(self, ser=None): + self.ser = ser + + def prepare(self): + exp_hello = b"Hello, World!" + time.sleep(0.1) + if self.ser.in_waiting < 13: + time.sleep(2) + hello = self.ser.read(self.ser.in_waiting) + if hello[-13:] != exp_hello: + raise Exception( + "Improper board initialization message: %s" % hello) + self.uartp = UARTP(self.ser) + + def send_var(self, key, value): + self.uartp.send(struct.pack("B", key) + value) + ack = self.uartp.recv() + if len(ack) != 1 or ack[0] != key: + raise Exception("Unacknowledged variable transfer") + + def obtain_var(self, key): + c = struct.pack("B", key) + self.uartp.send(c) + v = self.uartp.recv() + if len(v) < 1 or v[0] != key: + raise Exception("Could not obtain variable from board") + return v[1:] + + def do_cmd(self, action): + c = struct.pack("B", action) + self.uartp.send(c) + ack = self.uartp.recv() + if len(ack) != 1 or ack[0] != action: + raise Exception("Unacknowledged command") + + +class UARTP: + def __init__(self, ser): + UARTP.SYN = 0xf9 + UARTP.FIN = 0xf3 + self.ser = ser + + def uart_read(self): + r = self.ser.read(1) + if len(r) != 1: + raise Exception("Serial read error") + return r[0] + + def uart_write(self, c): + b = struct.pack("B", c) + r = self.ser.write(b) + if r != len(b): + raise Exception("Serial write error") + return r + + def send(self, buf): + self.uart_write(UARTP.SYN) + len_ind_0 = 0xff & len(buf) + len_ind_1 = 0xff & (len(buf) >> 7) + if len(buf) < 128: + self.uart_write(len_ind_0) + else: + self.uart_write(len_ind_0 | 0x80) + self.uart_write(len_ind_1) + fcs = 0 + for i in range(len(buf)): + info = buf[i] + fcs = (fcs + info) & 0xff + self.uart_write(buf[i]) + fcs = (0xff - fcs) & 0xff + self.uart_write(fcs) + self.uart_write(UARTP.FIN) + # eprint("sent frame '%s'" % buf.hex()) + + def recv(self): + tag_old = UARTP.FIN + while 1: + tag = tag_old + while 1: + if tag_old == UARTP.FIN: + if tag == UARTP.SYN: + break + tag_old = tag + tag = self.uart_read() + tag_old = tag + + pkt = self.uart_read() + if pkt & 0x80: + pkt &= 0x7f + pkt |= self.uart_read() << 7 + + fcs = 0 + buf = [] + for i in range(pkt): + info = self.uart_read() + buf.append(info) + fcs = (fcs + info) & 0xff + fcs = (fcs + self.uart_read()) & 0xff + + tag = self.uart_read() + if fcs == 0xff: + if tag == UARTP.FIN: + buf = bytes(buf) + # eprint("rcvd frame '%s'" % buf.hex()) + if len(buf) >= 1 and buf[0] == 0xde: + sys.stderr.buffer.write(buf[1:]) + sys.stderr.flush() + else: + return buf + + +def run_nist_aead_test_line(dut, i, m, ad, k, npub, c): + eprint() + eprint("Count = %d" % i) + eprint(" m = %s" % m.hex()) + eprint(" ad = %s" % ad.hex()) + eprint("npub = %s" % npub.hex()) + eprint(" k = %s" % k.hex()) + eprint(" c = %s" % c.hex()) + + dut.send_var(ord('c'), b"\0" * (len(m) + 32)) + dut.send_var(ord('s'), b"") + + dut.send_var(ord('m'), m) + dut.send_var(ord('a'), ad) + dut.send_var(ord('k'), k) + dut.send_var(ord('p'), npub) + + dut.do_cmd(ord('e')) + output = dut.obtain_var(ord('C')) + print(" c = %s" % output.hex()) + if c != output: + raise Exception("output of encryption is different from " + + "expected ciphertext") + + dut.send_var(ord('m'), b"\0" * len(c)) + dut.send_var(ord('s'), b"") + + dut.send_var(ord('c'), c) + dut.send_var(ord('a'), ad) + dut.send_var(ord('k'), k) + dut.send_var(ord('p'), npub) + + dut.do_cmd(ord('d')) + output = dut.obtain_var(ord('M')) + print(" m = %s" % output.hex()) + if m != output: + raise Exception("output of encryption is different from " + + "expected ciphertext") + + +def compare_dumps(dump_a, dump_b): + """ + Gets the length of the longes streaks of equal bytes in two RAM dumps + """ + streaks = [] + streak_beg = 0 + streak_end = 0 + for i in range(len(dump_a)): + if dump_a[i] == dump_b[i]: + streak_end = i + else: + if streak_end != streak_beg: + streaks.append((streak_beg, streak_end)) + streak_beg = i + streak_end = i + + for b, e in streaks: + eprint("equal bytes from 0x%x to 0x%x (length: %d)" % + (b, e, e-b)) + + b, e = max(streaks, key=lambda a: a[1]-a[0]) + eprint( + "longest equal bytes streak from 0x%x to 0x%x (length: %d)" % + (b, e, e-b)) + return e-b + + +def parse_nist_aead_test_vectors(test_file_path): + with open(test_file_path, 'r') as test_file: + lineprog = re.compile( + r"^\s*([A-Z]+)\s*=\s*(([0-9a-f])*)\s*$", + re.IGNORECASE) + m = b"" + ad = b"" + k = b"" + npub = b"" + c = b"" + i = -1 + for line in test_file.readlines(): + line = line.strip() + res = lineprog.match(line) + if line == "": + yield i, m, ad, k, npub, c + m = b"" + ad = b"" + k = b"" + npub = b"" + c = b"" + elif res is not None: + if res[1].lower() == 'count': + i = int(res[2], 10) + elif res[1].lower() == 'key': + k = bytes.fromhex(res[2]) + elif res[1].lower() == 'nonce': + npub = bytes.fromhex(res[2]) + elif res[1].lower() == 'pt': + m = bytes.fromhex(res[2]) + elif res[1].lower() == 'ad': + ad = bytes.fromhex(res[2]) + elif res[1].lower() == 'ct': + c = bytes.fromhex(res[2]) + else: + raise Exception( + "ERROR: unparsed line in test vectors file: '%s'" + % res) + else: + raise Exception( + "ERROR: unparsed line in test vectors file: '%s'" % line) + + +class TimeMeasurementTool: + def begin_measurement(self): + pass + + def arm(self): + pass + + def unarm(self): + pass + + def end_measurement(self): + pass + + +class LogicMultiplexerTimeMeasurements(TimeMeasurementTool): + + def __init__(self, mask=0xffffffffffffffff): + self.mask = mask + self.sock = None + self.capture = [] + + def recv_samples(self): + import socket + + capture = [] + while 1: + try: + rcvd = self.sock.recv(16) + except socket.timeout: + break + except BlockingIOError: + break + if len(rcvd) != 16: + raise Exception("Could not receive 16 bytes of logic sample!") + + time, value = struct.unpack(" %s\n" % (' '.join(sys.argv), outfile)) + mdbfile.close() + return 0 + raise Exception("Capture didn't complete successfully") + + +def main(argv): + if len(argv) < 3: + print("Usage: test_common.py port LWC_AEAD_KAT.txt") + + eprint(argv[0]) + script_dir = os.path.split(argv[0])[0] + if len(script_dir) > 0: + os.chdir(script_dir) + + kat = list(parse_nist_aead_test_vectors(argv[2])) + + dev = argv[1] + ser = serial.Serial(dev, baudrate=115200, timeout=5) + dut = DeviceUnderTestAeadUARTP(ser) + + try: + tool = SaleaeTimeMeasurements() + tool.begin_measurement() + dut.flash() + eprint("Flashed") + dut.prepare() + eprint("Prepared") + sys.stdout.write("Hello, World!\n") + sys.stdout.flush() + + dump_a = dut.dump_ram() + + for i, m, ad, k, npub, c in kat: + tool.arm() + run_nist_aead_test_line(dut, i, m, ad, k, npub, c) + tool.unarm() + + if dump_a is not None and i == 1: + dump_b = dut.dump_ram() + longest = compare_dumps(dump_a, dump_b) + print(" longest chunk of untouched memory = %d" % longest) + + except Exception as ex: + print("TEST FAILED") + raise ex + + finally: + tool.end_measurement() + sys.stdout.flush() + sys.stderr.flush() + + +if __name__ == "__main__": + sys.exit(main(sys.argv))