diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 63b4abea..c0287961 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -422,8 +422,14 @@ func generateSendEvent( for i := range queryRes.StateEvents { stateEvents[i] = queryRes.StateEvents[i].PDU } - provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) - if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + provider, err := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) + if err != nil { + return nil, &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden(err.Error()), + } + } + if err = gomatrixserverlib.Allowed(e.PDU, provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, *validRoomID, senderID) }); err != nil { return nil, &util.JSONResponse{ diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index 88a38f58..8b5e8eb2 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -131,7 +131,7 @@ func addPrevEventsToEvent( builder.Depth = queryRes.Depth - authEvents := gomatrixserverlib.NewAuthEvents(nil) + authEvents, _ := gomatrixserverlib.NewAuthEvents(nil) for i := range queryRes.StateEvents { err := authEvents.AddEvent(queryRes.StateEvents[i].PDU) @@ -140,7 +140,7 @@ func addPrevEventsToEvent( } } - refs, err := eventsNeeded.AuthEventReferences(&authEvents) + refs, err := eventsNeeded.AuthEventReferences(authEvents) if err != nil { return fmt.Errorf("eventsNeeded.AuthEventReferences: %w", err) } diff --git a/internal/gomatrixserverlib/authstate.go b/internal/gomatrixserverlib/authstate.go index ead14872..10be227d 100644 --- a/internal/gomatrixserverlib/authstate.go +++ b/internal/gomatrixserverlib/authstate.go @@ -165,7 +165,10 @@ func checkAllowedByAuthEvents( event PDU, eventsByID map[string]PDU, missingAuth EventProvider, userIDForSender spec.UserIDForSender, ) error { - authEvents := NewAuthEvents(nil) + authEvents, err := NewAuthEvents(nil) + if err != nil { + return err + } for _, ae := range event.AuthEventIDs() { retryEvent: @@ -214,7 +217,7 @@ func checkAllowedByAuthEvents( // If we made it this far then we've successfully got as many of the auth events as // as described by AuthEventIDs(). Check if they allow the event. - if err := Allowed(event, &authEvents, userIDForSender); err != nil { + if err := Allowed(event, authEvents, userIDForSender); err != nil { return fmt.Errorf( "gomatrixserverlib: event with ID %q is not allowed by its auth_events: %s", event.EventID(), err.Error(), @@ -335,7 +338,10 @@ func CheckSendJoinResponse( } eventsByID := map[string]PDU{} - authEventProvider := NewAuthEvents(nil) + authEventProvider, err := NewAuthEvents(nil) + if err != nil { + return nil, err + } // Since checkAllowedByAuthEvents needs to be able to look up any of the // auth events by ID only, we will build a map which contains references @@ -369,7 +375,7 @@ func CheckSendJoinResponse( } // Now check that the join event is valid against the supplied state. - if err := Allowed(joinEvent, &authEventProvider, userIDForSender); err != nil { + if err := Allowed(joinEvent, authEventProvider, userIDForSender); err != nil { return nil, fmt.Errorf( "gomatrixserverlib: event with ID %q is not allowed by the current room state: %w", joinEvent.EventID(), err, diff --git a/internal/gomatrixserverlib/eventauth.go b/internal/gomatrixserverlib/eventauth.go index 8fc586ce..a0c90837 100644 --- a/internal/gomatrixserverlib/eventauth.go +++ b/internal/gomatrixserverlib/eventauth.go @@ -301,15 +301,17 @@ func (a *AuthEvents) Clear() { // NewAuthEvents returns an AuthEventProvider backed by the given events. New events can be added by // calling AddEvent(). -func NewAuthEvents(events []PDU) AuthEvents { +func NewAuthEvents(events []PDU) (*AuthEvents, error) { a := AuthEvents{ events: make(map[StateKeyTuple]PDU, len(events)), roomIDs: make(map[string]struct{}), } for _, e := range events { - a.AddEvent(e) // nolint: errcheck + if err := a.AddEvent(e); err != nil { + return nil, err + } } - return a + return &a, nil } // A NotAllowed error is returned if an event does not pass the auth checks. diff --git a/internal/gomatrixserverlib/eventauth_test.go b/internal/gomatrixserverlib/eventauth_test.go index 7698196d..bd98a6fa 100644 --- a/internal/gomatrixserverlib/eventauth_test.go +++ b/internal/gomatrixserverlib/eventauth_test.go @@ -1035,7 +1035,7 @@ func TestAuthEvents(t *testing.T) { if err != nil { t.Fatalf("TestAuthEvents: failed to create power_levels event: %s", err) } - a := NewAuthEvents([]PDU{power}) + a, _ := NewAuthEvents([]PDU{power}) var e PDU if e, err = a.PowerLevels(); err != nil || e != power { t.Errorf("TestAuthEvents: failed to get same power_levels event") diff --git a/internal/gomatrixserverlib/handleinvite_test.go b/internal/gomatrixserverlib/handleinvite_test.go index 9afd1103..f23b0430 100644 --- a/internal/gomatrixserverlib/handleinvite_test.go +++ b/internal/gomatrixserverlib/handleinvite_test.go @@ -38,9 +38,12 @@ func (r *TestStateQuerier) GetAuthEvents(ctx context.Context, event PDU) (AuthEv return nil, fmt.Errorf("failed getting auth provider") } - eventProvider := AuthEvents{} + var eventProvider *AuthEvents if r.createEvent != nil { - eventProvider = NewAuthEvents([]PDU{r.createEvent}) + var err error + if eventProvider, err = NewAuthEvents([]PDU{r.createEvent}); err != nil { + return nil, err + } if r.inviterMemberEvent != nil { err := eventProvider.AddEvent(r.inviterMemberEvent) if err != nil { @@ -48,7 +51,7 @@ func (r *TestStateQuerier) GetAuthEvents(ctx context.Context, event PDU) (AuthEv } } } - return &eventProvider, nil + return eventProvider, nil } func (r *TestStateQuerier) GetState(ctx context.Context, roomID spec.RoomID, stateWanted []StateKeyTuple) ([]PDU, error) { diff --git a/internal/gomatrixserverlib/handlejoin.go b/internal/gomatrixserverlib/handlejoin.go index 4254c660..d11bb184 100644 --- a/internal/gomatrixserverlib/handlejoin.go +++ b/internal/gomatrixserverlib/handlejoin.go @@ -116,8 +116,11 @@ func HandleMakeJoin(input HandleMakeJoinInput) (*HandleMakeJoinResponse, error) return nil, spec.InternalServerError{Err: fmt.Sprintf("expected join event from template builder. got: %s", event.Type())} } - provider := NewAuthEvents(state) - if err = Allowed(event, &provider, input.UserIDQuerier); err != nil { + provider, err := NewAuthEvents(state) + if err != nil { + return nil, spec.Forbidden(err.Error()) + } + if err = Allowed(event, provider, input.UserIDQuerier); err != nil { return nil, spec.Forbidden(err.Error()) } diff --git a/internal/gomatrixserverlib/handleleave.go b/internal/gomatrixserverlib/handleleave.go index 6202f78e..8487097b 100644 --- a/internal/gomatrixserverlib/handleleave.go +++ b/internal/gomatrixserverlib/handleleave.go @@ -81,8 +81,11 @@ func HandleMakeLeave(input HandleMakeLeaveInput) (*HandleMakeLeaveResponse, erro return nil, spec.InternalServerError{Err: fmt.Sprintf("expected leave event from template builder. got: %s", event.Type())} } - provider := NewAuthEvents(stateEvents) - if err := Allowed(event, &provider, input.UserIDQuerier); err != nil { + provider, err := NewAuthEvents(stateEvents) + if err != nil { + return nil, spec.Forbidden(err.Error()) + } + if err = Allowed(event, provider, input.UserIDQuerier); err != nil { return nil, spec.Forbidden(err.Error()) } diff --git a/internal/gomatrixserverlib/performinvite.go b/internal/gomatrixserverlib/performinvite.go index 8370cb35..9de7e898 100644 --- a/internal/gomatrixserverlib/performinvite.go +++ b/internal/gomatrixserverlib/performinvite.go @@ -123,7 +123,10 @@ func PerformInvite(ctx context.Context, input PerformInviteInput, fedClient Fede input.EventTemplate.Depth = latestEvents.Depth - authEvents := NewAuthEvents(nil) + authEvents, err := NewAuthEvents(nil) + if err != nil { + return nil, err + } for _, event := range latestEvents.StateEvents { err := authEvents.AddEvent(event) @@ -132,7 +135,7 @@ func PerformInvite(ctx context.Context, input PerformInviteInput, fedClient Fede } } - refs, err := stateNeeded.AuthEventReferences(&authEvents) + refs, err := stateNeeded.AuthEventReferences(authEvents) if err != nil { return nil, fmt.Errorf("eventsNeeded.AuthEventReferences: %w", err) } diff --git a/internal/gomatrixserverlib/stateresolutionv2.go b/internal/gomatrixserverlib/stateresolutionv2.go index de8923a7..8711bed6 100644 --- a/internal/gomatrixserverlib/stateresolutionv2.go +++ b/internal/gomatrixserverlib/stateresolutionv2.go @@ -34,7 +34,7 @@ const ( type stateResolverV2 struct { allower *allowerContext // Used to auth and apply events - authProvider AuthEvents // Used in the allower + authProvider *AuthEvents // Used in the allower authEventMap map[string]PDU // Map of all provided auth events conflictedEventMap map[string]PDU // Map of all provided conflicted events powerLevelContents map[string]*PowerLevelContent // A cache of all power level contents @@ -64,9 +64,10 @@ func ResolveStateConflictsV2( // Prepare the state resolver. conflictedControlEvents := make([]PDU, 0, len(conflicted)) conflictedOthers := make([]PDU, 0, len(conflicted)) + authProvider, _ := NewAuthEvents(nil) r := stateResolverV2{ authEventMap: eventMapFromEvents(authEvents), - authProvider: NewAuthEvents(nil), + authProvider: authProvider, conflictedEventMap: eventMapFromEvents(conflicted), powerLevelContents: make(map[string]*PowerLevelContent), powerLevelMainlinePos: make(map[string]int), @@ -95,7 +96,7 @@ func ResolveStateConflictsV2( return r.result } - r.allower = newAllowerContext(&r.authProvider, userIDForSender, *roomID) + r.allower = newAllowerContext(r.authProvider, userIDForSender, *roomID) // This is a map to help us determine if an event already belongs to the // unconflicted set. If it does then we shouldn't add it back into the @@ -481,7 +482,7 @@ func (r *stateResolverV2) authAndApplyEvents(events []PDU) { } // Check if the event is allowed based on the current partial state. - r.allower.update(&r.authProvider) + r.allower.update(r.authProvider) if err := r.allower.allowed(event); err != nil { // The event was not allowed by the partial state and/or relevant // auth events from the event, so skip it. diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index e97cc852..a7ef8e9d 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -184,7 +184,7 @@ func (r *Inputer) processRoomEvent( // Sort all of the servers into a map so that we can randomise // their order. Then make sure that the input origin and the // event origin are first on the list. - servers := map[spec.ServerName]struct{}{} + servers := make(map[spec.ServerName]struct{}, len(serverRes.ServerNames)) for _, server := range serverRes.ServerNames { servers[server] = struct{}{} } @@ -210,9 +210,9 @@ func (r *Inputer) processRoomEvent( // Check that the auth events of the event are known. // If they aren't then we will ask the federation API for them. - authEvents := gomatrixserverlib.NewAuthEvents(nil) + authEvents, _ := gomatrixserverlib.NewAuthEvents(nil) knownEvents := map[string]*types.Event{} - if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, authEvents, knownEvents, serverRes.ServerNames); err != nil { return fmt.Errorf("r.fetchAuthEvents: %w", err) } @@ -221,7 +221,7 @@ func (r *Inputer) processRoomEvent( // Check if the event is allowed by its auth events. If it isn't then // we consider the event to be "rejected" — it will still be persisted. - if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + if err = gomatrixserverlib.Allowed(event, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { isRejected = true @@ -643,10 +643,14 @@ func (r *Inputer) processStateBefore( // At this point, stateBeforeEvent should be populated either by // the supplied state in the input request, or from the prev events. // Check whether the event is allowed or not. - stateBeforeAuth := gomatrixserverlib.NewAuthEvents( + stateBeforeAuth, err := gomatrixserverlib.NewAuthEvents( gomatrixserverlib.ToPDUs(stateBeforeEvent), ) - if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + if err != nil { + rejectionErr = fmt.Errorf("NewAuthEvents failed: %w", err) + return + } + if rejectionErr = gomatrixserverlib.Allowed(event, stateBeforeAuth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); rejectionErr != nil { rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go index 8441ac33..80a1bf15 100644 --- a/roomserver/internal/input/input_events_test.go +++ b/roomserver/internal/input/input_events_test.go @@ -50,7 +50,7 @@ func Test_EventAuth(t *testing.T) { }, test.WithStateKey(bob.ID), test.WithAuthIDs(authEventIDs)) // Add the auth events to the allower - allower := gomatrixserverlib.NewAuthEvents(nil) + allower, _ := gomatrixserverlib.NewAuthEvents(nil) for _, a := range authEvents { if err := allower.AddEvent(a); err != nil { t.Fatalf("allower.AddEvent failed: %v", err) @@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) { } // Finally check that the event is NOT allowed - if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + if err := gomatrixserverlib.Allowed(ev.PDU, allower, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) }); err == nil { t.Fatalf("event should not be allowed, but it was") diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index db8830bc..9064194b 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -931,14 +931,14 @@ serverLoop: } func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) error { - authUsingState := gomatrixserverlib.NewAuthEvents(nil) + authUsingState, _ := gomatrixserverlib.NewAuthEvents(nil) for i := range stateEvents { err := authUsingState.AddEvent(stateEvents[i]) if err != nil { return err } } - return gomatrixserverlib.Allowed(e, &authUsingState, userIDForSender) + return gomatrixserverlib.Allowed(e, authUsingState, userIDForSender) } func (t *missingStateReq) hadEvent(eventID string) { diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index db0f4d9b..d16dcd06 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -350,7 +350,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo // TODO: 3pid invite events var builtEvents []*types.HeaderedEvent - authEvents := gomatrixserverlib.NewAuthEvents(nil) + authEvents, _ := gomatrixserverlib.NewAuthEvents(nil) if err != nil { util.GetLogger(ctx).WithError(err).Error("rsapi.QuerySenderIDForUser failed") return "", &util.JSONResponse{ @@ -380,7 +380,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo builder.PrevEvents = []string{builtEvents[i-1].EventID()} } var ev gomatrixserverlib.PDU - if err = builder.AddAuthEvents(&authEvents); err != nil { + if err = builder.AddAuthEvents(authEvents); err != nil { util.GetLogger(ctx).WithError(err).Error("AddAuthEvents failed") return "", &util.JSONResponse{ Code: http.StatusInternalServerError, @@ -396,7 +396,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } - if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + if err = gomatrixserverlib.Allowed(ev, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return c.RSAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 8fea22c0..3d6c659c 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -478,7 +478,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []gomatrixserverlib.FledglingEvent) error { var err error var builtEvents []*types.HeaderedEvent - authEvents := gomatrixserverlib.NewAuthEvents(nil) + authEvents, _ := gomatrixserverlib.NewAuthEvents(nil) for i, e := range eventsToMake { depth := i + 1 // depth starts at 1 @@ -503,7 +503,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, send return err } builder := verImpl.NewEventBuilderFromProtoEvent(&proto) - if err = builder.AddAuthEvents(&authEvents); err != nil { + if err = builder.AddAuthEvents(authEvents); err != nil { return err } @@ -514,7 +514,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, send } - if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + if err = gomatrixserverlib.Allowed(event, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err) @@ -594,8 +594,11 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, send for i := range queryRes.StateEvents { stateEvents[i] = queryRes.StateEvents[i].PDU } - provider := gomatrixserverlib.NewAuthEvents(stateEvents) - if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + provider, err := gomatrixserverlib.NewAuthEvents(stateEvents) + if err != nil { + return nil, err + } + if err = gomatrixserverlib.Allowed(headeredEvent.PDU, provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client? diff --git a/test/room.go b/test/room.go index 0ffaa853..cfa5ebee 100644 --- a/test/room.go +++ b/test/room.go @@ -51,7 +51,7 @@ type Room struct { visibility gomatrixserverlib.HistoryVisibility creator *User - authEvents gomatrixserverlib.AuthEvents + authEvents *gomatrixserverlib.AuthEvents currentState map[string]*rstypes.HeaderedEvent events []*rstypes.HeaderedEvent } @@ -63,10 +63,11 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room { if creator.srvName == "" { t.Fatalf("NewRoom: creator doesn't belong to a server: %+v", *creator) } + authEvents, _ := gomatrixserverlib.NewAuthEvents(nil) r := &Room{ ID: fmt.Sprintf("!%d:%s", counter, creator.srvName), creator: creator, - authEvents: gomatrixserverlib.NewAuthEvents(nil), + authEvents: authEvents, preset: PresetPublicChat, Version: gomatrixserverlib.RoomVersionV9, currentState: make(map[string]*rstypes.HeaderedEvent), @@ -81,7 +82,7 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room { func (r *Room) MustGetAuthEventRefsForEvent(t *testing.T, needed gomatrixserverlib.StateNeeded) []string { t.Helper() - a, err := needed.AuthEventReferences(&r.authEvents) + a, err := needed.AuthEventReferences(r.authEvents) if err != nil { t.Fatalf("MustGetAuthEvents: %v", err) } @@ -183,7 +184,7 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten builder.PrevEvents = []string{r.events[len(r.events)-1].EventID()} } - err = builder.AddAuthEvents(&r.authEvents) + err = builder.AddAuthEvents(r.authEvents) if err != nil { t.Fatalf("CreateEvent[%s]: failed to AuthEventReferences: %s", eventType, err) } @@ -199,7 +200,7 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten if err != nil { t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err) } - if err = gomatrixserverlib.Allowed(ev, &r.authEvents, UserIDForSender); err != nil { + if err = gomatrixserverlib.Allowed(ev, r.authEvents, UserIDForSender); err != nil { t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err) } headeredEvent := &rstypes.HeaderedEvent{PDU: ev}