#!/usr/bin/env python3 import os import re import sys import time import fcntl import struct import socket import subprocess def eprint(*args, **kargs): print(*args, file=sys.stderr, **kargs) class DeviceUnderTest: def __init__(self): pass def flash(self): """ This method should be overridden to flash the DUT. """ pass def dump_ram(self): """ This should be overridden to return the RAM dump as a bytes string. """ return None class DeviceUnderTestAeadUARTP(DeviceUnderTest): def __init__(self, ser=None): self.ser = ser self.firmware_path = None def prepare(self): exp_hello = b"Hello, World!" hello = self.ser.read(13) if hello[-13:] != exp_hello: time.sleep(2) hello += self.ser.read(self.ser.in_waiting) if hello[-13:] != exp_hello: raise Exception( "Improper board initialization message: %s" % hello) self.uartp = UARTP(self.ser) def send_var(self, key, value): self.uartp.send(struct.pack("B", key) + value) ack = self.uartp.recv() if len(ack) != 1 or ack[0] != key: raise Exception("Unacknowledged variable transfer") def obtain_var(self, key): c = struct.pack("B", key) self.uartp.send(c) v = self.uartp.recv() if len(v) < 1 or v[0] != key: raise Exception("Could not obtain variable from board") return v[1:] def do_cmd(self, action): c = struct.pack("B", action) self.uartp.send(c) ack = self.uartp.recv() if len(ack) != 1 or ack[0] != action: raise Exception("Unacknowledged command") class UARTP: def __init__(self, ser): UARTP.SYN = 0xf9 UARTP.FIN = 0xf3 self.ser = ser def uart_read(self): r = self.ser.read(1) if len(r) != 1: raise Exception("Serial read error") return r[0] def uart_write(self, c): b = struct.pack("B", c) r = self.ser.write(b) if r != len(b): raise Exception("Serial write error") return r def send(self, buf): self.uart_write(UARTP.SYN) len_ind_0 = 0xff & len(buf) len_ind_1 = 0xff & (len(buf) >> 7) if len(buf) < 128: self.uart_write(len_ind_0) else: self.uart_write(len_ind_0 | 0x80) self.uart_write(len_ind_1) fcs = 0 for i in range(len(buf)): info = buf[i] fcs = (fcs + info) & 0xff self.uart_write(buf[i]) fcs = (0xff - fcs) & 0xff self.uart_write(fcs) self.uart_write(UARTP.FIN) # eprint("sent frame '%s'" % buf.hex()) def recv(self): tag_old = UARTP.FIN while 1: tag = tag_old while 1: if tag_old == UARTP.FIN: if tag == UARTP.SYN: break tag_old = tag tag = self.uart_read() tag_old = tag pkt = self.uart_read() if pkt & 0x80: pkt &= 0x7f pkt |= self.uart_read() << 7 fcs = 0 buf = [] for i in range(pkt): info = self.uart_read() buf.append(info) fcs = (fcs + info) & 0xff fcs = (fcs + self.uart_read()) & 0xff tag = self.uart_read() if fcs == 0xff: if tag == UARTP.FIN: buf = bytes(buf) # eprint("rcvd frame '%s'" % buf.hex()) if len(buf) >= 1 and buf[0] == 0xde: sys.stderr.buffer.write(buf[1:]) sys.stderr.flush() else: return buf def run_nist_aead_test_line(dut, i, m, ad, k, npub, c): eprint() eprint("Count = %d" % i) eprint(" m = %s" % m.hex()) eprint(" ad = %s" % ad.hex()) eprint("npub = %s" % npub.hex()) eprint(" k = %s" % k.hex()) eprint(" c = %s" % c.hex()) dut.send_var(ord('c'), b"\0" * (len(m) + 32)) dut.send_var(ord('s'), b"") dut.send_var(ord('m'), m) dut.send_var(ord('a'), ad) dut.send_var(ord('k'), k) dut.send_var(ord('p'), npub) dut.do_cmd(ord('e')) output = dut.obtain_var(ord('C')) print(" c = %s" % output.hex()) if c != output: raise Exception("output of encryption is different from " + "expected ciphertext") dut.send_var(ord('m'), b"\0" * len(c)) dut.send_var(ord('s'), b"") dut.send_var(ord('c'), c) dut.send_var(ord('a'), ad) dut.send_var(ord('k'), k) dut.send_var(ord('p'), npub) dut.do_cmd(ord('d')) output = dut.obtain_var(ord('M')) print(" m = %s" % output.hex()) if m != output: raise Exception("output of encryption is different from " + "expected ciphertext") def compare_dumps(dump_a, dump_b): """ Gets the length of the longes streaks of equal bytes in two RAM dumps """ streaks = [] streak_beg = 0 streak_end = 0 for i in range(len(dump_a)): if dump_a[i] == dump_b[i]: streak_end = i else: if streak_end != streak_beg: streaks.append((streak_beg, streak_end)) streak_beg = i streak_end = i for b, e in streaks: eprint("equal bytes from 0x%x to 0x%x (length: %d)" % (b, e, e-b)) b, e = max(streaks, key=lambda a: a[1]-a[0]) eprint( "longest equal bytes streak from 0x%x to 0x%x (length: %d)" % (b, e, e-b)) return e-b def parse_nist_aead_test_vectors(test_file_path): with open(test_file_path, 'r') as test_file: lineprog = re.compile( r"^\s*([A-Z]+)\s*=\s*(([0-9a-f])*)\s*$", re.IGNORECASE) m = b"" ad = b"" k = b"" npub = b"" c = b"" i = -1 for line in test_file.readlines(): line = line.strip() res = lineprog.match(line) if line == "": yield i, m, ad, k, npub, c m = b"" ad = b"" k = b"" npub = b"" c = b"" elif res is not None: if res[1].lower() == 'count': i = int(res[2], 10) elif res[1].lower() == 'key': k = bytes.fromhex(res[2]) elif res[1].lower() == 'nonce': npub = bytes.fromhex(res[2]) elif res[1].lower() == 'pt': m = bytes.fromhex(res[2]) elif res[1].lower() == 'ad': ad = bytes.fromhex(res[2]) elif res[1].lower() == 'ct': c = bytes.fromhex(res[2]) else: raise Exception( "ERROR: unparsed line in test vectors file: '%s'" % res) else: raise Exception( "ERROR: unparsed line in test vectors file: '%s'" % line) class TimeMeasurementTool: def begin_measurement(self): pass def arm(self): pass def unarm(self): pass def end_measurement(self): pass class LogicMultiplexerTimeMeasurements(TimeMeasurementTool): def __init__(self, mask=0xffffffffffffffff): self.mask = mask self.sock = None self.capture = [] def recv_samples(self): import socket capture = [] while 1: try: rcvd = self.sock.recv(16) except socket.timeout: break except BlockingIOError: break if len(rcvd) != 16: raise Exception("Could not receive 16 bytes of logic sample!") time, value = struct.unpack(" %s\n" % (' '.join(sys.argv), outfile)) mdbfile.close() return 0 raise Exception("Capture didn't complete successfully") class FileMutex: def __init__(self, lock_path): self.lock_path = lock_path self.locked = False self.lock_fd = None eprint("Locking %s mutex" % lock_path) self.lock_fd = open(lock_path, 'w') fcntl.lockf(self.lock_fd, fcntl.LOCK_EX) self.locked = True print('%d' % os.getpid(), file=self.lock_fd) self.lock_fd.flush() eprint("%s mutex locked." % lock_path) def __del__(self): if not self.locked: return eprint("Releasing %s mutex" % self.lock_path) self.lock_fd.close() self.locked = False eprint("%s mutex released." % self.lock_path) class OpenOcd: def __init__(self, config_file, tcl_port=6666, verbose=False): self.verbose = verbose self.tclRpcIp = "127.0.0.1" self.tclRpcPort = tcl_port self.bufferSize = 4096 self.process = subprocess.Popen([ 'openocd', '-f', config_file, '-c', 'tcl_port %d' % tcl_port, '-c', 'gdb_port disabled', '-c', 'telnet_port disabled', ], stderr=sys.stderr, stdout=sys.stderr) self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) while 1: try: self.sock.connect((self.tclRpcIp, self.tclRpcPort)) break except Exception: time.sleep(.1) def __del__(self): self.send('exit') self.sock.close() self.process.kill() time.sleep(.1) self.process.send_signal(9) def send(self, cmd): """ Send a command string to TCL RPC. Return the result that was read. """ data = cmd.encode('ascii') if self.verbose: print("<- ", data) self.sock.send(data + b"\x1a") res = self._recv() return res def _recv(self): """ Read from the stream until the token (\x1a) was received. """ data = b'' while len(data) < 1 or data[-1] != 0x1a: chunk = self.sock.recv(self.bufferSize) data += chunk data = data[:-1] # strip trailing \x1a if self.verbose: print("-> ", data) return data.decode('ascii') def run_nist_lws_aead_test(dut, vectors_file, build_dir, logic_mask=0xffff): kat = list(parse_nist_aead_test_vectors(vectors_file)) dut.flash() dut.prepare() sys.stdout.write("Board prepared\n") sys.stdout.flush() ram_dumps = [dut.dump_ram()] tool = LogicMultiplexerTimeMeasurements(logic_mask) try: tool.begin_measurement() for i, m, ad, k, npub, c in kat: tool.arm() run_nist_aead_test_line(dut, i, m, ad, k, npub, c) tool.unarm() if i == 1 and ram_dumps[0] is not None: ram_dumps.append(dut.dump_ram()) longest = compare_dumps(ram_dumps[0], ram_dumps[1]) print(" longest chunk of untouched memory = %d" % longest) except Exception as ex: print("TEST FAILED") raise ex finally: tool.end_measurement() for i, d in enumerate(ram_dumps): path = os.path.join(build_dir, 'ram_dump.%d.bin' % i) if d is not None: with open(path, 'wb') as f: f.write(d) logic_trace = tool.capture path = os.path.join(build_dir, 'logic_trace.csv') with open(path, 'wt') as f: print("TIME,VALUE", file=f) for t, v in logic_trace: print("%f,0x%x" % (t, v), file=f)