-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
88020f3
commit 31dbdd2
Showing
5 changed files
with
206 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
package httpclient | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"io" | ||
) | ||
|
||
// Similar implementation to http/net MaxBytesReader | ||
// https://pkg.go.dev/net/http#MaxBytesReader | ||
// What's happening differently here, is that the field that | ||
// is limited is the response and not the request, thus | ||
// the error handling/message needed to be accurate. | ||
|
||
// ErrResponseBodyTooLarge indicates response body is too large | ||
var ErrResponseBodyTooLarge = errors.New("http: response body too large") | ||
|
||
// MaxBytesReader is similar to io.LimitReader but is intended for | ||
// limiting the size of incoming request bodies. In contrast to | ||
// io.LimitReader, MaxBytesReader's result is a ReadCloser, returns a | ||
// non-EOF error for a Read beyond the limit, and closes the | ||
// underlying reader when its Close method is called. | ||
// | ||
// MaxBytesReader prevents clients from accidentally or maliciously | ||
// sending a large request and wasting server resources. | ||
func MaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser { | ||
return &maxBytesReader{r: r, n: n} | ||
} | ||
|
||
type maxBytesReader struct { | ||
r io.ReadCloser // underlying reader | ||
n int64 // max bytes remaining | ||
err error // sticky error | ||
} | ||
|
||
func (l *maxBytesReader) Read(p []byte) (n int, err error) { | ||
if l.err != nil { | ||
return 0, l.err | ||
} | ||
if len(p) == 0 { | ||
return 0, nil | ||
} | ||
// If they asked for a 32KB byte read but only 5 bytes are | ||
// remaining, no need to read 32KB. 6 bytes will answer the | ||
// question of the whether we hit the limit or go past it. | ||
if int64(len(p)) > l.n+1 { | ||
p = p[:l.n+1] | ||
} | ||
n, err = l.r.Read(p) | ||
|
||
if int64(n) <= l.n { | ||
l.n -= int64(n) | ||
l.err = err | ||
return n, err | ||
} | ||
|
||
n = int(l.n) | ||
l.n = 0 | ||
|
||
l.err = fmt.Errorf("error: %w, response limit is set to: %d", ErrResponseBodyTooLarge, n) | ||
return n, l.err | ||
} | ||
|
||
func (l *maxBytesReader) Close() error { | ||
return l.r.Close() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package httpclient | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"io" | ||
"strings" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestMaxBytesReader(t *testing.T) { | ||
tcs := []struct { | ||
limit int64 | ||
bodyLength int | ||
body string | ||
err error | ||
}{ | ||
{limit: 1, bodyLength: 1, body: "d", err: errors.New("error: http: response body too large, response limit is set to: 1")}, | ||
{limit: 1000000, bodyLength: 5, body: "dummy", err: nil}, | ||
{limit: 0, bodyLength: 0, body: "", err: errors.New("error: http: response body too large, response limit is set to: 0")}, | ||
} | ||
for _, tc := range tcs { | ||
t.Run(fmt.Sprintf("Test MaxBytesReader with limit: %d", tc.limit), func(t *testing.T) { | ||
body := io.NopCloser(strings.NewReader("dummy")) | ||
readCloser := MaxBytesReader(body, tc.limit) | ||
|
||
bodyBytes, err := io.ReadAll(readCloser) | ||
if err != nil { | ||
require.EqualError(t, tc.err, err.Error()) | ||
} else { | ||
require.NoError(t, tc.err) | ||
} | ||
|
||
require.Len(t, bodyBytes, tc.bodyLength) | ||
require.Equal(t, string(bodyBytes), tc.body) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
package httpclient | ||
|
||
import ( | ||
"net/http" | ||
) | ||
|
||
// ResponseLimitMiddlewareName is the middleware name used by ResponseLimitMiddleware. | ||
const ResponseLimitMiddlewareName = "response-limit" | ||
|
||
func ResponseLimitMiddleware(limit int64) Middleware { | ||
return NamedMiddlewareFunc(ResponseLimitMiddlewareName, func(opts Options, next http.RoundTripper) http.RoundTripper { | ||
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) { | ||
res, err := next.RoundTrip(req) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if limit <= 0 { | ||
return res, nil | ||
} | ||
|
||
if res != nil && res.StatusCode != http.StatusSwitchingProtocols { | ||
res.Body = MaxBytesReader(res.Body, limit) | ||
} | ||
|
||
return res, nil | ||
}) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
package httpclient | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"strings" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestResponseLimitMiddleware(t *testing.T) { | ||
tcs := []struct { | ||
limit int64 | ||
bodyLength int | ||
body string | ||
err error | ||
}{ | ||
{limit: 1, bodyLength: 1, body: "d", err: errors.New("error: http: response body too large, response limit is set to: 1")}, | ||
{limit: 1000000, bodyLength: 5, body: "dummy", err: nil}, | ||
{limit: 0, bodyLength: 5, body: "dummy", err: nil}, | ||
} | ||
for _, tc := range tcs { | ||
t.Run(fmt.Sprintf("Test ResponseLimitMiddleware with limit: %d", tc.limit), func(t *testing.T) { | ||
finalRoundTripper := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { | ||
return &http.Response{StatusCode: http.StatusOK, Request: req, Body: io.NopCloser(strings.NewReader("dummy"))}, nil | ||
}) | ||
|
||
mw := ResponseLimitMiddleware(tc.limit) | ||
rt := mw.CreateMiddleware(Options{}, finalRoundTripper) | ||
require.NotNil(t, rt) | ||
middlewareName, ok := mw.(MiddlewareName) | ||
require.True(t, ok) | ||
require.Equal(t, ResponseLimitMiddlewareName, middlewareName.MiddlewareName()) | ||
|
||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://test.com/query", nil) | ||
require.NoError(t, err) | ||
res, err := rt.RoundTrip(req) | ||
require.NoError(t, err) | ||
require.NotNil(t, res) | ||
require.NotNil(t, res.Body) | ||
|
||
bodyBytes, err := io.ReadAll(res.Body) | ||
if err != nil { | ||
require.EqualError(t, tc.err, err.Error()) | ||
} else { | ||
require.NoError(t, tc.err) | ||
} | ||
require.NoError(t, res.Body.Close()) | ||
|
||
require.Len(t, bodyBytes, tc.bodyLength) | ||
require.Equal(t, string(bodyBytes), tc.body) | ||
}) | ||
} | ||
} |