Skip to content

Commit

Permalink
internal/bigmod: add more //go:norace annotations and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
emmansun authored Dec 6, 2024
1 parent 0d56114 commit 865159d
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 66 deletions.
113 changes: 52 additions & 61 deletions internal/bigmod/nat.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ const (
_S = _W / 8
)

// Note: These functions make many loops over all the words in a Nat.
// These loops used to be in assembly, invisible to -race, -asan, and -msan,
// but now they are in Go and incur significant overhead in those modes.
// To bring the old performance back, we mark all functions that loop
// over Nat words with //go:norace. Because //go:norace does not
// propagate across inlining, we must also mark functions that inline
// //go:norace functions - specifically, those that inline add, addMulVVW,
// assign, cmpGeq, rshift1, and sub.

// choice represents a constant-time boolean. The value of choice is always
// either 1 or 0. We use an int instead of bool in order to make decisions in
// constant time by turning it into a mask.
Expand All @@ -40,14 +49,6 @@ func ctEq(x, y uint) choice {
return not(choice(c1 | c2))
}

// ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this
// function does not depend on its inputs.
func ctGeq(x, y uint) choice {
// If x < y, then x - y generates a carry.
_, carry := bits.Sub(x, y, 0)
return not(choice(carry))
}

// Nat represents an arbitrary natural number
//
// Each Nat has an announced length, which is the number of limbs it has stored.
Expand Down Expand Up @@ -84,6 +85,7 @@ func (x *Nat) expand(n int) *Nat {
return x
}
extraLimbs := x.limbs[len(x.limbs):n]
// clear(extraLimbs)
for i := range extraLimbs {
extraLimbs[i] = 0
}
Expand All @@ -97,6 +99,7 @@ func (x *Nat) reset(n int) *Nat {
x.limbs = make([]uint, n)
return x
}
// clear(x.limbs)
for i := range x.limbs {
x.limbs[i] = 0
}
Expand Down Expand Up @@ -131,7 +134,7 @@ func (x *Nat) trim() *Nat {
}

