diff --git a/src/library/curve25519.cpp b/src/library/curve25519.cpp index e63673c7..94bb98c5 100644 --- a/src/library/curve25519.cpp +++ b/src/library/curve25519.cpp @@ -13,6 +13,210 @@ namespace do { memset(ptr, 0, size); } while(vptr[vidx]); } +#if defined(__x86_64__) + //Generic (slow). + struct element + { + typedef uint64_t smallval_t; + typedef uint64_t cond_t; + typedef uint64_t limb_t; + typedef __uint128_t wide_t; + const static int shift = 51; + const static limb_t mask = (1ULL << shift) - 1; + //a^(2^count) -> self + inline void square(const element& a, unsigned count = 1) + { + wide_t s[5]; + limb_t x[5]; + memcpy(x, a.n, sizeof(x)); + for(unsigned i = 0; i < count; i++) { + s[0] = (wide_t)x[0] * (wide_t)x[0] + 19 * (((wide_t)x[1] * (wide_t)x[4] << 1) + + ((wide_t)x[2] * (wide_t)x[3] << 1)); + s[1] = ((wide_t)x[0] * (wide_t)x[1] << 1) + 19 * (((wide_t)x[2] * (wide_t)x[4] << 1) + + ((wide_t)x[3] * (wide_t)x[3])); + s[2] = ((wide_t)x[0] * (wide_t)x[2] << 1) + ((wide_t)x[1] * (wide_t)x[1]) + + 19 * ((wide_t)x[3] * (wide_t)x[4] << 1); + s[3] = ((wide_t)x[0] * (wide_t)x[3] << 1) + ((wide_t)x[1] * (wide_t)x[2] << 1) + + 19 * ((wide_t)x[4] * (wide_t)x[4]); + s[4] = ((wide_t)x[0] * (wide_t)x[4] << 1) + ((wide_t)x[1] * (wide_t)x[3] << 1) + + ((wide_t)x[2] * (wide_t)x[2]); + s[1] += (s[0] >> shift); + s[0] &= mask; + s[2] += (s[1] >> shift); + s[1] &= mask; + s[3] += (s[2] >> shift); + s[2] &= mask; + s[4] += (s[3] >> shift); + s[3] &= mask; + s[0] += 19 * (s[4] >> shift); + s[4] &= mask; + s[1] += s[0] >> shift; + s[0] &= mask; + x[0] = s[0]; + x[1] = s[1]; + x[2] = s[2]; + x[3] = s[3]; + x[4] = s[4]; + } + memcpy(n, x, sizeof(x)); + zeroize(x, sizeof(x)); + zeroize(s, sizeof(s)); + } + //a * b -> self + inline void multiply(const element& a, const element& b) + { + wide_t s[5]; + s[0] = (wide_t)a.n[0] * (wide_t)b.n[0] + 19 * ((wide_t)a.n[4] * (wide_t)b.n[1] + + (wide_t)a.n[3] * (wide_t)b.n[2] + (wide_t)a.n[2] * (wide_t)b.n[3] + + (wide_t)a.n[1] * (wide_t)b.n[4]); + s[1] = (wide_t)a.n[0] * (wide_t)b.n[1] + (wide_t)a.n[1] * (wide_t)b.n[0] + + 19 * ((wide_t)a.n[4] * (wide_t)b.n[2] + (wide_t)a.n[3] * (wide_t)b.n[3] + + (wide_t)a.n[2] * (wide_t)b.n[4]); + s[2] = (wide_t)a.n[0] * (wide_t)b.n[2] + (wide_t)a.n[1] * (wide_t)b.n[1] + + (wide_t)a.n[2] * (wide_t)b.n[0] + 19 * ((wide_t)a.n[4] * (wide_t)b.n[3] + + (wide_t)a.n[3] * (wide_t)b.n[4]); + s[3] = (wide_t)a.n[0] * (wide_t)b.n[3] + (wide_t)a.n[1] * (wide_t)b.n[2] + + (wide_t)a.n[2] * (wide_t)b.n[1] + (wide_t)a.n[3] * (wide_t)b.n[0] + + 19 * (wide_t)a.n[4] * (wide_t)b.n[4]; + s[4] = (wide_t)a.n[0] * (wide_t)b.n[4] + (wide_t)a.n[1] * (wide_t)b.n[3] + + (wide_t)a.n[2] * (wide_t)b.n[2] + (wide_t)a.n[3] * (wide_t)b.n[1] + + (wide_t)a.n[4] * (wide_t)b.n[0]; + s[1] += (s[0] >> shift); + s[0] &= mask; + s[2] += (s[1] >> shift); + s[1] &= mask; + s[3] += (s[2] >> shift); + s[2] &= mask; + s[4] += (s[3] >> shift); + s[3] &= mask; + s[0] += 19 * (s[4] >> shift); + s[4] &= mask; + s[1] += s[0] >> shift; + s[0] &= mask; + n[0] = s[0]; + n[1] = s[1]; + n[2] = s[2]; + n[3] = s[3]; + n[4] = s[4]; + zeroize(s, sizeof(s)); + } + //e - self -> self + inline void diff_back(const element& e) + { + limb_t C1 = 2 * mask - 2 * (19 - 1); + limb_t C2 = 2 * mask; + n[0] = e.n[0] + C1 - n[0]; + for(unsigned i = 1; i < 5; i++) + n[i] = e.n[i] + C2 - n[i]; + limb_t carry = 0; + for(unsigned i = 0; i < 5; i++) { + n[i] += carry; + carry = n[i] >> shift; + n[i] &= mask; + } + carry *= 19; + n[0] += carry; + } + //a * b -> self (with constant b). + inline void multiply(const element& a, smallval_t b) + { + limb_t carry = 0; + for(unsigned i = 0; i < 5; i++) { + wide_t x = (wide_t)a.n[i] * b + carry; + n[i] = x & mask; + carry = x >> shift; + } + carry *= 19; + n[0] += carry; + } + //Reduce mod 2^255-19 and store to buffer. + inline void store(uint8_t* buffer) + { + limb_t carry = 19; + for(int i = 0; i < 5; i++) { + n[i] = n[i] + carry; + carry = n[i] >> shift; + n[i] = n[i] & mask; + } + carry = 19 - carry * 19; + for(int i = 0; i < 5; i++) { + n[i] = n[i] - carry; + carry = (n[i] >> shift) & 1; + n[i] = n[i] & mask; + } + for(unsigned i = 0; i < 32; i++) { + buffer[i] = n[8 * i / shift] >> (8 * i % shift); + if(8 * i % shift > shift - 8) + buffer[i] |= n[8 * i / shift + 1] << (shift - 8 * i % shift); + } + } + //Load from buffer. + inline explicit element(const uint8_t* buffer) + { + memset(n, 0, sizeof(n)); + for(unsigned i = 0; i < 32; i++) { + n[8 * i / shift] |= (limb_t)buffer[i] << (8 * i % shift); + n[8 * i / shift] &= mask; + if(8 * i % shift > shift - 8) { + n[8 * i / shift + 1] |= (limb_t)buffer[i] >> (shift - 8 * i % shift); + } + } + } + //Construct 0. + inline element() + { + memset(n, 0, sizeof(n)); + } + //Construct small value. + inline element(smallval_t sval) + { + memset(n, 0, sizeof(n)); + n[0] = sval; + } + //self + e -> self. + inline void sum(const element& e) + { + limb_t carry = 0; + for(int i = 0; i < 5; i++) { + n[i] = n[i] + e.n[i] + carry; + carry = n[i] >> shift; + n[i] = n[i] & mask; + } + n[0] += carry * 19; + } + //If condition=1, swap self,e. + inline void swap_cond(element& e, cond_t condition) + { + condition = -condition; + for(int i = 0; i < 5; i++) { + limb_t t = condition & (n[i] ^ e.n[i]); + n[i] ^= t; + e.n[i] ^= t; + } + } + inline ~element() + { + zeroize(n, sizeof(n)); + } + void debug(const char* pfx) const + { + uint8_t buf[34]; + std::cerr << pfx << ": "; + memset(buf, 0, 34); + for(unsigned i = 0; i < 5*64; i++) { + unsigned rbit = shift*(i>>6)+(i&63); + if((n[i>>6] >> (i&63)) & 1) + buf[rbit>>3]|=(1<<(rbit&7)); + } + for(unsigned i = 33; i < 34; i--) + std::cerr << std::setw(2) << std::setfill('0') << std::hex << std::uppercase + << (int)buf[i]; + std::cerr << std::endl; + } + private: + limb_t n[5]; + }; +#else //Generic (slow). struct element { @@ -219,6 +423,7 @@ namespace private: uint32_t n[10]; }; +#endif } static void montgomery(element& dblx, element& dblz, element& sumx, element& sumz, @@ -320,6 +525,33 @@ void curve25519_clamp(uint8_t* key) const uint8_t curve25519_base[32] = {9}; + +uint64_t arch_get_tsc() +{ + uint32_t a, b; + asm volatile("rdtsc" : "=a"(a), "=d"(b)); + return ((uint64_t)b << 32) | a; +} + +int main() +{ + uint8_t buf[128] = {0}; + FILE* fd = fopen("/dev/urandom", "rb"); + uint64_t ctr = 0; + buf[32] = 9; + fread(buf, 1, 32, fd); + buf[0] &= 248; + buf[31] &= 127; + buf[31] |= 64; + uint64_t t = arch_get_tsc(); + for(unsigned i = 0; i < 10000; i++) { + curve25519(buf+64, buf, buf+32); + } + t = arch_get_tsc() - t; + std::cerr << "Avg: " << t / 10000 << std::endl; + return 0; +} + /* //For comparision extern "C" @@ -368,5 +600,4 @@ int main() std::cerr << "Passed " << ctr << " tests." << std::endl; } } - -*/ +*/ \ No newline at end of file