From bc7260c4187e13be881527279d0c7093bb6a44f4 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Wed, 24 May 2023 19:18:40 +0200 Subject: [PATCH 01/11] Add send_leave --- handleleave.go | 124 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) diff --git a/handleleave.go b/handleleave.go index 4cc582bf..bbb87f17 100644 --- a/handleleave.go +++ b/handleleave.go @@ -15,9 +15,12 @@ package gomatrixserverlib import ( + "context" "fmt" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) type HandleMakeLeaveResponse struct { @@ -37,6 +40,7 @@ type HandleMakeLeaveInput struct { BuildEventTemplate func(*ProtoEvent) (PDU, []PDU, error) } +// HandleMakeLeave handles requests to `/make_leave` func HandleMakeLeave(input HandleMakeLeaveInput) (*HandleMakeLeaveResponse, error) { if input.UserID.Domain() != input.RequestOrigin { @@ -98,3 +102,123 @@ func HandleMakeLeave(input HandleMakeLeaveInput) (*HandleMakeLeaveResponse, erro } return &makeLeaveResponse, nil } + +type LatestStateQuerier interface { + LatestState(ctx context.Context, roomID spec.RoomID, userID spec.UserID) ([]PDU, error) +} + +// HandleSendLeave handles requests to `/send_leave +// Returns the parsed event and an error. +func HandleSendLeave(ctx context.Context, + requestContent []byte, + origin spec.ServerName, + roomVersion RoomVersion, + eventID, roomID string, + querier LatestStateQuerier, + verifier JSONVerifier, +) (PDU, error) { + + rID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } + + verImpl, err := GetRoomVersion(roomVersion) + if err != nil { + return nil, spec.UnsupportedRoomVersion(fmt.Sprintf("QueryRoomVersionForRoom returned unknown version: %s", roomVersion)) + + } + + // Decode the event JSON from the request. + event, err := verImpl.NewEventFromUntrustedJSON(requestContent) + switch err.(type) { + case BadJSONError: + return nil, spec.BadJSON(err.Error()) + case nil: + default: + return nil, spec.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()) + } + + // Check that the room ID is correct. + if (event.RoomID()) != roomID { + return nil, spec.BadJSON("The room ID in the request path must match the room ID in the leave event JSON") + } + + // Check that the event ID is correct. + if event.EventID() != eventID { + return nil, spec.BadJSON("The event ID in the request path must match the event ID in the leave event JSON") + + } + + if event.StateKey() == nil || event.StateKeyEquals("") { + return nil, spec.BadJSON("No state key was provided in the leave event.") + } + if !event.StateKeyEquals(event.Sender()) { + return nil, spec.BadJSON("Event state key must match the event sender.") + } + + leavingUser, err := spec.NewUserID(*event.StateKey(), true) + if err != nil { + return nil, spec.Forbidden("The leaving user ID is invalid") + } + + // Check that the sender belongs to the server that is sending us + // the request. By this point we've already asserted that the sender + // and the state key are equal so we don't need to check both. + sender, err := spec.NewUserID(event.Sender(), true) + if err != nil { + return nil, spec.Forbidden("The sender of the join is invalid") + } + if sender.Domain() != origin { + return nil, spec.Forbidden("The sender does not match the server that originated the request") + } + + stateEvents, err := querier.LatestState(ctx, *rID, *leavingUser) + if err != nil { + return nil, err + } + // handle cases we can no-op + switch { + case len(stateEvents) == 0: + return nil, nil + case len(stateEvents) == 1: + if mem, merr := stateEvents[0].Membership(); merr == nil && mem == spec.Leave { + return nil, nil + } + case event.EventID() == stateEvents[0].EventID(): + return nil, nil + } + + // Check that the event is signed by the server sending the request. + redacted, err := verImpl.RedactEventJSON(event.JSON()) + if err != nil { + logrus.WithError(err).Errorf("XXX: leave.go") + return nil, spec.BadJSON("The event JSON could not be redacted") + } + verifyRequests := []VerifyJSONRequest{{ + ServerName: sender.Domain(), + Message: redacted, + AtTS: event.OriginServerTS(), + StrictValidityChecking: true, + }} + verifyResults, err := verifier.VerifyJSONs(ctx, verifyRequests) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("keys.VerifyJSONs failed") + return nil, spec.InternalServerError{} + } + if verifyResults[0].Error != nil { + return nil, spec.Forbidden("The leave must be signed by the server it originated on") + } + + // check membership is set to leave + mem, err := event.Membership() + if err != nil { + util.GetLogger(ctx).WithError(err).Error("event.Membership failed") + return nil, spec.BadJSON("missing content.membership key") + } + if mem != spec.Leave { + return nil, spec.BadJSON("The membership in the event content must be set to leave") + } + + return event, nil +} From 9e1377a918b611654648ca96f132cb4b1a715e36 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 25 May 2023 15:14:08 +0200 Subject: [PATCH 02/11] Add HandleSendLeave tests --- handleleave.go | 8 +- handleleave_test.go | 180 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 185 insertions(+), 3 deletions(-) diff --git a/handleleave.go b/handleleave.go index bbb87f17..516f1f52 100644 --- a/handleleave.go +++ b/handleleave.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" - "github.com/sirupsen/logrus" ) type HandleMakeLeaveResponse struct { @@ -126,7 +125,6 @@ func HandleSendLeave(ctx context.Context, verImpl, err := GetRoomVersion(roomVersion) if err != nil { return nil, spec.UnsupportedRoomVersion(fmt.Sprintf("QueryRoomVersionForRoom returned unknown version: %s", roomVersion)) - } // Decode the event JSON from the request. @@ -150,6 +148,7 @@ func HandleSendLeave(ctx context.Context, } + // Sanity check that we really received a state event if event.StateKey() == nil || event.StateKeyEquals("") { return nil, spec.BadJSON("No state key was provided in the leave event.") } @@ -180,19 +179,22 @@ func HandleSendLeave(ctx context.Context, // handle cases we can no-op switch { case len(stateEvents) == 0: + // we weren't joined at all return nil, nil case len(stateEvents) == 1: + // We are/were joined/invited/banned or something if mem, merr := stateEvents[0].Membership(); merr == nil && mem == spec.Leave { return nil, nil } case event.EventID() == stateEvents[0].EventID(): + // we already processed this event return nil, nil } // Check that the event is signed by the server sending the request. redacted, err := verImpl.RedactEventJSON(event.JSON()) if err != nil { - logrus.WithError(err).Errorf("XXX: leave.go") + util.GetLogger(ctx).WithError(err).Errorf("unable to redact event") return nil, spec.BadJSON("The event JSON could not be redacted") } verifyRequests := []VerifyJSONRequest{{ diff --git a/handleleave_test.go b/handleleave_test.go index 197eadea..05f121a2 100644 --- a/handleleave_test.go +++ b/handleleave_test.go @@ -1,6 +1,7 @@ package gomatrixserverlib import ( + "context" "crypto/rand" "fmt" "testing" @@ -221,3 +222,182 @@ func TestHandleMakeLeave(t *testing.T) { }) } } + +type dummyQuerier struct { + pdus []PDU +} + +func (d dummyQuerier) LatestState(ctx context.Context, roomID spec.RoomID, userID spec.UserID) ([]PDU, error) { + return d.pdus, nil +} + +type noopJSONVerifier struct { + err error + results []VerifyJSONResult +} + +func (v *noopJSONVerifier) VerifyJSONs(ctx context.Context, requests []VerifyJSONRequest) ([]VerifyJSONResult, error) { + return v.results, v.err +} + +func TestHandleSendLeave(t *testing.T) { + type args struct { + ctx context.Context + requestContent []byte + origin spec.ServerName + roomVersion RoomVersion + eventID string + roomID string + querier LatestStateQuerier + verifier JSONVerifier + } + + _, sk, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("Failed generating key: %v", err) + } + keyID := KeyID("ed25519:1234") + + validUser, _ := spec.NewUserID("@valid:localhost", true) + + stateKey := "" + eb := MustGetRoomVersion(RoomVersionV10).NewEventBuilderFromProtoEvent(&ProtoEvent{ + Sender: validUser.String(), + RoomID: "!valid:localhost", + Type: spec.MRoomCreate, + StateKey: &stateKey, + PrevEvents: []interface{}{}, + AuthEvents: []interface{}{}, + Depth: 0, + Content: spec.RawJSON(`{"creator":"@user:local","m.federate":true,"room_version":"10"}`), + Unsigned: spec.RawJSON(""), + }) + createEvent, err := eb.Build(time.Now(), "localhost", keyID, sk) + if err != nil { + t.Fatalf("Failed building create event: %v", err) + } + + stateKey = validUser.String() + eb = MustGetRoomVersion(RoomVersionV10).NewEventBuilderFromProtoEvent(&ProtoEvent{ + Sender: validUser.String(), + RoomID: "!valid:localhost", + Type: spec.MRoomMember, + StateKey: &stateKey, + PrevEvents: []interface{}{}, + AuthEvents: []interface{}{}, + Depth: 0, + Content: spec.RawJSON(`{"membership":"leave"}`), + Unsigned: spec.RawJSON(""), + }) + leaveEvent, err := eb.Build(time.Now(), "localhost", keyID, sk) + if err != nil { + t.Fatalf("Failed building create event: %v", err) + } + + eb = MustGetRoomVersion(RoomVersionV10).NewEventBuilderFromProtoEvent(&ProtoEvent{ + Sender: validUser.String(), + RoomID: "!valid:localhost", + Type: spec.MRoomMember, + StateKey: &stateKey, + PrevEvents: []interface{}{}, + AuthEvents: []interface{}{}, + Depth: 0, + Content: spec.RawJSON(`{"membership":"join"}`), + Unsigned: spec.RawJSON(""), + }) + joinEvent, err := eb.Build(time.Now(), "localhost", keyID, sk) + if err != nil { + t.Fatalf("Failed building create event: %v", err) + } + + tests := []struct { + name string + args args + want PDU + wantErr assert.ErrorAssertionFunc + }{ + { + name: "invalid roomID", + args: args{roomID: "@notvalid:localhost"}, + wantErr: assert.Error, + }, + { + name: "invalid room version", + args: args{roomID: "!notvalid:localhost", roomVersion: "-1"}, + wantErr: assert.Error, + }, + { + name: "invalid content body", + args: args{roomID: "!notvalid:localhost", roomVersion: "1", requestContent: []byte("{")}, + wantErr: assert.Error, + }, + { + name: "not canonical JSON", + args: args{roomID: "!notvalid:localhost", roomVersion: "10", requestContent: []byte(`{"int":9007199254740992}`)}, // number to large, not canonical json + wantErr: assert.Error, + }, + { + name: "wrong roomID in request", + args: args{roomID: "!notvalid:localhost", roomVersion: "10", requestContent: createEvent.JSON()}, + wantErr: assert.Error, + }, + { + name: "wrong eventID in request", + args: args{roomID: "!valid:localhost", roomVersion: "10", requestContent: createEvent.JSON()}, + wantErr: assert.Error, + }, + { + name: "empty statekey", + args: args{roomID: "!valid:localhost", roomVersion: "10", eventID: createEvent.EventID(), requestContent: createEvent.JSON()}, + wantErr: assert.Error, + }, + { + name: "wrong request origin", + args: args{roomID: "!valid:localhost", roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + wantErr: assert.Error, + }, + { + name: "never joined the room no-ops", + args: args{roomID: "!valid:localhost", querier: dummyQuerier{}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + wantErr: assert.NoError, + }, + { + name: "already left the the room no-ops", + args: args{roomID: "!valid:localhost", querier: dummyQuerier{pdus: []PDU{leaveEvent}}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + wantErr: assert.NoError, + }, + { + name: "already left the the room no-ops 2", + args: args{roomID: "!valid:localhost", querier: dummyQuerier{pdus: []PDU{leaveEvent, leaveEvent}}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + wantErr: assert.NoError, + }, + { + name: "JSON validation fails", + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdus: []PDU{createEvent}}, verifier: &noopJSONVerifier{err: fmt.Errorf("err")}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + wantErr: assert.Error, + }, + { + name: "JSON validation fails 2", + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdus: []PDU{createEvent}}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{Error: fmt.Errorf("err")}}}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + wantErr: assert.Error, + }, + { + name: "membership not set to leave", + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdus: []PDU{createEvent}}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: "10", eventID: joinEvent.EventID(), requestContent: joinEvent.JSON()}, + wantErr: assert.Error, + }, + { + name: "membership set to leave", + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdus: []PDU{createEvent}}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + wantErr: assert.NoError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := HandleSendLeave(tt.args.ctx, tt.args.requestContent, tt.args.origin, tt.args.roomVersion, tt.args.eventID, tt.args.roomID, tt.args.querier, tt.args.verifier) + if !tt.wantErr(t, err, fmt.Sprintf("HandleSendLeave(%v, %v, %v, %v, %v, %v, %v, %v)", tt.args.ctx, tt.args.requestContent, tt.args.origin, tt.args.roomVersion, tt.args.eventID, tt.args.roomID, tt.args.querier, tt.args.verifier)) { + return + } + }) + } +} From 2640a7efaaaa139273769b54af58ed4643c2581f Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Tue, 30 May 2023 18:00:30 +0200 Subject: [PATCH 03/11] Update interface --- handleleave.go | 28 +++++++++++++--------------- handleleave_test.go | 20 ++++++++++---------- 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/handleleave.go b/handleleave.go index 516f1f52..118b350b 100644 --- a/handleleave.go +++ b/handleleave.go @@ -102,8 +102,8 @@ func HandleMakeLeave(input HandleMakeLeaveInput) (*HandleMakeLeaveResponse, erro return &makeLeaveResponse, nil } -type LatestStateQuerier interface { - LatestState(ctx context.Context, roomID spec.RoomID, userID spec.UserID) ([]PDU, error) +type CurrentStateQuerier interface { + CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (PDU, error) } // HandleSendLeave handles requests to `/send_leave @@ -113,7 +113,7 @@ func HandleSendLeave(ctx context.Context, origin spec.ServerName, roomVersion RoomVersion, eventID, roomID string, - querier LatestStateQuerier, + querier CurrentStateQuerier, verifier JSONVerifier, ) (PDU, error) { @@ -172,22 +172,20 @@ func HandleSendLeave(ctx context.Context, return nil, spec.Forbidden("The sender does not match the server that originated the request") } - stateEvents, err := querier.LatestState(ctx, *rID, *leavingUser) + stateEvent, err := querier.CurrentStateEvent(ctx, *rID, spec.MRoomMember, leavingUser.String()) if err != nil { return nil, err } - // handle cases we can no-op - switch { - case len(stateEvents) == 0: - // we weren't joined at all + // we weren't joined at all + if stateEvent == nil { return nil, nil - case len(stateEvents) == 1: - // We are/were joined/invited/banned or something - if mem, merr := stateEvents[0].Membership(); merr == nil && mem == spec.Leave { - return nil, nil - } - case event.EventID() == stateEvents[0].EventID(): - // we already processed this event + } + // We are/were joined/invited/banned or something + if mem, merr := stateEvent.Membership(); merr == nil && mem == spec.Leave { + return nil, nil + } + // we already processed this event + if event.EventID() == stateEvent.EventID() { return nil, nil } diff --git a/handleleave_test.go b/handleleave_test.go index 05f121a2..d6c5c91e 100644 --- a/handleleave_test.go +++ b/handleleave_test.go @@ -224,11 +224,11 @@ func TestHandleMakeLeave(t *testing.T) { } type dummyQuerier struct { - pdus []PDU + pdu PDU } -func (d dummyQuerier) LatestState(ctx context.Context, roomID spec.RoomID, userID spec.UserID) ([]PDU, error) { - return d.pdus, nil +func (d dummyQuerier) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (PDU, error) { + return d.pdu, nil } type noopJSONVerifier struct { @@ -248,7 +248,7 @@ func TestHandleSendLeave(t *testing.T) { roomVersion RoomVersion eventID string roomID string - querier LatestStateQuerier + querier CurrentStateQuerier verifier JSONVerifier } @@ -363,32 +363,32 @@ func TestHandleSendLeave(t *testing.T) { }, { name: "already left the the room no-ops", - args: args{roomID: "!valid:localhost", querier: dummyQuerier{pdus: []PDU{leaveEvent}}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{roomID: "!valid:localhost", querier: dummyQuerier{pdu: leaveEvent}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.NoError, }, { name: "already left the the room no-ops 2", - args: args{roomID: "!valid:localhost", querier: dummyQuerier{pdus: []PDU{leaveEvent, leaveEvent}}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{roomID: "!valid:localhost", querier: dummyQuerier{pdu: leaveEvent}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.NoError, }, { name: "JSON validation fails", - args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdus: []PDU{createEvent}}, verifier: &noopJSONVerifier{err: fmt.Errorf("err")}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{err: fmt.Errorf("err")}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.Error, }, { name: "JSON validation fails 2", - args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdus: []PDU{createEvent}}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{Error: fmt.Errorf("err")}}}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{Error: fmt.Errorf("err")}}}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.Error, }, { name: "membership not set to leave", - args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdus: []PDU{createEvent}}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: "10", eventID: joinEvent.EventID(), requestContent: joinEvent.JSON()}, + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: "10", eventID: joinEvent.EventID(), requestContent: joinEvent.JSON()}, wantErr: assert.Error, }, { name: "membership set to leave", - args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdus: []PDU{createEvent}}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.NoError, }, } From 645785dce8bfce990d9906f6d7ecd2eba289581c Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Tue, 30 May 2023 18:09:26 +0200 Subject: [PATCH 04/11] Remove dupe test --- handleleave_test.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/handleleave_test.go b/handleleave_test.go index d6c5c91e..9494faab 100644 --- a/handleleave_test.go +++ b/handleleave_test.go @@ -362,12 +362,7 @@ func TestHandleSendLeave(t *testing.T) { wantErr: assert.NoError, }, { - name: "already left the the room no-ops", - args: args{roomID: "!valid:localhost", querier: dummyQuerier{pdu: leaveEvent}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, - wantErr: assert.NoError, - }, - { - name: "already left the the room no-ops 2", + name: "already left the room no-ops", args: args{roomID: "!valid:localhost", querier: dummyQuerier{pdu: leaveEvent}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.NoError, }, From b7a785d518d159fcb528f62c011233b7e8fd2998 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Tue, 30 May 2023 18:10:48 +0200 Subject: [PATCH 05/11] Use version consts --- handleleave_test.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/handleleave_test.go b/handleleave_test.go index 9494faab..596377d8 100644 --- a/handleleave_test.go +++ b/handleleave_test.go @@ -328,62 +328,62 @@ func TestHandleSendLeave(t *testing.T) { }, { name: "invalid content body", - args: args{roomID: "!notvalid:localhost", roomVersion: "1", requestContent: []byte("{")}, + args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV1, requestContent: []byte("{")}, wantErr: assert.Error, }, { name: "not canonical JSON", - args: args{roomID: "!notvalid:localhost", roomVersion: "10", requestContent: []byte(`{"int":9007199254740992}`)}, // number to large, not canonical json + args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV10, requestContent: []byte(`{"int":9007199254740992}`)}, // number to large, not canonical json wantErr: assert.Error, }, { name: "wrong roomID in request", - args: args{roomID: "!notvalid:localhost", roomVersion: "10", requestContent: createEvent.JSON()}, + args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV10, requestContent: createEvent.JSON()}, wantErr: assert.Error, }, { name: "wrong eventID in request", - args: args{roomID: "!valid:localhost", roomVersion: "10", requestContent: createEvent.JSON()}, + args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, requestContent: createEvent.JSON()}, wantErr: assert.Error, }, { name: "empty statekey", - args: args{roomID: "!valid:localhost", roomVersion: "10", eventID: createEvent.EventID(), requestContent: createEvent.JSON()}, + args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, eventID: createEvent.EventID(), requestContent: createEvent.JSON()}, wantErr: assert.Error, }, { name: "wrong request origin", - args: args{roomID: "!valid:localhost", roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.Error, }, { name: "never joined the room no-ops", - args: args{roomID: "!valid:localhost", querier: dummyQuerier{}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{roomID: "!valid:localhost", querier: dummyQuerier{}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.NoError, }, { name: "already left the room no-ops", - args: args{roomID: "!valid:localhost", querier: dummyQuerier{pdu: leaveEvent}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{roomID: "!valid:localhost", querier: dummyQuerier{pdu: leaveEvent}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.NoError, }, { name: "JSON validation fails", - args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{err: fmt.Errorf("err")}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{err: fmt.Errorf("err")}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.Error, }, { name: "JSON validation fails 2", - args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{Error: fmt.Errorf("err")}}}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{Error: fmt.Errorf("err")}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.Error, }, { name: "membership not set to leave", - args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: "10", eventID: joinEvent.EventID(), requestContent: joinEvent.JSON()}, + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: joinEvent.EventID(), requestContent: joinEvent.JSON()}, wantErr: assert.Error, }, { name: "membership set to leave", - args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: "10", eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.NoError, }, } From 540062a7516da38bbbaf11137d241c5f64116b59 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 2 Jun 2023 12:59:22 +0200 Subject: [PATCH 06/11] PR review comments, let IRoomVersion handle send leave --- eventversion.go | 27 +++++++++++++++++++++ handleleave.go | 31 +++++++----------------- handleleave_test.go | 58 +++++++++++++++++---------------------------- 3 files changed, 57 insertions(+), 59 deletions(-) diff --git a/eventversion.go b/eventversion.go index 1b5e0283..5062d2a4 100644 --- a/eventversion.go +++ b/eventversion.go @@ -1,6 +1,7 @@ package gomatrixserverlib import ( + "context" "fmt" "github.com/matrix-org/gomatrixserverlib/spec" @@ -30,6 +31,8 @@ type IRoomVersion interface { NewEventFromUntrustedJSON(eventJSON []byte) (result PDU, err error) NewEventBuilder() *EventBuilder NewEventBuilderFromProtoEvent(pe *ProtoEvent) *EventBuilder + + HandleSendLeave(ctx context.Context, event PDU, origin spec.ServerName, eventID, roomID string, querier CurrentStateQuerier, verifier JSONVerifier) (PDU, error) } // StateResAlgorithm refers to a version of the state resolution algorithm. @@ -112,6 +115,7 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnocksForbidden, allowRestrictedJoinsInEventAuth: NoRestrictedJoins, requireIntegerPowerLevels: false, + handleSendLeaveFunc: handleSendLeave, }, RoomVersionV2: RoomVersionImpl{ ver: RoomVersionV2, @@ -126,6 +130,7 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnocksForbidden, allowRestrictedJoinsInEventAuth: NoRestrictedJoins, requireIntegerPowerLevels: false, + handleSendLeaveFunc: handleSendLeave, }, RoomVersionV3: RoomVersionImpl{ ver: RoomVersionV3, @@ -140,6 +145,7 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnocksForbidden, allowRestrictedJoinsInEventAuth: NoRestrictedJoins, requireIntegerPowerLevels: false, + handleSendLeaveFunc: handleSendLeave, }, RoomVersionV4: RoomVersionImpl{ ver: RoomVersionV4, @@ -154,6 +160,7 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnocksForbidden, allowRestrictedJoinsInEventAuth: NoRestrictedJoins, requireIntegerPowerLevels: false, + handleSendLeaveFunc: handleSendLeave, }, RoomVersionV5: RoomVersionImpl{ ver: RoomVersionV5, @@ -168,6 +175,7 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnocksForbidden, allowRestrictedJoinsInEventAuth: NoRestrictedJoins, requireIntegerPowerLevels: false, + handleSendLeaveFunc: handleSendLeave, }, RoomVersionV6: RoomVersionImpl{ ver: RoomVersionV6, @@ -182,6 +190,7 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnocksForbidden, allowRestrictedJoinsInEventAuth: NoRestrictedJoins, requireIntegerPowerLevels: false, + handleSendLeaveFunc: handleSendLeave, }, RoomVersionV7: RoomVersionImpl{ ver: RoomVersionV7, @@ -196,6 +205,7 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnockOnly, allowRestrictedJoinsInEventAuth: NoRestrictedJoins, requireIntegerPowerLevels: false, + handleSendLeaveFunc: handleSendLeave, }, RoomVersionV8: RoomVersionImpl{ ver: RoomVersionV8, @@ -210,6 +220,7 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnockOnly, allowRestrictedJoinsInEventAuth: RestrictedOnly, requireIntegerPowerLevels: false, + handleSendLeaveFunc: handleSendLeave, }, RoomVersionV9: RoomVersionImpl{ ver: RoomVersionV9, @@ -224,6 +235,7 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnockOnly, allowRestrictedJoinsInEventAuth: RestrictedOnly, requireIntegerPowerLevels: false, + handleSendLeaveFunc: handleSendLeave, }, RoomVersionV10: RoomVersionImpl{ ver: RoomVersionV10, @@ -238,6 +250,7 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnockOrKnockRestricted, allowRestrictedJoinsInEventAuth: RestrictedOrKnockRestricted, requireIntegerPowerLevels: true, + handleSendLeaveFunc: handleSendLeave, }, "org.matrix.msc3667": RoomVersionImpl{ // based on room version 7 ver: RoomVersion("org.matrix.msc3667"), @@ -252,6 +265,7 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnockOnly, allowRestrictedJoinsInEventAuth: NoRestrictedJoins, requireIntegerPowerLevels: true, + handleSendLeaveFunc: handleSendLeave, }, "org.matrix.msc3787": RoomVersionImpl{ // roughly, the union of v7 and v9 ver: RoomVersion("org.matrix.msc3787"), @@ -266,6 +280,7 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnockOrKnockRestricted, allowRestrictedJoinsInEventAuth: RestrictedOrKnockRestricted, requireIntegerPowerLevels: false, + handleSendLeaveFunc: handleSendLeave, }, } @@ -334,6 +349,7 @@ type RoomVersionImpl struct { powerLevelsIncludeNotifications bool requireIntegerPowerLevels bool stable bool + handleSendLeaveFunc func(ctx context.Context, event PDU, origin spec.ServerName, eventID, roomID string, querier CurrentStateQuerier, verifier JSONVerifier) (PDU, error) } func (v RoomVersionImpl) Version() RoomVersion { @@ -431,6 +447,17 @@ func (v RoomVersionImpl) RedactEventJSON(eventJSON []byte) ([]byte, error) { return v.redactionAlgorithm(eventJSON) } +// HandleSendLeave handles requests to `/send_leave` +func (v RoomVersionImpl) HandleSendLeave(ctx context.Context, + event PDU, + origin spec.ServerName, + eventID, roomID string, + querier CurrentStateQuerier, + verifier JSONVerifier, +) (PDU, error) { + return v.handleSendLeaveFunc(ctx, event, origin, eventID, roomID, querier, verifier) +} + func (v RoomVersionImpl) NewEventFromTrustedJSON(eventJSON []byte, redacted bool) (result PDU, err error) { return newEventFromTrustedJSON(eventJSON, redacted, v) } diff --git a/handleleave.go b/handleleave.go index 118b350b..f81b81d3 100644 --- a/handleleave.go +++ b/handleleave.go @@ -106,12 +106,11 @@ type CurrentStateQuerier interface { CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (PDU, error) } -// HandleSendLeave handles requests to `/send_leave -// Returns the parsed event and an error. -func HandleSendLeave(ctx context.Context, - requestContent []byte, +// handleSendLeave handles requests to `/send_leave +// Returns the parsed event or an error. +func handleSendLeave(ctx context.Context, + event PDU, origin spec.ServerName, - roomVersion RoomVersion, eventID, roomID string, querier CurrentStateQuerier, verifier JSONVerifier, @@ -122,21 +121,6 @@ func HandleSendLeave(ctx context.Context, return nil, err } - verImpl, err := GetRoomVersion(roomVersion) - if err != nil { - return nil, spec.UnsupportedRoomVersion(fmt.Sprintf("QueryRoomVersionForRoom returned unknown version: %s", roomVersion)) - } - - // Decode the event JSON from the request. - event, err := verImpl.NewEventFromUntrustedJSON(requestContent) - switch err.(type) { - case BadJSONError: - return nil, spec.BadJSON(err.Error()) - case nil: - default: - return nil, spec.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()) - } - // Check that the room ID is correct. if (event.RoomID()) != roomID { return nil, spec.BadJSON("The room ID in the request path must match the room ID in the leave event JSON") @@ -190,14 +174,15 @@ func HandleSendLeave(ctx context.Context, } // Check that the event is signed by the server sending the request. - redacted, err := verImpl.RedactEventJSON(event.JSON()) + resultEvent := event + event.Redact() if err != nil { util.GetLogger(ctx).WithError(err).Errorf("unable to redact event") return nil, spec.BadJSON("The event JSON could not be redacted") } verifyRequests := []VerifyJSONRequest{{ ServerName: sender.Domain(), - Message: redacted, + Message: event.JSON(), AtTS: event.OriginServerTS(), StrictValidityChecking: true, }} @@ -220,5 +205,5 @@ func HandleSendLeave(ctx context.Context, return nil, spec.BadJSON("The membership in the event content must be set to leave") } - return event, nil + return resultEvent, nil } diff --git a/handleleave_test.go b/handleleave_test.go index 596377d8..a8ff8a44 100644 --- a/handleleave_test.go +++ b/handleleave_test.go @@ -242,14 +242,14 @@ func (v *noopJSONVerifier) VerifyJSONs(ctx context.Context, requests []VerifyJSO func TestHandleSendLeave(t *testing.T) { type args struct { - ctx context.Context - requestContent []byte - origin spec.ServerName - roomVersion RoomVersion - eventID string - roomID string - querier CurrentStateQuerier - verifier JSONVerifier + ctx context.Context + event PDU + origin spec.ServerName + roomVersion RoomVersion + eventID string + roomID string + querier CurrentStateQuerier + verifier JSONVerifier } _, sk, err := ed25519.GenerateKey(rand.Reader) @@ -318,79 +318,65 @@ func TestHandleSendLeave(t *testing.T) { }{ { name: "invalid roomID", - args: args{roomID: "@notvalid:localhost"}, - wantErr: assert.Error, - }, - { - name: "invalid room version", - args: args{roomID: "!notvalid:localhost", roomVersion: "-1"}, - wantErr: assert.Error, - }, - { - name: "invalid content body", - args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV1, requestContent: []byte("{")}, - wantErr: assert.Error, - }, - { - name: "not canonical JSON", - args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV10, requestContent: []byte(`{"int":9007199254740992}`)}, // number to large, not canonical json + args: args{roomID: "@notvalid:localhost", roomVersion: RoomVersionV10}, wantErr: assert.Error, }, { name: "wrong roomID in request", - args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV10, requestContent: createEvent.JSON()}, + args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV10, event: createEvent}, wantErr: assert.Error, }, { name: "wrong eventID in request", - args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, requestContent: createEvent.JSON()}, + args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, event: createEvent}, wantErr: assert.Error, }, { name: "empty statekey", - args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, eventID: createEvent.EventID(), requestContent: createEvent.JSON()}, + args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, eventID: createEvent.EventID(), event: createEvent}, wantErr: assert.Error, }, { name: "wrong request origin", - args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), event: leaveEvent}, wantErr: assert.Error, }, { name: "never joined the room no-ops", - args: args{roomID: "!valid:localhost", querier: dummyQuerier{}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{roomID: "!valid:localhost", querier: dummyQuerier{}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), event: leaveEvent}, wantErr: assert.NoError, }, { name: "already left the room no-ops", - args: args{roomID: "!valid:localhost", querier: dummyQuerier{pdu: leaveEvent}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{roomID: "!valid:localhost", querier: dummyQuerier{pdu: leaveEvent}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), event: leaveEvent}, wantErr: assert.NoError, }, { name: "JSON validation fails", - args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{err: fmt.Errorf("err")}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{err: fmt.Errorf("err")}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), event: leaveEvent}, wantErr: assert.Error, }, { name: "JSON validation fails 2", - args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{Error: fmt.Errorf("err")}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{Error: fmt.Errorf("err")}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), event: leaveEvent}, wantErr: assert.Error, }, { name: "membership not set to leave", - args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: joinEvent.EventID(), requestContent: joinEvent.JSON()}, + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: joinEvent.EventID(), event: joinEvent}, wantErr: assert.Error, }, { name: "membership set to leave", - args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), event: leaveEvent}, wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := HandleSendLeave(tt.args.ctx, tt.args.requestContent, tt.args.origin, tt.args.roomVersion, tt.args.eventID, tt.args.roomID, tt.args.querier, tt.args.verifier) - if !tt.wantErr(t, err, fmt.Sprintf("HandleSendLeave(%v, %v, %v, %v, %v, %v, %v, %v)", tt.args.ctx, tt.args.requestContent, tt.args.origin, tt.args.roomVersion, tt.args.eventID, tt.args.roomID, tt.args.querier, tt.args.verifier)) { + verImpl := MustGetRoomVersion(tt.args.roomVersion) + _, err := verImpl.HandleSendLeave(tt.args.ctx, tt.args.event, tt.args.origin, tt.args.eventID, tt.args.roomID, tt.args.querier, tt.args.verifier) + if !tt.wantErr(t, err, fmt.Sprintf("handleSendLeave(%v, %v, %v, %v, %v, %v, %v, %v)", tt.args.ctx, tt.args.event, tt.args.origin, tt.args.roomVersion, tt.args.eventID, tt.args.roomID, tt.args.querier, tt.args.verifier)) { return } }) From 1f821508d72b4dcfc66ebb9ca843c87c8eb99a3a Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Wed, 22 Nov 2023 16:03:42 +0100 Subject: [PATCH 07/11] Revert "PR review comments, let IRoomVersion handle send leave" This reverts commit 540062a7516da38bbbaf11137d241c5f64116b59. --- eventversion.go | 27 --------------------- handleleave.go | 31 +++++++++++++++++------- handleleave_test.go | 58 ++++++++++++++++++++++++++++----------------- 3 files changed, 59 insertions(+), 57 deletions(-) diff --git a/eventversion.go b/eventversion.go index 5062d2a4..1b5e0283 100644 --- a/eventversion.go +++ b/eventversion.go @@ -1,7 +1,6 @@ package gomatrixserverlib import ( - "context" "fmt" "github.com/matrix-org/gomatrixserverlib/spec" @@ -31,8 +30,6 @@ type IRoomVersion interface { NewEventFromUntrustedJSON(eventJSON []byte) (result PDU, err error) NewEventBuilder() *EventBuilder NewEventBuilderFromProtoEvent(pe *ProtoEvent) *EventBuilder - - HandleSendLeave(ctx context.Context, event PDU, origin spec.ServerName, eventID, roomID string, querier CurrentStateQuerier, verifier JSONVerifier) (PDU, error) } // StateResAlgorithm refers to a version of the state resolution algorithm. @@ -115,7 +112,6 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnocksForbidden, allowRestrictedJoinsInEventAuth: NoRestrictedJoins, requireIntegerPowerLevels: false, - handleSendLeaveFunc: handleSendLeave, }, RoomVersionV2: RoomVersionImpl{ ver: RoomVersionV2, @@ -130,7 +126,6 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnocksForbidden, allowRestrictedJoinsInEventAuth: NoRestrictedJoins, requireIntegerPowerLevels: false, - handleSendLeaveFunc: handleSendLeave, }, RoomVersionV3: RoomVersionImpl{ ver: RoomVersionV3, @@ -145,7 +140,6 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnocksForbidden, allowRestrictedJoinsInEventAuth: NoRestrictedJoins, requireIntegerPowerLevels: false, - handleSendLeaveFunc: handleSendLeave, }, RoomVersionV4: RoomVersionImpl{ ver: RoomVersionV4, @@ -160,7 +154,6 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnocksForbidden, allowRestrictedJoinsInEventAuth: NoRestrictedJoins, requireIntegerPowerLevels: false, - handleSendLeaveFunc: handleSendLeave, }, RoomVersionV5: RoomVersionImpl{ ver: RoomVersionV5, @@ -175,7 +168,6 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnocksForbidden, allowRestrictedJoinsInEventAuth: NoRestrictedJoins, requireIntegerPowerLevels: false, - handleSendLeaveFunc: handleSendLeave, }, RoomVersionV6: RoomVersionImpl{ ver: RoomVersionV6, @@ -190,7 +182,6 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnocksForbidden, allowRestrictedJoinsInEventAuth: NoRestrictedJoins, requireIntegerPowerLevels: false, - handleSendLeaveFunc: handleSendLeave, }, RoomVersionV7: RoomVersionImpl{ ver: RoomVersionV7, @@ -205,7 +196,6 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnockOnly, allowRestrictedJoinsInEventAuth: NoRestrictedJoins, requireIntegerPowerLevels: false, - handleSendLeaveFunc: handleSendLeave, }, RoomVersionV8: RoomVersionImpl{ ver: RoomVersionV8, @@ -220,7 +210,6 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnockOnly, allowRestrictedJoinsInEventAuth: RestrictedOnly, requireIntegerPowerLevels: false, - handleSendLeaveFunc: handleSendLeave, }, RoomVersionV9: RoomVersionImpl{ ver: RoomVersionV9, @@ -235,7 +224,6 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnockOnly, allowRestrictedJoinsInEventAuth: RestrictedOnly, requireIntegerPowerLevels: false, - handleSendLeaveFunc: handleSendLeave, }, RoomVersionV10: RoomVersionImpl{ ver: RoomVersionV10, @@ -250,7 +238,6 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnockOrKnockRestricted, allowRestrictedJoinsInEventAuth: RestrictedOrKnockRestricted, requireIntegerPowerLevels: true, - handleSendLeaveFunc: handleSendLeave, }, "org.matrix.msc3667": RoomVersionImpl{ // based on room version 7 ver: RoomVersion("org.matrix.msc3667"), @@ -265,7 +252,6 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnockOnly, allowRestrictedJoinsInEventAuth: NoRestrictedJoins, requireIntegerPowerLevels: true, - handleSendLeaveFunc: handleSendLeave, }, "org.matrix.msc3787": RoomVersionImpl{ // roughly, the union of v7 and v9 ver: RoomVersion("org.matrix.msc3787"), @@ -280,7 +266,6 @@ var roomVersionMeta = map[RoomVersion]IRoomVersion{ allowKnockingInEventAuth: KnockOrKnockRestricted, allowRestrictedJoinsInEventAuth: RestrictedOrKnockRestricted, requireIntegerPowerLevels: false, - handleSendLeaveFunc: handleSendLeave, }, } @@ -349,7 +334,6 @@ type RoomVersionImpl struct { powerLevelsIncludeNotifications bool requireIntegerPowerLevels bool stable bool - handleSendLeaveFunc func(ctx context.Context, event PDU, origin spec.ServerName, eventID, roomID string, querier CurrentStateQuerier, verifier JSONVerifier) (PDU, error) } func (v RoomVersionImpl) Version() RoomVersion { @@ -447,17 +431,6 @@ func (v RoomVersionImpl) RedactEventJSON(eventJSON []byte) ([]byte, error) { return v.redactionAlgorithm(eventJSON) } -// HandleSendLeave handles requests to `/send_leave` -func (v RoomVersionImpl) HandleSendLeave(ctx context.Context, - event PDU, - origin spec.ServerName, - eventID, roomID string, - querier CurrentStateQuerier, - verifier JSONVerifier, -) (PDU, error) { - return v.handleSendLeaveFunc(ctx, event, origin, eventID, roomID, querier, verifier) -} - func (v RoomVersionImpl) NewEventFromTrustedJSON(eventJSON []byte, redacted bool) (result PDU, err error) { return newEventFromTrustedJSON(eventJSON, redacted, v) } diff --git a/handleleave.go b/handleleave.go index f81b81d3..118b350b 100644 --- a/handleleave.go +++ b/handleleave.go @@ -106,11 +106,12 @@ type CurrentStateQuerier interface { CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (PDU, error) } -// handleSendLeave handles requests to `/send_leave -// Returns the parsed event or an error. -func handleSendLeave(ctx context.Context, - event PDU, +// HandleSendLeave handles requests to `/send_leave +// Returns the parsed event and an error. +func HandleSendLeave(ctx context.Context, + requestContent []byte, origin spec.ServerName, + roomVersion RoomVersion, eventID, roomID string, querier CurrentStateQuerier, verifier JSONVerifier, @@ -121,6 +122,21 @@ func handleSendLeave(ctx context.Context, return nil, err } + verImpl, err := GetRoomVersion(roomVersion) + if err != nil { + return nil, spec.UnsupportedRoomVersion(fmt.Sprintf("QueryRoomVersionForRoom returned unknown version: %s", roomVersion)) + } + + // Decode the event JSON from the request. + event, err := verImpl.NewEventFromUntrustedJSON(requestContent) + switch err.(type) { + case BadJSONError: + return nil, spec.BadJSON(err.Error()) + case nil: + default: + return nil, spec.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()) + } + // Check that the room ID is correct. if (event.RoomID()) != roomID { return nil, spec.BadJSON("The room ID in the request path must match the room ID in the leave event JSON") @@ -174,15 +190,14 @@ func handleSendLeave(ctx context.Context, } // Check that the event is signed by the server sending the request. - resultEvent := event - event.Redact() + redacted, err := verImpl.RedactEventJSON(event.JSON()) if err != nil { util.GetLogger(ctx).WithError(err).Errorf("unable to redact event") return nil, spec.BadJSON("The event JSON could not be redacted") } verifyRequests := []VerifyJSONRequest{{ ServerName: sender.Domain(), - Message: event.JSON(), + Message: redacted, AtTS: event.OriginServerTS(), StrictValidityChecking: true, }} @@ -205,5 +220,5 @@ func handleSendLeave(ctx context.Context, return nil, spec.BadJSON("The membership in the event content must be set to leave") } - return resultEvent, nil + return event, nil } diff --git a/handleleave_test.go b/handleleave_test.go index a8ff8a44..596377d8 100644 --- a/handleleave_test.go +++ b/handleleave_test.go @@ -242,14 +242,14 @@ func (v *noopJSONVerifier) VerifyJSONs(ctx context.Context, requests []VerifyJSO func TestHandleSendLeave(t *testing.T) { type args struct { - ctx context.Context - event PDU - origin spec.ServerName - roomVersion RoomVersion - eventID string - roomID string - querier CurrentStateQuerier - verifier JSONVerifier + ctx context.Context + requestContent []byte + origin spec.ServerName + roomVersion RoomVersion + eventID string + roomID string + querier CurrentStateQuerier + verifier JSONVerifier } _, sk, err := ed25519.GenerateKey(rand.Reader) @@ -318,65 +318,79 @@ func TestHandleSendLeave(t *testing.T) { }{ { name: "invalid roomID", - args: args{roomID: "@notvalid:localhost", roomVersion: RoomVersionV10}, + args: args{roomID: "@notvalid:localhost"}, + wantErr: assert.Error, + }, + { + name: "invalid room version", + args: args{roomID: "!notvalid:localhost", roomVersion: "-1"}, + wantErr: assert.Error, + }, + { + name: "invalid content body", + args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV1, requestContent: []byte("{")}, + wantErr: assert.Error, + }, + { + name: "not canonical JSON", + args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV10, requestContent: []byte(`{"int":9007199254740992}`)}, // number to large, not canonical json wantErr: assert.Error, }, { name: "wrong roomID in request", - args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV10, event: createEvent}, + args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV10, requestContent: createEvent.JSON()}, wantErr: assert.Error, }, { name: "wrong eventID in request", - args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, event: createEvent}, + args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, requestContent: createEvent.JSON()}, wantErr: assert.Error, }, { name: "empty statekey", - args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, eventID: createEvent.EventID(), event: createEvent}, + args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, eventID: createEvent.EventID(), requestContent: createEvent.JSON()}, wantErr: assert.Error, }, { name: "wrong request origin", - args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), event: leaveEvent}, + args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.Error, }, { name: "never joined the room no-ops", - args: args{roomID: "!valid:localhost", querier: dummyQuerier{}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), event: leaveEvent}, + args: args{roomID: "!valid:localhost", querier: dummyQuerier{}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.NoError, }, { name: "already left the room no-ops", - args: args{roomID: "!valid:localhost", querier: dummyQuerier{pdu: leaveEvent}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), event: leaveEvent}, + args: args{roomID: "!valid:localhost", querier: dummyQuerier{pdu: leaveEvent}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.NoError, }, { name: "JSON validation fails", - args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{err: fmt.Errorf("err")}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), event: leaveEvent}, + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{err: fmt.Errorf("err")}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.Error, }, { name: "JSON validation fails 2", - args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{Error: fmt.Errorf("err")}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), event: leaveEvent}, + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{Error: fmt.Errorf("err")}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.Error, }, { name: "membership not set to leave", - args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: joinEvent.EventID(), event: joinEvent}, + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: joinEvent.EventID(), requestContent: joinEvent.JSON()}, wantErr: assert.Error, }, { name: "membership set to leave", - args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), event: leaveEvent}, + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, wantErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - verImpl := MustGetRoomVersion(tt.args.roomVersion) - _, err := verImpl.HandleSendLeave(tt.args.ctx, tt.args.event, tt.args.origin, tt.args.eventID, tt.args.roomID, tt.args.querier, tt.args.verifier) - if !tt.wantErr(t, err, fmt.Sprintf("handleSendLeave(%v, %v, %v, %v, %v, %v, %v, %v)", tt.args.ctx, tt.args.event, tt.args.origin, tt.args.roomVersion, tt.args.eventID, tt.args.roomID, tt.args.querier, tt.args.verifier)) { + _, err := HandleSendLeave(tt.args.ctx, tt.args.requestContent, tt.args.origin, tt.args.roomVersion, tt.args.eventID, tt.args.roomID, tt.args.querier, tt.args.verifier) + if !tt.wantErr(t, err, fmt.Sprintf("HandleSendLeave(%v, %v, %v, %v, %v, %v, %v, %v)", tt.args.ctx, tt.args.requestContent, tt.args.origin, tt.args.roomVersion, tt.args.eventID, tt.args.roomID, tt.args.querier, tt.args.verifier)) { return } }) From e3b02c98f7ebfd8d5e146a6807e6bba0247869cb Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Wed, 22 Nov 2023 16:17:21 +0100 Subject: [PATCH 08/11] Make tests/lint happy for now --- handleleave.go | 14 +++++++------- handleleave_test.go | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/handleleave.go b/handleleave.go index 1a85f8c0..abbe530f 100644 --- a/handleleave.go +++ b/handleleave.go @@ -140,7 +140,7 @@ func HandleSendLeave(ctx context.Context, } // Check that the room ID is correct. - if (event.RoomID()) != roomID { + if (event.RoomID().String()) != roomID { return nil, spec.BadJSON("The room ID in the request path must match the room ID in the leave event JSON") } @@ -154,7 +154,7 @@ func HandleSendLeave(ctx context.Context, if event.StateKey() == nil || event.StateKeyEquals("") { return nil, spec.BadJSON("No state key was provided in the leave event.") } - if !event.StateKeyEquals(event.Sender()) { + if !event.StateKeyEquals(event.SenderID().ToUserID().String()) { return nil, spec.BadJSON("Event state key must match the event sender.") } @@ -166,7 +166,7 @@ func HandleSendLeave(ctx context.Context, // Check that the sender belongs to the server that is sending us // the request. By this point we've already asserted that the sender // and the state key are equal so we don't need to check both. - sender, err := spec.NewUserID(event.Sender(), true) + sender, err := spec.NewUserID(event.SenderID().ToUserID().String(), true) if err != nil { return nil, spec.Forbidden("The sender of the join is invalid") } @@ -198,10 +198,10 @@ func HandleSendLeave(ctx context.Context, return nil, spec.BadJSON("The event JSON could not be redacted") } verifyRequests := []VerifyJSONRequest{{ - ServerName: sender.Domain(), - Message: redacted, - AtTS: event.OriginServerTS(), - StrictValidityChecking: true, + ServerName: sender.Domain(), + Message: redacted, + AtTS: event.OriginServerTS(), + ValidityCheckingFunc: StrictValiditySignatureCheck, }} verifyResults, err := verifier.VerifyJSONs(ctx, verifyRequests) if err != nil { diff --git a/handleleave_test.go b/handleleave_test.go index 4cdf5bfd..1a9bffe4 100644 --- a/handleleave_test.go +++ b/handleleave_test.go @@ -269,7 +269,7 @@ func TestHandleSendLeave(t *testing.T) { stateKey := "" eb := MustGetRoomVersion(RoomVersionV10).NewEventBuilderFromProtoEvent(&ProtoEvent{ - Sender: validUser.String(), + SenderID: validUser.String(), RoomID: "!valid:localhost", Type: spec.MRoomCreate, StateKey: &stateKey, @@ -286,7 +286,7 @@ func TestHandleSendLeave(t *testing.T) { stateKey = validUser.String() eb = MustGetRoomVersion(RoomVersionV10).NewEventBuilderFromProtoEvent(&ProtoEvent{ - Sender: validUser.String(), + SenderID: validUser.String(), RoomID: "!valid:localhost", Type: spec.MRoomMember, StateKey: &stateKey, @@ -302,7 +302,7 @@ func TestHandleSendLeave(t *testing.T) { } eb = MustGetRoomVersion(RoomVersionV10).NewEventBuilderFromProtoEvent(&ProtoEvent{ - Sender: validUser.String(), + SenderID: validUser.String(), RoomID: "!valid:localhost", Type: spec.MRoomMember, StateKey: &stateKey, From 2369defc080e0f6548c94c0abefe308310fa983d Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 23 Nov 2023 12:21:37 +0100 Subject: [PATCH 09/11] Some tweaks --- handleleave.go | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/handleleave.go b/handleleave.go index abbe530f..8da004d1 100644 --- a/handleleave.go +++ b/handleleave.go @@ -109,9 +109,9 @@ type CurrentStateQuerier interface { } // HandleSendLeave handles requests to `/send_leave -// Returns the parsed event and an error. +// Returns the parsed event or an error. func HandleSendLeave(ctx context.Context, - requestContent []byte, + content []byte, origin spec.ServerName, roomVersion RoomVersion, eventID, roomID string, @@ -126,11 +126,10 @@ func HandleSendLeave(ctx context.Context, verImpl, err := GetRoomVersion(roomVersion) if err != nil { - return nil, spec.UnsupportedRoomVersion(fmt.Sprintf("QueryRoomVersionForRoom returned unknown version: %s", roomVersion)) + return nil, spec.UnsupportedRoomVersion(fmt.Sprintf("Room version %s is not supported by this server", roomVersion)) } - // Decode the event JSON from the request. - event, err := verImpl.NewEventFromUntrustedJSON(requestContent) + event, err := verImpl.NewEventFromUntrustedJSON(content) switch err.(type) { case BadJSONError: return nil, spec.BadJSON(err.Error()) @@ -140,23 +139,19 @@ func HandleSendLeave(ctx context.Context, } // Check that the room ID is correct. - if (event.RoomID().String()) != roomID { + if event.RoomID().String() != roomID { return nil, spec.BadJSON("The room ID in the request path must match the room ID in the leave event JSON") } // Check that the event ID is correct. if event.EventID() != eventID { return nil, spec.BadJSON("The event ID in the request path must match the event ID in the leave event JSON") - } // Sanity check that we really received a state event if event.StateKey() == nil || event.StateKeyEquals("") { return nil, spec.BadJSON("No state key was provided in the leave event.") } - if !event.StateKeyEquals(event.SenderID().ToUserID().String()) { - return nil, spec.BadJSON("Event state key must match the event sender.") - } leavingUser, err := spec.NewUserID(*event.StateKey(), true) if err != nil { @@ -174,15 +169,17 @@ func HandleSendLeave(ctx context.Context, return nil, spec.Forbidden("The sender does not match the server that originated the request") } + // Check the current membership of this user and + // if we maybe can just no-op this request. stateEvent, err := querier.CurrentStateEvent(ctx, *rID, spec.MRoomMember, leavingUser.String()) if err != nil { return nil, err } - // we weren't joined at all + // The user isn't joined at all if stateEvent == nil { return nil, nil } - // We are/were joined/invited/banned or something + // The user has already left. if mem, merr := stateEvent.Membership(); merr == nil && mem == spec.Leave { return nil, nil } From 498c48650dc9278a956692e4c8dd5941b27138c8 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 23 Nov 2023 12:37:08 +0100 Subject: [PATCH 10/11] Fix oops --- handleleave.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/handleleave.go b/handleleave.go index 8da004d1..6c391364 100644 --- a/handleleave.go +++ b/handleleave.go @@ -152,6 +152,9 @@ func HandleSendLeave(ctx context.Context, if event.StateKey() == nil || event.StateKeyEquals("") { return nil, spec.BadJSON("No state key was provided in the leave event.") } + if !event.StateKeyEquals(event.SenderID().ToUserID().String()) { + return nil, spec.BadJSON("Event state key must match the event sender.") + } leavingUser, err := spec.NewUserID(*event.StateKey(), true) if err != nil { From 185a49a8c72d9e71e308dc17471ccd76f253abe0 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 23 Nov 2023 13:58:15 +0100 Subject: [PATCH 11/11] Move membership check further to the top, so we can avoid unneeded DB calls --- handleleave.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/handleleave.go b/handleleave.go index 6c391364..4632bb2b 100644 --- a/handleleave.go +++ b/handleleave.go @@ -156,6 +156,16 @@ func HandleSendLeave(ctx context.Context, return nil, spec.BadJSON("Event state key must match the event sender.") } + // check membership is set to leave + mem, err := event.Membership() + if err != nil { + util.GetLogger(ctx).WithError(err).Error("event.Membership failed") + return nil, spec.BadJSON("missing content.membership key") + } + if mem != spec.Leave { + return nil, spec.BadJSON("The membership in the event content must be set to leave") + } + leavingUser, err := spec.NewUserID(*event.StateKey(), true) if err != nil { return nil, spec.Forbidden("The leaving user ID is invalid") @@ -212,15 +222,5 @@ func HandleSendLeave(ctx context.Context, return nil, spec.Forbidden("The leave must be signed by the server it originated on") } - // check membership is set to leave - mem, err := event.Membership() - if err != nil { - util.GetLogger(ctx).WithError(err).Error("event.Membership failed") - return nil, spec.BadJSON("missing content.membership key") - } - if mem != spec.Leave { - return nil, spec.BadJSON("The membership in the event content must be set to leave") - } - return event, nil }