genkat_aead.c 3.94 KB
Newer Older
lwc-tester committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148

// disable deprecation for sprintf and fopen
#ifdef _MSC_VER
#define _CRT_SECURE_NO_WARNINGS
#endif

#include <stdio.h>
#include <time.h>
#include <string.h>
#include <stdlib.h>
#include "crypto_aead.h"
#include "api.h"

#define KAT_SUCCESS          0
#define KAT_FILE_OPEN_ERROR -1
#define KAT_DATA_ERROR      -3
#define KAT_CRYPTO_FAILURE  -4

#define MAX_FILE_NAME				256
#define MAX_MESSAGE_LENGTH			256
#define MAX_ASSOCIATED_DATA_LENGTH	32

void init_buffer(unsigned char *buffer, int type, unsigned long long numbytes);

void fprint_bstr(FILE *fp, const char *label, const unsigned char *data, unsigned long long length);

int generate_test_vectors();

int main()
{
        srand(time(NULL));
    int ret = generate_test_vectors();

    if (ret != KAT_SUCCESS) {
    	fprintf(stderr, "test vector generation failed with code %d\n", ret);
    }

    return ret;
}

int generate_test_vectors()
{
    FILE                *fp;
    char                fileName[MAX_FILE_NAME];
    unsigned char       key[CRYPTO_KEYBYTES];
    unsigned char	    nonce[CRYPTO_NPUBBYTES];
    unsigned char       msg[MAX_MESSAGE_LENGTH];
    unsigned char       msg2[MAX_MESSAGE_LENGTH];
    unsigned char       ad[MAX_ASSOCIATED_DATA_LENGTH];
    unsigned char	    ct[MAX_MESSAGE_LENGTH + CRYPTO_ABYTES];
    unsigned long long  clen, mlen,mlen2,adlen;
    int                 count = 0 ,typek,typem,typen;
    int                 func_ret, ret_val = KAT_SUCCESS ;

    
 
    adlen=0;
    sprintf(fileName, "LWC_AEAD_KAT_%d_%d.txt", (CRYPTO_KEYBYTES * 8), (CRYPTO_NPUBBYTES * 8));

    if ((fp = fopen(fileName, "w")) == NULL) {
    	fprintf(stderr, "Couldn't open <%s> for write\n", fileName);
    	return KAT_FILE_OPEN_ERROR;
    }

    for (  ;  count<27 &&  (ret_val == KAT_SUCCESS);   ) {

                        

                        
                        
                        typek= (count)/9;
                        typen= (count%9)/3;
                        typem= (count%9)%3;

                        if(typem==0) mlen=64;
                        if(typem==1) mlen=128;
                        if(typem==2) mlen=256;

                        init_buffer(key, typek, sizeof(key));
    		init_buffer(nonce, typen, sizeof(nonce));
    		init_buffer(msg, typem, sizeof(msg));

    		fprintf(fp, "Count = %d\n", count++);

    		fprint_bstr(fp, "Key = ", key, CRYPTO_KEYBYTES);

    		fprint_bstr(fp, "Nonce = ", nonce, CRYPTO_NPUBBYTES);

    		fprint_bstr(fp, "PT = ", msg, mlen);

                        fprintf(fp, "Mlen = %lld\n", mlen);

 

    		if ((func_ret = crypto_aead_encrypt(ct, &clen, msg, mlen,ad,adlen, NULL, nonce, key)) != 0) {
    			fprintf(fp, "crypto_aead_encrypt returned <%d>\n", func_ret);
    			ret_val = KAT_CRYPTO_FAILURE;
    			break;
    		}
 
    		fprint_bstr(fp, "CT = ", ct, clen);

    		fprintf(fp, "\n");

    		if ((func_ret = crypto_aead_decrypt(msg2, &mlen2, NULL, ct, clen, ad,adlen, nonce, key)) != 0) {
    			fprintf(fp, "crypto_aead_decrypt returned <%d>\n", func_ret);
    			ret_val = KAT_CRYPTO_FAILURE;
    			break;
    		}

    		if (mlen != mlen2) {
    			fprintf(fp, "crypto_aead_decrypt returned bad 'mlen': Got <%llu>, expected <%llu>\n", mlen2, mlen);
    			ret_val = KAT_CRYPTO_FAILURE;
    			break;
    		}

    		if (memcmp(msg, msg2, mlen)) {
    			fprintf(fp, "crypto_aead_decrypt did not recover the plaintext\n");
    			ret_val = KAT_CRYPTO_FAILURE;
    			break;
    		}
    	 
    }

    fclose(fp);

    return ret_val;
}


void fprint_bstr(FILE *fp, const char *label, const unsigned char *data, unsigned long long length)
{
    fprintf(fp, "%s", label);

    for (unsigned long long i = 0; i < length; i++)
    	fprintf(fp, "%02X", data[i]);

    fprintf(fp, "\n");
}

void init_buffer(unsigned char *buffer, int type, unsigned long long numbytes)
{
        unsigned long long i;
    for (i = 0; i < numbytes; i++){
    	if(type==0) buffer[i] = 0;
                if(type==1) buffer[i] = 0xff;
                if(type==2) buffer[i] = i & 0xff;}
}