#!/usr/bin/env python3 import os import re import sys import time import struct import serial def eprint(*args, **kargs): print(*args, file=sys.stderr, **kargs) class DeviceUnderTest: def __init__(self): pass def flash(self): """ This method should be overridden to flash the DUT. """ pass def dump_ram(self): """ This should be overridden to return the RAM dump as a bytes string. """ return None class DeviceUnderTestAeadUARTP(DeviceUnderTest): def __init__(self, ser): self.ser = ser def prepare(self): exp_hello = b"Hello, World!" time.sleep(0.1) if self.ser.in_waiting < 13: time.sleep(2) hello = self.ser.read(self.ser.in_waiting) if hello[-13:] != exp_hello: raise Exception( "Improper board initialization message: %s" % hello) self.uartp = UARTP(self.ser) def send_var(self, key, value): self.uartp.send(struct.pack("B", key) + value) ack = self.uartp.recv() if len(ack) != 1 or ack[0] != key: raise Exception("Unacknowledged variable transfer") def obtain_var(self, key): c = struct.pack("B", key) self.uartp.send(c) v = self.uartp.recv() if len(v) < 1 or v[0] != key: raise Exception("Could not obtain variable from board") return v[1:] def do_cmd(self, action): c = struct.pack("B", action) self.uartp.send(c) ack = self.uartp.recv() if len(ack) != 1 or ack[0] != action: raise Exception("Unacknowledged command") class UARTP: def __init__(self, ser): UARTP.SYN = 0xf9 UARTP.FIN = 0xf3 self.ser = ser def uart_read(self): r = self.ser.read(1) if len(r) != 1: raise Exception("Serial read error") return r[0] def uart_write(self, c): b = struct.pack("B", c) r = self.ser.write(b) if r != len(b): raise Exception("Serial write error") return r def send(self, buf): self.uart_write(UARTP.SYN) len_ind_0 = 0xff & len(buf) len_ind_1 = 0xff & (len(buf) >> 7) if len(buf) < 128: self.uart_write(len_ind_0) else: self.uart_write(len_ind_0 | 0x80) self.uart_write(len_ind_1) fcs = 0 for i in range(len(buf)): info = buf[i] fcs = (fcs + info) & 0xff self.uart_write(buf[i]) fcs = (0xff - fcs) & 0xff self.uart_write(fcs) self.uart_write(UARTP.FIN) # eprint("sent frame '%s'" % buf.hex()) def recv(self): tag_old = UARTP.FIN while 1: tag = tag_old while 1: if tag_old == UARTP.FIN: if tag == UARTP.SYN: break tag_old = tag tag = self.uart_read() tag_old = tag pkt = self.uart_read() if pkt & 0x80: pkt &= 0x7f pkt |= self.uart_read() << 7 fcs = 0 buf = [] for i in range(pkt): info = self.uart_read() buf.append(info) fcs = (fcs + info) & 0xff fcs = (fcs + self.uart_read()) & 0xff tag = self.uart_read() if fcs == 0xff: if tag == UARTP.FIN: buf = bytes(buf) # eprint("rcvd frame '%s'" % buf.hex()) if len(buf) >= 1 and buf[0] == 0xde: sys.stderr.buffer.write(buf[1:]) sys.stderr.flush() else: return buf def run_nist_aead_test(dut, kat): dump_a = dut.dump_ram() tool = SaleaeTimeMeasurements() tool.begin_measurement() try: for i, m, ad, k, npub, c in kat: run_nist_aead_test_line(dut, i, m, ad, k, npub, c) if dump_a is not None and i == 1: dump_b = dut.dump_ram() longest = compare_dumps(dump_a, dump_b) print(" longest chunk of untouched memory = %d" % longest) finally: tool.end_measurement() def run_nist_aead_test_line(dut, i, m, ad, k, npub, c): eprint() eprint("Count = %d" % i) eprint(" m = %s" % m.hex()) eprint(" ad = %s" % ad.hex()) eprint("npub = %s" % npub.hex()) eprint(" k = %s" % k.hex()) eprint(" c = %s" % c.hex()) dut.send_var(ord('c'), b"\0" * (len(m) + 32)) dut.send_var(ord('s'), b"") dut.send_var(ord('m'), m) dut.send_var(ord('a'), ad) dut.send_var(ord('k'), k) dut.send_var(ord('p'), npub) dut.do_cmd(ord('e')) output = dut.obtain_var(ord('C')) print(" c = %s" % output.hex()) if c != output: raise Exception("output of encryption is different from " + "expected ciphertext") dut.send_var(ord('m'), b"\0" * len(c)) dut.send_var(ord('s'), b"") dut.send_var(ord('c'), c) dut.send_var(ord('a'), ad) dut.send_var(ord('k'), k) dut.send_var(ord('p'), npub) dut.do_cmd(ord('d')) output = dut.obtain_var(ord('M')) print(" m = %s" % output.hex()) if m != output: raise Exception("output of encryption is different from " + "expected ciphertext") def compare_dumps(dump_a, dump_b): """ Gets the length of the longes streaks of equal bytes in two RAM dumps """ streaks = [] streak_beg = 0 streak_end = 0 for i in range(len(dump_a)): if dump_a[i] == dump_b[i]: streak_end = i else: if streak_end != streak_beg: streaks.append((streak_beg, streak_end)) streak_beg = i streak_end = i for b, e in streaks: eprint("equal bytes from 0x%x to 0x%x (length: %d)" % (b, e, e-b)) b, e = max(streaks, key=lambda a: a[1]-a[0]) eprint( "longest equal bytes streak from 0x%x to 0x%x (length: %d)" % (b, e, e-b)) return e-b def parse_nist_aead_test_vectors(test_file_path): with open(test_file_path, 'r') as test_file: lineprog = re.compile( r"^\s*([A-Z]+)\s*=\s*(([0-9a-f])*)\s*$", re.IGNORECASE) m = b"" ad = b"" k = b"" npub = b"" c = b"" i = -1 for line in test_file.readlines(): line = line.strip() res = lineprog.match(line) if line == "": yield i, m, ad, k, npub, c m = b"" ad = b"" k = b"" npub = b"" c = b"" elif res is not None: if res[1].lower() == 'count': i = int(res[2], 10) elif res[1].lower() == 'key': k = bytes.fromhex(res[2]) elif res[1].lower() == 'nonce': npub = bytes.fromhex(res[2]) elif res[1].lower() == 'pt': m = bytes.fromhex(res[2]) elif res[1].lower() == 'ad': ad = bytes.fromhex(res[2]) elif res[1].lower() == 'ct': c = bytes.fromhex(res[2]) else: raise Exception( "ERROR: unparsed line in test vectors file: '%s'" % res) else: raise Exception( "ERROR: unparsed line in test vectors file: '%s'" % line) class SaleaeTimeMeasurements: __slots__ = ['sal'] def __init__(self): import saleae self.sal = saleae.Saleae() def begin_measurement(self): # Channel 0 is reset # Channel 1 is crypto_busy import time sal = self.sal sal.set_active_channels([0, 1], []) sal.set_sample_rate(sal.get_all_sample_rates()[0]) sal.set_capture_seconds(6000) sal.capture_start() time.sleep(1) if sal.is_processing_complete(): raise Exception("Capture didn't start successfully") def end_measurement(self): import time sal = self.sal if sal.is_processing_complete(): raise Exception("Capture finished before expected") time.sleep(1) sal.capture_stop() time.sleep(.1) for attempt in range(3): if not sal.is_processing_complete(): print("Waiting for capture to complete...") time.sleep(1) continue outfile = "measurement_%s.csv" % time.strftime("%Y%m%d-%H%M%S") outfile = os.path.join("measurements", outfile) if os.path.isfile(outfile): os.unlink(outfile) sal.export_data2(os.path.abspath(outfile)) print("Measurements written to '%s'" % outfile) mdbfile = os.path.join("measurements", "measurements.txt") mdbfile = open(mdbfile, "a") mdbfile.write("%s > %s\n" % (' '.join(sys.argv), outfile)) mdbfile.close() return 0 raise Exception("Capture didn't complete successfully") def main(argv): if len(argv) < 3: print("Usage: test_common.py port LWC_AEAD_KAT.txt") eprint(argv[0]) script_dir = os.path.split(argv[0])[0] if len(script_dir) > 0: os.chdir(script_dir) kat = list(parse_nist_aead_test_vectors(argv[2])) dev = argv[1] ser = serial.Serial(dev, baudrate=115200, timeout=5) dut = DeviceUnderTestAeadUARTP(ser) dut.flash() eprint("Flashed") dut.prepare() eprint("Board initialized properly") sys.stdout.write("Hello, World!\n") sys.stdout.flush() try: run_nist_aead_test(dut, kat) return 0 except Exception as ex: print("TEST FAILED") raise ex finally: sys.stdout.flush() sys.stderr.flush() if __name__ == "__main__": sys.exit(main(sys.argv))