decrypt.c 2.08 KB
Newer Older
Martin Schläffer committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
#include "api.h"
#include "endian.h"
#include "permutations.h"

#define PA_ROUNDS 12
#define PB_ROUNDS 8
#define IV                                                            \
  ((uint64_t)(8 * (CRYPTO_KEYBYTES)) << 56 |                          \
   (uint64_t)(8 * (ASCON_RATE)) << 48 | (uint64_t)(PA_ROUNDS) << 40 | \
   (uint64_t)(PB_ROUNDS) << 32)

int crypto_aead_decrypt(uint8_t* m, uint64_t* mlen, uint8_t* nsec,
                        const uint8_t* c, uint64_t clen, const uint8_t* ad,
                        uint64_t adlen, const uint8_t* npub, const uint8_t* k) {
  if (clen < CRYPTO_ABYTES) {
    *mlen = 0;
    return -1;
  }

  const uint64_t K0 = U64BIG(*(uint64_t*)k);
  const uint64_t K1 = U64BIG(*(uint64_t*)(k + 8));
  const uint64_t N0 = U64BIG(*(uint64_t*)npub);
  const uint64_t N1 = U64BIG(*(uint64_t*)(npub + 8));
  state_t s;
  uint32_t i;
  (void)nsec;

  /* set plaintext size */
  *mlen = clen - CRYPTO_ABYTES;

  /* initialization */
  s.x0 = IV;
  s.x1 = K0;
  s.x2 = K1;
  s.x3 = N0;
  s.x4 = N1;
  P12();
  s.x3 ^= K0;
  s.x4 ^= K1;

  /* process associated data */
  if (adlen) {
    AD();
    for (i = 0; i < adlen; ++i, ++ad)
      if (i < 8)
        s.x0 ^= SETBYTE(*ad, i);
      else
        s.x1 ^= SETBYTE(*ad, i % 8);
    if (adlen < 8)
      s.x0 ^= SETBYTE(0x80, adlen);
    else
      s.x1 ^= SETBYTE(0x80, adlen % 8);
    P8();
  }
  s.x4 ^= 1;

  /* process plaintext */
  clen -= CRYPTO_ABYTES;
  CT();
  for (i = 0; i < clen; ++i, ++m, ++c) {
    if (i < 8) {
      *m = GETBYTE(s.x0, i) ^ *c;
      s.x0 &= ~SETBYTE(0xff, i);
      s.x0 |= SETBYTE(*c, i);
    } else {
      *m = GETBYTE(s.x1, i % 8) ^ *c;
      s.x1 &= ~SETBYTE(0xff, i % 8);
      s.x1 |= SETBYTE(*c, i % 8);
    }
  }
  if (clen < 8)
    s.x0 ^= SETBYTE(0x80, clen);
  else
    s.x1 ^= SETBYTE(0x80, clen % 8);

  /* finalization */
  s.x2 ^= K0;
  s.x3 ^= K1;
  P12();
  s.x3 ^= K0;
  s.x4 ^= K1;

  /* verify tag (should be constant time, check compiler output) */
  if (((s.x3 ^ U64BIG(*(uint64_t*)c)) | (s.x4 ^ U64BIG(*(uint64_t*)(c + 8)))) !=
      0) {
    *mlen = 0;
    return -1;
  }

  return 0;
}