diff --git a/handleleave.go b/handleleave.go index bcd73fa6..4632bb2b 100644 --- a/handleleave.go +++ b/handleleave.go @@ -15,9 +15,11 @@ package gomatrixserverlib import ( + "context" "fmt" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" ) type HandleMakeLeaveResponse struct { @@ -39,6 +41,7 @@ type HandleMakeLeaveInput struct { BuildEventTemplate func(*ProtoEvent) (PDU, []PDU, error) } +// HandleMakeLeave handles requests to `/make_leave` func HandleMakeLeave(input HandleMakeLeaveInput) (*HandleMakeLeaveResponse, error) { if input.UserID.Domain() != input.RequestOrigin { @@ -100,3 +103,124 @@ func HandleMakeLeave(input HandleMakeLeaveInput) (*HandleMakeLeaveResponse, erro } return &makeLeaveResponse, nil } + +type CurrentStateQuerier interface { + CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (PDU, error) +} + +// HandleSendLeave handles requests to `/send_leave +// Returns the parsed event or an error. +func HandleSendLeave(ctx context.Context, + content []byte, + origin spec.ServerName, + roomVersion RoomVersion, + eventID, roomID string, + querier CurrentStateQuerier, + verifier JSONVerifier, +) (PDU, error) { + + rID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } + + verImpl, err := GetRoomVersion(roomVersion) + if err != nil { + return nil, spec.UnsupportedRoomVersion(fmt.Sprintf("Room version %s is not supported by this server", roomVersion)) + } + + event, err := verImpl.NewEventFromUntrustedJSON(content) + switch err.(type) { + case BadJSONError: + return nil, spec.BadJSON(err.Error()) + case nil: + default: + return nil, spec.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()) + } + + // Check that the room ID is correct. + if event.RoomID().String() != roomID { + return nil, spec.BadJSON("The room ID in the request path must match the room ID in the leave event JSON") + } + + // Check that the event ID is correct. + if event.EventID() != eventID { + return nil, spec.BadJSON("The event ID in the request path must match the event ID in the leave event JSON") + } + + // Sanity check that we really received a state event + if event.StateKey() == nil || event.StateKeyEquals("") { + return nil, spec.BadJSON("No state key was provided in the leave event.") + } + if !event.StateKeyEquals(event.SenderID().ToUserID().String()) { + return nil, spec.BadJSON("Event state key must match the event sender.") + } + + // check membership is set to leave + mem, err := event.Membership() + if err != nil { + util.GetLogger(ctx).WithError(err).Error("event.Membership failed") + return nil, spec.BadJSON("missing content.membership key") + } + if mem != spec.Leave { + return nil, spec.BadJSON("The membership in the event content must be set to leave") + } + + leavingUser, err := spec.NewUserID(*event.StateKey(), true) + if err != nil { + return nil, spec.Forbidden("The leaving user ID is invalid") + } + + // Check that the sender belongs to the server that is sending us + // the request. By this point we've already asserted that the sender + // and the state key are equal so we don't need to check both. + sender, err := spec.NewUserID(event.SenderID().ToUserID().String(), true) + if err != nil { + return nil, spec.Forbidden("The sender of the join is invalid") + } + if sender.Domain() != origin { + return nil, spec.Forbidden("The sender does not match the server that originated the request") + } + + // Check the current membership of this user and + // if we maybe can just no-op this request. + stateEvent, err := querier.CurrentStateEvent(ctx, *rID, spec.MRoomMember, leavingUser.String()) + if err != nil { + return nil, err + } + // The user isn't joined at all + if stateEvent == nil { + return nil, nil + } + // The user has already left. + if mem, merr := stateEvent.Membership(); merr == nil && mem == spec.Leave { + return nil, nil + } + // we already processed this event + if event.EventID() == stateEvent.EventID() { + return nil, nil + } + + // Check that the event is signed by the server sending the request. + redacted, err := verImpl.RedactEventJSON(event.JSON()) + if err != nil { + util.GetLogger(ctx).WithError(err).Errorf("unable to redact event") + return nil, spec.BadJSON("The event JSON could not be redacted") + } + verifyRequests := []VerifyJSONRequest{{ + ServerName: sender.Domain(), + Message: redacted, + AtTS: event.OriginServerTS(), + ValidityCheckingFunc: StrictValiditySignatureCheck, + }} + verifyResults, err := verifier.VerifyJSONs(ctx, verifyRequests) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("keys.VerifyJSONs failed") + return nil, spec.InternalServerError{} + } + if verifyResults[0].Error != nil { + return nil, spec.Forbidden("The leave must be signed by the server it originated on") + } + + return event, nil +} diff --git a/handleleave_test.go b/handleleave_test.go index a7df3024..1a9bffe4 100644 --- a/handleleave_test.go +++ b/handleleave_test.go @@ -1,6 +1,7 @@ package gomatrixserverlib import ( + "context" "crypto/rand" "fmt" "testing" @@ -228,3 +229,177 @@ func TestHandleMakeLeave(t *testing.T) { }) } } + +type dummyQuerier struct { + pdu PDU +} + +func (d dummyQuerier) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (PDU, error) { + return d.pdu, nil +} + +type noopJSONVerifier struct { + err error + results []VerifyJSONResult +} + +func (v *noopJSONVerifier) VerifyJSONs(ctx context.Context, requests []VerifyJSONRequest) ([]VerifyJSONResult, error) { + return v.results, v.err +} + +func TestHandleSendLeave(t *testing.T) { + type args struct { + ctx context.Context + requestContent []byte + origin spec.ServerName + roomVersion RoomVersion + eventID string + roomID string + querier CurrentStateQuerier + verifier JSONVerifier + } + + _, sk, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("Failed generating key: %v", err) + } + keyID := KeyID("ed25519:1234") + + validUser, _ := spec.NewUserID("@valid:localhost", true) + + stateKey := "" + eb := MustGetRoomVersion(RoomVersionV10).NewEventBuilderFromProtoEvent(&ProtoEvent{ + SenderID: validUser.String(), + RoomID: "!valid:localhost", + Type: spec.MRoomCreate, + StateKey: &stateKey, + PrevEvents: []interface{}{}, + AuthEvents: []interface{}{}, + Depth: 0, + Content: spec.RawJSON(`{"creator":"@user:local","m.federate":true,"room_version":"10"}`), + Unsigned: spec.RawJSON(""), + }) + createEvent, err := eb.Build(time.Now(), "localhost", keyID, sk) + if err != nil { + t.Fatalf("Failed building create event: %v", err) + } + + stateKey = validUser.String() + eb = MustGetRoomVersion(RoomVersionV10).NewEventBuilderFromProtoEvent(&ProtoEvent{ + SenderID: validUser.String(), + RoomID: "!valid:localhost", + Type: spec.MRoomMember, + StateKey: &stateKey, + PrevEvents: []interface{}{}, + AuthEvents: []interface{}{}, + Depth: 0, + Content: spec.RawJSON(`{"membership":"leave"}`), + Unsigned: spec.RawJSON(""), + }) + leaveEvent, err := eb.Build(time.Now(), "localhost", keyID, sk) + if err != nil { + t.Fatalf("Failed building create event: %v", err) + } + + eb = MustGetRoomVersion(RoomVersionV10).NewEventBuilderFromProtoEvent(&ProtoEvent{ + SenderID: validUser.String(), + RoomID: "!valid:localhost", + Type: spec.MRoomMember, + StateKey: &stateKey, + PrevEvents: []interface{}{}, + AuthEvents: []interface{}{}, + Depth: 0, + Content: spec.RawJSON(`{"membership":"join"}`), + Unsigned: spec.RawJSON(""), + }) + joinEvent, err := eb.Build(time.Now(), "localhost", keyID, sk) + if err != nil { + t.Fatalf("Failed building create event: %v", err) + } + + tests := []struct { + name string + args args + want PDU + wantErr assert.ErrorAssertionFunc + }{ + { + name: "invalid roomID", + args: args{roomID: "@notvalid:localhost"}, + wantErr: assert.Error, + }, + { + name: "invalid room version", + args: args{roomID: "!notvalid:localhost", roomVersion: "-1"}, + wantErr: assert.Error, + }, + { + name: "invalid content body", + args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV1, requestContent: []byte("{")}, + wantErr: assert.Error, + }, + { + name: "not canonical JSON", + args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV10, requestContent: []byte(`{"int":9007199254740992}`)}, // number to large, not canonical json + wantErr: assert.Error, + }, + { + name: "wrong roomID in request", + args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV10, requestContent: createEvent.JSON()}, + wantErr: assert.Error, + }, + { + name: "wrong eventID in request", + args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, requestContent: createEvent.JSON()}, + wantErr: assert.Error, + }, + { + name: "empty statekey", + args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, eventID: createEvent.EventID(), requestContent: createEvent.JSON()}, + wantErr: assert.Error, + }, + { + name: "wrong request origin", + args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + wantErr: assert.Error, + }, + { + name: "never joined the room no-ops", + args: args{roomID: "!valid:localhost", querier: dummyQuerier{}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + wantErr: assert.NoError, + }, + { + name: "already left the room no-ops", + args: args{roomID: "!valid:localhost", querier: dummyQuerier{pdu: leaveEvent}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + wantErr: assert.NoError, + }, + { + name: "JSON validation fails", + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{err: fmt.Errorf("err")}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + wantErr: assert.Error, + }, + { + name: "JSON validation fails 2", + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{Error: fmt.Errorf("err")}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + wantErr: assert.Error, + }, + { + name: "membership not set to leave", + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: joinEvent.EventID(), requestContent: joinEvent.JSON()}, + wantErr: assert.Error, + }, + { + name: "membership set to leave", + args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()}, + wantErr: assert.NoError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := HandleSendLeave(tt.args.ctx, tt.args.requestContent, tt.args.origin, tt.args.roomVersion, tt.args.eventID, tt.args.roomID, tt.args.querier, tt.args.verifier) + if !tt.wantErr(t, err, fmt.Sprintf("HandleSendLeave(%v, %v, %v, %v, %v, %v, %v, %v)", tt.args.ctx, tt.args.requestContent, tt.args.origin, tt.args.roomVersion, tt.args.eventID, tt.args.roomID, tt.args.querier, tt.args.verifier)) { + return + } + }) + } +}