diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 075f673d..92231713 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -107,6 +107,23 @@ func NewInternalAPI( stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1, cfg.P2PFederationRetriesUntilAssumedOffline+1, cfg.EnableRelays) + // Add servers to whitelist if enabled + if cfg.EnableWhitelist { + // We need to clear the list of the whitelisted servers during init + err = stats.DB.RemoveAllServersFromWhitelist() + if err != nil { + logrus.WithError(err).Panic("failed to clear whitelisted servers") + } + + // Add each whitelisted server to the data + for _, server := range cfg.WhitelistedServers { + err = stats.DB.AddServerToWhitelist(server) + if err != nil { + logrus.WithError(err).Panicf("failed to add server %s to whitelist", server) + } + } + } + js, nats := natsInstance.Prepare(processContext, &cfg.Matrix.JetStream) signingInfo := dendriteCfg.Global.SigningIdentities() diff --git a/federationapi/internal/api.go b/federationapi/internal/api.go index 809cf204..ba771b2d 100644 --- a/federationapi/internal/api.go +++ b/federationapi/internal/api.go @@ -112,6 +112,13 @@ func NewFederationInternalAPI( } } +// IsWhitelistedOrAny checks if the server is whitelisted or the whitelist is disabled (we can connect to any server) +func (a *FederationInternalAPI) IsWhitelistedOrAny(s spec.ServerName) bool { + // Thread-safe, since DB access is performed in mutex and stats.Whitelisted is constant + stats := a.statistics.ForServer(s) // Calls mutex if the stats do not exist yet + return !a.cfg.EnableWhitelist || stats.Whitelisted() // Lazy eval +} + func (a *FederationInternalAPI) IsBlacklistedOrBackingOff(s spec.ServerName) (*statistics.ServerStatistics, error) { stats := a.statistics.ForServer(s) if stats.Blacklisted() { diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index b6bc7a5e..d3bc51fd 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -17,6 +17,9 @@ const defaultTimeout = time.Second * 30 func (a *FederationInternalAPI) MakeJoin( ctx context.Context, origin, s spec.ServerName, roomID, userID string, ) (res gomatrixserverlib.MakeJoinResponse, err error) { + if !a.IsWhitelistedOrAny(s) { + return &fclient.RespMakeJoin{}, nil + } // Is thread-safe, so we can omit ctx call ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.federation.MakeJoin(ctx, origin, s, roomID, userID) @@ -29,6 +32,9 @@ func (a *FederationInternalAPI) MakeJoin( func (a *FederationInternalAPI) SendJoin( ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU, ) (res gomatrixserverlib.SendJoinResponse, err error) { + if !a.IsWhitelistedOrAny(s) { + return &fclient.RespSendJoin{}, nil + } ctx, cancel := context.WithTimeout(ctx, time.Minute*5) defer cancel() ires, err := a.federation.SendJoin(ctx, origin, s, event) @@ -42,6 +48,9 @@ func (a *FederationInternalAPI) GetEventAuth( ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, ) (res fclient.RespEventAuth, err error) { + if !a.IsWhitelistedOrAny(s) { + return fclient.RespEventAuth{}, nil + } ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { @@ -56,6 +65,9 @@ func (a *FederationInternalAPI) GetEventAuth( func (a *FederationInternalAPI) GetUserDevices( ctx context.Context, origin, s spec.ServerName, userID string, ) (fclient.RespUserDevices, error) { + if !a.IsWhitelistedOrAny(s) { + return fclient.RespUserDevices{}, nil + } ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { @@ -70,6 +82,9 @@ func (a *FederationInternalAPI) GetUserDevices( func (a *FederationInternalAPI) ClaimKeys( ctx context.Context, origin, s spec.ServerName, oneTimeKeys map[string]map[string]string, ) (fclient.RespClaimKeys, error) { + if !a.IsWhitelistedOrAny(s) { + return fclient.RespClaimKeys{}, nil + } ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { @@ -84,6 +99,9 @@ func (a *FederationInternalAPI) ClaimKeys( func (a *FederationInternalAPI) QueryKeys( ctx context.Context, origin, s spec.ServerName, keys map[string][]string, ) (fclient.RespQueryKeys, error) { + if !a.IsWhitelistedOrAny(s) { + return fclient.RespQueryKeys{}, nil + } ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { return a.federation.QueryKeys(ctx, origin, s, keys) }) @@ -96,6 +114,9 @@ func (a *FederationInternalAPI) QueryKeys( func (a *FederationInternalAPI) Backfill( ctx context.Context, origin, s spec.ServerName, roomID string, limit int, eventIDs []string, ) (res gomatrixserverlib.Transaction, err error) { + if !a.IsWhitelistedOrAny(s) { + return gomatrixserverlib.Transaction{}, nil + } ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { @@ -110,6 +131,9 @@ func (a *FederationInternalAPI) Backfill( func (a *FederationInternalAPI) LookupState( ctx context.Context, origin, s spec.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.StateResponse, err error) { + if !a.IsWhitelistedOrAny(s) { + return &fclient.RespState{}, nil + } ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { @@ -125,6 +149,9 @@ func (a *FederationInternalAPI) LookupState( func (a *FederationInternalAPI) LookupStateIDs( ctx context.Context, origin, s spec.ServerName, roomID, eventID string, ) (res gomatrixserverlib.StateIDResponse, err error) { + if !a.IsWhitelistedOrAny(s) { + return fclient.RespStateIDs{}, nil + } ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { @@ -140,6 +167,9 @@ func (a *FederationInternalAPI) LookupMissingEvents( ctx context.Context, origin, s spec.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, ) (res fclient.RespMissingEvents, err error) { + if !a.IsWhitelistedOrAny(s) { + return fclient.RespMissingEvents{}, nil + } ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { @@ -154,6 +184,9 @@ func (a *FederationInternalAPI) LookupMissingEvents( func (a *FederationInternalAPI) GetEvent( ctx context.Context, origin, s spec.ServerName, eventID string, ) (res gomatrixserverlib.Transaction, err error) { + if !a.IsWhitelistedOrAny(s) { + return gomatrixserverlib.Transaction{}, nil + } ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { @@ -168,6 +201,9 @@ func (a *FederationInternalAPI) GetEvent( func (a *FederationInternalAPI) LookupServerKeys( ctx context.Context, s spec.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) ([]gomatrixserverlib.ServerKeys, error) { + if !a.IsWhitelistedOrAny(s) { + return []gomatrixserverlib.ServerKeys{}, nil + } ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { @@ -183,6 +219,9 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( ctx context.Context, origin, s spec.ServerName, r fclient.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion, ) (res fclient.MSC2836EventRelationshipsResponse, err error) { + if !a.IsWhitelistedOrAny(s) { + return res, nil + } ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { @@ -197,6 +236,9 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( func (a *FederationInternalAPI) RoomHierarchies( ctx context.Context, origin, s spec.ServerName, roomID string, suggestedOnly bool, ) (res fclient.RoomHierarchyResponse, err error) { + if !a.IsWhitelistedOrAny(s) { + return res, nil + } ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { diff --git a/federationapi/internal/query.go b/federationapi/internal/query.go index 22b1eb44..3f63dfad 100644 --- a/federationapi/internal/query.go +++ b/federationapi/internal/query.go @@ -86,6 +86,10 @@ func (a *FederationInternalAPI) QueryServerKeys( } util.GetLogger(ctx).WithField("server", req.ServerName).WithError(err).Warn("notary: failed to satisfy keys request entirely from cache, hitting direct") + if !a.IsWhitelistedOrAny(req.ServerName) { + return nil + } + serverKeys, err := a.fetchServerKeysDirectly(ctx, req.ServerName) if err != nil { // try to load as much as we can from the cache in a best effort basis diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index 92aa9291..555233b4 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -78,6 +78,13 @@ func (s *Statistics) ForServer(serverName spec.ServerName) *ServerStatistics { server.blacklisted.Store(blacklisted) } + whitelisted, err := s.DB.IsServerWhitelisted(serverName) + if err != nil { + logrus.WithError(err).Errorf("Failed to get whitelist entry %q", serverName) + } else { + server.whitelisted.Store(whitelisted) + } + // Don't bother hitting the database 2 additional times // if we don't want to use relays. if !s.enableRelays { @@ -118,6 +125,7 @@ type ServerStatistics struct { statistics *Statistics // serverName spec.ServerName // blacklisted atomic.Bool // is the node blacklisted + whitelisted atomic.Bool // is the node whitelisted assumedOffline atomic.Bool // is the node assumed to be offline backoffStarted atomic.Bool // is the backoff started backoffUntil atomic.Value // time.Time until this backoff interval ends @@ -281,6 +289,10 @@ func (s *ServerStatistics) Blacklisted() bool { return s.blacklisted.Load() } +// Whitelisted returns true if the server is whitelisted and false +// otherwise. +func (s *ServerStatistics) Whitelisted() bool { return s.whitelisted.Load() } + // AssumedOffline returns true if the server is assumed offline and false // otherwise. func (s *ServerStatistics) AssumedOffline() bool { @@ -302,6 +314,16 @@ func (s *ServerStatistics) removeBlacklist() bool { return wasBlacklisted } +// removeWhitelist removes the whitelisted status from the server. +// Returns whether the server was whitelisted. +func (s *ServerStatistics) removeWhitelist() bool { + if s.Whitelisted() { + _ = s.statistics.DB.RemoveServerFromWhitelist(s.serverName) + return true + } + return false +} + // removeAssumedOffline removes the assumed offline status from the server. func (s *ServerStatistics) removeAssumedOffline() { if s.AssumedOffline() { diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index cba701f8..222a8a7a 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -49,6 +49,11 @@ type Database interface { RemoveAllServersFromBlacklist() error IsServerBlacklisted(serverName spec.ServerName) (bool, error) + AddServerToWhitelist(serverName spec.ServerName) error + RemoveServerFromWhitelist(serverName spec.ServerName) error + RemoveAllServersFromWhitelist() error + IsServerWhitelisted(serverName spec.ServerName) (bool, error) + // Adds the server to the list of assumed offline servers. // If the server already exists in the table, nothing happens and returns success. SetServerAssumedOffline(ctx context.Context, serverName spec.ServerName) error diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index 4a5fc977..4af7b151 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -38,6 +38,10 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties if err != nil { return nil, err } + whitelist, err := NewPostgresWhitelistTable(d.db) + if err != nil { + return nil, err + } joinedHosts, err := NewPostgresJoinedHostsTable(d.db) if err != nil { return nil, err @@ -104,6 +108,7 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties FederationQueueEDUs: queueEDUs, FederationQueueJSON: queueJSON, FederationBlacklist: blacklist, + FederationWhitelist: whitelist, FederationAssumedOffline: assumedOffline, FederationRelayServers: relayServers, FederationInboundPeeks: inboundPeeks, diff --git a/federationapi/storage/postgres/whitelist_table.go b/federationapi/storage/postgres/whitelist_table.go new file mode 100644 index 00000000..748fa507 --- /dev/null +++ b/federationapi/storage/postgres/whitelist_table.go @@ -0,0 +1,94 @@ +package postgres + +import ( + "context" + "database/sql" + + "github.com/element-hq/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib/spec" +) + +const whitelistSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_whitelist ( + -- The whitelisted server name + server_name TEXT NOT NULL, + UNIQUE (server_name) +); +` + +const insertWhitelistSQL = "" + + "INSERT INTO federationsender_whitelist (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectWhitelistSQL = "" + + "SELECT server_name FROM federationsender_whitelist WHERE server_name = $1" + +const deleteWhitelistSQL = "" + + "DELETE FROM federationsender_whitelist WHERE server_name = $1" + +const deleteAllWhitelistSQL = "" + + "TRUNCATE federationsender_whitelist" + +type whitelistStatements struct { + db *sql.DB + insertWhitelistStmt *sql.Stmt + selectWhitelistStmt *sql.Stmt + deleteWhitelistStmt *sql.Stmt + deleteAllWhitelistStmt *sql.Stmt +} + +func NewPostgresWhitelistTable(db *sql.DB) (s *whitelistStatements, err error) { + s = &whitelistStatements{ + db: db, + } + _, err = db.Exec(whitelistSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertWhitelistStmt, insertWhitelistSQL}, + {&s.selectWhitelistStmt, selectWhitelistSQL}, + {&s.deleteWhitelistStmt, deleteWhitelistSQL}, + {&s.deleteAllWhitelistStmt, deleteAllWhitelistSQL}, + }.Prepare(db) +} + +func (s *whitelistStatements) InsertWhitelist( + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertWhitelistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *whitelistStatements) SelectWhitelist( + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectWhitelistStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is whitelisted, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +func (s *whitelistStatements) DeleteWhitelist( + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteWhitelistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *whitelistStatements) DeleteAllWhitelist( + ctx context.Context, txn *sql.Tx, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllWhitelistStmt) + _, err := stmt.ExecContext(ctx) + return err +} diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 19b870b2..48cf0afd 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -31,6 +31,7 @@ type Database struct { FederationQueueJSON tables.FederationQueueJSON FederationJoinedHosts tables.FederationJoinedHosts FederationBlacklist tables.FederationBlacklist + FederationWhitelist tables.FederationWhitelist FederationAssumedOffline tables.FederationAssumedOffline FederationRelayServers tables.FederationRelayServers FederationOutboundPeeks tables.FederationOutboundPeeks @@ -148,6 +149,14 @@ func (d *Database) AddServerToBlacklist( }) } +func (d *Database) AddServerToWhitelist( + serverName spec.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationWhitelist.InsertWhitelist(context.TODO(), txn, serverName) + }) +} + func (d *Database) RemoveServerFromBlacklist( serverName spec.ServerName, ) error { @@ -156,18 +165,38 @@ func (d *Database) RemoveServerFromBlacklist( }) } +func (d *Database) RemoveServerFromWhitelist( + serverName spec.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationWhitelist.DeleteWhitelist(context.TODO(), txn, serverName) + }) +} + func (d *Database) RemoveAllServersFromBlacklist() error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.DeleteAllBlacklist(context.TODO(), txn) }) } +func (d *Database) RemoveAllServersFromWhitelist() error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationWhitelist.DeleteAllWhitelist(context.TODO(), txn) + }) +} + func (d *Database) IsServerBlacklisted( serverName spec.ServerName, ) (bool, error) { return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName) } +func (d *Database) IsServerWhitelisted( + serverName spec.ServerName, +) (bool, error) { + return d.FederationWhitelist.SelectWhitelist(context.TODO(), nil, serverName) +} + func (d *Database) SetServerAssumedOffline( ctx context.Context, serverName spec.ServerName, diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index c0a06d12..8abae833 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -36,6 +36,10 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties if err != nil { return nil, err } + whitelist, err := NewSQLiteWhitelistTable(d.db) + if err != nil { + return nil, err + } joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) if err != nil { return nil, err @@ -102,6 +106,7 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties FederationQueueEDUs: queueEDUs, FederationQueueJSON: queueJSON, FederationBlacklist: blacklist, + FederationWhitelist: whitelist, FederationAssumedOffline: assumedOffline, FederationRelayServers: relayServers, FederationOutboundPeeks: outboundPeeks, diff --git a/federationapi/storage/sqlite3/whitelist_table.go b/federationapi/storage/sqlite3/whitelist_table.go new file mode 100644 index 00000000..755cd76e --- /dev/null +++ b/federationapi/storage/sqlite3/whitelist_table.go @@ -0,0 +1,94 @@ +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/element-hq/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib/spec" +) + +const whitelistSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_whitelist ( + -- The whitelisted server name + server_name TEXT NOT NULL, + UNIQUE (server_name) +); +` + +const insertWhitelistSQL = "" + + "INSERT INTO federationsender_whitelist (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectWhitelistSQL = "" + + "SELECT server_name FROM federationsender_whitelist WHERE server_name = $1" + +const deleteWhitelistSQL = "" + + "DELETE FROM federationsender_whitelist WHERE server_name = $1" + +const deleteAllWhitelistSQL = "" + + "DELETE FROM federationsender_whitelist" + +type whitelistStatements struct { + db *sql.DB + insertWhitelistStmt *sql.Stmt + selectWhitelistStmt *sql.Stmt + deleteWhitelistStmt *sql.Stmt + deleteAllWhitelistStmt *sql.Stmt +} + +func NewSQLiteWhitelistTable(db *sql.DB) (s *whitelistStatements, err error) { + s = &whitelistStatements{ + db: db, + } + _, err = db.Exec(whitelistSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertWhitelistStmt, insertWhitelistSQL}, + {&s.selectWhitelistStmt, selectWhitelistSQL}, + {&s.deleteWhitelistStmt, deleteWhitelistSQL}, + {&s.deleteAllWhitelistStmt, deleteAllWhitelistSQL}, + }.Prepare(db) +} + +func (s *whitelistStatements) InsertWhitelist( + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertWhitelistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *whitelistStatements) SelectWhitelist( + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectWhitelistStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is whitelisted, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +func (s *whitelistStatements) DeleteWhitelist( + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteWhitelistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *whitelistStatements) DeleteAllWhitelist( + ctx context.Context, txn *sql.Tx, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllWhitelistStmt) + _, err := stmt.ExecContext(ctx) + return err +} diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 2173a93f..e0c314a9 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -72,6 +72,13 @@ type FederationBlacklist interface { DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error } +type FederationWhitelist interface { + InsertWhitelist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error + SelectWhitelist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (bool, error) + DeleteWhitelist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error + DeleteAllWhitelist(ctx context.Context, txn *sql.Tx) error +} + type FederationAssumedOffline interface { InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (bool, error) diff --git a/setup/config/config_federationapi.go b/setup/config/config_federationapi.go index ed417a74..1ada41a0 100644 --- a/setup/config/config_federationapi.go +++ b/setup/config/config_federationapi.go @@ -47,7 +47,13 @@ type FederationAPI struct { // Should we prefer direct key fetches over perspective ones? PreferDirectFetch bool `yaml:"prefer_direct_fetch"` - // Deny/Allow lists used for restricting request scopes. + // Enable servers whitelist function + EnableWhitelist bool `yaml:"enable_whitelist"` + + // The list of whitelisted servers + WhitelistedServers []spec.ServerName `yaml:"whitelisted_servers"` + + // Deny/Allow lists used for restricting request scopes. DenyNetworkCIDRs []string `yaml:"deny_networks"` AllowNetworkCIDRs []string `yaml:"allow_networks"` } @@ -91,6 +97,8 @@ func (c *FederationAPI) Defaults(opts DefaultOpts) { c.Database.ConnectionString = "file:federationapi.db" } } + c.EnableWhitelist = false + c.WhitelistedServers = make([]spec.ServerName, 0) } func (c *FederationAPI) Verify(configErrs *ConfigErrors) { diff --git a/setup/config/config_test.go b/setup/config/config_test.go index 263aa9f3..65865ea1 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -96,6 +96,8 @@ client_api: federation_api: database: connection_string: file:federationapi.db + enable_whitelist: true + whitelisted_servers: ["https://matrix.org"] key_server: database: connection_string: file:keyserver.db