diff --git a/internal/bigmod/nat.go b/internal/bigmod/nat.go index ebcc0151..6dfa7f09 100644 --- a/internal/bigmod/nat.go +++ b/internal/bigmod/nat.go @@ -318,14 +318,48 @@ type Modulus struct { // rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs). func rr(m *Modulus) *Nat { rr := NewNat().ExpandFor(m) - // R*R is 2^(2 * _W * n). We can safely get 2^(_W * (n - 1)) by setting the - // most significant limb to 1. We then get to R*R by shifting left by _W - // n + 1 times. - n := len(rr.limbs) - rr.limbs[n-1] = 1 - for i := n - 1; i < 2*n; i++ { - rr.shiftIn(0, m) // x = x * 2^_W mod m + n := uint(len(rr.limbs)) + mLen := uint(m.BitLen()) + logR := _W * n + + // We start by computing R = 2^(_W * n) mod m. We can get pretty close, to + // 2^⌊log₂m⌋, by setting the highest bit we can without having to reduce. + rr.limbs[n-1] = 1 << ((mLen - 1) % _W) + // Then we double until we reach 2^(_W * n). + for i := mLen - 1; i < logR; i++ { + rr.Add(rr, m) + } + + // Next we need to get from R to 2^(_W * n) R mod m (aka from one to R in + // the Montgomery domain, meaning we can use Montgomery multiplication now). + // We could do that by doubling _W * n times, or with a square-and-double + // chain log2(_W * n) long. Turns out the fastest thing is to start out with + // doublings, and switch to square-and-double once the exponent is large + // enough to justify the cost of the multiplications. + + // The threshold is selected experimentally as a linear function of n. + threshold := n / 4 + + // We calculate how many of the most-significant bits of the exponent we can + // compute before crossing the threshold, and we do it with doublings. + i := bits.UintSize + for logR>>i <= threshold { + i-- + } + for k := uint(0); k < logR>>i; k++ { + rr.Add(rr, m) + } + + // Then we process the remaining bits of the exponent with a + // square-and-double chain. + for i > 0 { + rr.montgomeryMul(rr, rr, m) + i-- + if logR>>i&1 != 0 { + rr.Add(rr, m) + } } + return rr } @@ -775,26 +809,21 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat { return out.montgomeryReduction(m) } -// ExpShort calculates out = x^e mod m. +// ExpShortVarTime calculates out = x^e mod m. // // The output will be resized to the size of m and overwritten. x must already -// be reduced modulo m. This leaks the exact bit size of the exponent. -func (out *Nat) ExpShort(x *Nat, e uint, m *Modulus) *Nat { - xR := NewNat().Set(x).montgomeryRepresentation(m) - - out.resetFor(m) - out.limbs[0] = 1 - out.montgomeryRepresentation(m) - +// be reduced modulo m. This leaks the exponent through timing side-channels. +func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat { // For short exponents, precomputing a table and using a window like in Exp - // doesn't pay off. Instead, we do a simple constant-time conditional - // square-and-multiply chain, skipping the initial run of zeroes. - tmp := NewNat().ExpandFor(m) - for i := bits.UintSize - bitLen(e); i < bits.UintSize; i++ { + // doesn't pay off. Instead, we do a simple conditional square-and-multiply + // chain, skipping the initial run of zeroes. + xR := NewNat().Set(x).montgomeryRepresentation(m) + out.Set(xR) + for i := bits.UintSize - bitLen(e) + 1; i < bits.UintSize; i++ { out.montgomeryMul(out, out, m) - k := (e >> (bits.UintSize - i - 1)) & 1 - tmp.montgomeryMul(out, xR, m) - out.assign(ctEq(k, 1), tmp) + if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 { + out.montgomeryMul(out, xR, m) + } } return out.montgomeryReduction(m) } diff --git a/internal/bigmod/nat_test.go b/internal/bigmod/nat_test.go index 5d9474db..7c731443 100644 --- a/internal/bigmod/nat_test.go +++ b/internal/bigmod/nat_test.go @@ -303,7 +303,7 @@ func TestExpShort(t *testing.T) { m := modulusFromBytes([]byte{13}) x := &Nat{[]uint{3}} out := &Nat{[]uint{0}} - out.ExpShort(x, 12, m) + out.ExpShortVarTime(x, 12, m) expected := &Nat{[]uint{1}} if out.Equal(expected) != 1 { t.Errorf("%+v != %+v", out, expected) @@ -383,6 +383,13 @@ func makeBenchmarkExponent() []byte { return e } +func BenchmarkRR256(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + makeBenchmarkModulus(4) + } +} + func BenchmarkModAdd(b *testing.B) { x := makeBenchmarkValue(32) y := makeBenchmarkValue(32)