Some more optimization, some test code for this thing

This commit is contained in:
Ilari Liusvaara 2014-05-23 11:49:50 +03:00
parent c5372a4826
commit 16663f7786

View file

@ -27,51 +27,53 @@ namespace
inline void square(const element& a, unsigned count = 1)
{
wide_t s[5];
limb_t x[5];
memcpy(x, a.n, sizeof(x));
memcpy(n, a.n, sizeof(n));
for(unsigned i = 0; i < count; i++) {
s[0] = (wide_t)x[0] * (wide_t)x[0] + 19 * (((wide_t)x[1] * (wide_t)x[4] << 1) +
((wide_t)x[2] * (wide_t)x[3] << 1));
s[1] = ((wide_t)x[0] * (wide_t)x[1] << 1) + 19 * (((wide_t)x[2] * (wide_t)x[4] << 1) +
((wide_t)x[3] * (wide_t)x[3]));
s[2] = ((wide_t)x[0] * (wide_t)x[2] << 1) + ((wide_t)x[1] * (wide_t)x[1]) +
38 * ((wide_t)x[3] * (wide_t)x[4]);
s[3] = ((wide_t)x[0] * (wide_t)x[3] << 1) + ((wide_t)x[1] * (wide_t)x[2] << 1) +
19 * ((wide_t)x[4] * (wide_t)x[4]);
s[4] = ((wide_t)x[0] * (wide_t)x[4] << 1) + ((wide_t)x[1] * (wide_t)x[3] << 1) +
((wide_t)x[2] * (wide_t)x[2]);
s[0] = (wide_t)n[0] * (wide_t)n[0] +
(wide_t)(n[1] << 1) * (wide_t)(n[4] * 19) +
(wide_t)(n[2] << 1) * (wide_t)(n[3] * 19);
s[1] = (wide_t)n[0] * (wide_t)(n[1] << 1) +
(wide_t)(n[2] << 1) * (wide_t)(n[4] * 19) +
(wide_t)n[3] * (wide_t)(n[3] * 19);
s[2] = ((wide_t)n[0] * (wide_t)(n[2] << 1) +
(wide_t)n[1] * (wide_t)n[1]) +
(wide_t)(n[3] << 1) * (wide_t)(n[4] * 19);
s[3] = ((wide_t)n[0] * (wide_t)n[3] << 1) +
((wide_t)n[1] * (wide_t)n[2] << 1) +
((wide_t)n[4] * (wide_t)(n[4] * 19));
s[4] = ((wide_t)n[0] * (wide_t)n[4] << 1) +
((wide_t)n[1] * (wide_t)n[3] << 1) +
((wide_t)n[2] * (wide_t)n[2]);
s[1] += (s[0] >> shift);
s[2] += (s[1] >> shift);
x[2] = (limb_t)s[2] & mask;
n[2] = (limb_t)s[2] & mask;
s[3] += (s[2] >> shift);
x[3] = (limb_t)s[3] & mask;
n[3] = (limb_t)s[3] & mask;
s[4] += (s[3] >> shift);
x[4] = (limb_t)s[4] & mask;
n[4] = (limb_t)s[4] & mask;
s[0] = ((limb_t)s[0] & mask) + 19 * (limb_t)(s[4] >> shift);
x[0] = (limb_t)s[0] & mask;
n[0] = (limb_t)s[0] & mask;
s[1] = ((limb_t)s[1] & mask) + (limb_t)(s[0] >> shift);
x[1] = (limb_t)s[1];
n[1] = (limb_t)s[1];
}
memcpy(n, x, sizeof(x));
zeroize(x, sizeof(x));
zeroize(s, sizeof(s));
}
//a * b -> self
inline void multiply(const element& a, const element& b)
{
wide_t s[5];
s[0] = (wide_t)a.n[0] * (wide_t)b.n[0] + 19 * ((wide_t)a.n[4] * (wide_t)b.n[1] +
(wide_t)a.n[3] * (wide_t)b.n[2] + (wide_t)a.n[2] * (wide_t)b.n[3] +
(wide_t)a.n[1] * (wide_t)b.n[4]);
s[0] = (wide_t)a.n[0] * (wide_t)b.n[0] + (wide_t)(a.n[4] * 19) * (wide_t)b.n[1] +
(wide_t)(a.n[3] * 19) * (wide_t)b.n[2] + (wide_t)a.n[2] * (wide_t)(b.n[3] * 19) +
(wide_t)a.n[1] * (wide_t)(b.n[4] * 19);
s[1] = (wide_t)a.n[0] * (wide_t)b.n[1] + (wide_t)a.n[1] * (wide_t)b.n[0] +
19 * ((wide_t)a.n[4] * (wide_t)b.n[2] + (wide_t)a.n[3] * (wide_t)b.n[3] +
(wide_t)a.n[2] * (wide_t)b.n[4]);
(wide_t)(a.n[4] * 19) * (wide_t)b.n[2] + (wide_t)(a.n[3] * 19) * (wide_t)b.n[3] +
(wide_t)a.n[2] * (wide_t)(b.n[4] * 19);
s[2] = (wide_t)a.n[0] * (wide_t)b.n[2] + (wide_t)a.n[1] * (wide_t)b.n[1] +
(wide_t)a.n[2] * (wide_t)b.n[0] + 19 * ((wide_t)a.n[4] * (wide_t)b.n[3] +
(wide_t)a.n[3] * (wide_t)b.n[4]);
(wide_t)a.n[2] * (wide_t)b.n[0] + (wide_t)(a.n[4] * 19) * (wide_t)b.n[3] +
(wide_t)(a.n[3] * 19) * (wide_t)b.n[4];
s[3] = (wide_t)a.n[0] * (wide_t)b.n[3] + (wide_t)a.n[1] * (wide_t)b.n[2] +
(wide_t)a.n[2] * (wide_t)b.n[1] + (wide_t)a.n[3] * (wide_t)b.n[0] +
19 * (wide_t)a.n[4] * (wide_t)b.n[4];
(wide_t)(a.n[4] * 19) * (wide_t)b.n[4];
s[4] = (wide_t)a.n[0] * (wide_t)b.n[4] + (wide_t)a.n[1] * (wide_t)b.n[3] +
(wide_t)a.n[2] * (wide_t)b.n[2] + (wide_t)a.n[3] * (wide_t)b.n[1] +
(wide_t)a.n[4] * (wide_t)b.n[0];
@ -445,25 +447,26 @@ static void cmultiply(element& ox, element& oz, const uint8_t* key, const elemen
{
element x1a(1), z1a, x2a(base), z2a(1), x1b, z1b(1), x2b, z2b(1);
element::cond_t lbit = 0;
for(unsigned i = 31; i < 32; i--) {
uint8_t x = key[i];
for(unsigned j = 0; j < 4; j++) {
element::cond_t bit = (x >> 7);
x1a.swap_cond(x2a, bit);
z1a.swap_cond(z2a, bit);
x1a.swap_cond(x2a, bit ^ lbit);
z1a.swap_cond(z2a, bit ^ lbit);
montgomery(x1b, z1b, x2b, z2b, x1a, z1a, x2a, z2a, base);
x1b.swap_cond(x2b, bit);
z1b.swap_cond(z2b, bit);
lbit = bit;
x <<= 1;
bit = (x >> 7);
x1b.swap_cond(x2b, bit);
z1b.swap_cond(z2b, bit);
x1b.swap_cond(x2b, bit ^ lbit);
z1b.swap_cond(z2b, bit ^ lbit);
montgomery(x1a, z1a, x2a, z2a, x1b, z1b, x2b, z2b, base);
x1a.swap_cond(x2a, bit);
z1a.swap_cond(z2a, bit);
x <<= 1;
lbit = bit;
}
}
x1a.swap_cond(x2a, lbit);
z1a.swap_cond(z2a, lbit);
ox = x1a;
oz = z1a;
};
@ -514,7 +517,8 @@ void curve25519_clamp(uint8_t* key)
const uint8_t curve25519_base[32] = {9};
/*
#ifdef CURVE25519_TEST_MODE
#include <cmath>
uint64_t arch_get_tsc()
{
uint32_t a, b;
@ -533,17 +537,27 @@ int main()
uint8_t buf[128] = {0};
FILE* fd = fopen("/dev/urandom", "rb");
uint64_t ctr = 0;
uint64_t _t;
double tsum;
double tsqr;
uint64_t tmin = 999999999;
buf[32] = 9;
fread(buf, 1, 32, fd);
buf[0] &= 248;
buf[31] &= 127;
buf[31] |= 64;
uint64_t t = arch_get_tsc();
for(unsigned i = 0; i < 10000; i++) {
for(unsigned i = 0; i < 32768; i++) {
_t = arch_get_tsc();
curve25519(buf+64, buf, buf+32);
_t = arch_get_tsc() - _t;
tsum += _t;
tsqr += _t * _t;
if(_t < tmin) tmin = _t;
}
t = arch_get_tsc() - t;
std::cerr << "Avg: " << t / 10000 << std::endl;
tsum /= 32768;
tsqr /= 32768;
std::cerr << "Time: " << tsum << "+-" << sqrt(tsqr - tsum * tsum) << " >=" << tmin << std::endl;
while(true) {
fread(buf, 1, 32, fd);
buf[0] &= 248;
@ -579,4 +593,4 @@ int main()
std::cerr << "Passed " << ctr << " tests." << std::endl;
}
}
*/
#endif