diff --git a/.github/workflows/docker-publish-tags.yml b/.github/workflows/docker-publish-tags.yml index 2ae6b8b..065c248 100644 --- a/.github/workflows/docker-publish-tags.yml +++ b/.github/workflows/docker-publish-tags.yml @@ -22,7 +22,6 @@ env: jobs: build: - runs-on: ubuntu-latest permissions: contents: read diff --git a/.gitignore b/.gitignore index fea24ab..b6a39e3 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ nirn-proxy.* nirn-proxy .env *.txt -*.log \ No newline at end of file +*.log +k6_tests/node_modules \ No newline at end of file diff --git a/CONFIG.md b/CONFIG.md index 124a5ec..0ce57ee 100644 --- a/CONFIG.md +++ b/CONFIG.md @@ -48,6 +48,11 @@ If using Kubernetes, create a headless service and use it here for easy clusteri Example: `nirn-headless.default.svc.cluster.local` or `nirn.mydomain.com` +##### MAX_BEARER_COUNT +Bearer token queues max size. Internally, bearer queues are put in an LRU map, this env var represents the max amount of items for this map. +Requests are never interrupted midway, even when an entry is evicted. A low LRU size may cause increased 429s if a bearer token has too many requests queued and fires another one after eviction. +Default: 1024 + ## Unstable env vars Collection of env vars that may be removed at any time, mainly used for Discord introducing new behaviour on their edge api versions diff --git a/README.md b/README.md index 6ffc7de..dded248 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ It is designed to be minimally invasive and exploits common library patterns to - Works with any API version (Also supports using two or more versions for the same bot) - Small resource footprint - Works with webhooks +- Works with Bearer tokens - Prometheus metrics exported out of the box - No hardcoded routes, therefore no need of updates for new routes introduced by Discord @@ -56,8 +57,6 @@ The proxy may return a 408 Request Timeout if Discord takes more than $REQUEST_T The ratelimiting only works with `X-RateLimit-Precision` set to `seconds`. If you are using Discord API v8+, that is the only possible behaviour. For users on v6 or v7, please refer to your library docs for information on which precision it uses and how to change it to seconds. -Bearer tokens should work, however this was not at all tested and is not the main use case for this project - ### Why? As projects grow, it's desirable to break them into multiple pieces, each responsible for its own domain. Discord provides gateway sharding on their end but REST can get tricky once you start moving logic out of the shards themselves and lose the guild affinity that shards inherently have, thus a centralized place for handling ratelimits is a must to prevent cloudflare bans and prevent avoidable 429s. At the time this project was created, there was no alternative that fully satisfied our requirements like multi-bot support. We are also early adopters of Discord features, so we need a proxy that supports new routes without us having to manually update it. Thus, this project was born. @@ -99,6 +98,10 @@ Global ratelimits are handled by a single node on the cluster, however this affi The best deployment strategy for the cluster is to kill nodes one at a time, preferably with the replacement node already up. +### Bearer Tokens + +Bearer tokens are first class citizens. They are treated differently than bot tokens, while bot queues are long lived and never get evicted, Bearer queues are put into an LRU and are spread out by their token hash instead of by the path hash. This provides a more even spread of bearer queues across nodes in the cluster. In addition, Bearer globals are always handled locally. You can control how many bearer queues to keep at any time with the MAX_BEARER_COUNT env var. + ### Profiling The proxy can be profiled at runtime by enabling the ENABLE_PPROF flag and browsing to `http://ip:7654/debug/pprof/` diff --git a/go.mod b/go.mod index 0f5af6e..256b20e 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.16 require ( github.com/Clever/leakybucket v1.2.0 github.com/bwmarrin/snowflake v0.3.0 // indirect + github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/hashicorp/memberlist v0.3.1 // indirect github.com/joho/godotenv v1.4.0 // indirect github.com/prometheus/client_golang v1.11.0 diff --git a/go.sum b/go.sum index 2e8eb38..9704e57 100644 --- a/go.sum +++ b/go.sum @@ -64,6 +64,8 @@ github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerX github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/golang-lru v0.5.0 h1:CL2msUPvZTLb5O648aiLNJw3hnBxN2+1Jq8rCOH9wdo= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= +github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/memberlist v0.3.1 h1:MXgUXLqva1QvpVEDQW1IQLG0wivQAtmFlHRQ+1vWZfM= github.com/hashicorp/memberlist v0.3.1/go.mod h1:MS2lj3INKhZjWNqd3N0m3J+Jxf3DAOnAH9VT3Sh9MUE= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= diff --git a/k6_tests/loadtest1.js b/k6_tests/loadtest1.js new file mode 100644 index 0000000..f14350c --- /dev/null +++ b/k6_tests/loadtest1.js @@ -0,0 +1,21 @@ +import http from 'k6/http'; +import {check, sleep} from 'k6'; +export const options = { + noConnectionReuse: true, + vus: 50, + iterations: 50 +}; + +export default function() { + const params = { + headers: { 'Authorization': __ENV.TOKEN }, + }; + let res = http.get('http://localhost:8080/api/v9/gateway', params); + check(res, { 'success': (r) => r.status >= 200 && r.status < 400 }); + + let res2 = http.get('http://localhost:8080/api/v9/guilds/203039963636301824', params); + check(res2, { 'success': (r) => r.status >= 200 && r.status < 400 }); + + let res3 = http.get('http://localhost:8080/api/v9/guilds/203039963636301824/channels', params); + check(res3, { 'success': (r) => r.status >= 200 && r.status < 400 }); +} \ No newline at end of file diff --git a/lib/discord.go b/lib/discord.go index a94797f..1cc108c 100644 --- a/lib/discord.go +++ b/lib/discord.go @@ -10,6 +10,7 @@ import ( "math" "net" "net/http" + "strings" "time" ) @@ -79,6 +80,10 @@ func GetBotGlobalLimit(token string) (uint, error) { return math.MaxUint32, nil } + if strings.HasPrefix(token, "Bearer") { + return 50, nil + } + bot, err := doDiscordReq(context.Background(), "/api/v9/gateway/bot", "GET", nil, map[string][]string{"Authorization": {token}}, "") if err != nil { diff --git a/lib/queue.go b/lib/queue.go index ff1b1c8..b7224a1 100644 --- a/lib/queue.go +++ b/lib/queue.go @@ -39,6 +39,7 @@ type RequestQueue struct { identifier string isTokenInvalid *int64 botLimit uint + queueType QueueType } @@ -66,16 +67,27 @@ func NewRequestQueue(processor func(ctx context.Context, item *QueueItem) (*http return nil, err } - user, err := GetBotUser(token) - if err != nil && token != "" { - return nil, err + queueType := NoAuth + var user *BotUserResponse + if !strings.HasPrefix(token, "Bearer") { + user, err = GetBotUser(token) + if err != nil && token != "" { + return nil, err + } + } else { + queueType = Bearer } identifier := "NoAuth" if user != nil { + queueType = Bot identifier = user.Username + "#" + user.Discrim } + if queueType == Bearer { + identifier = "Bearer" + } + ret := &RequestQueue{ queues: make(map[uint64]*QueueChannel), processor: processor, @@ -85,15 +97,31 @@ func NewRequestQueue(processor func(ctx context.Context, item *QueueItem) (*http user: user, identifier: identifier, isTokenInvalid: new(int64), - botLimit: limit, + botLimit: limit, + queueType: queueType, } - logger.WithFields(logrus.Fields{ "globalLimit": limit, "identifier": identifier, "bufferSize": bufferSize }).Info("Created new queue") + if queueType != Bearer { + logger.WithFields(logrus.Fields{"globalLimit": limit, "identifier": identifier, "bufferSize": bufferSize}).Info("Created new queue") + // Only sweep bot queues, bearer queues get completely destroyed and hold way less endpoints + go ret.tickSweep() + } else { + logger.WithFields(logrus.Fields{"globalLimit": limit, "identifier": identifier, "bufferSize": bufferSize}).Debug("Created new bearer queue") + } - go ret.tickSweep() return ret, nil } +func (q *RequestQueue) destroy() { + q.Lock() + defer q.Unlock() + logger.Debug("Destroying queue") + for _, val := range q.queues { + close(val.ch) + } + q.queues = nil +} + func (q *RequestQueue) sweep() { q.Lock() defer q.Unlock() @@ -124,11 +152,13 @@ func (q *RequestQueue) Queue(req *http.Request, res *http.ResponseWriter, path s "method": req.Method, }).Trace("Inbound request") + q.Lock() ch := q.getQueueChannel(path, pathHash) doneChan := make(chan *http.Response) errChan := make(chan error) ch.ch <- &QueueItem{req, res, doneChan, errChan } + q.Unlock() select { case resp := <-doneChan: return path, resp, nil @@ -138,8 +168,6 @@ func (q *RequestQueue) Queue(req *http.Request, res *http.ResponseWriter, path s } func (q *RequestQueue) getQueueChannel(path string, pathHash uint64) *QueueChannel { - q.Lock() - defer q.Unlock() t := time.Now() ch, ok := q.queues[pathHash] if !ok { @@ -206,34 +234,6 @@ func parseHeaders(headers *http.Header) (int64, int64, time.Duration, bool, erro return limitParsed, remainingParsed, reset, isGlobal, nil } -func (q *RequestQueue) takeGlobal(path string) { -takeGlobal: - waitTime := atomic.LoadInt64(q.globalLockedUntil) - - if waitTime > 0 { - logger.WithFields(logrus.Fields{ - "bucket": path, - "waitTime": waitTime, - }).Trace("Waiting for existing global to clear") - time.Sleep(time.Until(time.Unix(0, waitTime))) - sw := atomic.CompareAndSwapInt64(q.globalLockedUntil, waitTime, 0) - if sw { - logger.Info("Unlocked global bucket") - } - } - - _, err := q.globalBucket.Add(1) - if err != nil { - reset := q.globalBucket.Reset() - logger.WithFields(logrus.Fields{ - "bucket": path, - "waitTime": time.Until(reset), - }).Trace("Failed to grab global token, sleeping for a bit") - time.Sleep(time.Until(reset)) - goto takeGlobal - } -} - func return404webhook(item *QueueItem) { res := *item.Res res.WriteHeader(404) @@ -286,7 +286,6 @@ func (q *RequestQueue) subscribe(ch *QueueChannel, path string, pathHash uint64) continue } - q.takeGlobal(path) if atomic.LoadInt64(q.isTokenInvalid) > 0 { return401(item) @@ -345,7 +344,7 @@ func (q *RequestQueue) subscribe(ch *QueueChannel, path string, pathHash uint64) ret404 = true } - if resp.StatusCode == 401 && !isInteraction(item.Req.URL.String()) && q.identifier != "NoAuth" { + if resp.StatusCode == 401 && !isInteraction(item.Req.URL.String()) && q.queueType != NoAuth { // Permanently lock this queue logger.WithFields(logrus.Fields{ "bucket": path, diff --git a/lib/queue_manager.go b/lib/queue_manager.go index 155eeb7..ef56f75 100644 --- a/lib/queue_manager.go +++ b/lib/queue_manager.go @@ -1,18 +1,30 @@ package lib import ( + lru "github.com/hashicorp/golang-lru" "github.com/hashicorp/memberlist" "github.com/sirupsen/logrus" "net/http" "sort" "strconv" + "strings" "sync" "time" ) +type QueueType int64 + +const ( + Bot QueueType = iota + NoAuth + Bearer +) + type QueueManager struct { sync.RWMutex queues map[string]*RequestQueue + bearerQueues *lru.Cache + bearerMu sync.RWMutex bufferSize int cluster *memberlist.Memberlist clusterGlobalRateLimiter *ClusterGlobalRateLimiter @@ -23,9 +35,20 @@ type QueueManager struct { localNodeProxyListenAddr string } -func NewQueueManager(bufferSize int) *QueueManager { +func onEvictLruItem(key interface{}, value interface{}) { + go value.(*RequestQueue).destroy() +} + +func NewQueueManager(bufferSize int, maxBearerLruSize int) *QueueManager { + bearerMap, err := lru.NewWithEvict(maxBearerLruSize, onEvictLruItem) + + if err != nil { + panic(err) + } + q := &QueueManager{ queues: make(map[string]*RequestQueue), + bearerQueues: bearerMap, bufferSize: bufferSize, cluster: nil, clusterGlobalRateLimiter: NewClusterGlobalRateLimiter(), @@ -48,6 +71,8 @@ func (m *QueueManager) reindexMembers() { m.Lock() defer m.Unlock() + m.bearerMu.Lock() + defer m.bearerMu.Unlock() members := m.cluster.Members() var orderedMembers []string @@ -92,8 +117,14 @@ func (m *QueueManager) calculateRoute(pathHash uint64) string { return "" } + if pathHash == 0 { + return "" + } + m.RLock() defer m.RUnlock() + m.bearerMu.RLock() + defer m.bearerMu.RUnlock() members := m.orderedClusterMembers count := uint64(len(members)) @@ -148,18 +179,14 @@ func (m *QueueManager) Generate429(resp *http.ResponseWriter) { writer.Write([]byte("{\n\t\"global\": false,\n\t\"message\": \"You are being rate limited.\",\n\t\"retry_after\": 1\n}")) } -func (m *QueueManager) DiscordRequestHandler(resp http.ResponseWriter, req *http.Request) { - ConnectionsOpen.Inc() - defer ConnectionsOpen.Dec() - - token := req.Header.Get("Authorization") - +func (m *QueueManager) getOrCreateBotQueue(token string) (*RequestQueue, error) { m.RLock() q, ok := m.queues[token] m.RUnlock() if !ok { m.Lock() + defer m.Unlock() // Check if it wasn't created while we didn't hold the lock q, ok = m.queues[token] if !ok { @@ -167,27 +194,66 @@ func (m *QueueManager) DiscordRequestHandler(resp http.ResponseWriter, req *http q, err = NewRequestQueue(ProcessRequest, token, m.bufferSize) if err != nil { - resp.WriteHeader(500) - resp.Write([]byte(err.Error())) - logger.Error(err) - m.Unlock() - return + return nil, err } m.queues[token] = q } - m.Unlock() } - path := GetOptimisticBucketPath(req.URL.Path, req.Method) - pathHash := HashCRC64(path) - var botHash uint64 = 0 - if q.user != nil { - botHash = HashCRC64(q.user.Id) + return q, nil +} + +func (m *QueueManager) getOrCreateBearerQueue(token string) (*RequestQueue, error) { + m.bearerMu.RLock() + q, ok := m.bearerQueues.Get(token) + m.bearerMu.RUnlock() + + if !ok { + m.bearerMu.Lock() + defer m.bearerMu.Unlock() + // Check if it wasn't created while we didn't hold the lock + q, ok = m.bearerQueues.Get(token) + if !ok { + var err error + q, err = NewRequestQueue(ProcessRequest, token, 5) + + if err != nil { + return nil, err + } + + m.bearerQueues.Add(token, q) + } + } + + return q.(*RequestQueue), nil +} + +func (m *QueueManager) DiscordRequestHandler(resp http.ResponseWriter, req *http.Request) { + ConnectionsOpen.Inc() + defer ConnectionsOpen.Dec() + + token := req.Header.Get("Authorization") + routingHash, path, queueType := m.GetRequestRoutingInfo(req, token) + + m.fulfillRequest(&resp, req, queueType, path, routingHash, token) +} + +func (m *QueueManager) GetRequestRoutingInfo(req *http.Request, token string) (routingHash uint64, path string, queueType QueueType) { + path = GetOptimisticBucketPath(req.URL.Path, req.Method) + queueType = NoAuth + if strings.HasPrefix(token, "Bearer") { + queueType = Bearer + routingHash = HashCRC64(token) + } else { + queueType = Bot + routingHash = HashCRC64(path) } + return +} +func (m *QueueManager) fulfillRequest(resp *http.ResponseWriter, req *http.Request, queueType QueueType, path string, pathHash uint64, token string) { routeTo := m.calculateRoute(pathHash) - globalRouteTo := m.calculateRoute(botHash) routeToHeader := req.Header.Get("nirn-routed-to") req.Header.Del("nirn-routed-to") @@ -198,22 +264,42 @@ func (m *QueueManager) DiscordRequestHandler(resp http.ResponseWriter, req *http var err error if routeTo == "" || routeToHeader != "" { + var q *RequestQueue + var err error + if queueType == Bearer { + q, err = m.getOrCreateBearerQueue(token) + } else { + q, err = m.getOrCreateBotQueue(token) + } + + if err != nil { + (*resp).WriteHeader(500) + (*resp).Write([]byte(err.Error())) + logger.Error(err) + } + if q.identifier != "NoAuth" && m.cluster != nil { + var botHash uint64 = 0 + if q.user != nil { + botHash = HashCRC64(q.user.Id) + } + botLimit := q.botLimit + globalRouteTo := m.calculateRoute(botHash) - if globalRouteTo == "" { + if globalRouteTo == "" || queueType == Bearer { m.clusterGlobalRateLimiter.Take(botHash, botLimit) } else { err = m.clusterGlobalRateLimiter.FireGlobalRequest(req.Context(), globalRouteTo, botHash, botLimit) if err != nil { logger.WithFields(logrus.Fields{"function": "FireGlobalRequest"}).Error(err) ErrorCounter.Inc() - m.Generate429(&resp) + m.Generate429(resp) return } } } - _, _, err = q.Queue(req, &resp, path, pathHash) + _, _, err = q.Queue(req, resp, path, pathHash) if err != nil { logger.WithFields(logrus.Fields{"function": "Queue"}).Error(err) } @@ -221,13 +307,13 @@ func (m *QueueManager) DiscordRequestHandler(resp http.ResponseWriter, req *http var res *http.Response res, err = m.routeRequest(routeTo, req) if err == nil { - err = CopyResponseToResponseWriter(res, &resp) + err = CopyResponseToResponseWriter(res, resp) if err != nil { logger.WithFields(logrus.Fields{"function": "CopyResponseToResponseWriter"}).Error(err) } } else { logger.WithFields(logrus.Fields{"function": "routeRequest"}).Error(err) - m.Generate429(&resp) + m.Generate429(resp) } } diff --git a/main.go b/main.go index 2d4a7fc..dfffcec 100644 --- a/main.go +++ b/main.go @@ -77,8 +77,9 @@ func main() { setupLogger() bufferSize = lib.EnvGetInt("BUFFER_SIZE", 50) + maxBearerLruSize := lib.EnvGetInt("MAX_BEARER_COUNT", 1024) - manager := lib.NewQueueManager(bufferSize) + manager := lib.NewQueueManager(bufferSize, maxBearerLruSize) initCluster(port, manager)