From 1c17db1b2ede3c890d243c1db47cc26931a71c5e Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Sun, 22 Sep 2024 16:16:34 +0200 Subject: [PATCH] Use a mask for BLS aggregation and improve caching fixes #592 --- blssig/aggregation.go | 64 +++++++++++++++++------------ certs/certs.go | 12 ++++-- emulator/instance.go | 35 +++++++++------- emulator/signing.go | 53 ++++++++++++++++++------ gpbft/api.go | 20 +++++++-- gpbft/gpbft.go | 34 ++++++++-------- gpbft/mock_host_test.go | 83 ++++++++------------------------------ gpbft/participant.go | 26 ++++++++---- gpbft/participant_test.go | 5 +++ gpbft/powertable.go | 8 ++++ host.go | 11 +---- sim/adversary/decide.go | 28 +++++++------ sim/adversary/withhold.go | 10 +++-- sim/ec.go | 23 +++++++---- sim/justification.go | 9 ++++- sim/signing/fake.go | 74 +++++++++++++++++++++------------ test/signing_suite_test.go | 44 +++++++++++++------- 17 files changed, 319 insertions(+), 220 deletions(-) diff --git a/blssig/aggregation.go b/blssig/aggregation.go index 2fc56562..7d160d4e 100644 --- a/blssig/aggregation.go +++ b/blssig/aggregation.go @@ -17,7 +17,12 @@ import ( // Max size of the point cache. const maxPointCacheSize = 10_000 -func (v *Verifier) Aggregate(pubkeys []gpbft.PubKey, signatures [][]byte) (_agg []byte, _err error) { +type aggregation struct { + mask *bdn.Mask + scheme *bdn.Scheme +} + +func (a *aggregation) Aggregate(mask []int, signatures [][]byte) (_agg []byte, _err error) { defer func() { status := measurements.AttrStatusSuccess if _err != nil { @@ -25,29 +30,31 @@ func (v *Verifier) Aggregate(pubkeys []gpbft.PubKey, signatures [][]byte) (_agg } if perr := recover(); perr != nil { - _err = fmt.Errorf("panicked aggregating public keys: %v\n%s", + _err = fmt.Errorf("panicked aggregating signatures: %v\n%s", perr, string(debug.Stack())) log.Error(_err) status = measurements.AttrStatusPanic } metrics.aggregate.Record( - context.TODO(), int64(len(pubkeys)), + context.TODO(), int64(len(mask)), metric.WithAttributes(status), ) }() - if len(pubkeys) != len(signatures) { + if len(mask) != len(signatures) { return nil, fmt.Errorf("lengths of pubkeys and sigs does not match %d != %d", - len(pubkeys), len(signatures)) + len(mask), len(signatures)) } - mask, err := v.pubkeysToMask(pubkeys) - if err != nil { - return nil, fmt.Errorf("converting public keys to mask: %w", err) + bdnMask := a.mask.Clone() + for _, bit := range mask { + if err := bdnMask.SetBit(bit, true); err != nil { + return nil, err + } } - aggSigPoint, err := v.scheme.AggregateSignatures(signatures, mask) + aggSigPoint, err := a.scheme.AggregateSignatures(signatures, bdnMask) if err != nil { return nil, fmt.Errorf("computing aggregate signature: %w", err) } @@ -59,7 +66,7 @@ func (v *Verifier) Aggregate(pubkeys []gpbft.PubKey, signatures [][]byte) (_agg return aggSig, nil } -func (v *Verifier) VerifyAggregate(msg []byte, signature []byte, pubkeys []gpbft.PubKey) (_err error) { +func (a *aggregation) VerifyAggregate(mask []int, msg []byte, signature []byte) (_err error) { defer func() { status := measurements.AttrStatusSuccess if _err != nil { @@ -75,25 +82,35 @@ func (v *Verifier) VerifyAggregate(msg []byte, signature []byte, pubkeys []gpbft } metrics.verifyAggregate.Record( - context.TODO(), int64(len(pubkeys)), + context.TODO(), int64(len(mask)), metric.WithAttributes(status), ) }() - mask, err := v.pubkeysToMask(pubkeys) - if err != nil { - return fmt.Errorf("converting public keys to mask: %w", err) + bdnMask := a.mask.Clone() + for _, bit := range mask { + if err := bdnMask.SetBit(bit, true); err != nil { + return err + } } - aggPubKey, err := v.scheme.AggregatePublicKeys(mask) + aggPubKey, err := a.scheme.AggregatePublicKeys(bdnMask) if err != nil { return fmt.Errorf("aggregating public keys: %w", err) } - return v.scheme.Verify(aggPubKey, msg, signature) + return a.scheme.Verify(aggPubKey, msg, signature) } -func (v *Verifier) pubkeysToMask(pubkeys []gpbft.PubKey) (*bdn.Mask, error) { +func (v *Verifier) Aggregate(pubkeys []gpbft.PubKey) (_agg gpbft.Aggregate, _err error) { + defer func() { + if perr := recover(); perr != nil { + _err = fmt.Errorf("panicked aggregating public keys: %v\n%s", + perr, string(debug.Stack())) + log.Error(_err) + } + }() + kPubkeys := make([]kyber.Point, 0, len(pubkeys)) for i, p := range pubkeys { point, err := v.pubkeyToPoint(p) @@ -105,13 +122,10 @@ func (v *Verifier) pubkeysToMask(pubkeys []gpbft.PubKey) (*bdn.Mask, error) { mask, err := bdn.NewMask(v.keyGroup, kPubkeys, nil) if err != nil { - return nil, fmt.Errorf("creating bdn mask: %w", err) - } - for i := range kPubkeys { - err := mask.SetBit(i, true) - if err != nil { - return nil, fmt.Errorf("setting mask bit %d: %w", i, err) - } + return nil, fmt.Errorf("creating key mask: %w", err) } - return mask, nil + return &aggregation{ + mask: mask, + scheme: v.scheme, + }, nil } diff --git a/certs/certs.go b/certs/certs.go index b22ca346..40c3a7c2 100644 --- a/certs/certs.go +++ b/certs/certs.go @@ -150,7 +150,8 @@ func verifyFinalityCertificateSignature(verifier gpbft.Verifier, powerTable gpbf return fmt.Errorf("failed to scale power table: %w", err) } - signers := make([]gpbft.PubKey, 0, len(powerTable)) + keys := powerTable.PublicKeys() + mask := make([]int, 0, len(powerTable)) var signerPowers int64 if err := cert.Signers.ForEach(func(i uint64) error { if i >= uint64(len(powerTable)) { @@ -165,7 +166,7 @@ func verifyFinalityCertificateSignature(verifier gpbft.Verifier, powerTable gpbf cert.GPBFTInstance, powerTable[i].ID) } signerPowers += power - signers = append(signers, powerTable[i].PubKey) + mask = append(mask, int(i)) return nil }); err != nil { return err @@ -192,7 +193,12 @@ func verifyFinalityCertificateSignature(verifier gpbft.Verifier, powerTable gpbf signedBytes = payload.MarshalForSigning(nn) } - if err := verifier.VerifyAggregate(signedBytes, cert.Signature, signers); err != nil { + aggregate, err := verifier.Aggregate(keys) + if err != nil { + return err + } + + if err := aggregate.VerifyAggregate(mask, signedBytes, cert.Signature); err != nil { return fmt.Errorf("invalid signature on finality certificate for instance %d: %w", cert.GPBFTInstance, err) } return nil diff --git a/emulator/instance.go b/emulator/instance.go index d5882259..4d711159 100644 --- a/emulator/instance.go +++ b/emulator/instance.go @@ -14,14 +14,15 @@ import ( // Instance represents a GPBFT instance capturing all the information necessary // for GPBFT to function, along with the final decision reached if any. type Instance struct { - t *testing.T - id uint64 - supplementalData gpbft.SupplementalData - proposal gpbft.ECChain - powerTable *gpbft.PowerTable - beacon []byte - decision *gpbft.Justification - signing Signing + t *testing.T + id uint64 + supplementalData gpbft.SupplementalData + proposal gpbft.ECChain + powerTable *gpbft.PowerTable + beacon []byte + decision *gpbft.Justification + signing Signing + aggregateVerifier gpbft.Aggregate } // NewInstance instantiates a new Instance for emulation. If absent, the @@ -58,7 +59,8 @@ func NewInstance(t *testing.T, id uint64, powerEntries gpbft.PowerEntries, propo } proposalChain, err := gpbft.NewChain(proposal[0], proposal[1:]...) require.NoError(t, err) - return &Instance{ + + i := &Instance{ t: t, id: id, powerTable: pt, @@ -68,11 +70,18 @@ func NewInstance(t *testing.T, id uint64, powerEntries gpbft.PowerEntries, propo Commitments: [32]byte{}, PowerTable: ptCid, }, - signing: AdhocSigning(), } + + i.SetSigning(AdhocSigning()) + return i } -func (i *Instance) SetSigning(signing Signing) { i.signing = signing } +func (i *Instance) SetSigning(signing Signing) { + var err error + i.signing = signing + i.aggregateVerifier, err = signing.Aggregate(i.powerTable.Entries.PublicKeys()) + require.NoError(i.t, err) +} func (i *Instance) Proposal() gpbft.ECChain { return i.proposal } func (i *Instance) GetDecision() *gpbft.Justification { return i.decision } func (i *Instance) ID() uint64 { return i.id } @@ -140,7 +149,6 @@ func (i *Instance) NewJustificationWithPayload(payload gpbft.Payload, from ...gp msg := i.signing.MarshalPayloadForSigning(networkName, &payload) qr := gpbft.QuorumResult{ Signers: make([]int, len(from)), - PubKeys: make([]gpbft.PubKey, len(from)), Signatures: make([][]byte, len(from)), } for j, actor := range from { @@ -150,10 +158,9 @@ func (i *Instance) NewJustificationWithPayload(payload gpbft.Payload, from ...gp signature, err := i.signing.Sign(context.Background(), entry.PubKey, msg) require.NoError(i.t, err) qr.Signatures[j] = signature - qr.PubKeys[j] = entry.PubKey qr.Signers[j] = index } - aggregate, err := i.signing.Aggregate(qr.PubKeys, qr.Signatures) + aggregate, err := i.aggregateVerifier.Aggregate(qr.Signers, qr.Signatures) require.NoError(i.t, err) return &gpbft.Justification{ Vote: payload, diff --git a/emulator/signing.go b/emulator/signing.go index 6979f618..c715506e 100644 --- a/emulator/signing.go +++ b/emulator/signing.go @@ -3,6 +3,7 @@ package emulator import ( "bytes" "context" + "encoding/binary" "errors" "hash/crc32" @@ -58,13 +59,22 @@ func (s adhocSigning) Verify(sender gpbft.PubKey, msg, got []byte) error { } } -func (s adhocSigning) Aggregate(signers []gpbft.PubKey, sigs [][]byte) ([]byte, error) { - if len(signers) != len(sigs) { +type aggregate struct { + keys []gpbft.PubKey + signing adhocSigning +} + +// Aggregate implements gpbft.Aggregate. +func (a *aggregate) Aggregate(signerMask []int, sigs [][]byte) ([]byte, error) { + if len(signerMask) != len(sigs) { return nil, errors.New("public keys and signatures length mismatch") } hasher := crc32.NewIEEE() - for i, signer := range signers { - if _, err := hasher.Write(signer); err != nil { + for i, bit := range signerMask { + if err := binary.Write(hasher, binary.BigEndian, uint64(bit)); err != nil { + return nil, err + } + if _, err := hasher.Write(a.keys[bit]); err != nil { return nil, err } if _, err := hasher.Write(sigs[i]); err != nil { @@ -74,16 +84,17 @@ func (s adhocSigning) Aggregate(signers []gpbft.PubKey, sigs [][]byte) ([]byte, return hasher.Sum(nil), nil } -func (s adhocSigning) VerifyAggregate(payload, got []byte, signers []gpbft.PubKey) error { - signatures := make([][]byte, len(signers)) +// VerifyAggregate implements gpbft.Aggregate. +func (a *aggregate) VerifyAggregate(signerMask []int, payload []byte, got []byte) error { + signatures := make([][]byte, len(signerMask)) var err error - for i, signer := range signers { - signatures[i], err = s.Sign(context.Background(), signer, payload) + for i, bit := range signerMask { + signatures[i], err = a.signing.Sign(context.Background(), a.keys[bit], payload) if err != nil { return err } } - want, err := s.Aggregate(signers, signatures) + want, err := a.Aggregate(signerMask, signatures) if err != nil { return err } @@ -93,23 +104,34 @@ func (s adhocSigning) VerifyAggregate(payload, got []byte, signers []gpbft.PubKe return nil } +func (s adhocSigning) Aggregate(keys []gpbft.PubKey) (gpbft.Aggregate, error) { + return &aggregate{keys: keys, + signing: s, + }, nil +} + func (s adhocSigning) MarshalPayloadForSigning(name gpbft.NetworkName, payload *gpbft.Payload) []byte { return payload.MarshalForSigning(name) } type erroneousSigning struct{} +type erroneousAggregate struct{} func (p erroneousSigning) Verify(gpbft.PubKey, []byte, []byte) error { return errors.New("err Verify") } -func (p erroneousSigning) VerifyAggregate([]byte, []byte, []gpbft.PubKey) error { +func (p erroneousAggregate) VerifyAggregate([]int, []byte, []byte) error { return errors.New("err VerifyAggregate") } -func (p erroneousSigning) Aggregate([]gpbft.PubKey, [][]byte) ([]byte, error) { +func (p erroneousAggregate) Aggregate([]int, [][]byte) ([]byte, error) { return nil, errors.New("err Aggregate") } + +func (p erroneousSigning) Aggregate([]gpbft.PubKey) (gpbft.Aggregate, error) { + return erroneousAggregate{}, nil +} func (p erroneousSigning) Sign(context.Context, gpbft.PubKey, []byte) ([]byte, error) { return nil, errors.New("err Sign") } @@ -119,9 +141,16 @@ func (p erroneousSigning) MarshalPayloadForSigning(gpbft.NetworkName, *gpbft.Pay } type panicSigning struct{} +type panicAggregate struct{} func (p panicSigning) Verify(gpbft.PubKey, []byte, []byte) error { panic("π") } func (p panicSigning) VerifyAggregate([]byte, []byte, []gpbft.PubKey) error { panic("π") } -func (p panicSigning) Aggregate([]gpbft.PubKey, [][]byte) ([]byte, error) { panic("π") } func (p panicSigning) Sign(context.Context, gpbft.PubKey, []byte) ([]byte, error) { panic("π") } func (p panicSigning) MarshalPayloadForSigning(gpbft.NetworkName, *gpbft.Payload) []byte { panic("π") } + +func (p panicSigning) Aggregate([]gpbft.PubKey) (gpbft.Aggregate, error) { + return panicAggregate{}, nil +} + +func (p panicAggregate) VerifyAggregate([]int, []byte, []byte) error { panic("π") } +func (p panicAggregate) Aggregate([]int, [][]byte) ([]byte, error) { panic("π") } diff --git a/gpbft/api.go b/gpbft/api.go index bc91118b..43e6b58a 100644 --- a/gpbft/api.go +++ b/gpbft/api.go @@ -100,15 +100,27 @@ type SigningMarshaler interface { MarshalPayloadForSigning(NetworkName, *Payload) []byte } +type Aggregate interface { + // Aggregates signatures from a participants. + // + // Implementations must be safe for concurrent use. + Aggregate(signerMask []int, sigs [][]byte) ([]byte, error) + // VerifyAggregate verifies an aggregate signature. + // + // Implementations must be safe for concurrent use. + VerifyAggregate(signerMask []int, payload, aggSig []byte) error +} + type Verifier interface { // Verifies a signature for the given public key. + // // Implementations must be safe for concurrent use. Verify(pubKey PubKey, msg, sig []byte) error - // Aggregates signatures from a participants. - Aggregate(pubKeys []PubKey, sigs [][]byte) ([]byte, error) - // VerifyAggregate verifies an aggregate signature. + // Return an Aggregate that can aggregate and verify aggregate signatures made by the given + // public keys. + // // Implementations must be safe for concurrent use. - VerifyAggregate(payload, aggSig []byte, signers []PubKey) error + Aggregate(pubKeys []PubKey) (Aggregate, error) } type Signatures interface { diff --git a/gpbft/gpbft.go b/gpbft/gpbft.go index c024d8ba..0385bccb 100644 --- a/gpbft/gpbft.go +++ b/gpbft/gpbft.go @@ -156,6 +156,8 @@ type instance struct { input ECChain // The power table for the base chain, used for power in this instance. powerTable PowerTable + // The aggregate signature verifier/aggregator. + aggregateVerifier Aggregate // The beacon value from the base chain, used for tickets in this instance. beacon []byte // Current round number. @@ -218,6 +220,7 @@ func newInstance( input ECChain, data *SupplementalData, powerTable PowerTable, + aggregateVerifier Aggregate, beacon []byte) (*instance, error) { if input.IsZero() { return nil, fmt.Errorf("input is empty") @@ -228,16 +231,17 @@ func newInstance( metrics.currentRound.Record(context.TODO(), 0) return &instance{ - participant: participant, - instanceID: instanceID, - input: input, - powerTable: powerTable, - beacon: beacon, - round: 0, - phase: INITIAL_PHASE, - supplementalData: data, - proposal: input, - value: ECChain{}, + participant: participant, + instanceID: instanceID, + input: input, + powerTable: powerTable, + aggregateVerifier: aggregateVerifier, + beacon: beacon, + round: 0, + phase: INITIAL_PHASE, + supplementalData: data, + proposal: input, + value: ECChain{}, candidates: map[ChainKey]struct{}{ input.BaseChain().Key(): {}, }, @@ -986,7 +990,7 @@ func (i *instance) alarmAfterSynchrony() time.Time { // Builds a justification for a value from a quorum result. func (i *instance) buildJustification(quorum QuorumResult, round uint64, phase Phase, value ECChain) *Justification { - aggSignature, err := quorum.Aggregate(i.participant.host) + aggSignature, err := quorum.Aggregate(i.aggregateVerifier) if err != nil { panic(fmt.Errorf("aggregating for phase %v: %v", phase, err)) } @@ -1174,12 +1178,11 @@ func (q *quorumState) CouldReachStrongQuorumFor(key ChainKey, withAdversary bool type QuorumResult struct { // Signers is an array of indexes into the powertable, sorted in increasing order Signers []int - PubKeys []PubKey Signatures [][]byte } -func (q QuorumResult) Aggregate(v Verifier) ([]byte, error) { - return v.Aggregate(q.PubKeys, q.Signatures) +func (q QuorumResult) Aggregate(v Aggregate) ([]byte, error) { + return v.Aggregate(q.Signers, q.Signatures) } func (q QuorumResult) SignersBitfield() bitfield.BitField { @@ -1216,7 +1219,6 @@ func (q *quorumState) FindStrongQuorumFor(key ChainKey) (QuorumResult, bool) { // Accumulate signers and signatures until they reach a strong quorum. signatures := make([][]byte, 0, len(chainSupport.signatures)) - pubkeys := make([]PubKey, 0, len(signatures)) var justificationPower int64 for i, idx := range signers { if idx >= len(q.powerTable.Entries) { @@ -1226,11 +1228,9 @@ func (q *quorumState) FindStrongQuorumFor(key ChainKey) (QuorumResult, bool) { entry := q.powerTable.Entries[idx] justificationPower += power signatures = append(signatures, chainSupport.signatures[entry.ID]) - pubkeys = append(pubkeys, entry.PubKey) if IsStrongQuorum(justificationPower, q.powerTable.ScaledTotal) { return QuorumResult{ Signers: signers[:i+1], - PubKeys: pubkeys, Signatures: signatures, }, true } diff --git a/gpbft/mock_host_test.go b/gpbft/mock_host_test.go index 4bbbca06..36b50602 100644 --- a/gpbft/mock_host_test.go +++ b/gpbft/mock_host_test.go @@ -21,29 +21,29 @@ func (_m *MockHost) EXPECT() *MockHost_Expecter { return &MockHost_Expecter{mock: &_m.Mock} } -// Aggregate provides a mock function with given fields: pubKeys, sigs -func (_m *MockHost) Aggregate(pubKeys []PubKey, sigs [][]byte) ([]byte, error) { - ret := _m.Called(pubKeys, sigs) +// Aggregate provides a mock function with given fields: pubKeys +func (_m *MockHost) Aggregate(pubKeys []PubKey) (Aggregate, error) { + ret := _m.Called(pubKeys) if len(ret) == 0 { panic("no return value specified for Aggregate") } - var r0 []byte + var r0 Aggregate var r1 error - if rf, ok := ret.Get(0).(func([]PubKey, [][]byte) ([]byte, error)); ok { - return rf(pubKeys, sigs) + if rf, ok := ret.Get(0).(func([]PubKey) (Aggregate, error)); ok { + return rf(pubKeys) } - if rf, ok := ret.Get(0).(func([]PubKey, [][]byte) []byte); ok { - r0 = rf(pubKeys, sigs) + if rf, ok := ret.Get(0).(func([]PubKey) Aggregate); ok { + r0 = rf(pubKeys) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]byte) + r0 = ret.Get(0).(Aggregate) } } - if rf, ok := ret.Get(1).(func([]PubKey, [][]byte) error); ok { - r1 = rf(pubKeys, sigs) + if rf, ok := ret.Get(1).(func([]PubKey) error); ok { + r1 = rf(pubKeys) } else { r1 = ret.Error(1) } @@ -58,24 +58,23 @@ type MockHost_Aggregate_Call struct { // Aggregate is a helper method to define mock.On call // - pubKeys []PubKey -// - sigs [][]byte -func (_e *MockHost_Expecter) Aggregate(pubKeys interface{}, sigs interface{}) *MockHost_Aggregate_Call { - return &MockHost_Aggregate_Call{Call: _e.mock.On("Aggregate", pubKeys, sigs)} +func (_e *MockHost_Expecter) Aggregate(pubKeys interface{}) *MockHost_Aggregate_Call { + return &MockHost_Aggregate_Call{Call: _e.mock.On("Aggregate", pubKeys)} } -func (_c *MockHost_Aggregate_Call) Run(run func(pubKeys []PubKey, sigs [][]byte)) *MockHost_Aggregate_Call { +func (_c *MockHost_Aggregate_Call) Run(run func(pubKeys []PubKey)) *MockHost_Aggregate_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]PubKey), args[1].([][]byte)) + run(args[0].([]PubKey)) }) return _c } -func (_c *MockHost_Aggregate_Call) Return(_a0 []byte, _a1 error) *MockHost_Aggregate_Call { +func (_c *MockHost_Aggregate_Call) Return(_a0 Aggregate, _a1 error) *MockHost_Aggregate_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockHost_Aggregate_Call) RunAndReturn(run func([]PubKey, [][]byte) ([]byte, error)) *MockHost_Aggregate_Call { +func (_c *MockHost_Aggregate_Call) RunAndReturn(run func([]PubKey) (Aggregate, error)) *MockHost_Aggregate_Call { _c.Call.Return(run) return _c } @@ -536,54 +535,6 @@ func (_c *MockHost_Verify_Call) RunAndReturn(run func(PubKey, []byte, []byte) er return _c } -// VerifyAggregate provides a mock function with given fields: payload, aggSig, signers -func (_m *MockHost) VerifyAggregate(payload []byte, aggSig []byte, signers []PubKey) error { - ret := _m.Called(payload, aggSig, signers) - - if len(ret) == 0 { - panic("no return value specified for VerifyAggregate") - } - - var r0 error - if rf, ok := ret.Get(0).(func([]byte, []byte, []PubKey) error); ok { - r0 = rf(payload, aggSig, signers) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// MockHost_VerifyAggregate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyAggregate' -type MockHost_VerifyAggregate_Call struct { - *mock.Call -} - -// VerifyAggregate is a helper method to define mock.On call -// - payload []byte -// - aggSig []byte -// - signers []PubKey -func (_e *MockHost_Expecter) VerifyAggregate(payload interface{}, aggSig interface{}, signers interface{}) *MockHost_VerifyAggregate_Call { - return &MockHost_VerifyAggregate_Call{Call: _e.mock.On("VerifyAggregate", payload, aggSig, signers)} -} - -func (_c *MockHost_VerifyAggregate_Call) Run(run func(payload []byte, aggSig []byte, signers []PubKey)) *MockHost_VerifyAggregate_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]byte), args[1].([]byte), args[2].([]PubKey)) - }) - return _c -} - -func (_c *MockHost_VerifyAggregate_Call) Return(_a0 error) *MockHost_VerifyAggregate_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockHost_VerifyAggregate_Call) RunAndReturn(run func([]byte, []byte, []PubKey) error) *MockHost_VerifyAggregate_Call { - _c.Call.Return(run) - return _c -} - // NewMockHost creates a new instance of MockHost. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockHost(t interface { diff --git a/gpbft/participant.go b/gpbft/participant.go index 07a1d8ee..8ec6144e 100644 --- a/gpbft/participant.go +++ b/gpbft/participant.go @@ -333,7 +333,7 @@ func (p *Participant) validateJustification(msg *GMessage, comt *committee) erro // Check justification power and signature. var justificationPower int64 - signers := make([]PubKey, 0) + signers := make([]int, 0) if err := msg.Justification.Signers.ForEach(func(bit uint64) error { if int(bit) >= len(comt.power.Entries) { return fmt.Errorf("invalid signer index: %d", bit) @@ -343,7 +343,7 @@ func (p *Participant) validateJustification(msg *GMessage, comt *committee) erro return fmt.Errorf("signer with ID %d has no power", comt.power.Entries[bit].ID) } justificationPower += power - signers = append(signers, comt.power.Entries[bit].PubKey) + signers = append(signers, int(bit)) return nil }); err != nil { return fmt.Errorf("failed to iterate over signers: %w", err) @@ -354,7 +354,7 @@ func (p *Participant) validateJustification(msg *GMessage, comt *committee) erro } payload := p.host.MarshalPayloadForSigning(p.host.NetworkName(), &msg.Justification.Vote) - if err := p.host.VerifyAggregate(payload, msg.Justification.Signature, signers); err != nil { + if err := comt.aggregateVerifier.VerifyAggregate(signers, payload, msg.Justification.Signature); err != nil { return fmt.Errorf("verification of the aggregate failed: %+v: %w", msg.Justification, err) } @@ -445,7 +445,7 @@ func (p *Participant) beginInstance() error { if err != nil { return err } - if p.gpbft, err = newInstance(p, p.currentInstance, chain, data, *comt.power, comt.beacon); err != nil { + if p.gpbft, err = newInstance(p, p.currentInstance, chain, data, *comt.power, comt.aggregateVerifier, comt.beacon); err != nil { return fmt.Errorf("failed creating new gpbft instance: %w", err) } if err := p.gpbft.Start(); err != nil { @@ -490,7 +490,18 @@ func (p *Participant) fetchCommittee(instance uint64, phase Phase) (*committee, if err := power.Validate(); err != nil { return nil, fmt.Errorf("instance %d: %w: invalid power: %w", instance, ErrValidationNoCommittee, err) } - comt = &committee{power: power, beacon: beacon} + + // NOTE: we're intentionally keeping participants here even if they have no + // effective power (after rounding power) to simplify things. The runtime cost is + // minimal and it means that the keys can be aggregated before any rounding is done. + // TODO: this is slow and under a lock, but we only want to do it once per + // instance... ideally we'd have a per-instance lock/once, but that probably isn't + // worth it. + agg, err := p.host.Aggregate(power.Entries.PublicKeys()) + if err != nil { + return nil, fmt.Errorf("failed to pre-compute aggregate mask for instance %d: %w: %w", instance, ErrValidationNoCommittee, err) + } + comt = &committee{power: power, beacon: beacon, aggregateVerifier: agg} p.committees[instance] = comt } return comt, nil @@ -564,8 +575,9 @@ func (p *Participant) trace(format string, args ...any) { // A power table and beacon value used as the committee inputs to an instance. type committee struct { - power *PowerTable - beacon []byte + power *PowerTable + beacon []byte + aggregateVerifier Aggregate } // A collection of messages queued for delivery for a future instance. diff --git a/gpbft/participant_test.go b/gpbft/participant_test.go index 07344a60..349692d5 100644 --- a/gpbft/participant_test.go +++ b/gpbft/participant_test.go @@ -96,9 +96,12 @@ func (pt *participantTestSubject) Log(format string, args ...any) { } func (pt *participantTestSubject) expectBeginInstance() { + publicKeys := pt.powerTable.Entries.PublicKeys() + // Prepare the test host. pt.host.On("GetProposalForInstance", pt.instance).Return(pt.supplementalData, pt.canonicalChain, nil) pt.host.On("GetCommitteeForInstance", pt.instance).Return(pt.powerTable, pt.beacon, nil).Once() + pt.host.On("Aggregate", publicKeys).Return(nil, nil) pt.host.On("Time").Return(pt.time) pt.host.On("NetworkName").Return(pt.networkName).Maybe() // We need to use `Maybe` here because `MarshalPayloadForSigning` may be called @@ -111,6 +114,7 @@ func (pt *participantTestSubject) expectBeginInstance() { // Expect calls to get the host state prior to beginning of an instance. pt.host.EXPECT().GetProposalForInstance(pt.instance) pt.host.EXPECT().GetCommitteeForInstance(pt.instance) + pt.host.EXPECT().Aggregate(publicKeys) pt.host.EXPECT().Time() // Expect alarm is set to 2X of configured delta. @@ -194,6 +198,7 @@ func (pt *participantTestSubject) mockValidTicket(target gpbft.PubKey, ticket gp func (pt *participantTestSubject) mockCommitteeForInstance(instance uint64, powerTable *gpbft.PowerTable, beacon []byte) { pt.host.On("GetCommitteeForInstance", instance).Return(powerTable, beacon, nil).Once() + pt.host.On("Aggregate", powerTable.Entries.PublicKeys()).Return(nil, nil) } func (pt *participantTestSubject) mockCommitteeUnavailableForInstance(instance uint64) { diff --git a/gpbft/powertable.go b/gpbft/powertable.go index c8a34f71..c052dfac 100644 --- a/gpbft/powertable.go +++ b/gpbft/powertable.go @@ -33,6 +33,14 @@ type PowerTable struct { ScaledTotal int64 } +func (e PowerEntries) PublicKeys() []PubKey { + keys := make([]PubKey, len(e)) + for i, e := range e { + keys[i] = e.PubKey + } + return keys +} + func (p *PowerEntry) Equal(o *PowerEntry) bool { return p.ID == o.ID && p.Power == o.Power && bytes.Equal(p.PubKey, o.PubKey) } diff --git a/host.go b/host.go index 1d5bdf87..37d5c9d8 100644 --- a/host.go +++ b/host.go @@ -726,13 +726,6 @@ func (h *gpbftHost) Verify(pubKey gpbft.PubKey, msg []byte, sig []byte) error { return h.verifier.Verify(pubKey, msg, sig) } -// Aggregates signatures from a participants. -func (h *gpbftHost) Aggregate(pubKeys []gpbft.PubKey, sigs [][]byte) ([]byte, error) { - return h.verifier.Aggregate(pubKeys, sigs) -} - -// VerifyAggregate verifies an aggregate signature. -// Implementations must be safe for concurrent use. -func (h *gpbftHost) VerifyAggregate(payload []byte, aggSig []byte, signers []gpbft.PubKey) error { - return h.verifier.VerifyAggregate(payload, aggSig, signers) +func (h *gpbftHost) Aggregate(pubKeys []gpbft.PubKey) (gpbft.Aggregate, error) { + return h.verifier.Aggregate(pubKeys) } diff --git a/sim/adversary/decide.go b/sim/adversary/decide.go index 41b6317f..4413382c 100644 --- a/sim/adversary/decide.go +++ b/sim/adversary/decide.go @@ -97,30 +97,32 @@ func (i *ImmediateDecide) StartInstanceAt(instance uint64, _when time.Time) erro } var ( - pubkeys []gpbft.PubKey - sigs [][]byte + mask []int + sigs [][]byte ) if err := signers.ForEach(func(j uint64) error { - pubkey := gpbft.PubKey("fake pubkeyaaaaa") - sig := []byte("fake sig") - if j < uint64(len(powertable.Entries)) { - pubkey = powertable.Entries[j].PubKey - var err error - sig, err = i.host.Sign(context.Background(), pubkey, sigPayload) - if err != nil { - return err - } + if j >= uint64(len(powertable.Entries)) { + return nil + } + pubkey := powertable.Entries[j].PubKey + sig, err := i.host.Sign(context.Background(), pubkey, sigPayload) + if err != nil { + return err } - pubkeys = append(pubkeys, pubkey) + mask = append(mask, int(j)) sigs = append(sigs, sig) return nil }); err != nil { panic(err) } - aggregatedSig, err := i.host.Aggregate(pubkeys, sigs) + agg, err := i.host.Aggregate(powertable.Entries.PublicKeys()) + if err != nil { + panic(err) + } + aggregatedSig, err := agg.Aggregate(mask, sigs) if err != nil { panic(err) } diff --git a/sim/adversary/withhold.go b/sim/adversary/withhold.go index 89a1cd89..b46e7831 100644 --- a/sim/adversary/withhold.go +++ b/sim/adversary/withhold.go @@ -106,15 +106,19 @@ func (w *WithholdCommit) StartInstanceAt(instance uint64, _when time.Time) error sort.Ints(signers) signatures := make([][]byte, 0) - pubKeys := make([]gpbft.PubKey, 0) + mask := make([]int, 0) prepareMarshalled := w.host.MarshalPayloadForSigning(w.host.NetworkName(), &preparePayload) for _, signerIndex := range signers { entry := powertable.Entries[signerIndex] signatures = append(signatures, w.sign(entry.PubKey, prepareMarshalled)) - pubKeys = append(pubKeys, entry.PubKey) + mask = append(mask, signerIndex) justification.Signers.Set(uint64(signerIndex)) } - justification.Signature, err = w.host.Aggregate(pubKeys, signatures) + agg, err := w.host.Aggregate(powertable.Entries.PublicKeys()) + if err != nil { + panic(err) + } + justification.Signature, err = agg.Aggregate(mask, signatures) if err != nil { panic(err) } diff --git a/sim/ec.go b/sim/ec.go index e6b9c81b..3d39453f 100644 --- a/sim/ec.go +++ b/sim/ec.go @@ -36,8 +36,9 @@ type ECInstance struct { // SupplementalData is the additional data for this instance. SupplementalData *gpbft.SupplementalData - ec *simEC - decisions map[gpbft.ActorID]*gpbft.Justification + ec *simEC + decisions map[gpbft.ActorID]*gpbft.Justification + aggregateVerifier gpbft.Aggregate } type errGroup []error @@ -64,6 +65,12 @@ func (ec *simEC) BeginInstance(baseChain gpbft.ECChain, pt *gpbft.PowerTable) *E // Note a real beacon value will come from a finalised chain with some lookback. beacon := baseChain.Head().Key nextInstanceID := uint64(ec.Len()) + + agg, err := ec.verifier.Aggregate(pt.Entries.PublicKeys()) + if err != nil { + panic(err) + } + instance := &ECInstance{ Instance: nextInstanceID, BaseChain: baseChain, @@ -72,8 +79,9 @@ func (ec *simEC) BeginInstance(baseChain gpbft.ECChain, pt *gpbft.PowerTable) *E SupplementalData: &gpbft.SupplementalData{ PowerTable: gpbft.MakeCid([]byte(fmt.Sprintf("supp-data-pt@%d", nextInstanceID))), }, - ec: ec, - decisions: make(map[gpbft.ActorID]*gpbft.Justification), + ec: ec, + aggregateVerifier: agg, + decisions: make(map[gpbft.ActorID]*gpbft.Justification), } ec.instances = append(ec.instances, instance) return instance @@ -123,14 +131,14 @@ func (eci *ECInstance) validateDecision(decision *gpbft.Justification) error { // Extract signers. justificationPower := gpbft.NewStoragePower(0) - signers := make([]gpbft.PubKey, 0) + signers := make([]int, 0) powerTable := eci.PowerTable if err := decision.Signers.ForEach(func(bit uint64) error { if int(bit) >= len(powerTable.Entries) { return fmt.Errorf("invalid signer index: %d", bit) } justificationPower = big.Add(justificationPower, powerTable.Entries[bit].Power) - signers = append(signers, powerTable.Entries[bit].PubKey) + signers = append(signers, int(bit)) return nil }); err != nil { return fmt.Errorf("failed to iterate over signers: %w", err) @@ -144,7 +152,8 @@ func (eci *ECInstance) validateDecision(decision *gpbft.Justification) error { } // Verify aggregate signature payload := eci.ec.verifier.MarshalPayloadForSigning(eci.ec.networkName, &decision.Vote) - if err := eci.ec.verifier.VerifyAggregate(payload, decision.Signature, signers); err != nil { + + if err := eci.aggregateVerifier.VerifyAggregate(signers, payload, decision.Signature); err != nil { return fmt.Errorf("invalid aggregate signature: %v: %w", decision, err) } diff --git a/sim/justification.go b/sim/justification.go index 336c2f02..d75e2f53 100644 --- a/sim/justification.go +++ b/sim/justification.go @@ -67,6 +67,8 @@ func MakeJustification(backend signing.Backend, nn gpbft.NetworkName, chain gpbf slices.SortFunc(votes, func(a, b vote) int { return cmp.Compare(a.index, b.index) }) + signers = signers[:len(votes)] + slices.Sort(signers) pks := make([]gpbft.PubKey, len(votes)) sigs := make([][]byte, len(votes)) for i, vote := range votes { @@ -74,7 +76,12 @@ func MakeJustification(backend signing.Backend, nn gpbft.NetworkName, chain gpbf sigs[i] = vote.sig } - sig, err := backend.Aggregate(pks, sigs) + agg, err := backend.Aggregate(powerTable.PublicKeys()) + if err != nil { + return nil, err + } + + sig, err := agg.Aggregate(signers, sigs) if err != nil { return nil, err } diff --git a/sim/signing/fake.go b/sim/signing/fake.go index c962101e..2d107e89 100644 --- a/sim/signing/fake.go +++ b/sim/signing/fake.go @@ -76,35 +76,18 @@ func (s *FakeBackend) Verify(signer gpbft.PubKey, msg, sig []byte) error { } } -func (*FakeBackend) Aggregate(signers []gpbft.PubKey, sigs [][]byte) ([]byte, error) { - if len(signers) != len(sigs) { - return nil, errors.New("public keys and signatures length mismatch") - } - hasher := sha256.New() - for i, signer := range signers { +func (s *FakeBackend) Aggregate(keys []gpbft.PubKey) (gpbft.Aggregate, error) { + for i, signer := range keys { if len(signer) != 16 { - return nil, fmt.Errorf("wrong signer pubkey length: %d != 16", len(signer)) + return nil, fmt.Errorf("wrong signer %d pubkey length: %d != 16", i, len(signer)) } - hasher.Write(signer) - hasher.Write(sigs[i]) } - return hasher.Sum(nil), nil -} -func (s *FakeBackend) VerifyAggregate(payload, aggSig []byte, signers []gpbft.PubKey) error { - hasher := sha256.New() - for _, signer := range signers { - sig, err := s.generateSignature(signer, payload) - if err != nil { - return err - } - hasher.Write(signer) - hasher.Write(sig) - } - if !bytes.Equal(aggSig, hasher.Sum(nil)) { - return errors.New("signature is not valid") - } - return nil + return &fakeAggregate{ + keys: keys, + backend: s, + }, nil + } func (v *FakeBackend) MarshalPayloadForSigning(nn gpbft.NetworkName, p *gpbft.Payload) []byte { @@ -142,3 +125,44 @@ func (v *FakeBackend) MarshalPayloadForSigning(nn gpbft.NetworkName, p *gpbft.Pa } return buf.Bytes() } + +type fakeAggregate struct { + keys []gpbft.PubKey + backend *FakeBackend +} + +// Aggregate implements gpbft.Aggregate. +func (f *fakeAggregate) Aggregate(signerMask []int, sigs [][]byte) ([]byte, error) { + if len(signerMask) != len(sigs) { + return nil, errors.New("public keys and signatures length mismatch") + } + hasher := sha256.New() + for i, bit := range signerMask { + if bit >= len(f.keys) { + return nil, fmt.Errorf("signer %d out of range", bit) + } + binary.Write(hasher, binary.BigEndian, int64(bit)) + hasher.Write(f.keys[bit]) + hasher.Write(sigs[i]) + } + return hasher.Sum(nil), nil +} + +// VerifyAggregate implements gpbft.Aggregate. +func (f *fakeAggregate) VerifyAggregate(signerMask []int, payload []byte, aggSig []byte) error { + hasher := sha256.New() + for _, bit := range signerMask { + signer := f.keys[bit] + sig, err := f.backend.generateSignature(signer, payload) + if err != nil { + return err + } + binary.Write(hasher, binary.BigEndian, int64(bit)) + hasher.Write(signer) + hasher.Write(sig) + } + if !bytes.Equal(aggSig, hasher.Sum(nil)) { + return errors.New("signature is not valid") + } + return nil +} diff --git a/test/signing_suite_test.go b/test/signing_suite_test.go index 4d5680a3..9d26a3f1 100644 --- a/test/signing_suite_test.go +++ b/test/signing_suite_test.go @@ -90,44 +90,60 @@ func (s *SigningTestSuite) TestAggregateAndVerify() { pubKey2, signer2 := s.signerTestSubject(s.T()) pubKeys := []gpbft.PubKey{pubKey1, pubKey2} + aggregator, err := s.verifier.Aggregate(pubKeys) + require.NoError(s.T(), err) + + mask := []int{0, 1} sigs := make([][]byte, len(pubKeys)) - var err error sigs[0], err = signer1.Sign(ctx, pubKey1, msg) require.NoError(s.T(), err) sigs[1], err = signer2.Sign(ctx, pubKey2, msg) require.NoError(s.T(), err) - aggSig, err := s.verifier.Aggregate(pubKeys, sigs) + aggSig, err := aggregator.Aggregate(mask, sigs) require.NoError(t, err) - err = s.verifier.VerifyAggregate(msg, aggSig, pubKeys) + err = aggregator.VerifyAggregate(mask, msg, aggSig) require.NoError(t, err) - aggSig, err = s.verifier.Aggregate(pubKeys[0:1], sigs[0:1]) + aggSig, err = aggregator.Aggregate(mask[0:1], sigs[0:1]) require.NoError(t, err) - err = s.verifier.VerifyAggregate(msg, aggSig, pubKeys) + err = aggregator.VerifyAggregate(mask, msg, aggSig) require.Error(t, err) - aggSig, err = s.verifier.Aggregate(pubKeys, [][]byte{sigs[0], sigs[0]}) + aggSig, err = aggregator.Aggregate(mask, [][]byte{sigs[0], sigs[0]}) require.NoError(t, err) - err = s.verifier.VerifyAggregate(msg, aggSig, pubKeys) + err = aggregator.VerifyAggregate(mask, msg, aggSig) require.Error(t, err) - err = s.verifier.VerifyAggregate(msg, []byte("bad sig"), pubKeys) + err = aggregator.VerifyAggregate(mask, msg, []byte("bad sig")) require.Error(t, err) - _, err = s.verifier.Aggregate(pubKeys, [][]byte{sigs[0]}) + _, err = aggregator.Aggregate(mask, [][]byte{sigs[0]}) require.Error(t, err, "Missmatched pubkeys and sigs lengths should fail") { pubKeys2 := slices.Clone(pubKeys) - pubKeys2[0] = slices.Clone(pubKeys2[0]) - pubKeys2[0] = pubKeys2[0][1:len(pubKeys2)] - _, err = s.verifier.Aggregate(pubKeys2, sigs) - require.Error(t, err, "damaged pubkey should error") + pubKey3, _ := s.signerTestSubject(s.T()) + pubKeys2[0] = pubKey3 + wrongKeyAggregator, err := s.verifier.Aggregate(pubKeys2) + require.NoError(t, err) - require.Error(t, s.verifier.VerifyAggregate(msg, aggSig, pubKeys2), "damaged pubkey should error") + require.Error(t, wrongKeyAggregator.VerifyAggregate(mask, msg, aggSig), "wrong pubkey should error") } + + t.Run("mask out of range", func(t *testing.T) { + _, err = aggregator.Aggregate([]int{0, 3}, [][]byte{sigs[0]}) + require.Error(t, err, "mask out of range") + }) + + t.Run("empty signature is always valid", func(t *testing.T) { + sig, err := aggregator.Aggregate([]int{}, [][]byte{}) + require.NoError(t, err) + + err = aggregator.VerifyAggregate([]int{}, []byte("anything"), sig) + require.NoError(t, err) + }) }