Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add federation whitelist for servers #3498

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions federationapi/federationapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,23 @@ func NewInternalAPI(

stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1, cfg.P2PFederationRetriesUntilAssumedOffline+1, cfg.EnableRelays)

// Add servers to whitelist if enabled
if cfg.EnableWhitelist {
// We need to clear the list of the whitelisted servers during init
err = stats.DB.RemoveAllServersFromWhitelist()
if err != nil {
logrus.WithError(err).Panic("failed to clear whitelisted servers")
}

// Add each whitelisted server to the data
for _, server := range cfg.WhitelistedServers {
err = stats.DB.AddServerToWhitelist(server)
if err != nil {
logrus.WithError(err).Panicf("failed to add server %s to whitelist", server)
}
}
}

js, nats := natsInstance.Prepare(processContext, &cfg.Matrix.JetStream)

signingInfo := dendriteCfg.Global.SigningIdentities()
Expand Down
7 changes: 7 additions & 0 deletions federationapi/internal/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ func NewFederationInternalAPI(
}
}

// IsWhitelistedOrAny checks if the server is whitelisted or the whitelist is disabled (we can connect to any server)
func (a *FederationInternalAPI) IsWhitelistedOrAny(s spec.ServerName) bool {
// Thread-safe, since DB access is performed in mutex and stats.Whitelisted is constant
stats := a.statistics.ForServer(s) // Calls mutex if the stats do not exist yet
return !a.cfg.EnableWhitelist || stats.Whitelisted() // Lazy eval
}