// set assigns x = y, optionally resizing x to the appropriate size.
func (x *Nat) Set(y *Nat) *Nat {
func (x *Nat) set(y *Nat) *Nat {
x.reset(len(y.limbs))
copy(x.limbs, y.limbs)
return x
Expand Down Expand Up @@ -164,12 +167,14 @@ func (x *Nat) Bytes(m *Modulus) []byte {
// SetBytes returns an error if b >= m.
//
// The output will be resized to the size of m and overwritten.
//
//go:norace
func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
x.resetFor(m)
if err := x.setBytes(b); err != nil {
return nil, err
}
if x.CmpGeq(m.nat) == yes {
if x.cmpGeq(m.nat) == yes {
return nil, errors.New("input overflows the modulus")
}
return x, nil
Expand All @@ -195,20 +200,6 @@ func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
return x, nil
}

// SetOverflowedBytes assigns x = (b mode (m-1)) + 1, where b is a slice of big-endian bytes.
//
// The output will be resized to the size of m and overwritten.
func (x *Nat) SetOverflowedBytes(b []byte, m *Modulus) *Nat {
mMinusOne := NewNat().Set(m.nat)
mMinusOne.limbs[0]-- // due to m is odd, so we can safely subtract 1
one := NewNat().resetFor(m)
one.limbs[0] = 1
x.resetToBytes(b)
x = NewNat().modNat(x, mMinusOne) // x = x mod (m-1)
x.add(one) // we can safely add 1, no need to check overflow
return x
}

// bigEndianUint returns the contents of buf interpreted as a
// big-endian encoded uint value.
func bigEndianUint(buf []byte) uint {
Expand Down Expand Up @@ -309,8 +300,6 @@ func (x *Nat) IsMinusOne(m *Modulus) choice {
}

// IsOdd returns 1 if x is odd, and 0 otherwise.
//
//go:norace
func (x *Nat) IsOdd() choice {
if len(x.limbs) == 0 {
return no
Expand All @@ -333,12 +322,12 @@ func (x *Nat) TrailingZeroBitsVarTime() uint {
return t
}

// CmpGeq returns 1 if x >= y, and 0 otherwise.
// cmpGeq returns 1 if x >= y, and 0 otherwise.
//
// Both operands must have the same announced length.
//
//go:norace
func (x *Nat) CmpGeq(y *Nat) choice {
func (x *Nat) cmpGeq(y *Nat) choice {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
Expand Down Expand Up @@ -564,6 +553,8 @@ func NewModulus(b []byte) (*Modulus, error) {

// NewModulusProduct creates a new Modulus from the product of two numbers
// represented as big-endian byte slices. The result must be greater than one.
//
//go:norace
func NewModulusProduct(a, b []byte) (*Modulus, error) {
x := NewNat().resetToBytes(a)
y := NewNat().resetToBytes(b)
Expand Down Expand Up @@ -602,30 +593,23 @@ func (m *Modulus) Nat() *Nat {
// Make a copy so that the caller can't modify m.nat or alias it with
// another Nat in a modulus operation.
n := NewNat()
n.Set(m.nat)
n.set(m.nat)
return n
}

// shiftIn calculates x = x << _W + y mod m.
//
// This assumes that x is already reduced mod m, and that y < 2^_W.
func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
return x.shiftInNat(y, m.nat)
}

// shiftIn calculates x = x << _W + y mod m.
//
// This assumes that x is already reduced mod m, and that y < 2^_W.
//
//go:norace
func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
d := NewNat().reset(len(m.limbs))
func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
d := NewNat().resetFor(m)

// Eliminate bounds checks in the loop.
size := len(m.limbs)
size := len(m.nat.limbs)
xLimbs := x.limbs[:size]
dLimbs := d.limbs[:size]
mLimbs := m.limbs[:size]
mLimbs := m.nat.limbs[:size]

// Each iteration of this loop computes x = 2x + b mod m, where b is a bit
// from y. Effectively, it left-shifts x and adds y one bit at a time,
Expand Down Expand Up @@ -657,17 +641,10 @@ func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
// This works regardless how large the value of x is.
//
// The output will be resized to the size of m and overwritten.
func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
return out.modNat(x, m.nat)
}

// Mod calculates out = x mod m.
//
// This works regardless how large the value of x is.
//
// The output will be resized to the size of m and overwritten.
func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
out.reset(len(m.limbs))
//go:norace
func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
out.resetFor(m)
// Working our way from the most significant to the least significant limb,
// we can insert each limb at the least significant position, shifting all
// previous limbs left by _W. This way each limb will get shifted by the
Expand All @@ -676,7 +653,7 @@ func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
i := len(x.limbs) - 1
// For the first N - 1 limbs we can skip the actual shifting and position
// them at the shifted position, which starts at min(N - 2, i).
start := len(m.limbs) - 2
start := len(m.nat.limbs) - 2
if i < start {
start = i
}
Expand All @@ -686,7 +663,7 @@ func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
}
// We shift in the remaining limbs, reducing modulo m each time.
for i >= 0 {
out.shiftInNat(x.limbs[i], m)
out.shiftIn(x.limbs[i], m)
i--
}
return out
Expand Down Expand Up @@ -715,8 +692,10 @@ func (out *Nat) resetFor(m *Modulus) *Nat {
// overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m.
//
// x and m operands must have the same announced length.
//
//go:norace
func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
t := NewNat().Set(x)
t := NewNat().set(x)
underflow := t.sub(m.nat)
// We keep the result if x - m didn't underflow (meaning x >= m)
// or if always was set.
Expand All @@ -728,10 +707,12 @@ func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
//
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
//
//go:norace
func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
underflow := x.sub(y)
// If the subtraction underflowed, add m.
t := NewNat().Set(x)
t := NewNat().set(x)
t.add(m.nat)
x.assign(choice(underflow), t)
return x
Expand All @@ -752,6 +733,8 @@ func (x *Nat) SubOne(m *Modulus) *Nat {
//
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
//
//go:norace
func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
overflow := x.add(y)
x.maybeSubtractModulus(choice(overflow), m)
Expand Down Expand Up @@ -789,6 +772,8 @@ func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
//
// All inputs should be the same length and already reduced modulo m.
// x will be resized to the size of m and overwritten.
//
//go:norace
func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
n := len(m.nat.limbs)
mLimbs := m.nat.limbs[:n]
Expand Down Expand Up @@ -946,11 +931,13 @@ func addMulVVW(z, x []uint, y uint) (carry uint) {
//
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
//
//go:norace
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
if m.odd {
// A Montgomery multiplication by a value out of the Montgomery domain
// takes the result out of Montgomery representation.
xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m
xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
}
n := len(m.nat.limbs)
Expand Down Expand Up @@ -1009,6 +996,8 @@ func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
// to the size of m and overwritten. x must already be reduced modulo m.
//
// m must be odd, or Exp will panic.
//
//go:norace
func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
if !m.odd {
panic("bigmod: modulus for Exp must be odd")
Expand All @@ -1025,7 +1014,7 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
}
table[0].Set(x).montgomeryRepresentation(m)
table[0].set(x).montgomeryRepresentation(m)
for i := 1; i < len(table); i++ {
table[i].montgomeryMul(table[i-1], table[0], m)
}
Expand Down Expand Up @@ -1071,8 +1060,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
// For short exponents, precomputing a table and using a window like in Exp
// doesn't pay off. Instead, we do a simple conditional square-and-multiply
// chain, skipping the initial run of zeroes.
xR := NewNat().Set(x).montgomeryRepresentation(m)
out.Set(xR)
xR := NewNat().set(x).montgomeryRepresentation(m)
out.set(xR)
for i := bits.UintSize - bits.Len(e) + 1; i < bits.UintSize; i++ {
out.montgomeryMul(out, out, m)
if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 {
Expand All @@ -1088,6 +1077,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
//
// a must be reduced modulo m, but doesn't need to have the same size. The
// output will be resized to the size of m and overwritten.
//
//go:norace
func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
// This is the extended binary GCD algorithm described in the Handbook of
// Applied Cryptography, Algorithm 14.61, adapted by BoringSSL to bound
Expand Down Expand Up @@ -1121,7 +1112,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
return x, false
}

u := NewNat().Set(a).ExpandFor(m)
u := NewNat().set(a).ExpandFor(m)
v := m.Nat()

A := NewNat().reset(len(m.nat.limbs))
Expand All @@ -1148,7 +1139,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
// If both u and v are odd, subtract the smaller from the larger.
// If u = v, we need to subtract from v to hit the modified exit condition.
if u.IsOdd() == yes && v.IsOdd() == yes {
if v.CmpGeq(u) == no {
if v.cmpGeq(u) == no {
u.sub(v)
A.Add(C, m)
B.Add(D, &Modulus{nat: a})
Expand Down Expand Up @@ -1189,7 +1180,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
if u.IsOne() == no {
return x, false
}
return x.Set(A), true
return x.set(A), true
}
}
}
Expand Down
31 changes: 31 additions & 0 deletions internal/bigmod/nat_extension.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package bigmod

func (x *Nat) Set(y *Nat) *Nat {
return x.set(y)
}

// SetOverflowedBytes assigns x = (b mode (m-1)) + 1, where b is a slice of big-endian bytes.
//
// The output will be resized to the size of m and overwritten.
//
//go:norace
func (x *Nat) SetOverflowedBytes(b []byte, m *Modulus) *Nat {
mMinusOne := NewNat().set(m.nat)
mMinusOne.limbs[0]-- // due to m is odd, so we can safely subtract 1
mMinusOneM, _ := NewModulus(mMinusOne.Bytes(m))
one := NewNat().resetFor(m)
one.limbs[0] = 1
x.resetToBytes(b)
x = NewNat().Mod(x, mMinusOneM) // x = x mod (m-1)
x.add(one) // we can safely add 1, no need to check overflow
return x
}

// CmpGeq returns 1 if x >= y, and 0 otherwise.
//
// Both operands must have the same announced length.
//
//go:norace
func (x *Nat) CmpGeq(y *Nat) choice {
return x.cmpGeq(y)
}
10 changes: 5 additions & 5 deletions internal/bigmod/nat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {

func testModAddCommutative(a *Nat, b *Nat) bool {
m := maxModulus(uint(len(a.limbs)))
aPlusB := new(Nat).Set(a)
aPlusB := new(Nat).set(a)
aPlusB.Add(b, m)
bPlusA := new(Nat).Set(b)
bPlusA := new(Nat).set(b)
bPlusA.Add(a, m)
return aPlusB.Equal(bPlusA) == 1
}
Expand All @@ -77,7 +77,7 @@ func TestModAddCommutative(t *testing.T) {

func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
m := maxModulus(uint(len(a.limbs)))
original := new(Nat).Set(a)
original := new(Nat).set(a)
a.Sub(b, m)
a.Add(b, m)
return a.Equal(original) == 1
Expand All @@ -97,9 +97,9 @@ func TestMontgomeryRoundtrip(t *testing.T) {
aPlusOne := new(big.Int).SetBytes(natBytes(a))
aPlusOne.Add(aPlusOne, big.NewInt(1))
m, _ := NewModulus(aPlusOne.Bytes())
monty := new(Nat).Set(a)
monty := new(Nat).set(a)
monty.montgomeryRepresentation(m)
aAgain := new(Nat).Set(monty)
aAgain := new(Nat).set(monty)
aAgain.montgomeryMul(monty, one, m)
if a.Equal(aAgain) != 1 {
t.Errorf("%v != %v", a, aAgain)
Expand Down

0 comments on commit 865159d

Please sign in to comment.