Skip to content

Commit

Permalink
Change PDU to return validated RoomID (#414)
Browse files Browse the repository at this point in the history
  • Loading branch information
devonh authored Sep 15, 2023
1 parent 47bceff commit 095d10f
Show file tree
Hide file tree
Showing 16 changed files with 96 additions and 83 deletions.
4 changes: 2 additions & 2 deletions authstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ type FederatedStateProvider struct {

// StateIDsBeforeEvent implements StateProvider
func (p *FederatedStateProvider) StateIDsBeforeEvent(ctx context.Context, event PDU) ([]string, error) {
res, err := p.FedClient.LookupStateIDs(ctx, p.Origin, p.Server, event.RoomID(), event.EventID())
res, err := p.FedClient.LookupStateIDs(ctx, p.Origin, p.Server, event.RoomID().String(), event.EventID())
if err != nil {
return nil, err
}
Expand All @@ -77,7 +77,7 @@ func (p *FederatedStateProvider) StateIDsBeforeEvent(ctx context.Context, event

// StateBeforeEvent implements StateProvider
func (p *FederatedStateProvider) StateBeforeEvent(ctx context.Context, roomVer RoomVersion, event PDU, eventIDs []string) (map[string]PDU, error) {
res, err := p.FedClient.LookupState(ctx, p.Origin, p.Server, event.RoomID(), event.EventID(), roomVer)
res, err := p.FedClient.LookupState(ctx, p.Origin, p.Server, event.RoomID().String(), event.EventID(), roomVer)
if err != nil {
return nil, err
}
Expand Down
22 changes: 20 additions & 2 deletions eventV1.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,12 @@ func (e *eventV1) Version() RoomVersion {
return e.roomVersion
}

func (e *eventV1) RoomID() string {
return e.eventFields.RoomID
func (e *eventV1) RoomID() spec.RoomID {
roomID, err := spec.NewRoomID(e.eventFields.RoomID)
if err != nil {
panic(fmt.Errorf("RoomID is invalid: %w", err))
}
return *roomID
}

func (e *eventV1) Redacts() string {
Expand Down Expand Up @@ -280,6 +284,10 @@ func newEventFromUntrustedJSONV1(eventJSON []byte, roomVersion IRoomVersion) (PD
return nil, err
}

if err := checkID(res.eventFields.RoomID, "room", '!'); err != nil {
return nil, err
}

// We know the JSON must be valid here.
eventJSON = CanonicalJSONAssumeValid(eventJSON)

Expand Down Expand Up @@ -323,6 +331,11 @@ func newEventFromTrustedJSONV1(eventJSON []byte, redacted bool, roomVersion IRoo
if err := json.Unmarshal(eventJSON, &res); err != nil {
return nil, err
}

if err := checkID(res.eventFields.RoomID, "room", '!'); err != nil {
return nil, fmt.Errorf("RoomID is invalid: %w", err)
}

res.eventJSON = eventJSON
res.roomVersion = roomVersion.Version()
res.redacted = redacted
Expand All @@ -334,6 +347,11 @@ func newEventFromTrustedJSONWithEventIDV1(eventID string, eventJSON []byte, reda
if err := json.Unmarshal(eventJSON, &res); err != nil {
return nil, err
}

if err := checkID(res.eventFields.RoomID, "room", '!'); err != nil {
return nil, err
}

res.EventIDRaw = eventID
res.eventJSON = eventJSON
res.roomVersion = roomVersion.Version()
Expand Down
19 changes: 15 additions & 4 deletions eventV2.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ func newEventFromUntrustedJSONV2(eventJSON []byte, roomVersion IRoomVersion) (PD
if err = json.Unmarshal(eventJSON, &res); err != nil {
return nil, err
}

if err := checkID(res.eventFields.RoomID, "room", '!'); err != nil {
return nil, err
}

res.roomVersion = roomVersion.Version()

// We know the JSON must be valid here.
Expand Down Expand Up @@ -245,10 +250,6 @@ func CheckFields(input PDU) error { // nolint: gocyclo
}
}

if err := checkID(input.RoomID(), "room", '!'); err != nil {
return err
}

switch input.Version() {
case RoomVersionPseudoIDs:
default:
Expand All @@ -265,6 +266,11 @@ func newEventFromTrustedJSONV2(eventJSON []byte, redacted bool, roomVersion IRoo
if err := json.Unmarshal(eventJSON, &res); err != nil {
return nil, err
}

if err := checkID(res.eventFields.RoomID, "room", '!'); err != nil {
return nil, err
}

res.roomVersion = roomVersion.Version()
res.redacted = redacted
res.eventJSON = eventJSON
Expand All @@ -276,6 +282,11 @@ func newEventFromTrustedJSONWithEventIDV2(eventID string, eventJSON []byte, reda
if err := json.Unmarshal(eventJSON, &res); err != nil {
return nil, err
}

if err := checkID(res.eventFields.RoomID, "room", '!'); err != nil {
return nil, err
}

res.roomVersion = roomVersion.Version()
res.eventJSON = eventJSON
res.EventIDRaw = eventID
Expand Down
8 changes: 5 additions & 3 deletions eventV2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,11 @@ func TestCheckFields(t *testing.T) {
t.Run(tt.name+"-"+string(roomVersion), func(t *testing.T) {
ev, err := MustGetRoomVersion(roomVersion).NewEventBuilderFromProtoEvent(&tt.input).Build(time.Now(), "localhost", "ed25519:1", sk)
tt.wantErr(t, err)
err = CheckFields(ev)
tt.wantErr(t, err, fmt.Sprintf("CheckFields(%v)", tt.input))
t.Logf("%v", err)
if ev != nil {
err = CheckFields(ev)
tt.wantErr(t, err, fmt.Sprintf("CheckFields(%v)", tt.input))
t.Logf("%v", err)
}
switch e := err.(type) {
case EventValidationError:
assert.Equalf(t, tt.wantPersistable, e.Persistable, "unexpected persistable")
Expand Down
32 changes: 12 additions & 20 deletions eventauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func (a *AuthEvents) AddEvent(event PDU) error {
if event.StateKey() == nil {
return fmt.Errorf("AddEvent: event %q does not have a state key", event.Type())
}
a.roomIDs[event.RoomID()] = struct{}{}
a.roomIDs[event.RoomID().String()] = struct{}{}
a.events[StateKeyTuple{event.Type(), *event.StateKey()}] = event
return nil
}
Expand Down Expand Up @@ -412,11 +412,7 @@ func Allowed(event PDU, authEvents AuthEventProvider, userIDQuerier spec.UserIDF
if !authEvents.Valid() {
return errorf("authEvents contains events from different rooms")
}
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return err
}
return newAllowerContext(authEvents, userIDQuerier, *validRoomID).allowed(event)
return newAllowerContext(authEvents, userIDQuerier, event.RoomID()).allowed(event)
}

// createEventAllowed checks whether the m.room.create event is allowed.
Expand All @@ -428,16 +424,12 @@ func (a *allowerContext) createEventAllowed(event PDU) error {
if len(event.PrevEventIDs()) > 0 {
return errorf("create event must be the first event in the room: found %d prev_events", len(event.PrevEventIDs()))
}
roomIDDomain, err := domainFromID(event.RoomID())
if err != nil {
return err
}
sender, err := a.userIDQuerier(a.roomID, event.SenderID())
if err != nil {
return err
}
if string(sender.Domain()) != roomIDDomain {
return errorf("create event room ID domain does not match sender: %q != %q", roomIDDomain, sender.String())
if sender.Domain() != event.RoomID().Domain() {
return errorf("create event room ID domain does not match sender: %q != %q", event.RoomID().Domain(), sender.String())
}
c := struct {
Creator *string `json:"creator"`
Expand Down Expand Up @@ -479,10 +471,10 @@ func (a *allowerContext) aliasEventAllowed(event PDU) error {
return err
}

if event.RoomID() != a.create.roomID {
if event.RoomID().String() != a.create.roomID {
return errorf(
"create event has different roomID: %q (%s) != %q (%s)",
event.RoomID(), event.EventID(), a.create.roomID, a.create.eventID,
event.RoomID().String(), event.EventID(), a.create.roomID, a.create.eventID,
)
}

Expand Down Expand Up @@ -876,10 +868,10 @@ func (a *allowerContext) newEventAllower(senderID spec.SenderID) (e eventAllower
// commonChecks does the checks that are applied to all events types other than
// m.room.create, m.room.member, or m.room.alias.
func (e *eventAllower) commonChecks(event PDU) error {
if event.RoomID() != e.create.roomID {
if event.RoomID().String() != e.create.roomID {
return errorf(
"create event has different roomID: %q (%s) != %q (%s)",
event.RoomID(), event.EventID(), e.create.roomID, e.create.eventID,
event.RoomID().String(), event.EventID(), e.create.roomID, e.create.eventID,
)
}

Expand All @@ -889,7 +881,7 @@ func (e *eventAllower) commonChecks(event PDU) error {
return err
}
if userID == nil {
return errorf("userID not found for sender %q in room %q", event.SenderID(), event.RoomID())
return errorf("userID not found for sender %q in room %q", event.SenderID(), event.RoomID().String())
}
if err := e.create.UserIDAllowed(*userID); err != nil {
return err
Expand Down Expand Up @@ -983,10 +975,10 @@ func (a *allowerContext) newMembershipAllower(authEvents AuthEventProvider, even

// membershipAllowed checks whether the membership event is allowed
func (m *membershipAllower) membershipAllowed(event PDU) error { // nolint: gocyclo
if m.create.roomID != event.RoomID() {
if m.create.roomID != event.RoomID().String() {
return errorf(
"create event has different roomID: %q (%s) != %q (%s)",
event.RoomID(), event.EventID(), m.create.roomID, m.create.eventID,
event.RoomID().String(), event.EventID(), m.create.roomID, m.create.eventID,
)
}

Expand All @@ -1013,7 +1005,7 @@ func (m *membershipAllower) membershipAllowed(event PDU) error { // nolint: gocy
}

if sender == nil {
return errorf("userID not found for sender %q in room %q", m.senderID, event.RoomID())
return errorf("userID not found for sender %q in room %q", m.senderID, event.RoomID().String())
}
if err := m.create.UserIDAllowed(*sender); err != nil {
return err
Expand Down
24 changes: 15 additions & 9 deletions eventauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func testStateNeededForAuth(t *testing.T, eventdata string, protoEvent *ProtoEve
func TestStateNeededForCreate(t *testing.T) {
// Create events don't need anything.
skey := ""
testStateNeededForAuth(t, `[{"type": "m.room.create"}]`, &ProtoEvent{
testStateNeededForAuth(t, `[{"type": "m.room.create", "room_id": "!r1:a"}]`, &ProtoEvent{
Type: "m.room.create",
StateKey: &skey,
}, StateNeeded{})
Expand All @@ -103,7 +103,8 @@ func TestStateNeededForMessage(t *testing.T) {
// Message events need the create event, the sender and the power_levels.
testStateNeededForAuth(t, `[{
"type": "m.room.message",
"sender": "@u1:a"
"sender": "@u1:a",
"room_id": "!r1:a"
}]`, &ProtoEvent{
Type: "m.room.message",
SenderID: "@u1:a",
Expand All @@ -116,7 +117,7 @@ func TestStateNeededForMessage(t *testing.T) {

func TestStateNeededForAlias(t *testing.T) {
// Alias events need only the create event.
testStateNeededForAuth(t, `[{"type": "m.room.aliases"}]`, &ProtoEvent{
testStateNeededForAuth(t, `[{"type": "m.room.aliases", "room_id": "!r1:a"}]`, &ProtoEvent{
Type: "m.room.aliases",
}, StateNeeded{
Create: true,
Expand All @@ -137,7 +138,8 @@ func TestStateNeededForJoin(t *testing.T) {
"type": "m.room.member",
"state_key": "@u1:a",
"sender": "@u1:a",
"content": {"membership": "join"}
"content": {"membership": "join"},
"room_id": "!r1:a"
}]`, &b, StateNeeded{
Create: true,
JoinRules: true,
Expand All @@ -160,7 +162,8 @@ func TestStateNeededForInvite(t *testing.T) {
"type": "m.room.member",
"state_key": "@u2:b",
"sender": "@u1:a",
"content": {"membership": "invite"}
"content": {"membership": "invite"},
"room_id": "!r1:a"
}]`, &b, StateNeeded{
Create: true,
PowerLevels: true,
Expand Down Expand Up @@ -195,7 +198,8 @@ func TestStateNeededForInvite3PID(t *testing.T) {
"token": "my_token"
}
}
}
},
"room_id": "!r1:a"
}]`, &b, StateNeeded{
Create: true,
PowerLevels: true,
Expand Down Expand Up @@ -295,10 +299,12 @@ func testEventAllowed(t *testing.T, testCaseJSON string) {
for _, data := range tc.NotAllowed {
event, err := MustGetRoomVersion(RoomVersionV1).NewEventFromTrustedJSON(data, false)
if err != nil {
panic(err)
continue
}
if err := Allowed(event, &tc.AuthEvents, UserIDForSenderTest); err == nil {
t.Fatalf("Expected %q to not be allowed but it was", string(data))
if event != nil {
if err := Allowed(event, &tc.AuthEvents, UserIDForSenderTest); err == nil {
t.Fatalf("Expected %q to not be allowed but it was", string(data))
}
}
}
}
Expand Down
9 changes: 2 additions & 7 deletions eventcontent.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,9 @@ func NewCreateContentFromAuthEvents(authEvents AuthEventProvider, userIDForSende
err = errorf("unparseable create event content: %s", err.Error())
return
}
c.roomID = createEvent.RoomID()
c.roomID = createEvent.RoomID().String()
c.eventID = createEvent.EventID()
validRoomID, err := spec.NewRoomID(createEvent.RoomID())
if err != nil {
err = errorf("roomID is invalid: %s", err.Error())
return
}
sender, err := userIDForSender(*validRoomID, createEvent.SenderID())
sender, err := userIDForSender(createEvent.RoomID(), createEvent.SenderID())
if err != nil {
err = errorf("invalid sender userID: %s", err.Error())
return
Expand Down
6 changes: 1 addition & 5 deletions eventcrypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,7 @@ func VerifyEventSignatures(ctx context.Context, e PDU, verifier JSONVerifier, us
case RoomVersionPseudoIDs:
needed[spec.ServerName(e.SenderID())] = struct{}{}
default:
validRoomID, err := spec.NewRoomID(e.RoomID())
if err != nil {
return err
}
sender, err := userIDForSender(*validRoomID, e.SenderID())
sender, err := userIDForSender(e.RoomID(), e.SenderID())
if err != nil {
return fmt.Errorf("invalid sender userID: %w", err)
}
Expand Down
14 changes: 7 additions & 7 deletions fclient/federationclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func (ac *federationClient) sendJoin(
ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU, partialState bool,
) (res RespSendJoin, err error) {
path := federationPathPrefixV2 + "/send_join/" +
url.PathEscape(event.RoomID()) + "/" +
url.PathEscape(event.RoomID().String()) + "/" +
url.PathEscape(event.EventID())
if partialState {
path += "?omit_members=true"
Expand All @@ -257,7 +257,7 @@ func (ac *federationClient) sendJoin(
if ok && gerr.Code == 404 {
// fallback to v1 which returns [200, body]
v1path := federationPathPrefixV1 + "/send_join/" +
url.PathEscape(event.RoomID()) + "/" +
url.PathEscape(event.RoomID().String()) + "/" +
url.PathEscape(event.EventID())
v1req := NewFederationRequest("PUT", origin, s, v1path)
if err = v1req.SetContent(event); err != nil {
Expand Down Expand Up @@ -301,7 +301,7 @@ func (ac *federationClient) SendKnock(
ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU,
) (res RespSendKnock, err error) {
path := federationPathPrefixV1 + "/send_knock/" +
url.PathEscape(event.RoomID()) + "/" +
url.PathEscape(event.RoomID().String()) + "/" +
url.PathEscape(event.EventID())

req := NewFederationRequest("PUT", origin, s, path)
Expand Down Expand Up @@ -336,7 +336,7 @@ func (ac *federationClient) SendLeave(
ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU,
) (err error) {
path := federationPathPrefixV2 + "/send_leave/" +
url.PathEscape(event.RoomID()) + "/" +
url.PathEscape(event.RoomID().String()) + "/" +
url.PathEscape(event.EventID())
req := NewFederationRequest("PUT", origin, s, path)
if err = req.SetContent(event); err != nil {
Expand All @@ -348,7 +348,7 @@ func (ac *federationClient) SendLeave(
if ok && gerr.Code == 404 {
// fallback to v1 which returns [200, body]
v1path := federationPathPrefixV1 + "/send_leave/" +
url.PathEscape(event.RoomID()) + "/" +
url.PathEscape(event.RoomID().String()) + "/" +
url.PathEscape(event.EventID())
v1req := NewFederationRequest("PUT", origin, s, v1path)
if err = v1req.SetContent(event); err != nil {
Expand All @@ -369,7 +369,7 @@ func (ac *federationClient) SendInvite(
ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU,
) (res RespInvite, err error) {
path := federationPathPrefixV1 + "/invite/" +
url.PathEscape(event.RoomID()) + "/" +
url.PathEscape(event.RoomID().String()) + "/" +
url.PathEscape(event.EventID())
req := NewFederationRequest("PUT", origin, s, path)
if err = req.SetContent(event); err != nil {
Expand All @@ -386,7 +386,7 @@ func (ac *federationClient) SendInviteV2(
) (res RespInviteV2, err error) {
event := request.Event()
path := federationPathPrefixV2 + "/invite/" +
url.PathEscape(event.RoomID()) + "/" +
url.PathEscape(event.RoomID().String()) + "/" +
url.PathEscape(event.EventID())
req := NewFederationRequest("PUT", origin, s, path)
if err = req.SetContent(request); err != nil {
Expand Down
Loading

0 comments on commit 095d10f

Please sign in to comment.