func (a *FederationInternalAPI) IsBlacklistedOrBackingOff(s spec.ServerName) (*statistics.ServerStatistics, error) {
stats := a.statistics.ForServer(s)
if stats.Blacklisted() {
Expand Down
42 changes: 42 additions & 0 deletions federationapi/internal/federationclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ const defaultTimeout = time.Second * 30
func (a *FederationInternalAPI) MakeJoin(
ctx context.Context, origin, s spec.ServerName, roomID, userID string,
) (res gomatrixserverlib.MakeJoinResponse, err error) {
if !a.IsWhitelistedOrAny(s) {
return &fclient.RespMakeJoin{}, nil
} // Is thread-safe, so we can omit ctx call
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.federation.MakeJoin(ctx, origin, s, roomID, userID)
Expand All @@ -29,6 +32,9 @@ func (a *FederationInternalAPI) MakeJoin(
func (a *FederationInternalAPI) SendJoin(
ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU,
) (res gomatrixserverlib.SendJoinResponse, err error) {
if !a.IsWhitelistedOrAny(s) {
return &fclient.RespSendJoin{}, nil
}
ctx, cancel := context.WithTimeout(ctx, time.Minute*5)
defer cancel()
ires, err := a.federation.SendJoin(ctx, origin, s, event)
Expand All @@ -42,6 +48,9 @@ func (a *FederationInternalAPI) GetEventAuth(
ctx context.Context, origin, s spec.ServerName,
roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string,
) (res fclient.RespEventAuth, err error) {
if !a.IsWhitelistedOrAny(s) {
return fclient.RespEventAuth{}, nil
}
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
Expand All @@ -56,6 +65,9 @@ func (a *FederationInternalAPI) GetEventAuth(
func (a *FederationInternalAPI) GetUserDevices(
ctx context.Context, origin, s spec.ServerName, userID string,
) (fclient.RespUserDevices, error) {
if !a.IsWhitelistedOrAny(s) {
return fclient.RespUserDevices{}, nil
}
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
Expand All @@ -70,6 +82,9 @@ func (a *FederationInternalAPI) GetUserDevices(
func (a *FederationInternalAPI) ClaimKeys(
ctx context.Context, origin, s spec.ServerName, oneTimeKeys map[string]map[string]string,
) (fclient.RespClaimKeys, error) {
if !a.IsWhitelistedOrAny(s) {
return fclient.RespClaimKeys{}, nil
}
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
Expand All @@ -84,6 +99,9 @@ func (a *FederationInternalAPI) ClaimKeys(
func (a *FederationInternalAPI) QueryKeys(
ctx context.Context, origin, s spec.ServerName, keys map[string][]string,
) (fclient.RespQueryKeys, error) {
if !a.IsWhitelistedOrAny(s) {
return fclient.RespQueryKeys{}, nil
}
ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) {
return a.federation.QueryKeys(ctx, origin, s, keys)
})
Expand All @@ -96,6 +114,9 @@ func (a *FederationInternalAPI) QueryKeys(
func (a *FederationInternalAPI) Backfill(
ctx context.Context, origin, s spec.ServerName, roomID string, limit int, eventIDs []string,
) (res gomatrixserverlib.Transaction, err error) {
if !a.IsWhitelistedOrAny(s) {
return gomatrixserverlib.Transaction{}, nil
}
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
Expand All @@ -110,6 +131,9 @@ func (a *FederationInternalAPI) Backfill(
func (a *FederationInternalAPI) LookupState(
ctx context.Context, origin, s spec.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.StateResponse, err error) {
if !a.IsWhitelistedOrAny(s) {
return &fclient.RespState{}, nil
}
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
Expand All @@ -125,6 +149,9 @@ func (a *FederationInternalAPI) LookupState(
func (a *FederationInternalAPI) LookupStateIDs(
ctx context.Context, origin, s spec.ServerName, roomID, eventID string,
) (res gomatrixserverlib.StateIDResponse, err error) {
if !a.IsWhitelistedOrAny(s) {
return fclient.RespStateIDs{}, nil
}
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
Expand All @@ -140,6 +167,9 @@ func (a *FederationInternalAPI) LookupMissingEvents(
ctx context.Context, origin, s spec.ServerName, roomID string,
missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion,
) (res fclient.RespMissingEvents, err error) {
if !a.IsWhitelistedOrAny(s) {
return fclient.RespMissingEvents{}, nil
}
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
Expand All @@ -154,6 +184,9 @@ func (a *FederationInternalAPI) LookupMissingEvents(
func (a *FederationInternalAPI) GetEvent(
ctx context.Context, origin, s spec.ServerName, eventID string,
) (res gomatrixserverlib.Transaction, err error) {
if !a.IsWhitelistedOrAny(s) {
return gomatrixserverlib.Transaction{}, nil
}
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
Expand All @@ -168,6 +201,9 @@ func (a *FederationInternalAPI) GetEvent(
func (a *FederationInternalAPI) LookupServerKeys(
ctx context.Context, s spec.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp,
) ([]gomatrixserverlib.ServerKeys, error) {
if !a.IsWhitelistedOrAny(s) {
return []gomatrixserverlib.ServerKeys{}, nil
}
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
Expand All @@ -183,6 +219,9 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, origin, s spec.ServerName, r fclient.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion,
) (res fclient.MSC2836EventRelationshipsResponse, err error) {
if !a.IsWhitelistedOrAny(s) {
return res, nil
}
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
Expand All @@ -197,6 +236,9 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
func (a *FederationInternalAPI) RoomHierarchies(
ctx context.Context, origin, s spec.ServerName, roomID string, suggestedOnly bool,
) (res fclient.RoomHierarchyResponse, err error) {
if !a.IsWhitelistedOrAny(s) {
return res, nil
}
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
Expand Down
4 changes: 4 additions & 0 deletions federationapi/internal/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ func (a *FederationInternalAPI) QueryServerKeys(
}
util.GetLogger(ctx).WithField("server", req.ServerName).WithError(err).Warn("notary: failed to satisfy keys request entirely from cache, hitting direct")

if !a.IsWhitelistedOrAny(req.ServerName) {
return nil
}

serverKeys, err := a.fetchServerKeysDirectly(ctx, req.ServerName)
if err != nil {
// try to load as much as we can from the cache in a best effort basis
Expand Down
22 changes: 22 additions & 0 deletions federationapi/statistics/statistics.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ func (s *Statistics) ForServer(serverName spec.ServerName) *ServerStatistics {
server.blacklisted.Store(blacklisted)
}

whitelisted, err := s.DB.IsServerWhitelisted(serverName)
if err != nil {
logrus.WithError(err).Errorf("Failed to get whitelist entry %q", serverName)
} else {
server.whitelisted.Store(whitelisted)
}

// Don't bother hitting the database 2 additional times
// if we don't want to use relays.
if !s.enableRelays {
Expand Down Expand Up @@ -118,6 +125,7 @@ type ServerStatistics struct {
statistics *Statistics //
serverName spec.ServerName //
blacklisted atomic.Bool // is the node blacklisted
whitelisted atomic.Bool // is the node whitelisted
assumedOffline atomic.Bool // is the node assumed to be offline
backoffStarted atomic.Bool // is the backoff started
backoffUntil atomic.Value // time.Time until this backoff interval ends
Expand Down Expand Up @@ -281,6 +289,10 @@ func (s *ServerStatistics) Blacklisted() bool {
return s.blacklisted.Load()
}

// Whitelisted returns true if the server is whitelisted and false
// otherwise.
func (s *ServerStatistics) Whitelisted() bool { return s.whitelisted.Load() }

// AssumedOffline returns true if the server is assumed offline and false
// otherwise.
func (s *ServerStatistics) AssumedOffline() bool {
Expand All @@ -302,6 +314,16 @@ func (s *ServerStatistics) removeBlacklist() bool {
return wasBlacklisted
}

// removeWhitelist removes the whitelisted status from the server.
// Returns whether the server was whitelisted.
func (s *ServerStatistics) removeWhitelist() bool {
if s.Whitelisted() {
_ = s.statistics.DB.RemoveServerFromWhitelist(s.serverName)
return true
}
return false
}

// removeAssumedOffline removes the assumed offline status from the server.
func (s *ServerStatistics) removeAssumedOffline() {
if s.AssumedOffline() {
Expand Down
5 changes: 5 additions & 0 deletions federationapi/storage/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ type Database interface {
RemoveAllServersFromBlacklist() error
IsServerBlacklisted(serverName spec.ServerName) (bool, error)

AddServerToWhitelist(serverName spec.ServerName) error
RemoveServerFromWhitelist(serverName spec.ServerName) error
RemoveAllServersFromWhitelist() error
IsServerWhitelisted(serverName spec.ServerName) (bool, error)

// Adds the server to the list of assumed offline servers.
// If the server already exists in the table, nothing happens and returns success.
SetServerAssumedOffline(ctx context.Context, serverName spec.ServerName) error
Expand Down
5 changes: 5 additions & 0 deletions federationapi/storage/postgres/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
if err != nil {
return nil, err
}
whitelist, err := NewPostgresWhitelistTable(d.db)
if err != nil {
return nil, err
}
joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
if err != nil {
return nil, err
Expand Down Expand Up @@ -104,6 +108,7 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON,
FederationBlacklist: blacklist,
FederationWhitelist: whitelist,
FederationAssumedOffline: assumedOffline,
FederationRelayServers: relayServers,
FederationInboundPeeks: inboundPeeks,
Expand Down
94 changes: 94 additions & 0 deletions federationapi/storage/postgres/whitelist_table.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package postgres

import (
"context"
"database/sql"

"github.com/element-hq/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib/spec"
)

const whitelistSchema = `
CREATE TABLE IF NOT EXISTS federationsender_whitelist (
-- The whitelisted server name
server_name TEXT NOT NULL,
UNIQUE (server_name)
);
`

const insertWhitelistSQL = "" +
"INSERT INTO federationsender_whitelist (server_name) VALUES ($1)" +
" ON CONFLICT DO NOTHING"

const selectWhitelistSQL = "" +
"SELECT server_name FROM federationsender_whitelist WHERE server_name = $1"

const deleteWhitelistSQL = "" +
"DELETE FROM federationsender_whitelist WHERE server_name = $1"

const deleteAllWhitelistSQL = "" +
"TRUNCATE federationsender_whitelist"

type whitelistStatements struct {
db *sql.DB
insertWhitelistStmt *sql.Stmt
selectWhitelistStmt *sql.Stmt
deleteWhitelistStmt *sql.Stmt
deleteAllWhitelistStmt *sql.Stmt
}

func NewPostgresWhitelistTable(db *sql.DB) (s *whitelistStatements, err error) {
s = &whitelistStatements{
db: db,
}
_, err = db.Exec(whitelistSchema)
if err != nil {
return
}

return s, sqlutil.StatementList{
{&s.insertWhitelistStmt, insertWhitelistSQL},
{&s.selectWhitelistStmt, selectWhitelistSQL},
{&s.deleteWhitelistStmt, deleteWhitelistSQL},
{&s.deleteAllWhitelistStmt, deleteAllWhitelistSQL},
}.Prepare(db)
}

func (s *whitelistStatements) InsertWhitelist(
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.insertWhitelistStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}

func (s *whitelistStatements) SelectWhitelist(
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
) (bool, error) {
stmt := sqlutil.TxStmt(txn, s.selectWhitelistStmt)
res, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return false, err
}
defer res.Close() // nolint:errcheck
// The query will return the server name if the server is whitelisted, and
// will return no rows if not. By calling Next, we find out if a row was
// returned or not - we don't care about the value itself.
return res.Next(), nil
}

func (s *whitelistStatements) DeleteWhitelist(
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteWhitelistStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}

func (s *whitelistStatements) DeleteAllWhitelist(
ctx context.Context, txn *sql.Tx,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllWhitelistStmt)
_, err := stmt.ExecContext(ctx)
return err
}
Loading