diff --git a/authstate.go b/authstate.go index bec5f93e..046ed804 100644 --- a/authstate.go +++ b/authstate.go @@ -65,7 +65,7 @@ type FederatedStateProvider struct { // StateIDsBeforeEvent implements StateProvider func (p *FederatedStateProvider) StateIDsBeforeEvent(ctx context.Context, event PDU) ([]string, error) { - res, err := p.FedClient.LookupStateIDs(ctx, p.Origin, p.Server, event.RoomID(), event.EventID()) + res, err := p.FedClient.LookupStateIDs(ctx, p.Origin, p.Server, event.RoomID().String(), event.EventID()) if err != nil { return nil, err } @@ -77,7 +77,7 @@ func (p *FederatedStateProvider) StateIDsBeforeEvent(ctx context.Context, event // StateBeforeEvent implements StateProvider func (p *FederatedStateProvider) StateBeforeEvent(ctx context.Context, roomVer RoomVersion, event PDU, eventIDs []string) (map[string]PDU, error) { - res, err := p.FedClient.LookupState(ctx, p.Origin, p.Server, event.RoomID(), event.EventID(), roomVer) + res, err := p.FedClient.LookupState(ctx, p.Origin, p.Server, event.RoomID().String(), event.EventID(), roomVer) if err != nil { return nil, err } diff --git a/eventV1.go b/eventV1.go index a4c04b2a..c01fd851 100644 --- a/eventV1.go +++ b/eventV1.go @@ -104,8 +104,12 @@ func (e *eventV1) Version() RoomVersion { return e.roomVersion } -func (e *eventV1) RoomID() string { - return e.eventFields.RoomID +func (e *eventV1) RoomID() spec.RoomID { + roomID, err := spec.NewRoomID(e.eventFields.RoomID) + if err != nil { + panic(fmt.Errorf("RoomID is invalid: %w", err)) + } + return *roomID } func (e *eventV1) Redacts() string { @@ -280,6 +284,10 @@ func newEventFromUntrustedJSONV1(eventJSON []byte, roomVersion IRoomVersion) (PD return nil, err } + if err := checkID(res.eventFields.RoomID, "room", '!'); err != nil { + return nil, err + } + // We know the JSON must be valid here. eventJSON = CanonicalJSONAssumeValid(eventJSON) @@ -323,6 +331,11 @@ func newEventFromTrustedJSONV1(eventJSON []byte, redacted bool, roomVersion IRoo if err := json.Unmarshal(eventJSON, &res); err != nil { return nil, err } + + if err := checkID(res.eventFields.RoomID, "room", '!'); err != nil { + return nil, fmt.Errorf("RoomID is invalid: %w", err) + } + res.eventJSON = eventJSON res.roomVersion = roomVersion.Version() res.redacted = redacted @@ -334,6 +347,11 @@ func newEventFromTrustedJSONWithEventIDV1(eventID string, eventJSON []byte, reda if err := json.Unmarshal(eventJSON, &res); err != nil { return nil, err } + + if err := checkID(res.eventFields.RoomID, "room", '!'); err != nil { + return nil, err + } + res.EventIDRaw = eventID res.eventJSON = eventJSON res.roomVersion = roomVersion.Version() diff --git a/eventV2.go b/eventV2.go index f2d406ed..4be5e931 100644 --- a/eventV2.go +++ b/eventV2.go @@ -140,6 +140,11 @@ func newEventFromUntrustedJSONV2(eventJSON []byte, roomVersion IRoomVersion) (PD if err = json.Unmarshal(eventJSON, &res); err != nil { return nil, err } + + if err := checkID(res.eventFields.RoomID, "room", '!'); err != nil { + return nil, err + } + res.roomVersion = roomVersion.Version() // We know the JSON must be valid here. @@ -245,10 +250,6 @@ func CheckFields(input PDU) error { // nolint: gocyclo } } - if err := checkID(input.RoomID(), "room", '!'); err != nil { - return err - } - switch input.Version() { case RoomVersionPseudoIDs: default: @@ -265,6 +266,11 @@ func newEventFromTrustedJSONV2(eventJSON []byte, redacted bool, roomVersion IRoo if err := json.Unmarshal(eventJSON, &res); err != nil { return nil, err } + + if err := checkID(res.eventFields.RoomID, "room", '!'); err != nil { + return nil, err + } + res.roomVersion = roomVersion.Version() res.redacted = redacted res.eventJSON = eventJSON @@ -276,6 +282,11 @@ func newEventFromTrustedJSONWithEventIDV2(eventID string, eventJSON []byte, reda if err := json.Unmarshal(eventJSON, &res); err != nil { return nil, err } + + if err := checkID(res.eventFields.RoomID, "room", '!'); err != nil { + return nil, err + } + res.roomVersion = roomVersion.Version() res.eventJSON = eventJSON res.EventIDRaw = eventID diff --git a/eventV2_test.go b/eventV2_test.go index 57e63263..b9d5a6a8 100644 --- a/eventV2_test.go +++ b/eventV2_test.go @@ -152,9 +152,11 @@ func TestCheckFields(t *testing.T) { t.Run(tt.name+"-"+string(roomVersion), func(t *testing.T) { ev, err := MustGetRoomVersion(roomVersion).NewEventBuilderFromProtoEvent(&tt.input).Build(time.Now(), "localhost", "ed25519:1", sk) tt.wantErr(t, err) - err = CheckFields(ev) - tt.wantErr(t, err, fmt.Sprintf("CheckFields(%v)", tt.input)) - t.Logf("%v", err) + if ev != nil { + err = CheckFields(ev) + tt.wantErr(t, err, fmt.Sprintf("CheckFields(%v)", tt.input)) + t.Logf("%v", err) + } switch e := err.(type) { case EventValidationError: assert.Equalf(t, tt.wantPersistable, e.Persistable, "unexpected persistable") diff --git a/eventauth.go b/eventauth.go index 7952ac6b..9318b4c6 100644 --- a/eventauth.go +++ b/eventauth.go @@ -259,7 +259,7 @@ func (a *AuthEvents) AddEvent(event PDU) error { if event.StateKey() == nil { return fmt.Errorf("AddEvent: event %q does not have a state key", event.Type()) } - a.roomIDs[event.RoomID()] = struct{}{} + a.roomIDs[event.RoomID().String()] = struct{}{} a.events[StateKeyTuple{event.Type(), *event.StateKey()}] = event return nil } @@ -412,11 +412,7 @@ func Allowed(event PDU, authEvents AuthEventProvider, userIDQuerier spec.UserIDF if !authEvents.Valid() { return errorf("authEvents contains events from different rooms") } - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return err - } - return newAllowerContext(authEvents, userIDQuerier, *validRoomID).allowed(event) + return newAllowerContext(authEvents, userIDQuerier, event.RoomID()).allowed(event) } // createEventAllowed checks whether the m.room.create event is allowed. @@ -428,16 +424,12 @@ func (a *allowerContext) createEventAllowed(event PDU) error { if len(event.PrevEventIDs()) > 0 { return errorf("create event must be the first event in the room: found %d prev_events", len(event.PrevEventIDs())) } - roomIDDomain, err := domainFromID(event.RoomID()) - if err != nil { - return err - } sender, err := a.userIDQuerier(a.roomID, event.SenderID()) if err != nil { return err } - if string(sender.Domain()) != roomIDDomain { - return errorf("create event room ID domain does not match sender: %q != %q", roomIDDomain, sender.String()) + if sender.Domain() != event.RoomID().Domain() { + return errorf("create event room ID domain does not match sender: %q != %q", event.RoomID().Domain(), sender.String()) } c := struct { Creator *string `json:"creator"` @@ -479,10 +471,10 @@ func (a *allowerContext) aliasEventAllowed(event PDU) error { return err } - if event.RoomID() != a.create.roomID { + if event.RoomID().String() != a.create.roomID { return errorf( "create event has different roomID: %q (%s) != %q (%s)", - event.RoomID(), event.EventID(), a.create.roomID, a.create.eventID, + event.RoomID().String(), event.EventID(), a.create.roomID, a.create.eventID, ) } @@ -876,10 +868,10 @@ func (a *allowerContext) newEventAllower(senderID spec.SenderID) (e eventAllower // commonChecks does the checks that are applied to all events types other than // m.room.create, m.room.member, or m.room.alias. func (e *eventAllower) commonChecks(event PDU) error { - if event.RoomID() != e.create.roomID { + if event.RoomID().String() != e.create.roomID { return errorf( "create event has different roomID: %q (%s) != %q (%s)", - event.RoomID(), event.EventID(), e.create.roomID, e.create.eventID, + event.RoomID().String(), event.EventID(), e.create.roomID, e.create.eventID, ) } @@ -889,7 +881,7 @@ func (e *eventAllower) commonChecks(event PDU) error { return err } if userID == nil { - return errorf("userID not found for sender %q in room %q", event.SenderID(), event.RoomID()) + return errorf("userID not found for sender %q in room %q", event.SenderID(), event.RoomID().String()) } if err := e.create.UserIDAllowed(*userID); err != nil { return err @@ -983,10 +975,10 @@ func (a *allowerContext) newMembershipAllower(authEvents AuthEventProvider, even // membershipAllowed checks whether the membership event is allowed func (m *membershipAllower) membershipAllowed(event PDU) error { // nolint: gocyclo - if m.create.roomID != event.RoomID() { + if m.create.roomID != event.RoomID().String() { return errorf( "create event has different roomID: %q (%s) != %q (%s)", - event.RoomID(), event.EventID(), m.create.roomID, m.create.eventID, + event.RoomID().String(), event.EventID(), m.create.roomID, m.create.eventID, ) } @@ -1013,7 +1005,7 @@ func (m *membershipAllower) membershipAllowed(event PDU) error { // nolint: gocy } if sender == nil { - return errorf("userID not found for sender %q in room %q", m.senderID, event.RoomID()) + return errorf("userID not found for sender %q in room %q", m.senderID, event.RoomID().String()) } if err := m.create.UserIDAllowed(*sender); err != nil { return err diff --git a/eventauth_test.go b/eventauth_test.go index 3269c882..88773fc4 100644 --- a/eventauth_test.go +++ b/eventauth_test.go @@ -93,7 +93,7 @@ func testStateNeededForAuth(t *testing.T, eventdata string, protoEvent *ProtoEve func TestStateNeededForCreate(t *testing.T) { // Create events don't need anything. skey := "" - testStateNeededForAuth(t, `[{"type": "m.room.create"}]`, &ProtoEvent{ + testStateNeededForAuth(t, `[{"type": "m.room.create", "room_id": "!r1:a"}]`, &ProtoEvent{ Type: "m.room.create", StateKey: &skey, }, StateNeeded{}) @@ -103,7 +103,8 @@ func TestStateNeededForMessage(t *testing.T) { // Message events need the create event, the sender and the power_levels. testStateNeededForAuth(t, `[{ "type": "m.room.message", - "sender": "@u1:a" + "sender": "@u1:a", + "room_id": "!r1:a" }]`, &ProtoEvent{ Type: "m.room.message", SenderID: "@u1:a", @@ -116,7 +117,7 @@ func TestStateNeededForMessage(t *testing.T) { func TestStateNeededForAlias(t *testing.T) { // Alias events need only the create event. - testStateNeededForAuth(t, `[{"type": "m.room.aliases"}]`, &ProtoEvent{ + testStateNeededForAuth(t, `[{"type": "m.room.aliases", "room_id": "!r1:a"}]`, &ProtoEvent{ Type: "m.room.aliases", }, StateNeeded{ Create: true, @@ -137,7 +138,8 @@ func TestStateNeededForJoin(t *testing.T) { "type": "m.room.member", "state_key": "@u1:a", "sender": "@u1:a", - "content": {"membership": "join"} + "content": {"membership": "join"}, + "room_id": "!r1:a" }]`, &b, StateNeeded{ Create: true, JoinRules: true, @@ -160,7 +162,8 @@ func TestStateNeededForInvite(t *testing.T) { "type": "m.room.member", "state_key": "@u2:b", "sender": "@u1:a", - "content": {"membership": "invite"} + "content": {"membership": "invite"}, + "room_id": "!r1:a" }]`, &b, StateNeeded{ Create: true, PowerLevels: true, @@ -195,7 +198,8 @@ func TestStateNeededForInvite3PID(t *testing.T) { "token": "my_token" } } - } + }, + "room_id": "!r1:a" }]`, &b, StateNeeded{ Create: true, PowerLevels: true, @@ -295,10 +299,12 @@ func testEventAllowed(t *testing.T, testCaseJSON string) { for _, data := range tc.NotAllowed { event, err := MustGetRoomVersion(RoomVersionV1).NewEventFromTrustedJSON(data, false) if err != nil { - panic(err) + continue } - if err := Allowed(event, &tc.AuthEvents, UserIDForSenderTest); err == nil { - t.Fatalf("Expected %q to not be allowed but it was", string(data)) + if event != nil { + if err := Allowed(event, &tc.AuthEvents, UserIDForSenderTest); err == nil { + t.Fatalf("Expected %q to not be allowed but it was", string(data)) + } } } } diff --git a/eventcontent.go b/eventcontent.go index 12856eb9..83eb787a 100644 --- a/eventcontent.go +++ b/eventcontent.go @@ -69,14 +69,9 @@ func NewCreateContentFromAuthEvents(authEvents AuthEventProvider, userIDForSende err = errorf("unparseable create event content: %s", err.Error()) return } - c.roomID = createEvent.RoomID() + c.roomID = createEvent.RoomID().String() c.eventID = createEvent.EventID() - validRoomID, err := spec.NewRoomID(createEvent.RoomID()) - if err != nil { - err = errorf("roomID is invalid: %s", err.Error()) - return - } - sender, err := userIDForSender(*validRoomID, createEvent.SenderID()) + sender, err := userIDForSender(createEvent.RoomID(), createEvent.SenderID()) if err != nil { err = errorf("invalid sender userID: %s", err.Error()) return diff --git a/eventcrypto.go b/eventcrypto.go index ec83aed3..cc30f2c1 100644 --- a/eventcrypto.go +++ b/eventcrypto.go @@ -54,11 +54,7 @@ func VerifyEventSignatures(ctx context.Context, e PDU, verifier JSONVerifier, us case RoomVersionPseudoIDs: needed[spec.ServerName(e.SenderID())] = struct{}{} default: - validRoomID, err := spec.NewRoomID(e.RoomID()) - if err != nil { - return err - } - sender, err := userIDForSender(*validRoomID, e.SenderID()) + sender, err := userIDForSender(e.RoomID(), e.SenderID()) if err != nil { return fmt.Errorf("invalid sender userID: %w", err) } diff --git a/fclient/federationclient.go b/fclient/federationclient.go index 32303d9a..90d7e240 100644 --- a/fclient/federationclient.go +++ b/fclient/federationclient.go @@ -242,7 +242,7 @@ func (ac *federationClient) sendJoin( ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU, partialState bool, ) (res RespSendJoin, err error) { path := federationPathPrefixV2 + "/send_join/" + - url.PathEscape(event.RoomID()) + "/" + + url.PathEscape(event.RoomID().String()) + "/" + url.PathEscape(event.EventID()) if partialState { path += "?omit_members=true" @@ -257,7 +257,7 @@ func (ac *federationClient) sendJoin( if ok && gerr.Code == 404 { // fallback to v1 which returns [200, body] v1path := federationPathPrefixV1 + "/send_join/" + - url.PathEscape(event.RoomID()) + "/" + + url.PathEscape(event.RoomID().String()) + "/" + url.PathEscape(event.EventID()) v1req := NewFederationRequest("PUT", origin, s, v1path) if err = v1req.SetContent(event); err != nil { @@ -301,7 +301,7 @@ func (ac *federationClient) SendKnock( ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU, ) (res RespSendKnock, err error) { path := federationPathPrefixV1 + "/send_knock/" + - url.PathEscape(event.RoomID()) + "/" + + url.PathEscape(event.RoomID().String()) + "/" + url.PathEscape(event.EventID()) req := NewFederationRequest("PUT", origin, s, path) @@ -336,7 +336,7 @@ func (ac *federationClient) SendLeave( ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU, ) (err error) { path := federationPathPrefixV2 + "/send_leave/" + - url.PathEscape(event.RoomID()) + "/" + + url.PathEscape(event.RoomID().String()) + "/" + url.PathEscape(event.EventID()) req := NewFederationRequest("PUT", origin, s, path) if err = req.SetContent(event); err != nil { @@ -348,7 +348,7 @@ func (ac *federationClient) SendLeave( if ok && gerr.Code == 404 { // fallback to v1 which returns [200, body] v1path := federationPathPrefixV1 + "/send_leave/" + - url.PathEscape(event.RoomID()) + "/" + + url.PathEscape(event.RoomID().String()) + "/" + url.PathEscape(event.EventID()) v1req := NewFederationRequest("PUT", origin, s, v1path) if err = v1req.SetContent(event); err != nil { @@ -369,7 +369,7 @@ func (ac *federationClient) SendInvite( ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU, ) (res RespInvite, err error) { path := federationPathPrefixV1 + "/invite/" + - url.PathEscape(event.RoomID()) + "/" + + url.PathEscape(event.RoomID().String()) + "/" + url.PathEscape(event.EventID()) req := NewFederationRequest("PUT", origin, s, path) if err = req.SetContent(event); err != nil { @@ -386,7 +386,7 @@ func (ac *federationClient) SendInviteV2( ) (res RespInviteV2, err error) { event := request.Event() path := federationPathPrefixV2 + "/invite/" + - url.PathEscape(event.RoomID()) + "/" + + url.PathEscape(event.RoomID().String()) + "/" + url.PathEscape(event.EventID()) req := NewFederationRequest("PUT", origin, s, path) if err = req.SetContent(request); err != nil { diff --git a/handleinvite.go b/handleinvite.go index 1e027c85..e99c5ab7 100644 --- a/handleinvite.go +++ b/handleinvite.go @@ -74,7 +74,7 @@ func HandleInvite(ctx context.Context, input HandleInviteInput) (PDU, error) { } // Check that the room ID is correct. - if input.InviteEvent.RoomID() != input.RoomID.String() { + if input.InviteEvent.RoomID().String() != input.RoomID.String() { return nil, spec.BadJSON("The room ID in the request path must match the room ID in the invite event JSON") } diff --git a/handlejoin.go b/handlejoin.go index dcc2d3d4..468d209d 100644 --- a/handlejoin.go +++ b/handlejoin.go @@ -385,11 +385,11 @@ func HandleSendJoin(input HandleSendJoinInput) (*HandleSendJoinResponse, error) } // Check that the room ID is correct. - if event.RoomID() != input.RoomID.String() { + if event.RoomID().String() != input.RoomID.String() { return nil, spec.BadJSON( fmt.Sprintf( "The room ID in the request path (%q) must match the room ID in the join event JSON (%q)", - input.RoomID.String(), event.RoomID(), + input.RoomID.String(), event.RoomID().String(), ), ) } diff --git a/pdu.go b/pdu.go index d96c59d2..25b7926c 100644 --- a/pdu.go +++ b/pdu.go @@ -27,7 +27,7 @@ type PDU interface { Membership() (string, error) PowerLevels() (*PowerLevelContent, error) Version() RoomVersion - RoomID() string + RoomID() spec.RoomID Redacts() string // Redacted returns whether the event is redacted. Redacted() bool diff --git a/performjoin.go b/performjoin.go index f0f319fd..ce76bf59 100644 --- a/performjoin.go +++ b/performjoin.go @@ -356,7 +356,7 @@ func isWellFormedJoinMemberEvent(event PDU, roomID *spec.RoomID, senderID spec.S } else if membership != spec.Join { return false } - if event.RoomID() != roomID.String() { + if event.RoomID().String() != roomID.String() { return false } if !event.StateKeyEquals(string(senderID)) { diff --git a/spec/roomid.go b/spec/roomid.go index eca2ad25..9ee6a401 100644 --- a/spec/roomid.go +++ b/spec/roomid.go @@ -20,17 +20,17 @@ func NewRoomID(id string) (*RoomID, error) { } // Returns the full roomID string including leading sigil -func (room *RoomID) String() string { +func (room RoomID) String() string { return room.raw } // Returns just the localpart of the roomID -func (room *RoomID) OpaqueID() string { +func (room RoomID) OpaqueID() string { return room.opaqueID } // Returns just the domain of the roomID -func (room *RoomID) Domain() ServerName { +func (room RoomID) Domain() ServerName { return ServerName(room.domain) } diff --git a/stateresolution.go b/stateresolution.go index 31931fe1..767a29ef 100644 --- a/stateresolution.go +++ b/stateresolution.go @@ -159,10 +159,10 @@ func (r *stateResolver) addConflicted(events []PDU) { // nolint: gocyclo // Add an event to the resolved auth events. func (r *stateResolver) addAuthEvent(event PDU) { - if event.RoomID() != "" && r.roomID == "" { - r.roomID = event.RoomID() + if event.RoomID().String() != "" && r.roomID == "" { + r.roomID = event.RoomID().String() } - if r.roomID != event.RoomID() { + if r.roomID != event.RoomID().String() { r.valid = false } switch event.Type() { diff --git a/stateresolutionv2.go b/stateresolutionv2.go index 9bda07c5..a90b8fbf 100644 --- a/stateresolutionv2.go +++ b/stateresolutionv2.go @@ -71,24 +71,17 @@ func ResolveStateConflictsV2( result: make([]PDU, 0, len(conflicted)+len(unconflicted)), } var roomID *spec.RoomID - var err error if len(conflicted) > 0 { - roomID, err = spec.NewRoomID(conflicted[0].RoomID()) - if err != nil { - panic(err) - } + validRoomID := conflicted[0].RoomID() + roomID = &validRoomID } if len(unconflicted) > 0 { - roomID, err = spec.NewRoomID(unconflicted[0].RoomID()) - if err != nil { - panic(err) - } + validRoomID := unconflicted[0].RoomID() + roomID = &validRoomID } if len(authEvents) > 0 { - roomID, err = spec.NewRoomID(authEvents[0].RoomID()) - if err != nil { - panic(err) - } + validRoomID := authEvents[0].RoomID() + roomID = &validRoomID } // If we still don't have a roomID, we don't have conflicted, unconflicted // or any authEvents, which in theory shouldn't happen.