diff --git a/gjkr/member.go b/gjkr/member.go index d21199c..3ee9c64 100644 --- a/gjkr/member.go +++ b/gjkr/member.go @@ -11,8 +11,8 @@ type memberIndex uint16 // phase of the protocol. type member struct { memberIndex memberIndex + sessionID string group *group - evidenceLog evidenceLog logger Logger diff --git a/gjkr/message.go b/gjkr/message.go index fb23a94..8779c82 100644 --- a/gjkr/message.go +++ b/gjkr/message.go @@ -17,10 +17,15 @@ import "threshold.network/roast/ephemeral" // within the group. type ephemeralPublicKeyMessage struct { senderIndex memberIndex // i + sessionID string ephemeralPublicKeys map[memberIndex]*ephemeral.PublicKey // j -> Y_ij } -func (m *ephemeralPublicKeyMessage) senderIdx() memberIndex { +func (m *ephemeralPublicKeyMessage) getSenderIndex() memberIndex { return m.senderIndex } + +func (m *ephemeralPublicKeyMessage) getSessionID() string { + return m.sessionID +} diff --git a/gjkr/message_filter.go b/gjkr/message_filter.go index 7d553f1..1c8ea77 100644 --- a/gjkr/message_filter.go +++ b/gjkr/message_filter.go @@ -1,15 +1,33 @@ package gjkr +// filterForSession goes through the messages passed as a parameter and finds +// all messages sent for the given session ID. +func filterForSession[T interface{ getSessionID() string }]( + sessionID string, + list []T, +) []T { + result := make([]T, 0) + + for _, msg := range list { + if msg.getSessionID() == sessionID { + result = append(result, msg) + } + } + + return result +} + // findInactive goes through the messages passed as a parameter and finds all // inactive members for this set of messages. The function does not care if // the given member was already marked as inactive before. The function makes no // assumptions about the ordering of the list elements. -func findInactive[T interface{ senderIdx() memberIndex }]( - groupSize uint16, list []T, +func findInactive[T interface{ getSenderIndex() memberIndex }]( + groupSize uint16, + list []T, ) []memberIndex { senders := make(map[memberIndex]bool) for _, item := range list { - senders[item.senderIdx()] = true + senders[item.getSenderIndex()] = true } inactive := make([]memberIndex, 0) @@ -25,16 +43,16 @@ func findInactive[T interface{ senderIdx() memberIndex }]( // deduplicateBySender removes duplicated items for the given sender. It always // takes the first item that occurs for the given sender and ignores the // subsequent ones. -func deduplicateBySender[T interface{ senderIdx() memberIndex }]( +func deduplicateBySender[T interface{ getSenderIndex() memberIndex }]( list []T, ) []T { senders := make(map[memberIndex]bool) result := make([]T, 0) - for _, item := range list { - if _, exists := senders[item.senderIdx()]; !exists { - senders[item.senderIdx()] = true - result = append(result, item) + for _, msg := range list { + if _, exists := senders[msg.getSenderIndex()]; !exists { + senders[msg.getSenderIndex()] = true + result = append(result, msg) } } @@ -44,12 +62,12 @@ func deduplicateBySender[T interface{ senderIdx() memberIndex }]( func (m *symmetricKeyGeneratingMember) preProcessMessages( ephemeralPubKeyMessages []*ephemeralPublicKeyMessage, ) []*ephemeralPublicKeyMessage { - inactiveMembers := findInactive(m.group.groupSize, ephemeralPubKeyMessages) + forThisSession := filterForSession(m.sessionID, ephemeralPubKeyMessages) + + inactiveMembers := findInactive(m.group.groupSize, forThisSession) for _, ia := range inactiveMembers { m.group.markMemberAsInactive(ia) } - // TODO: validate session ID - - return deduplicateBySender(ephemeralPubKeyMessages) + return deduplicateBySender(forThisSession) } diff --git a/gjkr/message_filter_test.go b/gjkr/message_filter_test.go index 7f4f6f4..74e0224 100644 --- a/gjkr/message_filter_test.go +++ b/gjkr/message_filter_test.go @@ -6,6 +6,24 @@ import ( "threshold.network/roast/internal/testutils" ) +func TestFilterForSession(t *testing.T) { + msg0 := &ephemeralPublicKeyMessage{sessionID: "session-2", senderIndex: 1} + msg1 := &ephemeralPublicKeyMessage{sessionID: "session-1", senderIndex: 1} + msg2 := &ephemeralPublicKeyMessage{sessionID: "session-1", senderIndex: 2} + msg3 := &ephemeralPublicKeyMessage{sessionID: "session-2", senderIndex: 3} + + filtered := filterForSession("session-1", []*ephemeralPublicKeyMessage{ + msg0, msg1, msg2, msg3, + }) + + testutils.AssertDeepEqual( + t, + "filtered messages", + []*ephemeralPublicKeyMessage{msg1, msg2}, + filtered, + ) +} + func TestFindInactive(t *testing.T) { var tests = map[string]struct { groupSize uint16 @@ -80,7 +98,7 @@ func TestDeduplicateBySender(t *testing.T) { deduplicatedSenders := make([]memberIndex, 0) for _, msg := range deduplicateBySender(messages) { - deduplicatedSenders = append(deduplicatedSenders, msg.senderIdx()) + deduplicatedSenders = append(deduplicatedSenders, msg.getSenderIndex()) } testutils.AssertUint16SlicesEqual( t, diff --git a/internal/testutils/assert.go b/internal/testutils/assert.go index 8c0a0b5..e8523ce 100644 --- a/internal/testutils/assert.go +++ b/internal/testutils/assert.go @@ -3,6 +3,7 @@ package testutils import ( "fmt" "math/big" + "reflect" "testing" "golang.org/x/exp/slices" @@ -134,3 +135,19 @@ func AssertUint16SlicesEqual[T ~uint16]( ) } } + +func AssertDeepEqual( + t *testing.T, + description string, + expected any, + actual any, +) { + if !reflect.DeepEqual(expected, actual) { + t.Errorf( + "unexpected %s\nexpected: %v\nactual: %v\n", + description, + expected, + actual, + ) + } +}