#include "api.h" #include "ascon.h" #include "permutations.h" #include "printstate.h" __forceinline void loadkey(word_t* K0, word_t* K1, word_t* K2, const uint8_t* k) { KINIT(K0, K1, K2); if (CRYPTO_KEYBYTES == 20) { *K0 = XOR(*K0, KEYROT(WORD_T(0), LOAD(k, 4))); k += 4; } *K1 = XOR(*K1, LOAD64(k)); *K2 = XOR(*K2, LOAD64(k + 8)); } __forceinline void init(state_t* s, const uint8_t* npub, word_t K0, word_t K1, word_t K2) { word_t N0, N1; /* load nonce */ N0 = LOAD64(npub); N1 = LOAD64(npub + 8); /* initialization */ PINIT(s); s->x0 = XOR(s->x0, IV); if (CRYPTO_KEYBYTES == 20) s->x0 = XOR(s->x0, K0); s->x1 = XOR(s->x1, K1); s->x2 = XOR(s->x2, K2); s->x3 = XOR(s->x3, N0); s->x4 = XOR(s->x4, N1); P12(s); if (CRYPTO_KEYBYTES == 20) s->x2 = XOR(s->x2, K0); s->x3 = XOR(s->x3, K1); s->x4 = XOR(s->x4, K2); printstate("initialization", s); } __forceinline void absorb(state_t* s, const uint8_t* ad, uint64_t adlen) { word_t* restrict px; /* process associated data */ if (adlen) { while (adlen >= ASCON_RATE) { s->x0 = XOR(s->x0, LOAD64(ad)); if (ASCON_RATE == 16) s->x1 = XOR(s->x1, LOAD64(ad + 8)); PB(s); ad += ASCON_RATE; adlen -= ASCON_RATE; } /* final associated data block */ px = &s->x0; if (ASCON_RATE == 16 && adlen >= 8) { s->x0 = XOR(s->x0, LOAD64(ad)); px = &s->x1; ad += 8; adlen -= 8; } if (adlen) *px = XOR(*px, LOAD(ad, adlen)); *px = XOR(*px, PAD(adlen)); PB(s); } s->x4 = XOR(s->x4, WORD_T(1)); printstate("process associated data", s); } __forceinline void encrypt(state_t* s, uint8_t* c, const uint8_t* m, uint64_t mlen) { word_t* restrict px; /* process plaintext */ while (mlen >= ASCON_RATE) { s->x0 = XOR(s->x0, LOAD64(m)); STORE64(c, s->x0); if (ASCON_RATE == 16) { s->x1 = XOR(s->x1, LOAD64(m + 8)); STORE64(c + 8, s->x1); } PB(s); m += ASCON_RATE; c += ASCON_RATE; mlen -= ASCON_RATE; } /* final plaintext block */ px = &s->x0; if (ASCON_RATE == 16 && mlen >= 8) { s->x0 = XOR(s->x0, LOAD64(m)); STORE64(c, s->x0); px = &s->x1; m += 8; c += 8; mlen -= 8; } if (mlen) { *px = XOR(*px, LOAD(m, mlen)); STORE(c, *px, mlen); } *px = XOR(*px, PAD(mlen)); printstate("process plaintext", s); } __forceinline void decrypt(state_t* s, uint8_t* m, const uint8_t* c, uint64_t clen) { word_t* restrict px; word_t cx; /* process ciphertext */ while (clen >= ASCON_RATE) { cx = LOAD64(c); s->x0 = XOR(s->x0, cx); STORE64(m, s->x0); s->x0 = cx; if (ASCON_RATE == 16) { cx = LOAD64(c + 8); s->x1 = XOR(s->x1, cx); STORE64(m + 8, s->x1); s->x1 = cx; } PB(s); m += ASCON_RATE; c += ASCON_RATE; clen -= ASCON_RATE; } /* final ciphertext block */ px = &s->x0; if (ASCON_RATE == 16 && clen >= 8) { cx = LOAD64(c); s->x0 = XOR(s->x0, cx); STORE64(m, s->x0); s->x0 = cx; px = &s->x1; m += 8; c += 8; clen -= 8; } if (clen) { cx = LOAD(c, clen); *px = XOR(*px, cx); STORE(m, *px, clen); *px = CLEAR(*px, clen); *px = XOR(*px, cx); } *px = XOR(*px, PAD(clen)); printstate("process ciphertext", s); } __forceinline void final(state_t* s, word_t K0, word_t K1, word_t K2) { /* finalization */ if (CRYPTO_KEYBYTES == 16 && ASCON_RATE == 8) { s->x1 = XOR(s->x1, K1); s->x2 = XOR(s->x2, K2); } if (CRYPTO_KEYBYTES == 16 && ASCON_RATE == 16) { s->x2 = XOR(s->x2, K1); s->x3 = XOR(s->x3, K2); } if (CRYPTO_KEYBYTES == 20) { s->x1 = XOR(s->x1, KEYROT(K0, K1)); s->x2 = XOR(s->x2, KEYROT(K1, K2)); s->x3 = XOR(s->x3, KEYROT(K2, WORD_T(0))); } P12(s); s->x3 = XOR(s->x3, K1); s->x4 = XOR(s->x4, K2); printstate("finalization", s); } #if ASCON_INLINE_MODE #define INIT init #define ABSORB absorb #define ENCRYPT encrypt #define DECRYPT decrypt #define FINAL final #else #define INIT ascon_init #define ABSORB ascon_absorb #define ENCRYPT ascon_encrypt #define DECRYPT ascon_decrypt #define FINAL ascon_final void ascon_init(state_t* s, const uint8_t* npub, const uint8_t* k) { word_t K0, K1, K2; loadkey(&K0, &K1, &K2, k); init(s, npub, K0, K1, K2); } void ascon_absorb(state_t* s, const uint8_t* ad, uint64_t adlen) { absorb(s, ad, adlen); } void ascon_encrypt(state_t* s, uint8_t* c, const uint8_t* m, uint64_t mlen) { encrypt(s, c, m, mlen); } void ascon_decrypt(state_t* s, uint8_t* m, const uint8_t* c, uint64_t clen) { decrypt(s, m, c, clen); } void ascon_final(state_t* s, const uint8_t* k) { word_t K0, K1, K2; loadkey(&K0, &K1, &K2, k); final(s, K0, K1, K2); } #endif int crypto_aead_encrypt(uint8_t* c, uint64_t* clen, const uint8_t* m, uint64_t mlen, const uint8_t* ad, uint64_t adlen, const uint8_t* nsec, const uint8_t* npub, const uint8_t* k) { word_t K0, K1, K2; state_t s; (void)nsec; *clen = mlen + CRYPTO_ABYTES; /* perform ascon computation */ loadkey(&K0, &K1, &K2, k); INIT(&s, npub, K0, K1, K2); ABSORB(&s, ad, adlen); ENCRYPT(&s, c, m, mlen); FINAL(&s, K0, K1, K2); /* set tag */ c += mlen; STOREBYTES(c, s.x3, 8); STOREBYTES(c + 8, s.x4, 8); return 0; } int crypto_aead_decrypt(uint8_t* m, uint64_t* mlen, uint8_t* nsec, const uint8_t* c, uint64_t clen, const uint8_t* ad, uint64_t adlen, const uint8_t* npub, const uint8_t* k) { word_t K0, K1, K2; state_t s; (void)nsec; if (clen < CRYPTO_ABYTES) { *mlen = 0; return -1; } *mlen = clen = clen - CRYPTO_ABYTES; /* perform ascon computation */ loadkey(&K0, &K1, &K2, k); INIT(&s, npub, K0, K1, K2); ABSORB(&s, ad, adlen); DECRYPT(&s, m, c, clen); FINAL(&s, K0, K1, K2); /* verify tag (should be constant time, check compiler output) */ c += clen; s.x3 = XOR(s.x3, LOADBYTES(c, 8)); s.x4 = XOR(s.x4, LOADBYTES(c + 8, 8)); if (NOTZERO(s.x3, s.x4)) { *mlen = 0; return -1; } return 0; }