From b2ecd5648c3626a8c79c206f5cd95d611d18d56e Mon Sep 17 00:00:00 2001 From: enaix Date: Thu, 16 Jan 2025 12:31:37 +0300 Subject: [PATCH 1/4] Implement server whitelist --- federationapi/federationapi.go | 16 ++++ federationapi/internal/api.go | 6 ++ federationapi/internal/federationclient.go | 36 +++++++ federationapi/internal/query.go | 4 + federationapi/statistics/statistics.go | 22 +++++ federationapi/storage/interface.go | 5 + federationapi/storage/postgres/storage.go | 5 + .../storage/postgres/whitelist_table.go | 93 ++++++++++++++++++ federationapi/storage/shared/storage.go | 29 ++++++ federationapi/storage/sqlite3/storage.go | 5 + .../storage/sqlite3/whitelist_table.go | 94 +++++++++++++++++++ federationapi/storage/tables/interface.go | 7 ++ setup/config/config_federationapi.go | 8 ++ setup/config/config_test.go | 2 + 14 files changed, 332 insertions(+) create mode 100644 federationapi/storage/postgres/whitelist_table.go create mode 100644 federationapi/storage/sqlite3/whitelist_table.go diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 075f673db..d6f5b140a 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -107,6 +107,22 @@ func NewInternalAPI( stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1, cfg.P2PFederationRetriesUntilAssumedOffline+1, cfg.EnableRelays) + // Add servers to whitelist if enabled + if cfg.EnableWhitelist { + // We need to clear the list of the whitelisted servers during init + err = stats.DB.RemoveAllServersFromWhitelist() + if err != nil { + logrus.WithError(err).Panic("failed to clear whitelisted servers") + } + // Add each whitelisted server to the database + for _, server := range cfg.WhitelistedServers { + err = stats.DB.AddServerToWhitelist(server) + if err != nil { + logrus.WithError(err).Panicf("failed to add server %s to whitelist", server) + } + } + } + js, nats := natsInstance.Prepare(processContext, &cfg.Matrix.JetStream) signingInfo := dendriteCfg.Global.SigningIdentities() diff --git a/federationapi/internal/api.go b/federationapi/internal/api.go index 809cf2046..4b165bd05 100644 --- a/federationapi/internal/api.go +++ b/federationapi/internal/api.go @@ -112,6 +112,12 @@ func NewFederationInternalAPI( } } +// IsWhitelistedOrAny checks if the server is whitelisted or the whitelist is disabled, so we can connect to any server +func (a *FederationInternalAPI) IsWhitelistedOrAny(s spec.ServerName) bool { + stats := a.statistics.ForServer(s) + return stats.Whitelisted() || !a.cfg.EnableWhitelist +} + func (a *FederationInternalAPI) IsBlacklistedOrBackingOff(s spec.ServerName) (*statistics.ServerStatistics, error) { stats := a.statistics.ForServer(s) if stats.Blacklisted() { diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index b6bc7a5ed..ea073c39d 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -44,6 +44,9 @@ func (a *FederationInternalAPI) GetEventAuth( ) (res fclient.RespEventAuth, err error) { ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() + if !a.IsWhitelistedOrAny(s) { + return fclient.RespEventAuth{}, nil + } ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID) }) @@ -58,6 +61,9 @@ func (a *FederationInternalAPI) GetUserDevices( ) (fclient.RespUserDevices, error) { ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() + if !a.IsWhitelistedOrAny(s) { + return fclient.RespUserDevices{}, nil + } ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.GetUserDevices(ctx, origin, s, userID) }) @@ -72,6 +78,9 @@ func (a *FederationInternalAPI) ClaimKeys( ) (fclient.RespClaimKeys, error) { ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() + if !a.IsWhitelistedOrAny(s) { + return fclient.RespClaimKeys{}, nil + } ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys) }) @@ -84,6 +93,9 @@ func (a *FederationInternalAPI) ClaimKeys( func (a *FederationInternalAPI) QueryKeys( ctx context.Context, origin, s spec.ServerName, keys map[string][]string, ) (fclient.RespQueryKeys, error) { + if !a.IsWhitelistedOrAny(s) { + return fclient.RespQueryKeys{}, nil + } ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { return a.federation.QueryKeys(ctx, origin, s, keys) }) @@ -98,6 +110,9 @@ func (a *FederationInternalAPI) Backfill( ) (res gomatrixserverlib.Transaction, err error) { ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() + if !a.IsWhitelistedOrAny(s) { + return gomatrixserverlib.Transaction{}, nil + } ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs) }) @@ -112,6 +127,9 @@ func (a *FederationInternalAPI) LookupState( ) (res gomatrixserverlib.StateResponse, err error) { ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() + if !a.IsWhitelistedOrAny(s) { + return &fclient.RespState{}, nil + } // TODO check & ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion) }) @@ -127,6 +145,9 @@ func (a *FederationInternalAPI) LookupStateIDs( ) (res gomatrixserverlib.StateIDResponse, err error) { ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() + if !a.IsWhitelistedOrAny(s) { + return fclient.RespStateIDs{}, nil + } ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID) }) @@ -142,6 +163,9 @@ func (a *FederationInternalAPI) LookupMissingEvents( ) (res fclient.RespMissingEvents, err error) { ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() + if !a.IsWhitelistedOrAny(s) { + return fclient.RespMissingEvents{}, nil + } ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion) }) @@ -156,6 +180,9 @@ func (a *FederationInternalAPI) GetEvent( ) (res gomatrixserverlib.Transaction, err error) { ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() + if !a.IsWhitelistedOrAny(s) { + return gomatrixserverlib.Transaction{}, nil + } ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.GetEvent(ctx, origin, s, eventID) }) @@ -170,6 +197,9 @@ func (a *FederationInternalAPI) LookupServerKeys( ) ([]gomatrixserverlib.ServerKeys, error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() + if !a.IsWhitelistedOrAny(s) { + return []gomatrixserverlib.ServerKeys{}, nil + } ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupServerKeys(ctx, s, keyRequests) }) @@ -185,6 +215,9 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( ) (res fclient.MSC2836EventRelationshipsResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() + if !a.IsWhitelistedOrAny(s) { + return res, nil + } ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion) }) @@ -199,6 +232,9 @@ func (a *FederationInternalAPI) RoomHierarchies( ) (res fclient.RoomHierarchyResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() + if !a.IsWhitelistedOrAny(s) { + return res, nil + } ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.RoomHierarchy(ctx, origin, s, roomID, suggestedOnly) }) diff --git a/federationapi/internal/query.go b/federationapi/internal/query.go index 22b1eb44d..3f63dfad6 100644 --- a/federationapi/internal/query.go +++ b/federationapi/internal/query.go @@ -86,6 +86,10 @@ func (a *FederationInternalAPI) QueryServerKeys( } util.GetLogger(ctx).WithField("server", req.ServerName).WithError(err).Warn("notary: failed to satisfy keys request entirely from cache, hitting direct") + if !a.IsWhitelistedOrAny(req.ServerName) { + return nil + } + serverKeys, err := a.fetchServerKeysDirectly(ctx, req.ServerName) if err != nil { // try to load as much as we can from the cache in a best effort basis diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index 92aa92917..555233b4b 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -78,6 +78,13 @@ func (s *Statistics) ForServer(serverName spec.ServerName) *ServerStatistics { server.blacklisted.Store(blacklisted) } + whitelisted, err := s.DB.IsServerWhitelisted(serverName) + if err != nil { + logrus.WithError(err).Errorf("Failed to get whitelist entry %q", serverName) + } else { + server.whitelisted.Store(whitelisted) + } + // Don't bother hitting the database 2 additional times // if we don't want to use relays. if !s.enableRelays { @@ -118,6 +125,7 @@ type ServerStatistics struct { statistics *Statistics // serverName spec.ServerName // blacklisted atomic.Bool // is the node blacklisted + whitelisted atomic.Bool // is the node whitelisted assumedOffline atomic.Bool // is the node assumed to be offline backoffStarted atomic.Bool // is the backoff started backoffUntil atomic.Value // time.Time until this backoff interval ends @@ -281,6 +289,10 @@ func (s *ServerStatistics) Blacklisted() bool { return s.blacklisted.Load() } +// Whitelisted returns true if the server is whitelisted and false +// otherwise. +func (s *ServerStatistics) Whitelisted() bool { return s.whitelisted.Load() } + // AssumedOffline returns true if the server is assumed offline and false // otherwise. func (s *ServerStatistics) AssumedOffline() bool { @@ -302,6 +314,16 @@ func (s *ServerStatistics) removeBlacklist() bool { return wasBlacklisted } +// removeWhitelist removes the whitelisted status from the server. +// Returns whether the server was whitelisted. +func (s *ServerStatistics) removeWhitelist() bool { + if s.Whitelisted() { + _ = s.statistics.DB.RemoveServerFromWhitelist(s.serverName) + return true + } + return false +} + // removeAssumedOffline removes the assumed offline status from the server. func (s *ServerStatistics) removeAssumedOffline() { if s.AssumedOffline() { diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index cba701f86..222a8a7ae 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -49,6 +49,11 @@ type Database interface { RemoveAllServersFromBlacklist() error IsServerBlacklisted(serverName spec.ServerName) (bool, error) + AddServerToWhitelist(serverName spec.ServerName) error + RemoveServerFromWhitelist(serverName spec.ServerName) error + RemoveAllServersFromWhitelist() error + IsServerWhitelisted(serverName spec.ServerName) (bool, error) + // Adds the server to the list of assumed offline servers. // If the server already exists in the table, nothing happens and returns success. SetServerAssumedOffline(ctx context.Context, serverName spec.ServerName) error diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index 4a5fc9777..4af7b151d 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -38,6 +38,10 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties if err != nil { return nil, err } + whitelist, err := NewPostgresWhitelistTable(d.db) + if err != nil { + return nil, err + } joinedHosts, err := NewPostgresJoinedHostsTable(d.db) if err != nil { return nil, err @@ -104,6 +108,7 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties FederationQueueEDUs: queueEDUs, FederationQueueJSON: queueJSON, FederationBlacklist: blacklist, + FederationWhitelist: whitelist, FederationAssumedOffline: assumedOffline, FederationRelayServers: relayServers, FederationInboundPeeks: inboundPeeks, diff --git a/federationapi/storage/postgres/whitelist_table.go b/federationapi/storage/postgres/whitelist_table.go new file mode 100644 index 000000000..81828b706 --- /dev/null +++ b/federationapi/storage/postgres/whitelist_table.go @@ -0,0 +1,93 @@ +package postgres + +import ( + "context" + "database/sql" + + "github.com/element-hq/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib/spec" +) + +const whitelistSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_whitelist ( + -- The whitelisted server name + server_name TEXT NOT NULL, + UNIQUE (server_name) +` + +const insertWhitelistSQL = "" + + "INSERT INTO federationsender_whitelist (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectWhitelistSQL = "" + + "SELECT server_name FROM federationsender_whitelist WHERE server_name = $1" + +const deleteWhitelistSQL = "" + + "DELETE FROM federationsender_whitelist WHERE server_name = $1" + +const deleteAllWhitelistSQL = "" + + "TRUNCATE federationsender_whitelist" + +type whitelistStatements struct { + db *sql.DB + insertWhitelistStmt *sql.Stmt + selectWhitelistStmt *sql.Stmt + deleteWhitelistStmt *sql.Stmt + deleteAllWhitelistStmt *sql.Stmt +} + +func NewPostgresWhitelistTable(db *sql.DB) (s *whitelistStatements, err error) { + s = &whitelistStatements{ + db: db, + } + _, err = db.Exec(whitelistSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertWhitelistStmt, insertWhitelistSQL}, + {&s.selectWhitelistStmt, selectWhitelistSQL}, + {&s.deleteWhitelistStmt, deleteWhitelistSQL}, + {&s.deleteAllWhitelistStmt, deleteAllWhitelistSQL}, + }.Prepare(db) +} + +func (s *whitelistStatements) InsertWhitelist( + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertWhitelistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *whitelistStatements) SelectWhitelist( + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectWhitelistStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is whitelisted, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +func (s *whitelistStatements) DeleteWhitelist( + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteWhitelistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *whitelistStatements) DeleteAllWhitelist( + ctx context.Context, txn *sql.Tx, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllWhitelistStmt) + _, err := stmt.ExecContext(ctx) + return err +} diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 19b870b27..48cf0afd1 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -31,6 +31,7 @@ type Database struct { FederationQueueJSON tables.FederationQueueJSON FederationJoinedHosts tables.FederationJoinedHosts FederationBlacklist tables.FederationBlacklist + FederationWhitelist tables.FederationWhitelist FederationAssumedOffline tables.FederationAssumedOffline FederationRelayServers tables.FederationRelayServers FederationOutboundPeeks tables.FederationOutboundPeeks @@ -148,6 +149,14 @@ func (d *Database) AddServerToBlacklist( }) } +func (d *Database) AddServerToWhitelist( + serverName spec.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationWhitelist.InsertWhitelist(context.TODO(), txn, serverName) + }) +} + func (d *Database) RemoveServerFromBlacklist( serverName spec.ServerName, ) error { @@ -156,18 +165,38 @@ func (d *Database) RemoveServerFromBlacklist( }) } +func (d *Database) RemoveServerFromWhitelist( + serverName spec.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationWhitelist.DeleteWhitelist(context.TODO(), txn, serverName) + }) +} + func (d *Database) RemoveAllServersFromBlacklist() error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.DeleteAllBlacklist(context.TODO(), txn) }) } +func (d *Database) RemoveAllServersFromWhitelist() error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationWhitelist.DeleteAllWhitelist(context.TODO(), txn) + }) +} + func (d *Database) IsServerBlacklisted( serverName spec.ServerName, ) (bool, error) { return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName) } +func (d *Database) IsServerWhitelisted( + serverName spec.ServerName, +) (bool, error) { + return d.FederationWhitelist.SelectWhitelist(context.TODO(), nil, serverName) +} + func (d *Database) SetServerAssumedOffline( ctx context.Context, serverName spec.ServerName, diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index c0a06d120..8abae8333 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -36,6 +36,10 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties if err != nil { return nil, err } + whitelist, err := NewSQLiteWhitelistTable(d.db) + if err != nil { + return nil, err + } joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) if err != nil { return nil, err @@ -102,6 +106,7 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties FederationQueueEDUs: queueEDUs, FederationQueueJSON: queueJSON, FederationBlacklist: blacklist, + FederationWhitelist: whitelist, FederationAssumedOffline: assumedOffline, FederationRelayServers: relayServers, FederationOutboundPeeks: outboundPeeks, diff --git a/federationapi/storage/sqlite3/whitelist_table.go b/federationapi/storage/sqlite3/whitelist_table.go new file mode 100644 index 000000000..755cd76e8 --- /dev/null +++ b/federationapi/storage/sqlite3/whitelist_table.go @@ -0,0 +1,94 @@ +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/element-hq/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib/spec" +) + +const whitelistSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_whitelist ( + -- The whitelisted server name + server_name TEXT NOT NULL, + UNIQUE (server_name) +); +` + +const insertWhitelistSQL = "" + + "INSERT INTO federationsender_whitelist (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectWhitelistSQL = "" + + "SELECT server_name FROM federationsender_whitelist WHERE server_name = $1" + +const deleteWhitelistSQL = "" + + "DELETE FROM federationsender_whitelist WHERE server_name = $1" + +const deleteAllWhitelistSQL = "" + + "DELETE FROM federationsender_whitelist" + +type whitelistStatements struct { + db *sql.DB + insertWhitelistStmt *sql.Stmt + selectWhitelistStmt *sql.Stmt + deleteWhitelistStmt *sql.Stmt + deleteAllWhitelistStmt *sql.Stmt +} + +func NewSQLiteWhitelistTable(db *sql.DB) (s *whitelistStatements, err error) { + s = &whitelistStatements{ + db: db, + } + _, err = db.Exec(whitelistSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertWhitelistStmt, insertWhitelistSQL}, + {&s.selectWhitelistStmt, selectWhitelistSQL}, + {&s.deleteWhitelistStmt, deleteWhitelistSQL}, + {&s.deleteAllWhitelistStmt, deleteAllWhitelistSQL}, + }.Prepare(db) +} + +func (s *whitelistStatements) InsertWhitelist( + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertWhitelistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *whitelistStatements) SelectWhitelist( + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectWhitelistStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is whitelisted, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +func (s *whitelistStatements) DeleteWhitelist( + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteWhitelistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *whitelistStatements) DeleteAllWhitelist( + ctx context.Context, txn *sql.Tx, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllWhitelistStmt) + _, err := stmt.ExecContext(ctx) + return err +} diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 2173a93f2..e0c314a9a 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -72,6 +72,13 @@ type FederationBlacklist interface { DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error } +type FederationWhitelist interface { + InsertWhitelist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error + SelectWhitelist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (bool, error) + DeleteWhitelist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error + DeleteAllWhitelist(ctx context.Context, txn *sql.Tx) error +} + type FederationAssumedOffline interface { InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (bool, error) diff --git a/setup/config/config_federationapi.go b/setup/config/config_federationapi.go index 073c46e03..d0703e62f 100644 --- a/setup/config/config_federationapi.go +++ b/setup/config/config_federationapi.go @@ -46,6 +46,12 @@ type FederationAPI struct { // Should we prefer direct key fetches over perspective ones? PreferDirectFetch bool `yaml:"prefer_direct_fetch"` + + // Enable servers whitelist function + EnableWhitelist bool `yaml:"enable_whitelist"` + + // The list of whitelisted servers + WhitelistedServers []spec.ServerName `yaml:"whitelisted_servers"` } func (c *FederationAPI) Defaults(opts DefaultOpts) { @@ -73,6 +79,8 @@ func (c *FederationAPI) Defaults(opts DefaultOpts) { c.Database.ConnectionString = "file:federationapi.db" } } + c.EnableWhitelist = false + c.WhitelistedServers = make([]spec.ServerName, 0) } func (c *FederationAPI) Verify(configErrs *ConfigErrors) { diff --git a/setup/config/config_test.go b/setup/config/config_test.go index 263aa9f35..65865ea1c 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -96,6 +96,8 @@ client_api: federation_api: database: connection_string: file:federationapi.db + enable_whitelist: true + whitelisted_servers: ["https://matrix.org"] key_server: database: connection_string: file:keyserver.db From a7457d316b6bd310c429147e943af2783354988a Mon Sep 17 00:00:00 2001 From: enaix Date: Thu, 16 Jan 2025 14:05:31 +0300 Subject: [PATCH 2/4] Fix typo in database whitelist DDL --- federationapi/federationapi.go | 3 ++- federationapi/storage/postgres/whitelist_table.go | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index d6f5b140a..92231713d 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -114,7 +114,8 @@ func NewInternalAPI( if err != nil { logrus.WithError(err).Panic("failed to clear whitelisted servers") } - // Add each whitelisted server to the database + + // Add each whitelisted server to the data for _, server := range cfg.WhitelistedServers { err = stats.DB.AddServerToWhitelist(server) if err != nil { diff --git a/federationapi/storage/postgres/whitelist_table.go b/federationapi/storage/postgres/whitelist_table.go index 81828b706..748fa5079 100644 --- a/federationapi/storage/postgres/whitelist_table.go +++ b/federationapi/storage/postgres/whitelist_table.go @@ -13,6 +13,7 @@ CREATE TABLE IF NOT EXISTS federationsender_whitelist ( -- The whitelisted server name server_name TEXT NOT NULL, UNIQUE (server_name) +); ` const insertWhitelistSQL = "" + From 94deed77ecce5fb1ea65ea598508872f34c59e59 Mon Sep 17 00:00:00 2001 From: enaix Date: Thu, 16 Jan 2025 20:36:36 +0300 Subject: [PATCH 3/4] Code cleanup --- federationapi/internal/federationclient.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index ea073c39d..45272f8e3 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -129,7 +129,7 @@ func (a *FederationInternalAPI) LookupState( defer cancel() if !a.IsWhitelistedOrAny(s) { return &fclient.RespState{}, nil - } // TODO check & + } ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion) }) From b5ea31ff5c0666a94d3bb7dab9fc4acd640e1662 Mon Sep 17 00:00:00 2001 From: enaix Date: Fri, 17 Jan 2025 14:55:11 +0300 Subject: [PATCH 4/4] Fix federationclient whitelist checks, improve performance --- federationapi/internal/api.go | 7 +-- federationapi/internal/federationclient.go | 50 ++++++++++++---------- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/federationapi/internal/api.go b/federationapi/internal/api.go index 4b165bd05..ba771b2db 100644 --- a/federationapi/internal/api.go +++ b/federationapi/internal/api.go @@ -112,10 +112,11 @@ func NewFederationInternalAPI( } } -// IsWhitelistedOrAny checks if the server is whitelisted or the whitelist is disabled, so we can connect to any server +// IsWhitelistedOrAny checks if the server is whitelisted or the whitelist is disabled (we can connect to any server) func (a *FederationInternalAPI) IsWhitelistedOrAny(s spec.ServerName) bool { - stats := a.statistics.ForServer(s) - return stats.Whitelisted() || !a.cfg.EnableWhitelist + // Thread-safe, since DB access is performed in mutex and stats.Whitelisted is constant + stats := a.statistics.ForServer(s) // Calls mutex if the stats do not exist yet + return !a.cfg.EnableWhitelist || stats.Whitelisted() // Lazy eval } func (a *FederationInternalAPI) IsBlacklistedOrBackingOff(s spec.ServerName) (*statistics.ServerStatistics, error) { diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index 45272f8e3..d3bc51fda 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -17,6 +17,9 @@ const defaultTimeout = time.Second * 30 func (a *FederationInternalAPI) MakeJoin( ctx context.Context, origin, s spec.ServerName, roomID, userID string, ) (res gomatrixserverlib.MakeJoinResponse, err error) { + if !a.IsWhitelistedOrAny(s) { + return &fclient.RespMakeJoin{}, nil + } // Is thread-safe, so we can omit ctx call ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.federation.MakeJoin(ctx, origin, s, roomID, userID) @@ -29,6 +32,9 @@ func (a *FederationInternalAPI) MakeJoin( func (a *FederationInternalAPI) SendJoin( ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU, ) (res gomatrixserverlib.SendJoinResponse, err error) { + if !a.IsWhitelistedOrAny(s) { + return &fclient.RespSendJoin{}, nil + } ctx, cancel := context.WithTimeout(ctx, time.Minute*5) defer cancel() ires, err := a.federation.SendJoin(ctx, origin, s, event) @@ -42,11 +48,11 @@ func (a *FederationInternalAPI) GetEventAuth( ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, ) (res fclient.RespEventAuth, err error) { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() if !a.IsWhitelistedOrAny(s) { return fclient.RespEventAuth{}, nil } + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID) }) @@ -59,11 +65,11 @@ func (a *FederationInternalAPI) GetEventAuth( func (a *FederationInternalAPI) GetUserDevices( ctx context.Context, origin, s spec.ServerName, userID string, ) (fclient.RespUserDevices, error) { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() if !a.IsWhitelistedOrAny(s) { return fclient.RespUserDevices{}, nil } + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.GetUserDevices(ctx, origin, s, userID) }) @@ -76,11 +82,11 @@ func (a *FederationInternalAPI) GetUserDevices( func (a *FederationInternalAPI) ClaimKeys( ctx context.Context, origin, s spec.ServerName, oneTimeKeys map[string]map[string]string, ) (fclient.RespClaimKeys, error) { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() if !a.IsWhitelistedOrAny(s) { return fclient.RespClaimKeys{}, nil } + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys) }) @@ -108,11 +114,11 @@ func (a *FederationInternalAPI) QueryKeys( func (a *FederationInternalAPI) Backfill( ctx context.Context, origin, s spec.ServerName, roomID string, limit int, eventIDs []string, ) (res gomatrixserverlib.Transaction, err error) { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() if !a.IsWhitelistedOrAny(s) { return gomatrixserverlib.Transaction{}, nil } + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs) }) @@ -125,11 +131,11 @@ func (a *FederationInternalAPI) Backfill( func (a *FederationInternalAPI) LookupState( ctx context.Context, origin, s spec.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.StateResponse, err error) { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() if !a.IsWhitelistedOrAny(s) { return &fclient.RespState{}, nil } + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion) }) @@ -143,11 +149,11 @@ func (a *FederationInternalAPI) LookupState( func (a *FederationInternalAPI) LookupStateIDs( ctx context.Context, origin, s spec.ServerName, roomID, eventID string, ) (res gomatrixserverlib.StateIDResponse, err error) { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() if !a.IsWhitelistedOrAny(s) { return fclient.RespStateIDs{}, nil } + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID) }) @@ -161,11 +167,11 @@ func (a *FederationInternalAPI) LookupMissingEvents( ctx context.Context, origin, s spec.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, ) (res fclient.RespMissingEvents, err error) { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() if !a.IsWhitelistedOrAny(s) { return fclient.RespMissingEvents{}, nil } + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion) }) @@ -178,11 +184,11 @@ func (a *FederationInternalAPI) LookupMissingEvents( func (a *FederationInternalAPI) GetEvent( ctx context.Context, origin, s spec.ServerName, eventID string, ) (res gomatrixserverlib.Transaction, err error) { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() if !a.IsWhitelistedOrAny(s) { return gomatrixserverlib.Transaction{}, nil } + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.GetEvent(ctx, origin, s, eventID) }) @@ -195,11 +201,11 @@ func (a *FederationInternalAPI) GetEvent( func (a *FederationInternalAPI) LookupServerKeys( ctx context.Context, s spec.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) ([]gomatrixserverlib.ServerKeys, error) { - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() if !a.IsWhitelistedOrAny(s) { return []gomatrixserverlib.ServerKeys{}, nil } + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupServerKeys(ctx, s, keyRequests) }) @@ -213,11 +219,11 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( ctx context.Context, origin, s spec.ServerName, r fclient.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion, ) (res fclient.MSC2836EventRelationshipsResponse, err error) { - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() if !a.IsWhitelistedOrAny(s) { return res, nil } + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion) }) @@ -230,11 +236,11 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( func (a *FederationInternalAPI) RoomHierarchies( ctx context.Context, origin, s spec.ServerName, roomID string, suggestedOnly bool, ) (res fclient.RoomHierarchyResponse, err error) { - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() if !a.IsWhitelistedOrAny(s) { return res, nil } + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.RoomHierarchy(ctx, origin, s, roomID, suggestedOnly) })