Skip to content

Commit

Permalink
Response limit middleware (#917)
Browse files Browse the repository at this point in the history
  • Loading branch information
andresmgot authored Mar 11, 2024
1 parent 88020f3 commit 31dbdd2
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 0 deletions.
13 changes: 13 additions & 0 deletions backend/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (
SQLMaxOpenConnsDefault = "GF_SQL_MAX_OPEN_CONNS_DEFAULT"
SQLMaxIdleConnsDefault = "GF_SQL_MAX_IDLE_CONNS_DEFAULT"
SQLMaxConnLifetimeSecondsDefault = "GF_SQL_MAX_CONN_LIFETIME_SECONDS_DEFAULT"
ResponseLimit = "GF_RESPONSE_LIMIT"
)

type configKey struct{}
Expand Down Expand Up @@ -225,6 +226,18 @@ func (c *GrafanaCfg) UserFacingDefaultError() (string, error) {
return value, nil
}

func (c *GrafanaCfg) ResponseLimit() int64 {
count, ok := c.config[ResponseLimit]
if !ok {
return 0
}
i, err := strconv.ParseInt(count, 10, 64)
if err != nil {
return 0
}
return i
}

type userAgentKey struct{}

// UserAgentFromContext returns user agent from context.
Expand Down
66 changes: 66 additions & 0 deletions backend/httpclient/max_bytes_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package httpclient

import (
"errors"
"fmt"
"io"
)

// Similar implementation to http/net MaxBytesReader
// https://pkg.go.dev/net/http#MaxBytesReader
// What's happening differently here, is that the field that
// is limited is the response and not the request, thus
// the error handling/message needed to be accurate.

// ErrResponseBodyTooLarge indicates response body is too large
var ErrResponseBodyTooLarge = errors.New("http: response body too large")

// MaxBytesReader is similar to io.LimitReader but is intended for
// limiting the size of incoming request bodies. In contrast to
// io.LimitReader, MaxBytesReader's result is a ReadCloser, returns a
// non-EOF error for a Read beyond the limit, and closes the
// underlying reader when its Close method is called.
//
// MaxBytesReader prevents clients from accidentally or maliciously
// sending a large request and wasting server resources.
func MaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
return &maxBytesReader{r: r, n: n}
}

type maxBytesReader struct {
r io.ReadCloser // underlying reader
n int64 // max bytes remaining
err error // sticky error
}

func (l *maxBytesReader) Read(p []byte) (n int, err error) {
if l.err != nil {
return 0, l.err
}
if len(p) == 0 {
return 0, nil
}
// If they asked for a 32KB byte read but only 5 bytes are
// remaining, no need to read 32KB. 6 bytes will answer the
// question of the whether we hit the limit or go past it.
if int64(len(p)) > l.n+1 {
p = p[:l.n+1]
}
n, err = l.r.Read(p)

if int64(n) <= l.n {
l.n -= int64(n)
l.err = err
return n, err
}

n = int(l.n)
l.n = 0

l.err = fmt.Errorf("error: %w, response limit is set to: %d", ErrResponseBodyTooLarge, n)
return n, l.err
}

func (l *maxBytesReader) Close() error {
return l.r.Close()
}
40 changes: 40 additions & 0 deletions backend/httpclient/max_bytes_reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package httpclient

import (
"errors"
"fmt"
"io"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func TestMaxBytesReader(t *testing.T) {
tcs := []struct {
limit int64
bodyLength int
body string
err error
}{
{limit: 1, bodyLength: 1, body: "d", err: errors.New("error: http: response body too large, response limit is set to: 1")},
{limit: 1000000, bodyLength: 5, body: "dummy", err: nil},
{limit: 0, bodyLength: 0, body: "", err: errors.New("error: http: response body too large, response limit is set to: 0")},
}
for _, tc := range tcs {
t.Run(fmt.Sprintf("Test MaxBytesReader with limit: %d", tc.limit), func(t *testing.T) {
body := io.NopCloser(strings.NewReader("dummy"))
readCloser := MaxBytesReader(body, tc.limit)

bodyBytes, err := io.ReadAll(readCloser)
if err != nil {
require.EqualError(t, tc.err, err.Error())
} else {
require.NoError(t, tc.err)
}

require.Len(t, bodyBytes, tc.bodyLength)
require.Equal(t, string(bodyBytes), tc.body)
})
}
}
29 changes: 29 additions & 0 deletions backend/httpclient/response_limit_middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package httpclient

import (
"net/http"
)

// ResponseLimitMiddlewareName is the middleware name used by ResponseLimitMiddleware.
const ResponseLimitMiddlewareName = "response-limit"

func ResponseLimitMiddleware(limit int64) Middleware {
return NamedMiddlewareFunc(ResponseLimitMiddlewareName, func(opts Options, next http.RoundTripper) http.RoundTripper {
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
res, err := next.RoundTrip(req)
if err != nil {
return nil, err
}

if limit <= 0 {
return res, nil
}

if res != nil && res.StatusCode != http.StatusSwitchingProtocols {
res.Body = MaxBytesReader(res.Body, limit)
}

return res, nil
})
})
}
58 changes: 58 additions & 0 deletions backend/httpclient/response_limit_middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package httpclient

import (
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func TestResponseLimitMiddleware(t *testing.T) {
tcs := []struct {
limit int64
bodyLength int
body string
err error
}{
{limit: 1, bodyLength: 1, body: "d", err: errors.New("error: http: response body too large, response limit is set to: 1")},
{limit: 1000000, bodyLength: 5, body: "dummy", err: nil},
{limit: 0, bodyLength: 5, body: "dummy", err: nil},
}
for _, tc := range tcs {
t.Run(fmt.Sprintf("Test ResponseLimitMiddleware with limit: %d", tc.limit), func(t *testing.T) {
finalRoundTripper := RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{StatusCode: http.StatusOK, Request: req, Body: io.NopCloser(strings.NewReader("dummy"))}, nil
})

mw := ResponseLimitMiddleware(tc.limit)
rt := mw.CreateMiddleware(Options{}, finalRoundTripper)
require.NotNil(t, rt)
middlewareName, ok := mw.(MiddlewareName)
require.True(t, ok)
require.Equal(t, ResponseLimitMiddlewareName, middlewareName.MiddlewareName())

req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://test.com/query", nil)
require.NoError(t, err)
res, err := rt.RoundTrip(req)
require.NoError(t, err)
require.NotNil(t, res)
require.NotNil(t, res.Body)

bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
require.EqualError(t, tc.err, err.Error())
} else {
require.NoError(t, tc.err)
}
require.NoError(t, res.Body.Close())

require.Len(t, bodyBytes, tc.bodyLength)
require.Equal(t, string(bodyBytes), tc.body)
})
}
}

0 comments on commit 31dbdd2

Please sign in to comment.