From 3d8b3555142a679e338f7c12d86e206e1f182cf4 Mon Sep 17 00:00:00 2001 From: mantre Date: Thu, 28 Nov 2024 21:06:01 +0800 Subject: [PATCH] fix(consensus): verify vote if not exists --- consensus/log/log_test.go | 4 ++-- consensus/voteset/binary_voteset.go | 22 +++++++++++++--------- consensus/voteset/block_voteset.go | 22 +++++++++++++--------- types/vote/vote_type.go | 6 +++++- 4 files changed, 33 insertions(+), 21 deletions(-) diff --git a/consensus/log/log_test.go b/consensus/log/log_test.go index 971b960f2..ab5b86da0 100644 --- a/consensus/log/log_test.go +++ b/consensus/log/log_test.go @@ -64,14 +64,14 @@ func TestAddInvalidVoteType(t *testing.T) { log := NewLog() log.MoveToNewHeight(cmt.Validators()) - data, _ := hex.DecodeString("A701050218320301045820BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB" + + data, _ := hex.DecodeString("A7010F0218320301045820BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB" + "055501AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA06f607f6") invVote := new(vote.Vote) err := invVote.UnmarshalCBOR(data) assert.NoError(t, err) added, err := log.AddVote(invVote) - assert.Error(t, err) + assert.ErrorContains(t, err, "unexpected vote type: 15") assert.False(t, added) assert.False(t, log.HasVote(invVote.Hash())) } diff --git a/consensus/voteset/binary_voteset.go b/consensus/voteset/binary_voteset.go index aca2838bd..066144ce0 100644 --- a/consensus/voteset/binary_voteset.go +++ b/consensus/voteset/binary_voteset.go @@ -90,29 +90,33 @@ func (vs *BinaryVoteSet) AllVotes() []*vote.Vote { // AddVote attempts to add a vote to the VoteSet. Returns an error if the vote is invalid. func (vs *BinaryVoteSet) AddVote(vote *vote.Vote) (bool, error) { - power, err := vs.voteSet.verifyVote(vote) - if err != nil { - return false, err - } + var dupErr error roundVotes := vs.mustGetRoundVotes(vote.CPRound()) - existingVote, ok := roundVotes.allVotes[vote.Signer()] - if ok { + existingVote, exists := roundVotes.allVotes[vote.Signer()] + if exists { if existingVote.Hash() == vote.Hash() { // The vote is already added return false, nil } // It is a duplicated vote - err = ErrDuplicatedVote - } else { + dupErr = ErrDuplicatedVote + } + + power, err := vs.voteSet.verifyVote(vote) + if err != nil { + return false, err + } + + if !exists { roundVotes.allVotes[vote.Signer()] = vote roundVotes.votedPower += power } roundVotes.addVote(vote, power) - return true, err + return true, dupErr } func (vs *BinaryVoteSet) HasOneThirdOfTotalPower(cpRound int16) bool { diff --git a/consensus/voteset/block_voteset.go b/consensus/voteset/block_voteset.go index eb5e6cd0e..e38eb5d43 100644 --- a/consensus/voteset/block_voteset.go +++ b/consensus/voteset/block_voteset.go @@ -70,21 +70,25 @@ func (vs *BlockVoteSet) AllVotes() []*vote.Vote { // AddVote attempts to add a vote to the VoteSet. Returns an error if the vote is invalid. func (vs *BlockVoteSet) AddVote(vote *vote.Vote) (bool, error) { - power, err := vs.voteSet.verifyVote(vote) - if err != nil { - return false, err - } + var dupErr error - existingVote, ok := vs.allVotes[vote.Signer()] - if ok { + existingVote, exists := vs.allVotes[vote.Signer()] + if exists { if existingVote.Hash() == vote.Hash() { // The vote is already added return false, nil } // It is a duplicated vote - err = ErrDuplicatedVote - } else { + dupErr = ErrDuplicatedVote + } + + power, err := vs.voteSet.verifyVote(vote) + if err != nil { + return false, err + } + + if !exists { vs.allVotes[vote.Signer()] = vote } @@ -95,7 +99,7 @@ func (vs *BlockVoteSet) AddVote(vote *vote.Vote) (bool, error) { vs.quorumHash = &h } - return true, err + return true, dupErr } // HasQuorumHash checks if there is a block that has received quorum votes (2/3+ of total power). diff --git a/types/vote/vote_type.go b/types/vote/vote_type.go index 1760101a0..fa26db3f5 100644 --- a/types/vote/vote_type.go +++ b/types/vote/vote_type.go @@ -1,5 +1,9 @@ package vote +import ( + "fmt" +) + type Type int const ( @@ -33,6 +37,6 @@ func (t Type) String() string { case VoteTypeCPDecided: return "DECIDED" default: - return ("invalid vote type") + return fmt.Sprintf("%d", t) } }