From 865159d86a7d256c0e084fddb52a812ec6829669 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Fri, 6 Dec 2024 08:54:47 +0800 Subject: [PATCH] internal/bigmod: add more //go:norace annotations and refactoring --- internal/bigmod/nat.go | 113 ++++++++++++++----------------- internal/bigmod/nat_extension.go | 31 +++++++++ internal/bigmod/nat_test.go | 10 +-- 3 files changed, 88 insertions(+), 66 deletions(-) create mode 100644 internal/bigmod/nat_extension.go diff --git a/internal/bigmod/nat.go b/internal/bigmod/nat.go index ab98ed1..6cd801e 100644 --- a/internal/bigmod/nat.go +++ b/internal/bigmod/nat.go @@ -18,6 +18,15 @@ const ( _S = _W / 8 ) +// Note: These functions make many loops over all the words in a Nat. +// These loops used to be in assembly, invisible to -race, -asan, and -msan, +// but now they are in Go and incur significant overhead in those modes. +// To bring the old performance back, we mark all functions that loop +// over Nat words with //go:norace. Because //go:norace does not +// propagate across inlining, we must also mark functions that inline +// //go:norace functions - specifically, those that inline add, addMulVVW, +// assign, cmpGeq, rshift1, and sub. + // choice represents a constant-time boolean. The value of choice is always // either 1 or 0. We use an int instead of bool in order to make decisions in // constant time by turning it into a mask. @@ -40,14 +49,6 @@ func ctEq(x, y uint) choice { return not(choice(c1 | c2)) } -// ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this -// function does not depend on its inputs. -func ctGeq(x, y uint) choice { - // If x < y, then x - y generates a carry. - _, carry := bits.Sub(x, y, 0) - return not(choice(carry)) -} - // Nat represents an arbitrary natural number // // Each Nat has an announced length, which is the number of limbs it has stored. @@ -84,6 +85,7 @@ func (x *Nat) expand(n int) *Nat { return x } extraLimbs := x.limbs[len(x.limbs):n] + // clear(extraLimbs) for i := range extraLimbs { extraLimbs[i] = 0 } @@ -97,6 +99,7 @@ func (x *Nat) reset(n int) *Nat { x.limbs = make([]uint, n) return x } + // clear(x.limbs) for i := range x.limbs { x.limbs[i] = 0 } @@ -131,7 +134,7 @@ func (x *Nat) trim() *Nat { } // set assigns x = y, optionally resizing x to the appropriate size. -func (x *Nat) Set(y *Nat) *Nat { +func (x *Nat) set(y *Nat) *Nat { x.reset(len(y.limbs)) copy(x.limbs, y.limbs) return x @@ -164,12 +167,14 @@ func (x *Nat) Bytes(m *Modulus) []byte { // SetBytes returns an error if b >= m. // // The output will be resized to the size of m and overwritten. +// +//go:norace func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) { x.resetFor(m) if err := x.setBytes(b); err != nil { return nil, err } - if x.CmpGeq(m.nat) == yes { + if x.cmpGeq(m.nat) == yes { return nil, errors.New("input overflows the modulus") } return x, nil @@ -195,20 +200,6 @@ func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) { return x, nil } -// SetOverflowedBytes assigns x = (b mode (m-1)) + 1, where b is a slice of big-endian bytes. -// -// The output will be resized to the size of m and overwritten. -func (x *Nat) SetOverflowedBytes(b []byte, m *Modulus) *Nat { - mMinusOne := NewNat().Set(m.nat) - mMinusOne.limbs[0]-- // due to m is odd, so we can safely subtract 1 - one := NewNat().resetFor(m) - one.limbs[0] = 1 - x.resetToBytes(b) - x = NewNat().modNat(x, mMinusOne) // x = x mod (m-1) - x.add(one) // we can safely add 1, no need to check overflow - return x -} - // bigEndianUint returns the contents of buf interpreted as a // big-endian encoded uint value. func bigEndianUint(buf []byte) uint { @@ -309,8 +300,6 @@ func (x *Nat) IsMinusOne(m *Modulus) choice { } // IsOdd returns 1 if x is odd, and 0 otherwise. -// -//go:norace func (x *Nat) IsOdd() choice { if len(x.limbs) == 0 { return no @@ -333,12 +322,12 @@ func (x *Nat) TrailingZeroBitsVarTime() uint { return t } -// CmpGeq returns 1 if x >= y, and 0 otherwise. +// cmpGeq returns 1 if x >= y, and 0 otherwise. // // Both operands must have the same announced length. // //go:norace -func (x *Nat) CmpGeq(y *Nat) choice { +func (x *Nat) cmpGeq(y *Nat) choice { // Eliminate bounds checks in the loop. size := len(x.limbs) xLimbs := x.limbs[:size] @@ -564,6 +553,8 @@ func NewModulus(b []byte) (*Modulus, error) { // NewModulusProduct creates a new Modulus from the product of two numbers // represented as big-endian byte slices. The result must be greater than one. +// +//go:norace func NewModulusProduct(a, b []byte) (*Modulus, error) { x := NewNat().resetToBytes(a) y := NewNat().resetToBytes(b) @@ -602,30 +593,23 @@ func (m *Modulus) Nat() *Nat { // Make a copy so that the caller can't modify m.nat or alias it with // another Nat in a modulus operation. n := NewNat() - n.Set(m.nat) + n.set(m.nat) return n } -// shiftIn calculates x = x << _W + y mod m. -// -// This assumes that x is already reduced mod m, and that y < 2^_W. -func (x *Nat) shiftIn(y uint, m *Modulus) *Nat { - return x.shiftInNat(y, m.nat) -} - // shiftIn calculates x = x << _W + y mod m. // // This assumes that x is already reduced mod m, and that y < 2^_W. // //go:norace -func (x *Nat) shiftInNat(y uint, m *Nat) *Nat { - d := NewNat().reset(len(m.limbs)) +func (x *Nat) shiftIn(y uint, m *Modulus) *Nat { + d := NewNat().resetFor(m) // Eliminate bounds checks in the loop. - size := len(m.limbs) + size := len(m.nat.limbs) xLimbs := x.limbs[:size] dLimbs := d.limbs[:size] - mLimbs := m.limbs[:size] + mLimbs := m.nat.limbs[:size] // Each iteration of this loop computes x = 2x + b mod m, where b is a bit // from y. Effectively, it left-shifts x and adds y one bit at a time, @@ -657,17 +641,10 @@ func (x *Nat) shiftInNat(y uint, m *Nat) *Nat { // This works regardless how large the value of x is. // // The output will be resized to the size of m and overwritten. -func (out *Nat) Mod(x *Nat, m *Modulus) *Nat { - return out.modNat(x, m.nat) -} - -// Mod calculates out = x mod m. // -// This works regardless how large the value of x is. -// -// The output will be resized to the size of m and overwritten. -func (out *Nat) modNat(x *Nat, m *Nat) *Nat { - out.reset(len(m.limbs)) +//go:norace +func (out *Nat) Mod(x *Nat, m *Modulus) *Nat { + out.resetFor(m) // Working our way from the most significant to the least significant limb, // we can insert each limb at the least significant position, shifting all // previous limbs left by _W. This way each limb will get shifted by the @@ -676,7 +653,7 @@ func (out *Nat) modNat(x *Nat, m *Nat) *Nat { i := len(x.limbs) - 1 // For the first N - 1 limbs we can skip the actual shifting and position // them at the shifted position, which starts at min(N - 2, i). - start := len(m.limbs) - 2 + start := len(m.nat.limbs) - 2 if i < start { start = i } @@ -686,7 +663,7 @@ func (out *Nat) modNat(x *Nat, m *Nat) *Nat { } // We shift in the remaining limbs, reducing modulo m each time. for i >= 0 { - out.shiftInNat(x.limbs[i], m) + out.shiftIn(x.limbs[i], m) i-- } return out @@ -715,8 +692,10 @@ func (out *Nat) resetFor(m *Modulus) *Nat { // overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m. // // x and m operands must have the same announced length. +// +//go:norace func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) { - t := NewNat().Set(x) + t := NewNat().set(x) underflow := t.sub(m.nat) // We keep the result if x - m didn't underflow (meaning x >= m) // or if always was set. @@ -728,10 +707,12 @@ func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) { // // The length of both operands must be the same as the modulus. Both operands // must already be reduced modulo m. +// +//go:norace func (x *Nat) Sub(y *Nat, m *Modulus) *Nat { underflow := x.sub(y) // If the subtraction underflowed, add m. - t := NewNat().Set(x) + t := NewNat().set(x) t.add(m.nat) x.assign(choice(underflow), t) return x @@ -752,6 +733,8 @@ func (x *Nat) SubOne(m *Modulus) *Nat { // // The length of both operands must be the same as the modulus. Both operands // must already be reduced modulo m. +// +//go:norace func (x *Nat) Add(y *Nat, m *Modulus) *Nat { overflow := x.add(y) x.maybeSubtractModulus(choice(overflow), m) @@ -789,6 +772,8 @@ func (x *Nat) montgomeryReduction(m *Modulus) *Nat { // // All inputs should be the same length and already reduced modulo m. // x will be resized to the size of m and overwritten. +// +//go:norace func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat { n := len(m.nat.limbs) mLimbs := m.nat.limbs[:n] @@ -946,11 +931,13 @@ func addMulVVW(z, x []uint, y uint) (carry uint) { // // The length of both operands must be the same as the modulus. Both operands // must already be reduced modulo m. +// +//go:norace func (x *Nat) Mul(y *Nat, m *Modulus) *Nat { if m.odd { // A Montgomery multiplication by a value out of the Montgomery domain // takes the result out of Montgomery representation. - xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m + xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m } n := len(m.nat.limbs) @@ -1009,6 +996,8 @@ func (x *Nat) Mul(y *Nat, m *Modulus) *Nat { // to the size of m and overwritten. x must already be reduced modulo m. // // m must be odd, or Exp will panic. +// +//go:norace func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat { if !m.odd { panic("bigmod: modulus for Exp must be odd") @@ -1025,7 +1014,7 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat { NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), } - table[0].Set(x).montgomeryRepresentation(m) + table[0].set(x).montgomeryRepresentation(m) for i := 1; i < len(table); i++ { table[i].montgomeryMul(table[i-1], table[0], m) } @@ -1071,8 +1060,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat { // For short exponents, precomputing a table and using a window like in Exp // doesn't pay off. Instead, we do a simple conditional square-and-multiply // chain, skipping the initial run of zeroes. - xR := NewNat().Set(x).montgomeryRepresentation(m) - out.Set(xR) + xR := NewNat().set(x).montgomeryRepresentation(m) + out.set(xR) for i := bits.UintSize - bits.Len(e) + 1; i < bits.UintSize; i++ { out.montgomeryMul(out, out, m) if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 { @@ -1088,6 +1077,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat { // // a must be reduced modulo m, but doesn't need to have the same size. The // output will be resized to the size of m and overwritten. +// +//go:norace func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) { // This is the extended binary GCD algorithm described in the Handbook of // Applied Cryptography, Algorithm 14.61, adapted by BoringSSL to bound @@ -1121,7 +1112,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) { return x, false } - u := NewNat().Set(a).ExpandFor(m) + u := NewNat().set(a).ExpandFor(m) v := m.Nat() A := NewNat().reset(len(m.nat.limbs)) @@ -1148,7 +1139,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) { // If both u and v are odd, subtract the smaller from the larger. // If u = v, we need to subtract from v to hit the modified exit condition. if u.IsOdd() == yes && v.IsOdd() == yes { - if v.CmpGeq(u) == no { + if v.cmpGeq(u) == no { u.sub(v) A.Add(C, m) B.Add(D, &Modulus{nat: a}) @@ -1189,7 +1180,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) { if u.IsOne() == no { return x, false } - return x.Set(A), true + return x.set(A), true } } } diff --git a/internal/bigmod/nat_extension.go b/internal/bigmod/nat_extension.go new file mode 100644 index 0000000..114de88 --- /dev/null +++ b/internal/bigmod/nat_extension.go @@ -0,0 +1,31 @@ +package bigmod + +func (x *Nat) Set(y *Nat) *Nat { + return x.set(y) +} + +// SetOverflowedBytes assigns x = (b mode (m-1)) + 1, where b is a slice of big-endian bytes. +// +// The output will be resized to the size of m and overwritten. +// +//go:norace +func (x *Nat) SetOverflowedBytes(b []byte, m *Modulus) *Nat { + mMinusOne := NewNat().set(m.nat) + mMinusOne.limbs[0]-- // due to m is odd, so we can safely subtract 1 + mMinusOneM, _ := NewModulus(mMinusOne.Bytes(m)) + one := NewNat().resetFor(m) + one.limbs[0] = 1 + x.resetToBytes(b) + x = NewNat().Mod(x, mMinusOneM) // x = x mod (m-1) + x.add(one) // we can safely add 1, no need to check overflow + return x +} + +// CmpGeq returns 1 if x >= y, and 0 otherwise. +// +// Both operands must have the same announced length. +// +//go:norace +func (x *Nat) CmpGeq(y *Nat) choice { + return x.cmpGeq(y) +} diff --git a/internal/bigmod/nat_test.go b/internal/bigmod/nat_test.go index 78e9481..9ddba89 100644 --- a/internal/bigmod/nat_test.go +++ b/internal/bigmod/nat_test.go @@ -61,9 +61,9 @@ func (*Nat) Generate(r *rand.Rand, size int) reflect.Value { func testModAddCommutative(a *Nat, b *Nat) bool { m := maxModulus(uint(len(a.limbs))) - aPlusB := new(Nat).Set(a) + aPlusB := new(Nat).set(a) aPlusB.Add(b, m) - bPlusA := new(Nat).Set(b) + bPlusA := new(Nat).set(b) bPlusA.Add(a, m) return aPlusB.Equal(bPlusA) == 1 } @@ -77,7 +77,7 @@ func TestModAddCommutative(t *testing.T) { func testModSubThenAddIdentity(a *Nat, b *Nat) bool { m := maxModulus(uint(len(a.limbs))) - original := new(Nat).Set(a) + original := new(Nat).set(a) a.Sub(b, m) a.Add(b, m) return a.Equal(original) == 1 @@ -97,9 +97,9 @@ func TestMontgomeryRoundtrip(t *testing.T) { aPlusOne := new(big.Int).SetBytes(natBytes(a)) aPlusOne.Add(aPlusOne, big.NewInt(1)) m, _ := NewModulus(aPlusOne.Bytes()) - monty := new(Nat).Set(a) + monty := new(Nat).set(a) monty.montgomeryRepresentation(m) - aAgain := new(Nat).Set(monty) + aAgain := new(Nat).set(monty) aAgain.montgomeryMul(monty, one, m) if a.Equal(aAgain) != 1 { t.Errorf("%v != %v", a, aAgain)