Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rate limit per JWT claim #17

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions juju/example/krakend.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
"endpoints": [
{
"endpoint": "/showrss/{id}",
"headers_to_pass": [
"x-user",
"x-tier"
],
"backend": [
{
"host": [
Expand Down Expand Up @@ -37,6 +41,30 @@
],
"extra_config": {
"github.com/devopsfaith/krakend-ratelimit/juju/router": {
"tierConfiguration": {
"headerTier": "x-tier",
"strategy": "header",
"key": "x-user",
"duration": "1m",
"tiers": [
{
"name": "unlimited",
"limit": 0
},
{
"name": "gold",
"limit": 50
},
{
"name": "silver",
"limit": 20
},
{
"name": "bronze",
"limit": 5
}
]
},
"maxRate": 50,
"clientMaxRate": 5,
"strategy": "ip"
Expand Down
17 changes: 17 additions & 0 deletions juju/juju.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ package juju
import (
"context"

"time"

"github.com/juju/ratelimit"

krakendrate "github.com/devopsfaith/krakend-ratelimit"
Expand All @@ -19,6 +21,10 @@ func NewLimiter(maxRate float64, capacity int64) Limiter {
return Limiter{ratelimit.NewBucketWithRate(maxRate, capacity)}
}

func NewLimiterDuration(fillInterval time.Duration, capacity int64) Limiter {
return Limiter{ratelimit.NewBucketWithQuantum(fillInterval, capacity, capacity)}
}

// Limiter is a simple wrapper over the ratelimit.Bucket struct
type Limiter struct {
limiter *ratelimit.Bucket
Expand All @@ -37,7 +43,18 @@ func NewLimiterStore(maxRate float64, capacity int64, backend krakendrate.Backen
}
}

func NewLimiterDurationStore(fillInterval time.Duration, capacity int64, backend krakendrate.Backend) krakendrate.LimiterStore {
f := func() interface{} { return NewLimiterDuration(fillInterval, capacity) }
return func(t string) krakendrate.Limiter {
return backend.Load(t, f).(Limiter)
}
}

// NewMemoryStore returns a LimiterStore using the memory backend
func NewMemoryStore(maxRate float64, capacity int64) krakendrate.LimiterStore {
return NewLimiterStore(maxRate, capacity, krakendrate.DefaultShardedMemoryBackend(context.Background()))
}

func NewMemoryDurationStore(fillInterval time.Duration, capacity int64) krakendrate.LimiterStore {
return NewLimiterDurationStore(fillInterval, capacity, krakendrate.DefaultShardedMemoryBackend(context.Background()))
}
84 changes: 83 additions & 1 deletion juju/router/gin/gin.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package gin

import (
"context"
"log"
"net"
"net/http"
"strings"
"time"

"github.com/devopsfaith/krakend/config"
"github.com/devopsfaith/krakend/proxy"
Expand All @@ -25,7 +28,7 @@ func NewRateLimiterMw(next krakendgin.HandlerFactory) krakendgin.HandlerFactory
handlerFunc := next(remote, p)

cfg := router.ConfigGetter(remote.ExtraConfig).(router.Config)
if cfg == router.ZeroCfg || (cfg.MaxRate <= 0 && cfg.ClientMaxRate <= 0) {
if cfg == router.ZeroCfg || (cfg.MaxRate <= 0 && cfg.ClientMaxRate <= 0 || cfg.TierConfiguration == nil) {
return handlerFunc
}

Expand All @@ -40,6 +43,17 @@ func NewRateLimiterMw(next krakendgin.HandlerFactory) krakendgin.HandlerFactory
handlerFunc = NewHeaderLimiterMw(cfg.Key, float64(cfg.ClientMaxRate), cfg.ClientMaxRate)(handlerFunc)
}
}
if cfg.TierConfiguration != nil {
strategy := strings.ToLower(cfg.TierConfiguration.Strategy)
duration, err := time.ParseDuration(cfg.TierConfiguration.Duration)
if err != nil {
log.Printf("%s => Tier Configuration will be ignored.", err)
} else if strategy != "ip" && strategy != "header" {
log.Printf("%s is not a valid strategy => Tier Configuration will be ignored", strategy)
} else {
handlerFunc = NewTierLimiterMw(cfg.TierConfiguration, duration)(handlerFunc)
}
}
return handlerFunc
}
}
Expand Down Expand Up @@ -78,6 +92,21 @@ func NewIpLimiterWithKeyMw(header string, maxRate float64, capacity int64) Endpo
return NewTokenLimiterMw(NewIPTokenExtractor(header), juju.NewMemoryStore(maxRate, capacity))
}

// NewIpLimiterWithKeyMw creates a token ratelimiter using the IP/header of the request and tier name as a token
func NewTierLimiterMw(tierConfiguration *router.TierConfiguration, fillInterval time.Duration) EndpointMw {
var storesPerTier = krakendrate.NewShardedMemoryBackend(context.Background(), 2, fillInterval, krakendrate.PseudoFNV64a)
for _, tier := range tierConfiguration.Tiers {
if tier.Limit > 0 {
storesPerTier.Store(tier.Name, juju.NewMemoryDurationStore(fillInterval, tier.Limit))
}
}
return NewTokenLimiterPerTierMw(
NewConcatTokenExtractor(tierConfiguration.HeaderTier, strings.ToLower(tierConfiguration.Strategy), tierConfiguration.Key),
fillInterval,
storesPerTier,
)
}

// TokenExtractor defines the interface of the functions to use in order to extract a token for each request
type TokenExtractor func(*gin.Context) string

Expand All @@ -103,6 +132,33 @@ func HeaderTokenExtractor(header string) TokenExtractor {
return func(c *gin.Context) string { return c.Request.Header.Get(header) }
}

// ConcatTokenExtractor returns a TokenExtractor that concatenates all passed token extractors
func ConcatTokenExtractor(tokenExtractors []TokenExtractor) TokenExtractor {
return func(c *gin.Context) string {
var tokenValues = make([]string, len(tokenExtractors))
for i, tokenExtractor := range tokenExtractors {
tokenValues[i] = tokenExtractor(c)
}
return strings.Join(tokenValues, "-")
}
}

// NewConcatTokenExtractor generates a ConcatTokenExtractor using ip or header extractors depending on the strategy
func NewConcatTokenExtractor(headerTier string, strategy string, key string) TokenExtractor {
var tierTokenExtractor = HeaderTokenExtractor(headerTier)
var clientIdentifierTokenExtractor TokenExtractor
if strategy == "ip" {
if key == "" {
clientIdentifierTokenExtractor = IPTokenExtractor
} else {
clientIdentifierTokenExtractor = NewIPTokenExtractor(key)
}
} else if strategy == "header" {
clientIdentifierTokenExtractor = HeaderTokenExtractor(key)
}
return ConcatTokenExtractor([]TokenExtractor{tierTokenExtractor, clientIdentifierTokenExtractor})
}

// NewTokenLimiterMw returns a token based ratelimiting endpoint middleware with the received TokenExtractor and LimiterStore
func NewTokenLimiterMw(tokenExtractor TokenExtractor, limiterStore krakendrate.LimiterStore) EndpointMw {
return func(next gin.HandlerFunc) gin.HandlerFunc {
Expand All @@ -120,3 +176,29 @@ func NewTokenLimiterMw(tokenExtractor TokenExtractor, limiterStore krakendrate.L
}
}
}

// NewTokenLimiterPerTierMw returns a token based ratelimiting endpoint middleware with the received TokenExtractor and different LimiterStores per tier
func NewTokenLimiterPerTierMw(tokenExtractor TokenExtractor, fillInterval time.Duration, storesPerTier *krakendrate.ShardedMemoryBackend) EndpointMw {
var noResult = func() interface{} { return nil }
return func(next gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
tokenKey := tokenExtractor(c)
if tokenKey == "" {
c.AbortWithError(http.StatusTooManyRequests, krakendrate.ErrLimited)
return
}
tokenKeyParts := strings.Split(tokenKey, "-")
tierName, clientIdentifier := tokenKeyParts[0], tokenKeyParts[1]
tierLimiter := storesPerTier.Load(tierName, noResult)
if tierLimiter != nil {
if !tierLimiter.(krakendrate.LimiterStore)(clientIdentifier).Allow() {
c.AbortWithError(http.StatusTooManyRequests, krakendrate.ErrLimited)
return
}
} else {
log.Printf("Tier %s does not exist.", tierName)
}
next(c)
}
}
}
102 changes: 102 additions & 0 deletions juju/router/gin/gin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,108 @@ func TestNewRateLimiterMw_DefaultIP(t *testing.T) {
testRateLimiterMw(t, rd, cfg)
}

func TestNewRateLimiterMw_TierCustomHeader(t *testing.T) {
headerTier := "X-Tier"
headerUser := "X-User"

cfg := &config.EndpointConfig{
ExtraConfig: map[string]interface{}{
router.Namespace: map[string]interface{}{
"tierConfiguration": map[string]interface{}{
"headerTier": headerTier,
"strategy": "header",
"key": headerUser,
"duration": "1s",
"tiers": []map[string]interface{}{
{
"name": "tier1",
"limit": 100,
},
{
"name": "tier2",
"limit": 200,
},
},
},
},
},
}

rd := func(req *http.Request) {
req.Header.Add(headerTier, "tier1")
req.Header.Add(headerUser, "1234567890")
}

testRateLimiterMw(t, rd, cfg)
}

func TestNewRateLimiterMw_TierDefaultIP(t *testing.T) {
headerTier := "X-Tier"

cfg := &config.EndpointConfig{
ExtraConfig: map[string]interface{}{
router.Namespace: map[string]interface{}{
"tierConfiguration": map[string]interface{}{
"headerTier": headerTier,
"strategy": "ip",
"duration": "1s",
"tiers": []map[string]interface{}{
{
"name": "tier1",
"limit": 100,
},
{
"name": "tier2",
"limit": 200,
},
},
},
},
},
}

rd := func(req *http.Request) {
req.Header.Add(headerTier, "tier1")
}

testRateLimiterMw(t, rd, cfg)
}

func TestNewRateLimiterMw_TierCustomHeaderIP(t *testing.T) {
headerTier := "X-Tier"
headerIP := "X-Custom-Forwarded-For"

cfg := &config.EndpointConfig{
ExtraConfig: map[string]interface{}{
router.Namespace: map[string]interface{}{
"tierConfiguration": map[string]interface{}{
"headerTier": headerTier,
"strategy": "ip",
"key": headerIP,
"duration": "1s",
"tiers": []map[string]interface{}{
{
"name": "tier1",
"limit": 100,
},
{
"name": "tier2",
"limit": 200,
},
},
},
},
},
}

rd := func(req *http.Request) {
req.Header.Add(headerTier, "tier1")
req.Header.Add(headerIP, "1.1.1.1,2.2.2.2,3.3.3.3")
}

testRateLimiterMw(t, rd, cfg)
}

type requestDecorator func(*http.Request)

func testRateLimiterMw(t *testing.T, rd requestDecorator, cfg *config.EndpointConfig) {
Expand Down
37 changes: 33 additions & 4 deletions juju/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and http://en.wikipedia.org/wiki/Token_bucket for more details.
package router

import (
"encoding/json"
"fmt"

"github.com/devopsfaith/krakend/config"
Expand All @@ -32,10 +33,24 @@ const Namespace = "github.com/devopsfaith/krakend-ratelimit/juju/router"

// Config is the custom config struct containing the params for the router middlewares
type Config struct {
MaxRate int64
Strategy string
ClientMaxRate int64
Key string
MaxRate int64
Strategy string
ClientMaxRate int64
Key string
TierConfiguration *TierConfiguration
}

type TierConfiguration struct {
Strategy string
Key string
HeaderTier string
Duration string
Tiers []Tier
}

type Tier struct {
Name string
Limit int64
}

// ZeroCfg is the zero value for the Config struct
Expand Down Expand Up @@ -79,5 +94,19 @@ func ConfigGetter(e config.ExtraConfig) interface{} {
if v, ok := tmp["key"]; ok {
cfg.Key = fmt.Sprintf("%v", v)
}
if v, ok := tmp["tierConfiguration"]; ok {
jsonbody, err := json.Marshal(v)
if err != nil {
fmt.Println(err)
return ZeroCfg
}

tierConfiguration := TierConfiguration{}
if err := json.Unmarshal(jsonbody, &tierConfiguration); err != nil {
fmt.Println(err)
return ZeroCfg
}
cfg.TierConfiguration = &tierConfiguration
}
return cfg
}
3 changes: 3 additions & 0 deletions juju/router/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,7 @@ func TestConfigGetter(t *testing.T) {
if cfg.Key != "" {
t.Errorf("wrong value for Key. Want: '', have: %s", cfg.Key)
}
if cfg.TierConfiguration != nil {
t.Errorf("wrong value for TierConfiguration. Want: <nil>, have: %+v", cfg.TierConfiguration)
}
}