Commit f06d7886 by Enrico Pozzobon

work in progress in new test setup

parent 30d6aa6a
......@@ -4,17 +4,11 @@ import os
import sys
import stat
import shutil
import random
import subprocess
def build(algo_dir, template_dir, build_dir):
# create a new directory for the build
while build_dir is None:
r = "%09d" % random.randint(0, 999999999)
d = os.path.join("build", r)
if not os.path.isdir(d):
build_dir = d
print("Building in %s" % build_dir)
# copy all the files from the submitted algorithm into the build directory
......@@ -69,7 +63,6 @@ def build(algo_dir, template_dir, build_dir):
p.wait()
assert p.returncode == 0
finally:
sys.stdout.flush()
sys.stderr.flush()
......@@ -101,6 +94,7 @@ def find_test_vectors(d):
def main(argv):
submissions_dir = "all-lwc-submission-files"
template_dir = "templates/linux"
build_dir = 'build'
include_list = None
if len(argv) > 1:
template_dir = argv[1]
......@@ -115,7 +109,8 @@ def main(argv):
# get all the submissions by looking for files named "api.h"
subfiles = []
for submission in subs:
implementations_dir = os.path.join(submissions_dir, submission, "Implementations", "crypto_aead")
implementations_dir = os.path.join(
submissions_dir, submission, "Implementations", "crypto_aead")
if not os.path.isdir(implementations_dir):
continue
......@@ -160,27 +155,25 @@ def main(argv):
# if include_list was provided, skip elements not in the list
if include_list is not None:
if not n in include_list:
if n not in include_list:
continue
# Put all in a tuple and count
files.append((t, d, n))
# For testing, we only do the first 1
#files = files[:1]
# files = files[:1]
print("%d algorithms will be compiled" % len(files))
# Clear the build directory as it is a leftover from the previous execution
if os.path.isdir('build'):
shutil.rmtree('build')
os.mkdir('build')
if os.path.isdir(build_dir):
shutil.rmtree(build_dir)
os.mkdir(build_dir)
print()
# Write a script that executes all the tests one after the other
test_script_path = os.path.join("build", "test_all.sh")
test_script_path = os.path.join(build_dir, "test_all.sh")
with open(test_script_path, 'w') as test_script:
test_script.write("#!/bin/sh\n")
test_script.write("mkdir -p logs\n")
......@@ -189,10 +182,13 @@ def main(argv):
print()
print(d)
try:
build_dir = os.path.join("build", name)
b = build(d, template_dir, build_dir)
test_script.write("\n\necho \"TEST NUMBER %03d: TESTING %s\"\n" % (i, d))
test_script.write("python3 -u ./test.py %s %s 2> %s | tee %s\n" % (
b = build(d, template_dir, os.path.join(build_dir, name))
if b is None:
continue
test_script.write(
"\n\necho \"TEST NUMBER %03d: TESTING %s\"\n" % (i, d))
test_script.write(
"python3 -u ./test.py %s %s 2> %s | tee %s\n" % (
t,
os.path.join(b, 'test'),
os.path.join(b, 'test_stderr.log'),
......
......@@ -137,6 +137,7 @@ class UARTP:
else:
return buf
def main(argv):
eprint(argv[0])
script_dir = os.path.split(argv[0])[0]
......@@ -239,7 +240,6 @@ def main(argv):
else:
raise Exception("Unknown action %c" % action)
return 0
......
#!/usr/bin/env python3
import os
import re
import sys
import time
import struct
import serial
def eprint(*args, **kargs):
print(*args, file=sys.stderr, **kargs)
class DeviceUnderTest:
def __init__(self):
pass
def flash(self):
"""
This method should be overridden to flash the DUT.
"""
pass
def ram_dump(self):
"""
This should be overridden to return the RAM dump as a bytes string.
"""
return None
class DeviceUnderTestAeadUARTP(DeviceUnderTest):
def __init__(self, ser):
self.ser = ser
def prepare(self):
time.sleep(0.1)
exp_hello = b"Hello, World!"
hello = self.ser.read(self.ser.in_waiting)
if hello[-13:] != exp_hello:
raise Exception(
"Improper board initialization message: %s" % hello)
self.uartp = UARTP(self.ser)
def send_var(self, key, value):
self.uartp.send(struct.pack("B", key) + value)
ack = self.uartp.recv()
if len(ack) != 1 or ack[0] != key:
raise Exception("Unacknowledged variable transfer")
def obtain_var(self, key):
c = struct.pack("B", key)
self.uartp.send(c)
v = self.uartp.recv()
if len(v) < 1 or v[0] != key:
raise Exception("Could not obtain variable from board")
return v[1:]
def do_cmd(self, action):
c = struct.pack("B", action)
self.uartp.send(c)
ack = self.uartp.recv()
if len(ack) != 1 or ack[0] != action:
raise Exception("Unacknowledged command")
class UARTP:
def __init__(self, ser):
UARTP.SYN = 0xf9
UARTP.FIN = 0xf3
self.ser = ser
def uart_read(self):
r = self.ser.read(1)
if len(r) != 1:
raise Exception("Serial read error")
return r[0]
def uart_write(self, c):
b = struct.pack("B", c)
r = self.ser.write(b)
if r != len(b):
raise Exception("Serial write error")
return r
def send(self, buf):
self.uart_write(UARTP.SYN)
len_ind_0 = 0xff & len(buf)
len_ind_1 = 0xff & (len(buf) >> 7)
if len(buf) < 128:
self.uart_write(len_ind_0)
else:
self.uart_write(len_ind_0 | 0x80)
self.uart_write(len_ind_1)
fcs = 0
for i in range(len(buf)):
info = buf[i]
fcs = (fcs + info) & 0xff
self.uart_write(buf[i])
fcs = (0xff - fcs) & 0xff
self.uart_write(fcs)
self.uart_write(UARTP.FIN)
eprint("sent frame '%s'" % buf.hex())
def recv(self):
tag_old = UARTP.FIN
while 1:
tag = tag_old
while 1:
if tag_old == UARTP.FIN:
if tag == UARTP.SYN:
break
tag_old = tag
tag = self.uart_read()
tag_old = tag
pkt = self.uart_read()
if pkt & 0x80:
pkt &= 0x7f
pkt |= self.uart_read() << 7
fcs = 0
buf = []
for i in range(pkt):
info = self.uart_read()
buf.append(info)
fcs = (fcs + info) & 0xff
fcs = (fcs + self.uart_read()) & 0xff
tag = self.uart_read()
if fcs == 0xff:
if tag == UARTP.FIN:
buf = bytes(buf)
eprint("rcvd frame '%s'" % buf.hex())
if len(buf) >= 1 and buf[0] == 0xde:
sys.stderr.buffer.write(buf[1:])
sys.stderr.flush()
else:
return buf
def run_nist_aead_test(dut, kat):
dump_a = dut.dump_ram()
for i, m, ad, k, npub, c in kat:
tool = SaleaeTimeMeasurements()
tool.begin_measurement()
try:
run_nist_aead_test_line(dut, i, m, ad, k, npub, c)
finally:
tool.end_measurement()
if dump_a is not None and i == 0:
dump_b = dut.dump_ram()
longest = compare_dumps(dump_a, dump_b)
print(" longest chunk of untouched memory = %d" % longest)
def run_nist_aead_test_line(dut, i, m, ad, k, npub, c):
eprint()
eprint("Count = %d" % i)
eprint(" m = %s" % m.hex())
eprint(" ad = %s" % ad.hex())
eprint("npub = %s" % npub.hex())
eprint(" k = %s" % k.hex())
eprint(" c = %s" % c.hex())
dut.send_var(ord('c'), b"\0" * (len(m) + 32))
dut.send_var(ord('s'), b"")
dut.send_var(ord('m'), m)
dut.send_var(ord('a'), ad)
dut.send_var(ord('k'), k)
dut.send_var(ord('p'), npub)
dut.do_cmd(ord('e'))
output = dut.get_var(ord('C'))
print(" c = %s" % output.hex())
if c != output:
raise Exception("output of encryption is different from " +
"expected ciphertext")
dut.send_var(ord('m'), b"\0" * len(c))
dut.send_var(ord('s'), b"")
dut.send_var(ord('c'), c)
dut.send_var(ord('a'), ad)
dut.send_var(ord('k'), k)
dut.send_var(ord('p'), npub)
dut.do_cmd(ord('d'))
output = dut.get_var(ord('M'))
print(" m = %s" % output.hex())
if m != output:
raise Exception("output of encryption is different from " +
"expected ciphertext")
def compare_dumps(dump_a, dump_b):
"""
Gets the length of the longes streaks of equal bytes in two RAM dumps
"""
streaks = []
streak_beg = 0
streak_end = 0
for i in range(len(dump_a)):
if dump_a[i] == dump_b[i]:
streak_end = i
else:
if streak_end != streak_beg:
streaks.append((streak_beg, streak_end))
streak_beg = i
streak_end = i
for b, e in streaks:
eprint("equal bytes from 0x%x to 0x%x (length: %d)" %
(b, e, e-b))
b, e = max(streaks, key=lambda a: a[1]-a[0])
eprint(
"longest equal bytes streak from 0x%x to 0x%x (length: %d)" %
(b, e, e-b))
return e-b
def parse_nist_aead_test_vectors(test_file_path):
with open(test_file_path, 'r') as test_file:
lineprog = re.compile(
r"^\s*([A-Z]+)\s*=\s*(([0-9a-f])*)\s*$",
re.IGNORECASE)
m = b""
ad = b""
k = b""
npub = b""
c = b""
i = -1
for line in test_file.readlines():
line = line.strip()
res = lineprog.match(line)
if line == "":
yield i, m, ad, k, npub, c
m = b""
ad = b""
k = b""
npub = b""
c = b""
elif res is not None:
if res[1].lower() == 'count':
i = int(res[2], 10)
elif res[1].lower() == 'key':
k = bytes.fromhex(res[2])
elif res[1].lower() == 'nonce':
npub = bytes.fromhex(res[2])
elif res[1].lower() == 'pt':
m = bytes.fromhex(res[2])
elif res[1].lower() == 'ad':
ad = bytes.fromhex(res[2])
elif res[1].lower() == 'ct':
c = bytes.fromhex(res[2])
else:
raise Exception(
"ERROR: unparsed line in test vectors file: '%s'"
% res)
else:
raise Exception(
"ERROR: unparsed line in test vectors file: '%s'" % line)
class SaleaeTimeMeasurements:
__slots__ = ['sal']
def __init__(self):
import saleae
self.sal = saleae.Saleae()
def begin_measurement(self):
# Channel 0 is reset
# Channel 1 is crypto_busy
import time
sal = self.sal
sal.set_active_channels([0, 1], [])
sal.set_sample_rate(sal.get_all_sample_rates()[0])
sal.set_capture_seconds(6000)
sal.capture_start()
time.sleep(1)
if sal.is_processing_complete():
raise Exception("Capture didn't start successfully")
def end_measurement(self):
import time
sal = self.sal
if sal.is_processing_complete():
raise Exception("Capture finished before expected")
time.sleep(1)
sal.capture_stop()
time.sleep(.1)
for attempt in range(3):
if not sal.is_processing_complete():
print("Waiting for capture to complete...")
time.sleep(1)
continue
outfile = "measurement_%s.csv" % time.strftime("%Y%m%d-%H%M%S")
outfile = os.path.join("measurements", outfile)
if os.path.isfile(outfile):
os.unlink(outfile)
sal.export_data2(os.path.abspath(outfile))
print("Measurements written to '%s'" % outfile)
mdbfile = os.path.join("measurements", "measurements.txt")
mdbfile = open(mdbfile, "a")
mdbfile.write("%s > %s\n" % (' '.join(sys.argv), outfile))
mdbfile.close()
return 0
raise Exception("Capture didn't complete successfully")
def main(argv):
if len(argv) < 3:
print("Usage: test_common.py port LWC_AEAD_KAT.txt")
eprint(argv[0])
script_dir = os.path.split(argv[0])[0]
if len(script_dir) > 0:
os.chdir(script_dir)
kat = list(parse_nist_aead_test_vectors(argv[2]))
dev = argv[1]
ser = serial.Serial(dev, baudrate=115200, timeout=5)
dut = DeviceUnderTestAeadUARTP(ser)
dut.flash()
eprint("Flashed")
dut.prepare()
eprint("Board initialized properly")
sys.stdout.write("Hello, World!\n")
sys.stdout.flush()
try:
run_nist_aead_test(dut, kat)
return 0
except Exception as ex:
print("TEST FAILED")
raise ex
finally:
sys.stdout.flush()
sys.stderr.flush()
if __name__ == "__main__":
sys.exit(main(sys.argv))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment