Skip to content

Commit

Permalink
Tweak auth event handling
Browse files Browse the repository at this point in the history
  • Loading branch information
neilalexander committed Apr 10, 2024
1 parent 4569284 commit 01954ff
Show file tree
Hide file tree
Showing 16 changed files with 83 additions and 48 deletions.
10 changes: 8 additions & 2 deletions clientapi/routing/sendevent.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,14 @@ func generateSendEvent(
for i := range queryRes.StateEvents {
stateEvents[i] = queryRes.StateEvents[i].PDU
}
provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
provider, err := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden(err.Error()),
}
}
if err = gomatrixserverlib.Allowed(e.PDU, provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, *validRoomID, senderID)
}); err != nil {
return nil, &util.JSONResponse{
Expand Down
4 changes: 2 additions & 2 deletions internal/eventutil/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func addPrevEventsToEvent(

builder.Depth = queryRes.Depth

authEvents := gomatrixserverlib.NewAuthEvents(nil)
authEvents, _ := gomatrixserverlib.NewAuthEvents(nil)

for i := range queryRes.StateEvents {
err := authEvents.AddEvent(queryRes.StateEvents[i].PDU)
Expand All @@ -140,7 +140,7 @@ func addPrevEventsToEvent(
}
}

refs, err := eventsNeeded.AuthEventReferences(&authEvents)
refs, err := eventsNeeded.AuthEventReferences(authEvents)
if err != nil {
return fmt.Errorf("eventsNeeded.AuthEventReferences: %w", err)
}
Expand Down
14 changes: 10 additions & 4 deletions internal/gomatrixserverlib/authstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ func checkAllowedByAuthEvents(
event PDU, eventsByID map[string]PDU,
missingAuth EventProvider, userIDForSender spec.UserIDForSender,
) error {
authEvents := NewAuthEvents(nil)
authEvents, err := NewAuthEvents(nil)
if err != nil {
return err
}

for _, ae := range event.AuthEventIDs() {
retryEvent:
Expand Down Expand Up @@ -214,7 +217,7 @@ func checkAllowedByAuthEvents(

// If we made it this far then we've successfully got as many of the auth events as
// as described by AuthEventIDs(). Check if they allow the event.
if err := Allowed(event, &authEvents, userIDForSender); err != nil {
if err := Allowed(event, authEvents, userIDForSender); err != nil {
return fmt.Errorf(
"gomatrixserverlib: event with ID %q is not allowed by its auth_events: %s",
event.EventID(), err.Error(),
Expand Down Expand Up @@ -335,7 +338,10 @@ func CheckSendJoinResponse(
}

eventsByID := map[string]PDU{}
authEventProvider := NewAuthEvents(nil)
authEventProvider, err := NewAuthEvents(nil)
if err != nil {
return nil, err
}

// Since checkAllowedByAuthEvents needs to be able to look up any of the
// auth events by ID only, we will build a map which contains references
Expand Down Expand Up @@ -369,7 +375,7 @@ func CheckSendJoinResponse(
}

// Now check that the join event is valid against the supplied state.
if err := Allowed(joinEvent, &authEventProvider, userIDForSender); err != nil {
if err := Allowed(joinEvent, authEventProvider, userIDForSender); err != nil {
return nil, fmt.Errorf(
"gomatrixserverlib: event with ID %q is not allowed by the current room state: %w",
joinEvent.EventID(), err,
Expand Down
8 changes: 5 additions & 3 deletions internal/gomatrixserverlib/eventauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,15 +301,17 @@ func (a *AuthEvents) Clear() {

// NewAuthEvents returns an AuthEventProvider backed by the given events. New events can be added by
// calling AddEvent().
func NewAuthEvents(events []PDU) AuthEvents {
func NewAuthEvents(events []PDU) (*AuthEvents, error) {
a := AuthEvents{
events: make(map[StateKeyTuple]PDU, len(events)),
roomIDs: make(map[string]struct{}),
}
for _, e := range events {
a.AddEvent(e) // nolint: errcheck
if err := a.AddEvent(e); err != nil {
return nil, err
}
}
return a
return &a, nil
}

// A NotAllowed error is returned if an event does not pass the auth checks.
Expand Down
2 changes: 1 addition & 1 deletion internal/gomatrixserverlib/eventauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ func TestAuthEvents(t *testing.T) {
if err != nil {
t.Fatalf("TestAuthEvents: failed to create power_levels event: %s", err)
}
a := NewAuthEvents([]PDU{power})
a, _ := NewAuthEvents([]PDU{power})
var e PDU
if e, err = a.PowerLevels(); err != nil || e != power {
t.Errorf("TestAuthEvents: failed to get same power_levels event")
Expand Down
9 changes: 6 additions & 3 deletions internal/gomatrixserverlib/handleinvite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,20 @@ func (r *TestStateQuerier) GetAuthEvents(ctx context.Context, event PDU) (AuthEv
return nil, fmt.Errorf("failed getting auth provider")
}

eventProvider := AuthEvents{}
var eventProvider *AuthEvents
if r.createEvent != nil {
eventProvider = NewAuthEvents([]PDU{r.createEvent})
var err error
if eventProvider, err = NewAuthEvents([]PDU{r.createEvent}); err != nil {
return nil, err
}
if r.inviterMemberEvent != nil {
err := eventProvider.AddEvent(r.inviterMemberEvent)
if err != nil {
return nil, err
}
}
}
return &eventProvider, nil
return eventProvider, nil
}

func (r *TestStateQuerier) GetState(ctx context.Context, roomID spec.RoomID, stateWanted []StateKeyTuple) ([]PDU, error) {
Expand Down
7 changes: 5 additions & 2 deletions internal/gomatrixserverlib/handlejoin.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,11 @@ func HandleMakeJoin(input HandleMakeJoinInput) (*HandleMakeJoinResponse, error)
return nil, spec.InternalServerError{Err: fmt.Sprintf("expected join event from template builder. got: %s", event.Type())}
}

provider := NewAuthEvents(state)
if err = Allowed(event, &provider, input.UserIDQuerier); err != nil {
provider, err := NewAuthEvents(state)
if err != nil {
return nil, spec.Forbidden(err.Error())
}
if err = Allowed(event, provider, input.UserIDQuerier); err != nil {
return nil, spec.Forbidden(err.Error())
}

Expand Down
7 changes: 5 additions & 2 deletions internal/gomatrixserverlib/handleleave.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,11 @@ func HandleMakeLeave(input HandleMakeLeaveInput) (*HandleMakeLeaveResponse, erro
return nil, spec.InternalServerError{Err: fmt.Sprintf("expected leave event from template builder. got: %s", event.Type())}
}

provider := NewAuthEvents(stateEvents)
if err := Allowed(event, &provider, input.UserIDQuerier); err != nil {
provider, err := NewAuthEvents(stateEvents)
if err != nil {
return nil, spec.Forbidden(err.Error())
}
if err = Allowed(event, provider, input.UserIDQuerier); err != nil {
return nil, spec.Forbidden(err.Error())
}

Expand Down
7 changes: 5 additions & 2 deletions internal/gomatrixserverlib/performinvite.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ func PerformInvite(ctx context.Context, input PerformInviteInput, fedClient Fede

input.EventTemplate.Depth = latestEvents.Depth

authEvents := NewAuthEvents(nil)
authEvents, err := NewAuthEvents(nil)
if err != nil {
return nil, err
}

for _, event := range latestEvents.StateEvents {
err := authEvents.AddEvent(event)
Expand All @@ -132,7 +135,7 @@ func PerformInvite(ctx context.Context, input PerformInviteInput, fedClient Fede
}
}

refs, err := stateNeeded.AuthEventReferences(&authEvents)
refs, err := stateNeeded.AuthEventReferences(authEvents)
if err != nil {
return nil, fmt.Errorf("eventsNeeded.AuthEventReferences: %w", err)
}
Expand Down
9 changes: 5 additions & 4 deletions internal/gomatrixserverlib/stateresolutionv2.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ const (

type stateResolverV2 struct {
allower *allowerContext // Used to auth and apply events
authProvider AuthEvents // Used in the allower
authProvider *AuthEvents // Used in the allower
authEventMap map[string]PDU // Map of all provided auth events
conflictedEventMap map[string]PDU // Map of all provided conflicted events
powerLevelContents map[string]*PowerLevelContent // A cache of all power level contents
Expand Down Expand Up @@ -64,9 +64,10 @@ func ResolveStateConflictsV2(
// Prepare the state resolver.
conflictedControlEvents := make([]PDU, 0, len(conflicted))
conflictedOthers := make([]PDU, 0, len(conflicted))
authProvider, _ := NewAuthEvents(nil)
r := stateResolverV2{
authEventMap: eventMapFromEvents(authEvents),
authProvider: NewAuthEvents(nil),
authProvider: authProvider,
conflictedEventMap: eventMapFromEvents(conflicted),
powerLevelContents: make(map[string]*PowerLevelContent),
powerLevelMainlinePos: make(map[string]int),
Expand Down Expand Up @@ -95,7 +96,7 @@ func ResolveStateConflictsV2(
return r.result
}

r.allower = newAllowerContext(&r.authProvider, userIDForSender, *roomID)
r.allower = newAllowerContext(r.authProvider, userIDForSender, *roomID)

// This is a map to help us determine if an event already belongs to the
// unconflicted set. If it does then we shouldn't add it back into the
Expand Down Expand Up @@ -481,7 +482,7 @@ func (r *stateResolverV2) authAndApplyEvents(events []PDU) {
}

// Check if the event is allowed based on the current partial state.
r.allower.update(&r.authProvider)
r.allower.update(r.authProvider)
if err := r.allower.allowed(event); err != nil {
// The event was not allowed by the partial state and/or relevant
// auth events from the event, so skip it.
Expand Down
16 changes: 10 additions & 6 deletions roomserver/internal/input/input_events.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func (r *Inputer) processRoomEvent(
// Sort all of the servers into a map so that we can randomise
// their order. Then make sure that the input origin and the
// event origin are first on the list.
servers := map[spec.ServerName]struct{}{}
servers := make(map[spec.ServerName]struct{}, len(serverRes.ServerNames))
for _, server := range serverRes.ServerNames {
servers[server] = struct{}{}
}
Expand All @@ -210,9 +210,9 @@ func (r *Inputer) processRoomEvent(

// Check that the auth events of the event are known.
// If they aren't then we will ask the federation API for them.
authEvents := gomatrixserverlib.NewAuthEvents(nil)
authEvents, _ := gomatrixserverlib.NewAuthEvents(nil)
knownEvents := map[string]*types.Event{}
if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, authEvents, knownEvents, serverRes.ServerNames); err != nil {
return fmt.Errorf("r.fetchAuthEvents: %w", err)
}

Expand All @@ -221,7 +221,7 @@ func (r *Inputer) processRoomEvent(

// Check if the event is allowed by its auth events. If it isn't then
// we consider the event to be "rejected" — it will still be persisted.
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
if err = gomatrixserverlib.Allowed(event, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
isRejected = true
Expand Down Expand Up @@ -643,10 +643,14 @@ func (r *Inputer) processStateBefore(
// At this point, stateBeforeEvent should be populated either by
// the supplied state in the input request, or from the prev events.
// Check whether the event is allowed or not.
stateBeforeAuth := gomatrixserverlib.NewAuthEvents(
stateBeforeAuth, err := gomatrixserverlib.NewAuthEvents(
gomatrixserverlib.ToPDUs(stateBeforeEvent),
)
if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
if err != nil {
rejectionErr = fmt.Errorf("NewAuthEvents failed: %w", err)
return
}
if rejectionErr = gomatrixserverlib.Allowed(event, stateBeforeAuth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); rejectionErr != nil {
rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr)
Expand Down
4 changes: 2 additions & 2 deletions roomserver/internal/input/input_events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ func Test_EventAuth(t *testing.T) {
}, test.WithStateKey(bob.ID), test.WithAuthIDs(authEventIDs))

// Add the auth events to the allower
allower := gomatrixserverlib.NewAuthEvents(nil)
allower, _ := gomatrixserverlib.NewAuthEvents(nil)
for _, a := range authEvents {
if err := allower.AddEvent(a); err != nil {
t.Fatalf("allower.AddEvent failed: %v", err)
}
}

// Finally check that the event is NOT allowed
if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
if err := gomatrixserverlib.Allowed(ev.PDU, allower, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}); err == nil {
t.Fatalf("event should not be allowed, but it was")
Expand Down
4 changes: 2 additions & 2 deletions roomserver/internal/input/input_missing.go
Original file line number Diff line number Diff line change
Expand Up @@ -931,14 +931,14 @@ serverLoop:
}

func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) error {
authUsingState := gomatrixserverlib.NewAuthEvents(nil)
authUsingState, _ := gomatrixserverlib.NewAuthEvents(nil)
for i := range stateEvents {
err := authUsingState.AddEvent(stateEvents[i])
if err != nil {
return err
}
}
return gomatrixserverlib.Allowed(e, &authUsingState, userIDForSender)
return gomatrixserverlib.Allowed(e, authUsingState, userIDForSender)
}

func (t *missingStateReq) hadEvent(eventID string) {
Expand Down
6 changes: 3 additions & 3 deletions roomserver/internal/perform/perform_create_room.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
// TODO: 3pid invite events

var builtEvents []*types.HeaderedEvent
authEvents := gomatrixserverlib.NewAuthEvents(nil)
authEvents, _ := gomatrixserverlib.NewAuthEvents(nil)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsapi.QuerySenderIDForUser failed")
return "", &util.JSONResponse{
Expand Down Expand Up @@ -380,7 +380,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
builder.PrevEvents = []string{builtEvents[i-1].EventID()}
}
var ev gomatrixserverlib.PDU
if err = builder.AddAuthEvents(&authEvents); err != nil {
if err = builder.AddAuthEvents(authEvents); err != nil {
util.GetLogger(ctx).WithError(err).Error("AddAuthEvents failed")
return "", &util.JSONResponse{
Code: http.StatusInternalServerError,
Expand All @@ -396,7 +396,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
}
}

if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
if err = gomatrixserverlib.Allowed(ev, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return c.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
Expand Down
13 changes: 8 additions & 5 deletions roomserver/internal/perform/perform_upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []gomatrixserverlib.FledglingEvent) error {
var err error
var builtEvents []*types.HeaderedEvent
authEvents := gomatrixserverlib.NewAuthEvents(nil)
authEvents, _ := gomatrixserverlib.NewAuthEvents(nil)
for i, e := range eventsToMake {
depth := i + 1 // depth starts at 1

Expand All @@ -503,7 +503,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, send
return err
}
builder := verImpl.NewEventBuilderFromProtoEvent(&proto)
if err = builder.AddAuthEvents(&authEvents); err != nil {
if err = builder.AddAuthEvents(authEvents); err != nil {
return err
}

Expand All @@ -514,7 +514,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, send

}

if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
if err = gomatrixserverlib.Allowed(event, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
Expand Down Expand Up @@ -594,8 +594,11 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, send
for i := range queryRes.StateEvents {
stateEvents[i] = queryRes.StateEvents[i].PDU
}
provider := gomatrixserverlib.NewAuthEvents(stateEvents)
if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
provider, err := gomatrixserverlib.NewAuthEvents(stateEvents)
if err != nil {
return nil, err
}
if err = gomatrixserverlib.Allowed(headeredEvent.PDU, provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client?
Expand Down
Loading

0 comments on commit 01954ff

Please sign in to comment.