#!/usr/bin/env python3 import os import re import sys import glob import json import time import fcntl import struct import socket import subprocess import numpy as np def eprint(*args, **kargs): print(*args, file=sys.stderr, **kargs) class DeviceUnderTest: def __init__(self): pass def flash(self): """ This method should be overridden to flash the DUT. """ pass def dump_ram(self): """ This should be overridden to return the RAM dump as a bytes string. """ return None class DeviceUnderTestAeadUARTP(DeviceUnderTest): def __init__(self, ser=None): self.ser = ser self.firmware_path = None def prepare(self): exp_hello = b"Hello, World!" hello = self.ser.read(13) if hello[-13:] != exp_hello: time.sleep(2) hello += self.ser.read(self.ser.in_waiting) if hello[-13:] != exp_hello: raise Exception( "Improper board initialization message: %s" % hello) self.uartp = UARTP(self.ser) def send_var(self, key, value): self.uartp.send(struct.pack("B", key) + value) ack = self.uartp.recv() if len(ack) != 1 or ack[0] != key: raise Exception("Unacknowledged variable transfer") def obtain_var(self, key): c = struct.pack("B", key) self.uartp.send(c) v = self.uartp.recv() if len(v) < 1 or v[0] != key: raise Exception("Could not obtain variable from board") return v[1:] def do_cmd(self, action): c = struct.pack("B", action) self.uartp.send(c) ack = self.uartp.recv() if len(ack) != 1 or ack[0] != action: raise Exception("Unacknowledged command") class UARTP: def __init__(self, ser): UARTP.SYN = 0xf9 UARTP.FIN = 0xf3 self.ser = ser def uart_read(self): r = self.ser.read(1) if len(r) != 1: raise Exception("Serial read error") return r[0] def uart_write(self, c): b = struct.pack("B", c) r = self.ser.write(b) if r != len(b): raise Exception("Serial write error") return r def send(self, buf): self.uart_write(UARTP.SYN) len_ind_0 = 0xff & len(buf) len_ind_1 = 0xff & (len(buf) >> 7) if len(buf) < 128: self.uart_write(len_ind_0) else: self.uart_write(len_ind_0 | 0x80) self.uart_write(len_ind_1) fcs = 0 for i in range(len(buf)): info = buf[i] fcs = (fcs + info) & 0xff self.uart_write(buf[i]) fcs = (0xff - fcs) & 0xff self.uart_write(fcs) self.uart_write(UARTP.FIN) # eprint("sent frame '%s'" % buf.hex()) def recv(self): tag_old = UARTP.FIN while 1: tag = tag_old while 1: if tag_old == UARTP.FIN: if tag == UARTP.SYN: break tag_old = tag tag = self.uart_read() tag_old = tag pkt = self.uart_read() if pkt & 0x80: pkt &= 0x7f pkt |= self.uart_read() << 7 fcs = 0 buf = [] for i in range(pkt): info = self.uart_read() buf.append(info) fcs = (fcs + info) & 0xff fcs = (fcs + self.uart_read()) & 0xff tag = self.uart_read() if fcs == 0xff: if tag == UARTP.FIN: buf = bytes(buf) # eprint("rcvd frame '%s'" % buf.hex()) if len(buf) >= 1 and buf[0] == 0xde: sys.stderr.buffer.write(buf[1:]) sys.stderr.flush() else: return buf def run_nist_aead_test_line(dut, i, m, ad, k, npub, c): eprint() eprint("Count = %d" % i) eprint(" m = %s" % m.hex()) eprint(" ad = %s" % ad.hex()) eprint("npub = %s" % npub.hex()) eprint(" k = %s" % k.hex()) eprint(" c = %s" % c.hex()) dut.send_var(ord('c'), b"\0" * (len(m) + 32)) dut.send_var(ord('s'), b"") dut.send_var(ord('m'), m) dut.send_var(ord('a'), ad) dut.send_var(ord('k'), k) dut.send_var(ord('p'), npub) dut.do_cmd(ord('e')) output = dut.obtain_var(ord('C')) print(" c = %s" % output.hex()) if c != output: raise Exception( ("output of encryption (%s) is different from " + "expected ciphertext (%s)") % (output.hex(), c.hex())) dut.send_var(ord('m'), b"\0" * len(c)) dut.send_var(ord('s'), b"") dut.send_var(ord('c'), c) dut.send_var(ord('a'), ad) dut.send_var(ord('k'), k) dut.send_var(ord('p'), npub) dut.do_cmd(ord('d')) output = dut.obtain_var(ord('M')) print(" m = %s" % output.hex()) if m != output: raise Exception( ("output of decryption (%s) is different from " + "expected plaintext (%s)") % (output.hex(), m.hex())) def parse_nist_aead_test_vectors(test_file_path): with open(test_file_path, 'r') as test_file: lineprog = re.compile( r"^\s*([A-Z]+)\s*=\s*(([0-9a-f])*)\s*$", re.IGNORECASE) m = b"" ad = b"" k = b"" npub = b"" c = b"" i = -1 for line in test_file.readlines(): line = line.strip() res = lineprog.match(line) if line == "": yield i, m, ad, k, npub, c i = -1 m = b"" ad = b"" k = b"" npub = b"" c = b"" elif res is not None: if res[1].lower() == 'count': i = int(res[2], 10) elif res[1].lower() == 'key': k = bytes.fromhex(res[2]) elif res[1].lower() == 'nonce': npub = bytes.fromhex(res[2]) elif res[1].lower() == 'pt': m = bytes.fromhex(res[2]) elif res[1].lower() == 'ad': ad = bytes.fromhex(res[2]) elif res[1].lower() == 'ct': c = bytes.fromhex(res[2]) else: raise Exception( "ERROR: unparsed line in test vectors file: '%s'" % res) else: raise Exception( "ERROR: unparsed line in test vectors file: '%s'" % line) if i >= 0: yield i, m, ad, k, npub, c class TimeMeasurementTool: def begin_measurement(self): pass def arm(self): pass def unarm(self): pass def end_measurement(self): pass class LogicMultiplexerTimeMeasurements(TimeMeasurementTool): def __init__(self, mask=0xffffffffffffffff): self.mask = mask self.sock = None self.capture = [] def recv_samples(self): import socket capture = [] while 1: try: rcvd = self.sock.recv(16) except socket.timeout: break except BlockingIOError: break if len(rcvd) != 16: raise Exception("Could not receive 16 bytes of logic sample!") time, value = struct.unpack(" %s\n" % (' '.join(sys.argv), outfile)) mdbfile.close() return 0 raise Exception("Capture didn't complete successfully") class FileMutex: def __init__(self, lock_path): self.lock_path = lock_path self.locked = False self.lock_fd = None eprint("Locking %s mutex" % lock_path) self.lock_fd = open(lock_path, 'w') fcntl.lockf(self.lock_fd, fcntl.LOCK_EX) self.locked = True print('%d' % os.getpid(), file=self.lock_fd) self.lock_fd.flush() eprint("%s mutex locked." % lock_path) def __del__(self): if not self.locked: return eprint("Releasing %s mutex" % self.lock_path) self.lock_fd.close() self.locked = False eprint("%s mutex released." % self.lock_path) class OpenOcd: def __init__(self, config_file, tcl_port=6666, verbose=False): self.verbose = verbose self.tclRpcIp = "127.0.0.1" self.tclRpcPort = tcl_port self.bufferSize = 4096 self.process = subprocess.Popen([ 'openocd', '-f', config_file, '-c', 'tcl_port %d' % tcl_port, '-c', 'gdb_port disabled', '-c', 'telnet_port disabled', ], stderr=sys.stderr, stdout=sys.stderr) self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) while 1: try: self.sock.connect((self.tclRpcIp, self.tclRpcPort)) break except Exception: time.sleep(.1) def __del__(self): self.send('exit') self.sock.close() self.process.kill() time.sleep(.1) self.process.send_signal(9) def send(self, cmd): """ Send a command string to TCL RPC. Return the result that was read. """ data = cmd.encode('ascii') if self.verbose: print("<- ", data) self.sock.send(data + b"\x1a") res = self._recv() return res def _recv(self): """ Read from the stream until the token (\x1a) was received. """ data = b'' while len(data) < 1 or data[-1] != 0x1a: chunk = self.sock.recv(self.bufferSize) data += chunk data = data[:-1] # strip trailing \x1a if self.verbose: print("-> ", data) return data.decode('ascii') def pack_results(job, platform): build_dir = job.path subprocess.call( ['rm', 'results.zip', 'results.json'], cwd=build_dir) logic_path = os.path.join(build_dir, 'logic_trace.csv') logic_trace = [] with open(logic_path, 'rt') as f: f.readline() # skip header for line in f.readlines(): parts = line.split(',') t = float(parts[0].strip()) v = int(parts[1].strip(), 0) logic_trace.append((t, v)) dips = find_dips(logic_trace) dips_durations = [rais-fall for fall, rais in dips] ram_dumps = {} for dump_path in glob.glob(os.path.join(build_dir, "ram_dump.*.bin")): dump_name = os.path.basename(dump_path) m = re.match(r"ram_dump.(\d+).bin", dump_name) if not m: raise Exception("RAM dump has an unexpected name %s" % dump_name) idx = int(m[1], 0) with open(dump_path, 'rb') as f: ram_dumps[idx] = f.read() print(list(ram_dumps.keys())) if 0 in ram_dumps and 1 in ram_dumps: total_memory = len(ram_dumps[0]) untouched_memory = compare_dumps(ram_dumps[0], ram_dumps[1]) print(" longest chunk of untouched memory = %d" % untouched_memory) memory_utilization = total_memory - untouched_memory else: memory_utilization = None with open(os.path.join(build_dir, 'firmware_size.txt'), 'rt') as f: firmware_size = int(f.readline(), 0) cipher_family, cipher_variant, cipher_impl = tuple( job.cipher.split('.', 2) ) test_vectors_path = os.path.join(build_dir, 'LWC_AEAD_KAT.txt') test_vector = identify_test_vector(test_vectors_path) results = { 'format_version': '1.0', 'test_timestamp': job.time_started, 'avg_enc_time': np.mean(dips_durations[0::2]), 'avg_dec_time': np.mean(dips_durations[1::2]), 'firmware_size': firmware_size, 'memory_utilization': memory_utilization, 'cipher': { 'family': cipher_family, 'variant': cipher_variant, 'implementation': cipher_impl, 'timestamp': job.cipher_timestamp, }, 'template': { 'name': job.template, 'timestamp': job.template_timestamp, 'commit': job.template_commit, }, 'platform': { 'name': platform, }, 'test_vector': test_vector, 'dips': dips_durations, } json_path = os.path.join(build_dir, 'results.json') with open(json_path, 'wt') as f: json.dump(results, f) subprocess.check_call( ['zip', '-r', 'results.zip', '.'], cwd=build_dir) def find_dips(logic_trace): # There should be an even number of edges (2 edges for each dip) assert 0 != len(logic_trace) assert 0 == len(logic_trace) % 2 # First record should be a negative edge, last should be a positive one assert 0 == logic_trace[0][1] assert 0 != logic_trace[-1][1] # Record the start and end times of every dip dips = [] for i in range(0, len(logic_trace), 2): assert 0 == logic_trace[i][1] assert 0 != logic_trace[i+1][1] dips.append((logic_trace[i][0], logic_trace[i+1][0])) # Debounce dips by assuming that a data transfer # between two dips takes at least 1 microsecond THERESHOLD = 1e-6 debounced = [] i = 0 while i < len(dips)-1: fall, rais = dips[i] next_fall, next_rais = dips[i+1] xfer_time = next_fall - rais if xfer_time < THERESHOLD: # Merge current dip with the next dips[i] = (fall, next_rais) dips.pop(i+1) else: # Save current dip debounced.append((fall, rais)) i += 1 # Add the last dip debounced.append((dips[i][0], dips[i][1])) dips = debounced # There should be an even number of dips (encryption and decryption) assert 0 == len(dips) % 2 return dips def compare_dumps(dump_a, dump_b): """ Gets the length of the longes streaks of equal bytes in two RAM dumps """ streaks = [] streak_beg = 0 streak_end = 0 for i in range(len(dump_a)): if dump_a[i] == dump_b[i]: streak_end = i else: if streak_end != streak_beg: streaks.append((streak_beg, streak_end)) streak_beg = i streak_end = i for b, e in streaks: print( "equal bytes from 0x%x to 0x%x (length: %d)" % (b, e, e-b)) b, e = max(streaks, key=lambda a: a[1]-a[0]) print( "longest equal bytes streak from 0x%x to 0x%x (length: %d)" % (b, e, e-b)) return e-b def identify_test_vector(kat_path): # Check if a provided test vector is the official NIST LWC or not kat = list(parse_nist_aead_test_vectors(kat_path)) def is_nist_aead_kat(kat): if len(kat) != 1089: return False def genstr(length): return bytes([b % 256 for b in range(length)]) expected_k = genstr(len(kat[0][3])) expected_npub = genstr(len(kat[0][4])) expected_i = 0 for i, m, ad, k, npub, c in kat: expected_m = genstr((i-1) // 33) expected_ad = genstr((i-1) % 33) expected_i += 1 if not (expected_i == i and expected_m == m and expected_k == k and expected_ad == ad and expected_npub == npub): return False return True def is_nist_aead_kat_no_ad(kat): if len(kat) != 16: return False def genstr(length): return bytes([b % 256 for b in range(length)]) expected_k = genstr(len(kat[0][3])) expected_npub = genstr(len(kat[0][4])) expected_i = 0 expected_ad = b"" for i, m, ad, k, npub, c in kat: expected_m = genstr(i-1) expected_i += 1 if not (expected_i == i and expected_m == m and expected_k == k and expected_ad == ad and expected_npub == npub): return False return True if is_nist_aead_kat(kat): return "NIST AEAD KAT" if is_nist_aead_kat_no_ad(kat): return "NIST AEAD KAT NO AD" return None def run_nist_lws_aead_test(dut, vectors_file, build_dir, logic_mask=0xffff): kat = list(parse_nist_aead_test_vectors(vectors_file)) firmware_size = dut.firmware_size() path = os.path.join(build_dir, 'firmware_size.txt') with open(path, 'wt') as f: print(firmware_size, file=f) dut.flash() dut.prepare() sys.stdout.write("Board prepared\n") sys.stdout.flush() ram_dumps = [dut.dump_ram()] tool = LogicMultiplexerTimeMeasurements(logic_mask) try: tool.begin_measurement() for i, m, ad, k, npub, c in kat: tool.arm() run_nist_aead_test_line(dut, i, m, ad, k, npub, c) tool.unarm() if i == 1 and ram_dumps[0] is not None: ram_dumps.append(dut.dump_ram()) except Exception as ex: print("TEST FAILED") raise ex finally: tool.end_measurement() for i, d in enumerate(ram_dumps): path = os.path.join(build_dir, 'ram_dump.%d.bin' % i) if d is not None: with open(path, 'wb') as f: f.write(d) logic_trace = tool.capture path = os.path.join(build_dir, 'logic_trace.csv') with open(path, 'wt') as f: print("TIME,VALUE", file=f) for t, v in logic_trace: print("%.10f,0x%x" % (t, v), file=f)