diff --git a/compile_all.py b/compile_all.py index d209752..752b1b6 100755 --- a/compile_all.py +++ b/compile_all.py @@ -4,19 +4,13 @@ import os import sys import stat import shutil -import random import subprocess def build(algo_dir, template_dir, build_dir): # create a new directory for the build - while build_dir is None: - r = "%09d" % random.randint(0, 999999999) - d = os.path.join("build", r) - if not os.path.isdir(d): - build_dir = d print("Building in %s" % build_dir) - + # copy all the files from the submitted algorithm into the build directory shutil.copytree(algo_dir, build_dir) @@ -69,7 +63,6 @@ def build(algo_dir, template_dir, build_dir): p.wait() assert p.returncode == 0 - finally: sys.stdout.flush() sys.stderr.flush() @@ -101,6 +94,7 @@ def find_test_vectors(d): def main(argv): submissions_dir = "all-lwc-submission-files" template_dir = "templates/linux" + build_dir = 'build' include_list = None if len(argv) > 1: template_dir = argv[1] @@ -115,7 +109,8 @@ def main(argv): # get all the submissions by looking for files named "api.h" subfiles = [] for submission in subs: - implementations_dir = os.path.join(submissions_dir, submission, "Implementations", "crypto_aead") + implementations_dir = os.path.join( + submissions_dir, submission, "Implementations", "crypto_aead") if not os.path.isdir(implementations_dir): continue @@ -157,30 +152,28 @@ def main(argv): pieces = f.split(os.sep) n = pieces[1] + "." + ".".join(pieces[4:-1]) print(n) - + # if include_list was provided, skip elements not in the list if include_list is not None: - if not n in include_list: + if n not in include_list: continue # Put all in a tuple and count files.append((t, d, n)) - - # For testing, we only do the first 1 - #files = files[:1] + # files = files[:1] print("%d algorithms will be compiled" % len(files)) # Clear the build directory as it is a leftover from the previous execution - if os.path.isdir('build'): - shutil.rmtree('build') - os.mkdir('build') + if os.path.isdir(build_dir): + shutil.rmtree(build_dir) + os.mkdir(build_dir) print() # Write a script that executes all the tests one after the other - test_script_path = os.path.join("build", "test_all.sh") + test_script_path = os.path.join(build_dir, "test_all.sh") with open(test_script_path, 'w') as test_script: test_script.write("#!/bin/sh\n") test_script.write("mkdir -p logs\n") @@ -189,10 +182,13 @@ def main(argv): print() print(d) try: - build_dir = os.path.join("build", name) - b = build(d, template_dir, build_dir) - test_script.write("\n\necho \"TEST NUMBER %03d: TESTING %s\"\n" % (i, d)) - test_script.write("python3 -u ./test.py %s %s 2> %s | tee %s\n" % ( + b = build(d, template_dir, os.path.join(build_dir, name)) + if b is None: + continue + test_script.write( + "\n\necho \"TEST NUMBER %03d: TESTING %s\"\n" % (i, d)) + test_script.write( + "python3 -u ./test.py %s %s 2> %s | tee %s\n" % ( t, os.path.join(b, 'test'), os.path.join(b, 'test_stderr.log'), diff --git a/templates/f7/test b/templates/f7/test index 27cada9..ef7368a 100755 --- a/templates/f7/test +++ b/templates/f7/test @@ -112,7 +112,7 @@ class UARTP: tag_old = tag tag = self.uart_read() tag_old = tag - + l = self.uart_read() if l & 0x80: l &= 0x7f @@ -137,6 +137,7 @@ class UARTP: else: return buf + def main(argv): eprint(argv[0]) script_dir = os.path.split(argv[0])[0] @@ -238,8 +239,7 @@ def main(argv): else: raise Exception("Unknown action %c" % action) - - + return 0 diff --git a/test_common.py b/test_common.py new file mode 100644 index 0000000..af4b938 --- /dev/null +++ b/test_common.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python3 + +import os +import re +import sys +import time +import struct +import serial + + +def eprint(*args, **kargs): + print(*args, file=sys.stderr, **kargs) + + +class DeviceUnderTest: + def __init__(self): + pass + + def flash(self): + """ + This method should be overridden to flash the DUT. + """ + pass + + def ram_dump(self): + """ + This should be overridden to return the RAM dump as a bytes string. + """ + return None + + +class DeviceUnderTestAeadUARTP(DeviceUnderTest): + def __init__(self, ser): + self.ser = ser + + def prepare(self): + time.sleep(0.1) + exp_hello = b"Hello, World!" + hello = self.ser.read(self.ser.in_waiting) + if hello[-13:] != exp_hello: + raise Exception( + "Improper board initialization message: %s" % hello) + self.uartp = UARTP(self.ser) + + def send_var(self, key, value): + self.uartp.send(struct.pack("B", key) + value) + ack = self.uartp.recv() + if len(ack) != 1 or ack[0] != key: + raise Exception("Unacknowledged variable transfer") + + def obtain_var(self, key): + c = struct.pack("B", key) + self.uartp.send(c) + v = self.uartp.recv() + if len(v) < 1 or v[0] != key: + raise Exception("Could not obtain variable from board") + return v[1:] + + def do_cmd(self, action): + c = struct.pack("B", action) + self.uartp.send(c) + ack = self.uartp.recv() + if len(ack) != 1 or ack[0] != action: + raise Exception("Unacknowledged command") + + +class UARTP: + def __init__(self, ser): + UARTP.SYN = 0xf9 + UARTP.FIN = 0xf3 + self.ser = ser + + def uart_read(self): + r = self.ser.read(1) + if len(r) != 1: + raise Exception("Serial read error") + return r[0] + + def uart_write(self, c): + b = struct.pack("B", c) + r = self.ser.write(b) + if r != len(b): + raise Exception("Serial write error") + return r + + def send(self, buf): + self.uart_write(UARTP.SYN) + len_ind_0 = 0xff & len(buf) + len_ind_1 = 0xff & (len(buf) >> 7) + if len(buf) < 128: + self.uart_write(len_ind_0) + else: + self.uart_write(len_ind_0 | 0x80) + self.uart_write(len_ind_1) + fcs = 0 + for i in range(len(buf)): + info = buf[i] + fcs = (fcs + info) & 0xff + self.uart_write(buf[i]) + fcs = (0xff - fcs) & 0xff + self.uart_write(fcs) + self.uart_write(UARTP.FIN) + eprint("sent frame '%s'" % buf.hex()) + + def recv(self): + tag_old = UARTP.FIN + while 1: + tag = tag_old + while 1: + if tag_old == UARTP.FIN: + if tag == UARTP.SYN: + break + tag_old = tag + tag = self.uart_read() + tag_old = tag + + pkt = self.uart_read() + if pkt & 0x80: + pkt &= 0x7f + pkt |= self.uart_read() << 7 + + fcs = 0 + buf = [] + for i in range(pkt): + info = self.uart_read() + buf.append(info) + fcs = (fcs + info) & 0xff + fcs = (fcs + self.uart_read()) & 0xff + + tag = self.uart_read() + if fcs == 0xff: + if tag == UARTP.FIN: + buf = bytes(buf) + eprint("rcvd frame '%s'" % buf.hex()) + if len(buf) >= 1 and buf[0] == 0xde: + sys.stderr.buffer.write(buf[1:]) + sys.stderr.flush() + else: + return buf + + +def run_nist_aead_test(dut, kat): + dump_a = dut.dump_ram() + + for i, m, ad, k, npub, c in kat: + tool = SaleaeTimeMeasurements() + tool.begin_measurement() + try: + run_nist_aead_test_line(dut, i, m, ad, k, npub, c) + finally: + tool.end_measurement() + + if dump_a is not None and i == 0: + dump_b = dut.dump_ram() + longest = compare_dumps(dump_a, dump_b) + print(" longest chunk of untouched memory = %d" % longest) + + +def run_nist_aead_test_line(dut, i, m, ad, k, npub, c): + eprint() + eprint("Count = %d" % i) + eprint(" m = %s" % m.hex()) + eprint(" ad = %s" % ad.hex()) + eprint("npub = %s" % npub.hex()) + eprint(" k = %s" % k.hex()) + eprint(" c = %s" % c.hex()) + + dut.send_var(ord('c'), b"\0" * (len(m) + 32)) + dut.send_var(ord('s'), b"") + + dut.send_var(ord('m'), m) + dut.send_var(ord('a'), ad) + dut.send_var(ord('k'), k) + dut.send_var(ord('p'), npub) + + dut.do_cmd(ord('e')) + output = dut.get_var(ord('C')) + print(" c = %s" % output.hex()) + if c != output: + raise Exception("output of encryption is different from " + + "expected ciphertext") + + dut.send_var(ord('m'), b"\0" * len(c)) + dut.send_var(ord('s'), b"") + + dut.send_var(ord('c'), c) + dut.send_var(ord('a'), ad) + dut.send_var(ord('k'), k) + dut.send_var(ord('p'), npub) + + dut.do_cmd(ord('d')) + output = dut.get_var(ord('M')) + print(" m = %s" % output.hex()) + if m != output: + raise Exception("output of encryption is different from " + + "expected ciphertext") + + +def compare_dumps(dump_a, dump_b): + """ + Gets the length of the longes streaks of equal bytes in two RAM dumps + """ + streaks = [] + streak_beg = 0 + streak_end = 0 + for i in range(len(dump_a)): + if dump_a[i] == dump_b[i]: + streak_end = i + else: + if streak_end != streak_beg: + streaks.append((streak_beg, streak_end)) + streak_beg = i + streak_end = i + + for b, e in streaks: + eprint("equal bytes from 0x%x to 0x%x (length: %d)" % + (b, e, e-b)) + + b, e = max(streaks, key=lambda a: a[1]-a[0]) + eprint( + "longest equal bytes streak from 0x%x to 0x%x (length: %d)" % + (b, e, e-b)) + return e-b + + +def parse_nist_aead_test_vectors(test_file_path): + with open(test_file_path, 'r') as test_file: + lineprog = re.compile( + r"^\s*([A-Z]+)\s*=\s*(([0-9a-f])*)\s*$", + re.IGNORECASE) + m = b"" + ad = b"" + k = b"" + npub = b"" + c = b"" + i = -1 + for line in test_file.readlines(): + line = line.strip() + res = lineprog.match(line) + if line == "": + yield i, m, ad, k, npub, c + m = b"" + ad = b"" + k = b"" + npub = b"" + c = b"" + elif res is not None: + if res[1].lower() == 'count': + i = int(res[2], 10) + elif res[1].lower() == 'key': + k = bytes.fromhex(res[2]) + elif res[1].lower() == 'nonce': + npub = bytes.fromhex(res[2]) + elif res[1].lower() == 'pt': + m = bytes.fromhex(res[2]) + elif res[1].lower() == 'ad': + ad = bytes.fromhex(res[2]) + elif res[1].lower() == 'ct': + c = bytes.fromhex(res[2]) + else: + raise Exception( + "ERROR: unparsed line in test vectors file: '%s'" + % res) + else: + raise Exception( + "ERROR: unparsed line in test vectors file: '%s'" % line) + + +class SaleaeTimeMeasurements: + __slots__ = ['sal'] + + def __init__(self): + import saleae + self.sal = saleae.Saleae() + + def begin_measurement(self): + # Channel 0 is reset + # Channel 1 is crypto_busy + import time + sal = self.sal + + sal.set_active_channels([0, 1], []) + sal.set_sample_rate(sal.get_all_sample_rates()[0]) + sal.set_capture_seconds(6000) + sal.capture_start() + time.sleep(1) + if sal.is_processing_complete(): + raise Exception("Capture didn't start successfully") + + def end_measurement(self): + import time + sal = self.sal + + if sal.is_processing_complete(): + raise Exception("Capture finished before expected") + time.sleep(1) + sal.capture_stop() + time.sleep(.1) + for attempt in range(3): + if not sal.is_processing_complete(): + print("Waiting for capture to complete...") + time.sleep(1) + continue + outfile = "measurement_%s.csv" % time.strftime("%Y%m%d-%H%M%S") + outfile = os.path.join("measurements", outfile) + if os.path.isfile(outfile): + os.unlink(outfile) + sal.export_data2(os.path.abspath(outfile)) + print("Measurements written to '%s'" % outfile) + mdbfile = os.path.join("measurements", "measurements.txt") + mdbfile = open(mdbfile, "a") + mdbfile.write("%s > %s\n" % (' '.join(sys.argv), outfile)) + mdbfile.close() + return 0 + raise Exception("Capture didn't complete successfully") + + +def main(argv): + if len(argv) < 3: + print("Usage: test_common.py port LWC_AEAD_KAT.txt") + + eprint(argv[0]) + script_dir = os.path.split(argv[0])[0] + if len(script_dir) > 0: + os.chdir(script_dir) + + kat = list(parse_nist_aead_test_vectors(argv[2])) + + dev = argv[1] + ser = serial.Serial(dev, baudrate=115200, timeout=5) + dut = DeviceUnderTestAeadUARTP(ser) + + dut.flash() + eprint("Flashed") + dut.prepare() + eprint("Board initialized properly") + sys.stdout.write("Hello, World!\n") + sys.stdout.flush() + + try: + run_nist_aead_test(dut, kat) + return 0 + + except Exception as ex: + print("TEST FAILED") + raise ex + + finally: + sys.stdout.flush() + sys.stderr.flush() + + +if __name__ == "__main__": + sys.exit(main(sys.argv))