Faster arithmetic modulo 2^255-19 on 64-bit

Cuts the time to compute DHF from ~760k to ~360k cycles.
This commit is contained in:
Ilari Liusvaara 2014-05-22 14:03:12 +03:00
parent adb1b29c67
commit 2d811d2b4e

View file

@ -13,6 +13,210 @@ namespace
do { memset(ptr, 0, size); } while(vptr[vidx]); do { memset(ptr, 0, size); } while(vptr[vidx]);
} }
#if defined(__x86_64__)
//Generic (slow).
struct element
{
typedef uint64_t smallval_t;
typedef uint64_t cond_t;
typedef uint64_t limb_t;
typedef __uint128_t wide_t;
const static int shift = 51;
const static limb_t mask = (1ULL << shift) - 1;
//a^(2^count) -> self
inline void square(const element& a, unsigned count = 1)
{
wide_t s[5];
limb_t x[5];
memcpy(x, a.n, sizeof(x));
for(unsigned i = 0; i < count; i++) {
s[0] = (wide_t)x[0] * (wide_t)x[0] + 19 * (((wide_t)x[1] * (wide_t)x[4] << 1) +
((wide_t)x[2] * (wide_t)x[3] << 1));
s[1] = ((wide_t)x[0] * (wide_t)x[1] << 1) + 19 * (((wide_t)x[2] * (wide_t)x[4] << 1) +
((wide_t)x[3] * (wide_t)x[3]));
s[2] = ((wide_t)x[0] * (wide_t)x[2] << 1) + ((wide_t)x[1] * (wide_t)x[1]) +
19 * ((wide_t)x[3] * (wide_t)x[4] << 1);
s[3] = ((wide_t)x[0] * (wide_t)x[3] << 1) + ((wide_t)x[1] * (wide_t)x[2] << 1) +
19 * ((wide_t)x[4] * (wide_t)x[4]);
s[4] = ((wide_t)x[0] * (wide_t)x[4] << 1) + ((wide_t)x[1] * (wide_t)x[3] << 1) +
((wide_t)x[2] * (wide_t)x[2]);
s[1] += (s[0] >> shift);
s[0] &= mask;
s[2] += (s[1] >> shift);
s[1] &= mask;
s[3] += (s[2] >> shift);
s[2] &= mask;
s[4] += (s[3] >> shift);
s[3] &= mask;
s[0] += 19 * (s[4] >> shift);
s[4] &= mask;
s[1] += s[0] >> shift;
s[0] &= mask;
x[0] = s[0];
x[1] = s[1];
x[2] = s[2];
x[3] = s[3];
x[4] = s[4];
}
memcpy(n, x, sizeof(x));
zeroize(x, sizeof(x));
zeroize(s, sizeof(s));
}
//a * b -> self
inline void multiply(const element& a, const element& b)
{
wide_t s[5];
s[0] = (wide_t)a.n[0] * (wide_t)b.n[0] + 19 * ((wide_t)a.n[4] * (wide_t)b.n[1] +
(wide_t)a.n[3] * (wide_t)b.n[2] + (wide_t)a.n[2] * (wide_t)b.n[3] +
(wide_t)a.n[1] * (wide_t)b.n[4]);
s[1] = (wide_t)a.n[0] * (wide_t)b.n[1] + (wide_t)a.n[1] * (wide_t)b.n[0] +
19 * ((wide_t)a.n[4] * (wide_t)b.n[2] + (wide_t)a.n[3] * (wide_t)b.n[3] +
(wide_t)a.n[2] * (wide_t)b.n[4]);
s[2] = (wide_t)a.n[0] * (wide_t)b.n[2] + (wide_t)a.n[1] * (wide_t)b.n[1] +
(wide_t)a.n[2] * (wide_t)b.n[0] + 19 * ((wide_t)a.n[4] * (wide_t)b.n[3] +
(wide_t)a.n[3] * (wide_t)b.n[4]);
s[3] = (wide_t)a.n[0] * (wide_t)b.n[3] + (wide_t)a.n[1] * (wide_t)b.n[2] +
(wide_t)a.n[2] * (wide_t)b.n[1] + (wide_t)a.n[3] * (wide_t)b.n[0] +
19 * (wide_t)a.n[4] * (wide_t)b.n[4];
s[4] = (wide_t)a.n[0] * (wide_t)b.n[4] + (wide_t)a.n[1] * (wide_t)b.n[3] +
(wide_t)a.n[2] * (wide_t)b.n[2] + (wide_t)a.n[3] * (wide_t)b.n[1] +
(wide_t)a.n[4] * (wide_t)b.n[0];
s[1] += (s[0] >> shift);
s[0] &= mask;
s[2] += (s[1] >> shift);
s[1] &= mask;
s[3] += (s[2] >> shift);
s[2] &= mask;
s[4] += (s[3] >> shift);
s[3] &= mask;
s[0] += 19 * (s[4] >> shift);
s[4] &= mask;
s[1] += s[0] >> shift;
s[0] &= mask;
n[0] = s[0];
n[1] = s[1];
n[2] = s[2];
n[3] = s[3];
n[4] = s[4];
zeroize(s, sizeof(s));
}
//e - self -> self
inline void diff_back(const element& e)
{
limb_t C1 = 2 * mask - 2 * (19 - 1);
limb_t C2 = 2 * mask;
n[0] = e.n[0] + C1 - n[0];
for(unsigned i = 1; i < 5; i++)
n[i] = e.n[i] + C2 - n[i];
limb_t carry = 0;
for(unsigned i = 0; i < 5; i++) {
n[i] += carry;
carry = n[i] >> shift;
n[i] &= mask;
}
carry *= 19;
n[0] += carry;
}
//a * b -> self (with constant b).
inline void multiply(const element& a, smallval_t b)
{
limb_t carry = 0;
for(unsigned i = 0; i < 5; i++) {
wide_t x = (wide_t)a.n[i] * b + carry;
n[i] = x & mask;
carry = x >> shift;
}
carry *= 19;
n[0] += carry;
}
//Reduce mod 2^255-19 and store to buffer.
inline void store(uint8_t* buffer)
{
limb_t carry = 19;
for(int i = 0; i < 5; i++) {
n[i] = n[i] + carry;
carry = n[i] >> shift;
n[i] = n[i] & mask;
}
carry = 19 - carry * 19;
for(int i = 0; i < 5; i++) {
n[i] = n[i] - carry;
carry = (n[i] >> shift) & 1;
n[i] = n[i] & mask;
}
for(unsigned i = 0; i < 32; i++) {
buffer[i] = n[8 * i / shift] >> (8 * i % shift);
if(8 * i % shift > shift - 8)
buffer[i] |= n[8 * i / shift + 1] << (shift - 8 * i % shift);
}
}
//Load from buffer.
inline explicit element(const uint8_t* buffer)
{
memset(n, 0, sizeof(n));
for(unsigned i = 0; i < 32; i++) {
n[8 * i / shift] |= (limb_t)buffer[i] << (8 * i % shift);
n[8 * i / shift] &= mask;
if(8 * i % shift > shift - 8) {
n[8 * i / shift + 1] |= (limb_t)buffer[i] >> (shift - 8 * i % shift);
}
}
}
//Construct 0.
inline element()
{
memset(n, 0, sizeof(n));
}
//Construct small value.
inline element(smallval_t sval)
{
memset(n, 0, sizeof(n));
n[0] = sval;
}
//self + e -> self.
inline void sum(const element& e)
{
limb_t carry = 0;
for(int i = 0; i < 5; i++) {
n[i] = n[i] + e.n[i] + carry;
carry = n[i] >> shift;
n[i] = n[i] & mask;
}
n[0] += carry * 19;
}
//If condition=1, swap self,e.
inline void swap_cond(element& e, cond_t condition)
{
condition = -condition;
for(int i = 0; i < 5; i++) {
limb_t t = condition & (n[i] ^ e.n[i]);
n[i] ^= t;
e.n[i] ^= t;
}
}
inline ~element()
{
zeroize(n, sizeof(n));
}
void debug(const char* pfx) const
{
uint8_t buf[34];
std::cerr << pfx << ": ";
memset(buf, 0, 34);
for(unsigned i = 0; i < 5*64; i++) {
unsigned rbit = shift*(i>>6)+(i&63);
if((n[i>>6] >> (i&63)) & 1)
buf[rbit>>3]|=(1<<(rbit&7));
}
for(unsigned i = 33; i < 34; i--)
std::cerr << std::setw(2) << std::setfill('0') << std::hex << std::uppercase
<< (int)buf[i];
std::cerr << std::endl;
}
private:
limb_t n[5];
};
#else
//Generic (slow). //Generic (slow).
struct element struct element
{ {
@ -219,6 +423,7 @@ namespace
private: private:
uint32_t n[10]; uint32_t n[10];
}; };
#endif
} }
static void montgomery(element& dblx, element& dblz, element& sumx, element& sumz, static void montgomery(element& dblx, element& dblz, element& sumx, element& sumz,
@ -320,6 +525,33 @@ void curve25519_clamp(uint8_t* key)
const uint8_t curve25519_base[32] = {9}; const uint8_t curve25519_base[32] = {9};
uint64_t arch_get_tsc()
{
uint32_t a, b;
asm volatile("rdtsc" : "=a"(a), "=d"(b));
return ((uint64_t)b << 32) | a;
}
int main()
{
uint8_t buf[128] = {0};
FILE* fd = fopen("/dev/urandom", "rb");
uint64_t ctr = 0;
buf[32] = 9;
fread(buf, 1, 32, fd);
buf[0] &= 248;
buf[31] &= 127;
buf[31] |= 64;
uint64_t t = arch_get_tsc();
for(unsigned i = 0; i < 10000; i++) {
curve25519(buf+64, buf, buf+32);
}
t = arch_get_tsc() - t;
std::cerr << "Avg: " << t / 10000 << std::endl;
return 0;
}
/* /*
//For comparision //For comparision
extern "C" extern "C"
@ -368,5 +600,4 @@ int main()
std::cerr << "Passed " << ctr << " tests." << std::endl; std::cerr << "Passed " << ctr << " tests." << std::endl;
} }
} }
*/ */