From 78ccd46aa95313dd09c1d15f7570c74b40a4389e Mon Sep 17 00:00:00 2001 From: Jeevanandam M Date: Sat, 26 Oct 2024 22:56:05 -0700 Subject: [PATCH] feat!: redesign request retry flow, retry-after header, default retry conditions #886 --- README.md | 2 + client.go | 136 ++++++------ client_test.go | 57 ++--- context_test.go | 9 +- middleware.go | 9 +- request.go | 179 ++++++++++++---- request_test.go | 131 +++++++----- resty.go | 40 ++-- resty_test.go | 23 +- retry.go | 305 +++++++++++--------------- retry_test.go | 556 ++++++++++++++++++++++++------------------------ util.go | 89 +++----- util_test.go | 3 - 13 files changed, 777 insertions(+), 762 deletions(-) diff --git a/README.md b/README.md index aa86105c..b8cfea00 100644 --- a/README.md +++ b/README.md @@ -700,6 +700,8 @@ client.RemoveProxy() Resty uses [backoff](http://www.awsarchitectureblog.com/2015/03/backoff.html) to increase retry intervals after each attempt. +TODO update retry docs + Usage example: ```go diff --git a/client.go b/client.go index 19d32190..6dc3c7c1 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ -// Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved. +// Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. +// SPDX-License-Identifier: MIT package resty @@ -63,9 +64,9 @@ var ( hdrContentLengthKey = http.CanonicalHeaderKey("Content-Length") hdrContentEncodingKey = http.CanonicalHeaderKey("Content-Encoding") hdrContentDisposition = http.CanonicalHeaderKey("Content-Disposition") - hdrLocationKey = http.CanonicalHeaderKey("Location") hdrAuthorizationKey = http.CanonicalHeaderKey("Authorization") hdrWwwAuthenticateKey = http.CanonicalHeaderKey("WWW-Authenticate") + hdrRetryAfterKey = http.CanonicalHeaderKey("Retry-After") plainTextType = "text/plain; charset=utf-8" jsonContentType = "application/json" @@ -178,9 +179,9 @@ type Client struct { retryWaitTime time.Duration retryMaxWaitTime time.Duration retryConditions []RetryConditionFunc - retryHooks []OnRetryFunc - retryAfter RetryAfterFunc - retryResetReaders bool + retryHooks []RetryHookFunc + retryStrategy RetryStrategyFunc + isRetryDefaultConditions bool headerAuthorizationKey string responseBodyLimit int64 resBodyUnlimitedReads bool @@ -198,7 +199,6 @@ type Client struct { proxyURL *url.URL requestLog RequestLogCallback responseLog ResponseLogCallback - rateLimiter RateLimiter generateCurlOnDebug bool loadBalancer LoadBalancer beforeRequest []RequestMiddleware @@ -635,7 +635,8 @@ func (c *Client) R() *Request { RetryCount: c.retryCount, RetryWaitTime: c.retryWaitTime, RetryMaxWaitTime: c.retryMaxWaitTime, - RetryResetReaders: c.retryResetReaders, + RetryStrategy: c.retryStrategy, + IsRetryDefaultConditions: c.isRetryDefaultConditions, CloseConnection: c.closeConnection, DoNotParseResponse: c.notParseResponse, DebugBodyLimit: c.debugBodyLimit, @@ -1204,20 +1205,57 @@ func (c *Client) SetRetryMaxWaitTime(maxWaitTime time.Duration) *Client { return c } -// RetryAfter method returns the retry after callback function, that is -// used to calculate wait time between retries if it's registered; otherwise, it is nil. -func (c *Client) RetryAfter() RetryAfterFunc { +// RetryStrategy method returns the retry strategy function; otherwise, it is nil. +// +// See [Client.SetRetryStrategy] +func (c *Client) RetryStrategy() RetryStrategyFunc { + c.lock.RLock() + defer c.lock.RUnlock() + return c.retryStrategy +} + +// SetRetryStrategy method used to set the custom Retry strategy into Resty client, +// it is used to get wait time before each retry. It can be overridden at request +// level, see [Request.SetRetryStrategy] +// +// Default (nil) implies exponential backoff with a jitter strategy +func (c *Client) SetRetryStrategy(rs RetryStrategyFunc) *Client { + c.lock.Lock() + defer c.lock.Unlock() + c.retryStrategy = rs + return c +} + +// EnableRetryDefaultConditions method enables the Resty's default retry conditions +func (c *Client) EnableRetryDefaultConditions() *Client { + c.SetRetryDefaultConditions(true) + return c +} + +// DisableRetryDefaultConditions method disables the Resty's default retry conditions +func (c *Client) DisableRetryDefaultConditions() *Client { + c.SetRetryDefaultConditions(false) + return c +} + +// IsRetryDefaultConditions method returns true if Resty's default retry conditions +// are enabled otherwise false +// +// Default value is `true` +func (c *Client) IsRetryDefaultConditions() bool { c.lock.RLock() defer c.lock.RUnlock() - return c.retryAfter + return c.isRetryDefaultConditions } -// SetRetryAfter sets a callback to calculate the wait time between retries. -// Default (nil) implies exponential backoff with jitter -func (c *Client) SetRetryAfter(callback RetryAfterFunc) *Client { +// SetRetryDefaultConditions method is used to enable/disable the Resty's default +// retry conditions +// +// It can be overridden at request level, see [Request.SetRetryDefaultConditions] +func (c *Client) SetRetryDefaultConditions(b bool) *Client { c.lock.Lock() defer c.lock.Unlock() - c.retryAfter = callback + c.isRetryDefaultConditions = b return c } @@ -1241,17 +1279,8 @@ func (c *Client) AddRetryCondition(condition RetryConditionFunc) *Client { return c } -// AddRetryAfterErrorCondition adds the basic condition of retrying after encountering -// an error from the HTTP response -func (c *Client) AddRetryAfterErrorCondition() *Client { - c.AddRetryCondition(func(response *Response, err error) bool { - return response.IsError() - }) - return c -} - // RetryHooks method returns all the retry hook functions. -func (c *Client) RetryHooks() []OnRetryFunc { +func (c *Client) RetryHooks() []RetryHookFunc { c.lock.RLock() defer c.lock.RUnlock() return c.retryHooks @@ -1259,29 +1288,13 @@ func (c *Client) RetryHooks() []OnRetryFunc { // AddRetryHook adds a side-effecting retry hook to an array of hooks // that will be executed on each retry. -func (c *Client) AddRetryHook(hook OnRetryFunc) *Client { +func (c *Client) AddRetryHook(hook RetryHookFunc) *Client { c.lock.Lock() defer c.lock.Unlock() c.retryHooks = append(c.retryHooks, hook) return c } -// RetryResetReaders method returns true if the retry reset readers are enabled; otherwise, it is nil. -func (c *Client) RetryResetReaders() bool { - c.lock.RLock() - defer c.lock.RUnlock() - return c.retryResetReaders -} - -// SetRetryResetReaders method enables the Resty client to seek the start of all -// file readers are given as multipart files if the object implements [io.ReadSeeker]. -func (c *Client) SetRetryResetReaders(b bool) *Client { - c.lock.Lock() - defer c.lock.Unlock() - c.retryResetReaders = b - return c -} - // SetTLSClientConfig method sets TLSClientConfig for underlying client Transport. // // For Example: @@ -1539,22 +1552,6 @@ func (c *Client) SetOutputDirectory(dirPath string) *Client { return c } -// RateLimiter method returns the rate limiter interface -func (c *Client) RateLimiter() RateLimiter { - c.lock.RLock() - defer c.lock.RUnlock() - return c.rateLimiter -} - -// SetRateLimiter sets an optional [RateLimiter]. If set, the rate limiter will control -// all requests were made by this client. -func (c *Client) SetRateLimiter(rl RateLimiter) *Client { - c.lock.Lock() - defer c.lock.Unlock() - c.rateLimiter = rl - return c -} - // Transport method returns [http.Transport] currently in use or error // in case the currently used `transport` is not a [http.Transport]. // @@ -1947,30 +1944,18 @@ func (c *Client) Close() error { func (c *Client) executeBefore(req *Request) error { var err error - if isStringEmpty(req.Method) { - req.Method = MethodGet - } - // user defined on before request methods // to modify the *resty.Request object for _, f := range c.beforeRequestMiddlewares() { if err = f(c, req); err != nil { - return wrapNoRetryErr(err) - } - } - - // If there is a rate limiter set for this client, the Execute call - // will return an error if the rate limit is exceeded. - if req.client.RateLimiter() != nil { - if !req.client.RateLimiter().Allow() { - return ErrRateLimitExceeded + return err } } // resty middlewares for _, f := range c.beforeRequest { if err = f(c, req); err != nil { - return wrapNoRetryErr(err) + return err } } @@ -1981,7 +1966,7 @@ func (c *Client) executeBefore(req *Request) error { // call pre-request if defined if c.preReqHook != nil { if err = c.preReqHook(c, req.RawRequest); err != nil { - return wrapNoRetryErr(err) + return err } } @@ -1996,10 +1981,9 @@ func (c *Client) execute(req *Request) (*Response, error) { } if err := requestDebugLogger(c, req); err != nil { - return nil, wrapNoRetryErr(err) + return nil, err } - req.RawRequest.Body = wrapRequestBufferReleaser(req) req.Time = time.Now() resp, err := c.Client().Do(req.RawRequest) @@ -2046,7 +2030,7 @@ func (c *Client) execute(req *Request) (*Response, error) { } } - return response, wrapNoRetryErr(err) + return response, err } // getting TLS client config if not exists then create one diff --git a/client_test.go b/client_test.go index 5b37d8d8..f021f1eb 100644 --- a/client_test.go +++ b/client_test.go @@ -1,6 +1,7 @@ -// Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved. +// Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. +// SPDX-License-Identifier: MIT package resty @@ -9,13 +10,14 @@ import ( "compress/gzip" "compress/lzw" "context" - "crypto/rand" + cryprand "crypto/rand" "crypto/tls" "errors" "fmt" "io" "log" "math" + "math/rand" "net" "net/http" "net/url" @@ -202,19 +204,19 @@ func TestClientTimeout(t *testing.T) { ts := createGetServer(t) defer ts.Close() - c := dcnl().SetTimeout(time.Second * 3) + c := dcnl().SetTimeout(time.Millisecond * 200) _, err := c.R().Get(ts.URL + "/set-timeout-test") - assertEqual(t, true, strings.Contains(strings.ToLower(err.Error()), "timeout")) + assertEqual(t, true, strings.Contains(err.Error(), "Client.Timeout")) } func TestClientTimeoutWithinThreshold(t *testing.T) { ts := createGetServer(t) defer ts.Close() - c := dcnl().SetTimeout(time.Second * 3) - resp, err := c.R().Get(ts.URL + "/set-timeout-test-with-sequence") + c := dcnl().SetTimeout(200 * time.Millisecond) + resp, err := c.R().Get(ts.URL + "/set-timeout-test-with-sequence") assertError(t, err) seq1, _ := strconv.ParseInt(resp.String(), 10, 32) @@ -490,6 +492,12 @@ func TestClientSettingsCoverage(t *testing.T) { c.DisableDebug() + assertEqual(t, true, c.IsRetryDefaultConditions()) + c.DisableRetryDefaultConditions() + assertEqual(t, false, c.IsRetryDefaultConditions()) + c.EnableRetryDefaultConditions() + assertEqual(t, true, c.IsRetryDefaultConditions()) + // [Start] Custom Transport scenario ct := dcnl() ct.SetTransport(&CustomRoundTripper{}) @@ -1204,19 +1212,26 @@ func TestPostRedirectWithBody(t *testing.T) { ts := createPostServer(t) defer ts.Close() - targetURL, _ := url.Parse(ts.URL) - t.Log("ts.URL:", ts.URL) - t.Log("targetURL.Host:", targetURL.Host) + mu := sync.Mutex{} + rnd := rand.New(rand.NewSource(time.Now().UnixNano())) - c := dcnl() + c := dcnl().SetBaseURL(ts.URL) + + totalRequests := 4000 wg := sync.WaitGroup{} - for i := 0; i < 100; i++ { - wg.Add(1) + wg.Add(totalRequests) + for i := 0; i < totalRequests; i++ { + if i%50 == 0 { + time.Sleep(20 * time.Millisecond) // to prevent test server socket exhaustion + } go func() { defer wg.Done() + mu.Lock() + randNumber := rnd.Int() + mu.Unlock() resp, err := c.R(). - SetBody([]byte(strconv.Itoa(newRnd().Int()))). - Post(targetURL.String() + "/redirect-with-body") + SetBody([]byte(strconv.Itoa(randNumber))). + Post("/redirect-with-body") assertError(t, err) assertNotNil(t, resp) }() @@ -1252,14 +1267,6 @@ func TestUnixSocket(t *testing.T) { assertEqual(t, "Hello resty client from a server running on endpoint /hello!", res.String()) } -var _ RateLimiter = (*testRateLimiter)(nil) - -type testRateLimiter struct{} - -func (t *testRateLimiter) Allow() bool { - return false -} - func TestClientClone(t *testing.T) { parent := New() @@ -1268,9 +1275,6 @@ func TestClientClone(t *testing.T) { parent.SetBasicAuth("parent", "") parent.SetProxy("http://localhost:8080") - // set an interface field - tr := &testRateLimiter{} - parent.SetRateLimiter(tr) parent.SetCookie(&http.Cookie{ Name: "go-resty-1", Value: "This is cookie 1 value", @@ -1300,12 +1304,11 @@ func TestClientClone(t *testing.T) { // assert interface/pointer type assertEqual(t, parent.Client(), clone.Client()) - assertEqual(t, parent.RateLimiter(), clone.RateLimiter()) } func TestResponseBodyLimit(t *testing.T) { ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { - io.CopyN(w, rand.Reader, 100*800) + io.CopyN(w, cryprand.Reader, 100*800) }) defer ts.Close() diff --git a/context_test.go b/context_test.go index e3651758..9a24afb8 100644 --- a/context_test.go +++ b/context_test.go @@ -191,20 +191,19 @@ func TestSetContextCancelWithError(t *testing.T) { } func TestClientRetryWithSetContext(t *testing.T) { - var attemptctx int32 + var attempt int32 ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { t.Logf("Method: %v", r.Method) t.Logf("Path: %v", r.URL.Path) - attp := atomic.AddInt32(&attemptctx, 1) - if attp <= 4 { - time.Sleep(time.Second * 2) + if atomic.AddInt32(&attempt, 1) <= 4 { + time.Sleep(100 * time.Millisecond) } _, _ = w.Write([]byte("TestClientRetry page")) }) defer ts.Close() c := dcnl(). - SetTimeout(time.Second * 1). + SetTimeout(50 * time.Millisecond). SetRetryCount(3) _, err := c.R(). diff --git a/middleware.go b/middleware.go index 2f20ff50..b04e2bdd 100644 --- a/middleware.go +++ b/middleware.go @@ -1,11 +1,11 @@ -// Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved. +// Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. +// SPDX-License-Identifier: MIT package resty import ( - "bytes" "fmt" "io" "mime/multipart" @@ -230,10 +230,7 @@ func createHTTPRequest(c *Client, r *Request) (err error) { r.RawRequest, err = http.NewRequestWithContext(r.Context(), r.Method, r.URL, nil) } } else { - // fix data race: must deep copy. - // TODO investigate in details and remove this copy line - bodyBuf := bytes.NewBuffer(append([]byte{}, r.bodyBuf.Bytes()...)) - r.RawRequest, err = http.NewRequestWithContext(r.Context(), r.Method, r.URL, bodyBuf) + r.RawRequest, err = http.NewRequestWithContext(r.Context(), r.Method, r.URL, r.bodyBuf) } if err != nil { diff --git a/request.go b/request.go index c75cd36d..0233285e 100644 --- a/request.go +++ b/request.go @@ -1,6 +1,7 @@ -// Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved. +// Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. +// SPDX-License-Identifier: MIT package resty @@ -57,16 +58,13 @@ type Request struct { IsTrace bool AllowMethodGetPayload bool AllowMethodDeletePayload bool - - // Retry - RetryCount int - RetryWaitTime time.Duration - RetryMaxWaitTime time.Duration - RetryResetReaders bool - - // Attempt is to represent the request attempt made during a Resty - // request execution flow, including retry count. - Attempt int + IsDone bool + RetryCount int + RetryWaitTime time.Duration + RetryMaxWaitTime time.Duration + RetryStrategy RetryStrategyFunc + IsRetryDefaultConditions bool + Attempt int isMultiPart bool isFormData bool @@ -963,10 +961,36 @@ func (r *Request) SetRetryMaxWaitTime(maxWaitTime time.Duration) *Request { return r } -// SetRetryResetReaders method enables the Resty client to seek the start of all -// file readers are given as multipart files if the object implements [io.ReadSeeker]. -func (r *Request) SetRetryResetReaders(b bool) *Request { - r.RetryResetReaders = b +// SetRetryStrategy method used to set the custom Retry strategy on request, +// it is used to get wait time before each retry. It overrides the retry +// strategy set at the client instance level, see [Client.SetRetryStrategy] +// +// Default (nil) implies exponential backoff with a jitter strategy +func (r *Request) SetRetryStrategy(rs RetryStrategyFunc) *Request { + r.RetryStrategy = rs + return r +} + +// EnableRetryDefaultConditions method enables the Resty's default retry +// conditions on request level +func (r *Request) EnableRetryDefaultConditions() *Request { + r.SetRetryDefaultConditions(true) + return r +} + +// DisableRetryDefaultConditions method disables the Resty's default retry +// conditions on request level +func (r *Request) DisableRetryDefaultConditions() *Request { + r.SetRetryDefaultConditions(false) + return r +} + +// SetRetryDefaultConditions method is used to enable/disable the Resty's default +// retry conditions on request level +// +// It overrides value set at the client instance level, see [Client.SetRetryDefaultConditions] +func (r *Request) SetRetryDefaultConditions(b bool) *Request { + r.IsRetryDefaultConditions = b return r } @@ -1164,10 +1188,7 @@ func (r *Request) Send() (*Response, error) { // for current [Request]. // // resp, err := client.R().Execute(resty.MethodGet, "http://httpbin.org/get") -func (r *Request) Execute(method, url string) (*Response, error) { - var resp *Response - var err error - +func (r *Request) Execute(method, url string) (res *Response, err error) { defer func() { if rec := recover(); rec != nil { if err, ok := rec.(error); ok { @@ -1189,47 +1210,104 @@ func (r *Request) Execute(method, url string) (*Response, error) { r.Method = method r.URL = url - if r.RetryCount == 0 { - r.Attempt = 1 - resp, err = r.client.execute(r) - r.client.onErrorHooks(r, resp, unwrapNoRetryErr(err)) + if r.RetryCount < 0 { + r.RetryCount = 0 // default behavior is no retry + } + + var backoff *backoffWithJitter + if r.RetryCount > 0 { + backoff = newBackoffWithJitter(r.RetryWaitTime, r.RetryMaxWaitTime) + } + + // first request + retry count = total no. of requests + + for i := 0; i <= r.RetryCount; i++ { + r.Attempt++ + err = nil + res, err = r.client.execute(r) if err != nil { - r.sendLoadBalancerFeedback() + if isInvalidRequestError(err) { + return + } + if r.Context().Err() != nil { + return res, wrapErrors(r.Context().Err(), err) + } } - return resp, unwrapNoRetryErr(err) - } - err = backoff( - func() (*Response, error) { - r.Attempt++ - resp, err = r.client.execute(r) - if err != nil { - r.log.Warnf("%v, Attempt %v", err, r.Attempt) - r.sendLoadBalancerFeedback() + // we have reached the maximum retry count stop here + if r.Attempt-1 == r.RetryCount { + break + } + + if backoff != nil { + needsRetry := false + + // apply default retry conditions + if r.IsRetryDefaultConditions { + needsRetry = applyRetryDefaultConditions(res, err) + } + + // apply user-defined retry conditions if default one + // is still false + if !needsRetry && res != nil { + // user defined retry conditions + retryConditions := append(r.retryConditions, r.client.RetryConditions()...) + for _, retryCondition := range retryConditions { + if needsRetry = retryCondition(res, err); needsRetry { + break + } + } + } + + // retry not required stop here + if !needsRetry { + break } - return resp, err - }, - Retries(r.RetryCount), - WaitTime(r.RetryWaitTime), - MaxWaitTime(r.RetryMaxWaitTime), - RetryConditions(append(r.retryConditions, r.client.RetryConditions()...)), - RetryHooks(r.client.RetryHooks()), - ResetMultipartReaders(r.RetryResetReaders), - ) + // by default reset file readers + if err = r.resetFileReaders(); err != nil { + // if any error in reset readers, stop here + break + } + + // run user-defined retry hooks + for _, retryHookFunc := range r.client.RetryHooks() { + retryHookFunc(res, err) + } + + // let's drain the response body, before retry wait + drainBody(res) + + waitDuration, waitErr := backoff.NextWaitDuration(r.client, res, err, r.Attempt) + if waitErr != nil { + // if any error in retry strategy, stop here + err = wrapErrors(waitErr, err) + break + } + + timer := time.NewTimer(waitDuration) + select { + case <-r.Context().Done(): + timer.Stop() + return nil, wrapErrors(r.Context().Err(), err) + case <-timer.C: + } + } + } if r.isMultiPart { for _, mf := range r.multipartFields { mf.close() } } - if err != nil { r.log.Errorf("%v", err) } - r.client.onErrorHooks(r, resp, unwrapNoRetryErr(err)) - return resp, unwrapNoRetryErr(err) + r.IsDone = true + r.client.onErrorHooks(r, res, err) + r.sendLoadBalancerFeedback() // TODO revisit on call and success criteria + return } // Clone returns a deep copy of r with its context changed to ctx. @@ -1426,12 +1504,21 @@ func (r *Request) sendLoadBalancerFeedback() { if r.client.LoadBalancer() != nil { r.client.LoadBalancer().Feedback(&RequestFeedback{ BaseURL: r.baseURL, - Success: false, + Success: false, // TODO revisit condition to define success or not Attempt: r.Attempt, }) } } +func (r *Request) resetFileReaders() error { + for _, f := range r.multipartFields { + if err := f.resetReader(); err != nil { + return err + } + } + return nil +} + func jsonIndent(v []byte) []byte { buf := acquireBuffer() defer releaseBuffer(buf) diff --git a/request_test.go b/request_test.go index 470ed46a..96973ee5 100644 --- a/request_test.go +++ b/request_test.go @@ -1,6 +1,7 @@ -// Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved. +// Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. +// SPDX-License-Identifier: MIT package resty @@ -17,10 +18,9 @@ import ( "path/filepath" "strconv" "strings" + "sync" "testing" "time" - - "golang.org/x/time/rate" ) type AuthSuccess struct { @@ -67,52 +67,15 @@ func TestGetGH524(t *testing.T) { assertEqual(t, resp.Request.Header.Get("Content-Type"), "") // unable to reproduce reported issue } -func TestRateLimiter(t *testing.T) { - ts := createGetServer(t) - defer ts.Close() - - // Test a burst with a valid capacity and then a consecutive request that must fail. - - // Allow a rate of 1 every 100 ms but also allow bursts of 10 requests. - client := dcnl().SetRateLimiter(rate.NewLimiter(rate.Every(100*time.Millisecond), 10)) - - // Execute a burst of 10 requests. - for i := 0; i < 10; i++ { - resp, err := client.R(). - SetQueryParam("request_no", strconv.Itoa(i)).Get(ts.URL + "/") - assertError(t, err) - assertEqual(t, http.StatusOK, resp.StatusCode()) - } - // Next request issued directly should fail because burst of 10 has been consumed. - { - _, err := client.R(). - SetQueryParam("request_no", strconv.Itoa(11)).Get(ts.URL + "/") - assertErrorIs(t, ErrRateLimitExceeded, err) - } - - // Test continues request at a valid rate - - // Allow a rate of 1 every ms with no burst. - client = dcnl().SetRateLimiter(rate.NewLimiter(rate.Every(1*time.Millisecond), 1)) - - // Sending requests every ms+tiny delta must succeed. - for i := 0; i < 100; i++ { - resp, err := client.R(). - SetQueryParam("request_no", strconv.Itoa(i)).Get(ts.URL + "/") - assertError(t, err) - assertEqual(t, http.StatusOK, resp.StatusCode()) - time.Sleep(1*time.Millisecond + 100*time.Microsecond) - } -} - -func TestIllegalRetryCount(t *testing.T) { +func TestRequestNegativeRetryCount(t *testing.T) { ts := createGetServer(t) defer ts.Close() resp, err := dcnl().SetRetryCount(-1).R().Get(ts.URL + "/") assertNil(t, err) - assertNil(t, resp) + assertNotNil(t, resp) + assertEqual(t, "TestGet: text response", resp.String()) } func TestGetCustomUserAgent(t *testing.T) { @@ -425,10 +388,6 @@ func TestForceContentTypeForGH276andGH240(t *testing.T) { c := dcnl() c.SetDebug(false) c.SetRetryCount(3) - c.SetRetryAfter(RetryAfterFunc(func(*Client, *Response) (time.Duration, error) { - retried++ - return 0, nil - })) resp, err := c.R(). SetBody(map[string]any{"username": "testuser", "password": "testpass"}). @@ -1826,8 +1785,11 @@ func TestTraceInfoWithoutEnableTrace(t *testing.T) { } func TestTraceInfoOnTimeout(t *testing.T) { - client := dcnl() - client.SetBaseURL("http://resty-nowhere.local").EnableTrace() + client := NewWithTransportSettings(&TransportSettings{ + DialerTimeout: 100 * time.Millisecond, + }). + SetBaseURL("http://resty-nowhere.local"). + EnableTrace() resp, err := client.R().Get("/") assertNotNil(t, err) @@ -2157,15 +2119,72 @@ func TestRequestPanicContext(t *testing.T) { func TestRequestSettingsCoverage(t *testing.T) { c := dcnl() - c.R().SetCloseConnection(true) - - c.R().DisableTrace() - - c.R().SetResponseBodyUnlimitedReads(true) - - c.R().DisableDebug() + r1 := c.R() + assertEqual(t, false, r1.CloseConnection) + r1.SetCloseConnection(true) + assertEqual(t, true, r1.CloseConnection) + + r2 := c.R() + assertEqual(t, false, r2.IsTrace) + r2.EnableTrace() + assertEqual(t, true, r2.IsTrace) + r2.DisableTrace() + assertEqual(t, false, r2.IsTrace) + + r3 := c.R() + assertEqual(t, false, r3.ResponseBodyUnlimitedReads) + r3.SetResponseBodyUnlimitedReads(true) + assertEqual(t, true, r3.ResponseBodyUnlimitedReads) + r3.SetResponseBodyUnlimitedReads(false) + assertEqual(t, false, r3.ResponseBodyUnlimitedReads) + + r4 := c.R() + assertEqual(t, false, r4.Debug) + r4.EnableDebug() + assertEqual(t, true, r4.Debug) + r4.DisableDebug() + assertEqual(t, false, r4.Debug) + + r5 := c.R() + assertEqual(t, true, r5.IsRetryDefaultConditions) + r5.DisableRetryDefaultConditions() + assertEqual(t, false, r5.IsRetryDefaultConditions) + r5.EnableRetryDefaultConditions() + assertEqual(t, true, r5.IsRetryDefaultConditions) invalidJsonBytes := []byte(`{\" \": "value here"}`) result := jsonIndent(invalidJsonBytes) assertEqual(t, string(invalidJsonBytes), string(result)) } + +func TestRequestDataRace(t *testing.T) { + ts := createPostServer(t) + defer ts.Close() + + usersmap := map[string]any{ + "user1": ExampleUser{FirstName: "firstname1", LastName: "lastname1", ZipCode: "10001"}, + "user2": &ExampleUser{FirstName: "firstname2", LastName: "lastname3", ZipCode: "10002"}, + "user3": ExampleUser{FirstName: "firstname3", LastName: "lastname3", ZipCode: "10003"}, + } + + var users []map[string]any + users = append(users, usersmap) + + c := dcnl().SetBaseURL(ts.URL) + + totalRequests := 4000 + wg := sync.WaitGroup{} + wg.Add(totalRequests) + for i := 0; i < totalRequests; i++ { + if i%100 == 0 { + time.Sleep(20 * time.Millisecond) // to prevent test server socket exhaustion + } + go func() { + defer wg.Done() + res, err := c.R().SetContext(context.Background()).SetBody(users).Post("/usersmap") + assertError(t, err) + assertEqual(t, http.StatusAccepted, res.StatusCode()) + }() + } + wg.Wait() +} diff --git a/resty.go b/resty.go index 8375de66..08a6ef10 100644 --- a/resty.go +++ b/resty.go @@ -1,6 +1,7 @@ -// Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved. +// Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. +// SPDX-License-Identifier: MIT // Package resty provides Simple HTTP and REST client library for Go. package resty @@ -158,24 +159,25 @@ func createCookieJar() *cookiejar.Jar { func createClient(hc *http.Client) *Client { c := &Client{ // not setting language default values - lock: &sync.RWMutex{}, - queryParams: url.Values{}, - formData: url.Values{}, - header: http.Header{}, - cookies: make([]*http.Cookie, 0), - retryWaitTime: defaultWaitTime, - retryMaxWaitTime: defaultMaxWaitTime, - pathParams: make(map[string]string), - rawPathParams: make(map[string]string), - headerAuthorizationKey: hdrAuthorizationKey, - jsonEscapeHTML: true, - httpClient: hc, - debugBodyLimit: math.MaxInt32, - contentTypeEncoders: make(map[string]ContentTypeEncoder), - contentTypeDecoders: make(map[string]ContentTypeDecoder), - contentDecompressorKeys: make([]string, 0), - contentDecompressors: make(map[string]ContentDecompressor), - stopChan: make(chan bool), + lock: &sync.RWMutex{}, + queryParams: url.Values{}, + formData: url.Values{}, + header: http.Header{}, + cookies: make([]*http.Cookie, 0), + retryWaitTime: defaultWaitTime, + retryMaxWaitTime: defaultMaxWaitTime, + isRetryDefaultConditions: true, + pathParams: make(map[string]string), + rawPathParams: make(map[string]string), + headerAuthorizationKey: hdrAuthorizationKey, + jsonEscapeHTML: true, + httpClient: hc, + debugBodyLimit: math.MaxInt32, + contentTypeEncoders: make(map[string]ContentTypeEncoder), + contentTypeDecoders: make(map[string]ContentTypeDecoder), + contentDecompressorKeys: make([]string, 0), + contentDecompressors: make(map[string]ContentDecompressor), + stopChan: make(chan bool), } // Logger diff --git a/resty_test.go b/resty_test.go index c0e2f040..c84076bb 100644 --- a/resty_test.go +++ b/resty_test.go @@ -1,6 +1,7 @@ -// Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved. +// Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. +// SPDX-License-Identifier: MIT package resty @@ -30,6 +31,10 @@ import ( "time" ) +var ( + hdrLocationKey = http.CanonicalHeaderKey("Location") +) + //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Testing Unexported methods //___________________________________ @@ -71,7 +76,7 @@ func createGetServer(t *testing.T) *httptest.Server { case "/set-retrycount-test": attp := atomic.AddInt32(&attempt, 1) if attp <= 4 { - time.Sleep(time.Second * 6) + time.Sleep(time.Millisecond * 150) } _, _ = w.Write([]byte("TestClientRetry page")) case "/set-retrywaittime-test": @@ -99,10 +104,10 @@ func createGetServer(t *testing.T) *httptest.Server { atomic.AddInt32(&attempt, 1) case "/set-timeout-test-with-sequence": seq := atomic.AddInt32(&sequence, 1) - time.Sleep(time.Second * 2) + time.Sleep(100 * time.Millisecond) _, _ = fmt.Fprintf(w, "%d", seq) case "/set-timeout-test": - time.Sleep(time.Second * 6) + time.Sleep(400 * time.Millisecond) _, _ = w.Write([]byte("TestClientTimeout page")) case "/my-image.png": fileBytes, _ := os.ReadFile(filepath.Join(getTestDataPath(), "test-img.png")) @@ -124,6 +129,16 @@ func createGetServer(t *testing.T) *httptest.Server { case "/not-found-no-error": w.Header().Set(hdrContentTypeKey, "application/json") w.WriteHeader(http.StatusNotFound) + case "/retry-after-delay": + w.Header().Set(hdrContentTypeKey, "application/json; charset=utf-8") + if atomic.LoadInt32(&attempt) == 0 { + w.Header().Set(hdrRetryAfterKey, "1") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{ "message": "too many" }`)) + } else { + _, _ = w.Write([]byte(`{ "message": "hello" }`)) + } + atomic.AddInt32(&attempt, 1) } switch { diff --git a/retry.go b/retry.go index ccf39693..f70dd1f5 100644 --- a/retry.go +++ b/retry.go @@ -1,249 +1,196 @@ -// Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved. +// Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. +// SPDX-License-Identifier: MIT package resty import ( - "context" + "crypto/tls" "math" "math/rand" + "net/http" + "net/url" + "regexp" + "strconv" "sync" "time" ) const ( - defaultMaxRetries = 3 defaultWaitTime = time.Duration(100) * time.Millisecond defaultMaxWaitTime = time.Duration(2000) * time.Millisecond ) type ( - // Option is to create convenient retry options like wait time, max retries, etc. - Option func(*Options) - // RetryConditionFunc type is for the retry condition function // input: non-nil Response OR request execution error RetryConditionFunc func(*Response, error) bool - // OnRetryFunc is for side-effecting functions triggered on retry - OnRetryFunc func(*Response, error) - - // RetryAfterFunc returns time to wait before retry - // For example, it can parse HTTP Retry-After header - // https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html - // Non-nil error is returned if it is found that the request is not retryable - // (0, nil) is a special result that means 'use default algorithm' - RetryAfterFunc func(*Client, *Response) (time.Duration, error) - - // Options struct is used to hold retry settings. - Options struct { - maxRetries int - waitTime time.Duration - maxWaitTime time.Duration - retryConditions []RetryConditionFunc - retryHooks []OnRetryFunc - resetReaders bool - } + // RetryHookFunc is for side-effecting functions triggered on retry + RetryHookFunc func(*Response, error) + + // RetryStrategyFunc type is for custom retry strategy implementation + // By default Resty uses the capped exponential backoff with a jitter strategy + RetryStrategyFunc func(*Response, error) (time.Duration, error) ) -// Retries sets the max number of retries -func Retries(value int) Option { - return func(o *Options) { - o.maxRetries = value - } -} +var ( + regexErrTooManyRedirects = regexp.MustCompile(`stopped after \d+ redirects\z`) + regexErrScheme = regexp.MustCompile("unsupported protocol scheme") + regexErrInvalidHeader = regexp.MustCompile("invalid header") +) -// WaitTime sets the default wait time to sleep between requests -func WaitTime(value time.Duration) Option { - return func(o *Options) { - o.waitTime = value +func applyRetryDefaultConditions(res *Response, err error) bool { + // no retry on TLS error + if _, ok := err.(*tls.CertificateVerificationError); ok { + return false } -} -// MaxWaitTime sets the max wait time to sleep between requests -func MaxWaitTime(value time.Duration) Option { - return func(o *Options) { - o.maxWaitTime = value + // validate url error, so we can decide to retry or not + if u, ok := err.(*url.Error); ok { + if regexErrTooManyRedirects.MatchString(u.Error()) { + return false + } + if regexErrScheme.MatchString(u.Error()) { + return false + } + if regexErrInvalidHeader.MatchString(u.Error()) { + return false + } + return u.Temporary() // possible retry if it's true } -} -// RetryConditions sets the conditions that will be checked for retry -func RetryConditions(conditions []RetryConditionFunc) Option { - return func(o *Options) { - o.retryConditions = conditions + if res == nil { + return false } -} -// RetryHooks sets the hooks that will be executed after each retry -func RetryHooks(hooks []OnRetryFunc) Option { - return func(o *Options) { - o.retryHooks = hooks + // certain HTTP status codes are temporary so that we can retry + // - 429 Too Many Requests + // - 500 or above (it's better to ignore 501 Not Implemented) + // - 0 No status code received + if res.StatusCode() == http.StatusTooManyRequests || + (res.StatusCode() >= 500 && res.StatusCode() != http.StatusNotImplemented) || + res.StatusCode() == 0 { + return true } -} -// ResetMultipartReaders sets a boolean value which will lead the start being seeked out -// on all multipart file readers if they implement [io.ReadSeeker] -func ResetMultipartReaders(value bool) Option { - return func(o *Options) { - o.resetReaders = value - } + return false } -// backoff retries with increasing timeout duration up until X amount of retries -// (Default is 3 attempts, Override with option Retries(n)) -func backoff(operation func() (*Response, error), options ...Option) error { - // Defaults - opts := Options{ - maxRetries: defaultMaxRetries, - waitTime: defaultWaitTime, - maxWaitTime: defaultMaxWaitTime, - retryConditions: []RetryConditionFunc{}, +func newBackoffWithJitter(min, max time.Duration) *backoffWithJitter { + if min <= 0 { + min = defaultWaitTime } - - for _, o := range options { - o(&opts) + if max == 0 { + max = defaultMaxWaitTime } - var ( - resp *Response - err error - ) - - for attempt := 0; attempt <= opts.maxRetries; attempt++ { - resp, err = operation() - ctx := context.Background() - if resp != nil && resp.Request.ctx != nil { - ctx = resp.Request.ctx - } - if ctx.Err() != nil { - return err - } - - err1 := unwrapNoRetryErr(err) // raw error, it used for return users callback. - needsRetry := err != nil && err == err1 // retry on a few operation errors by default - - for _, condition := range opts.retryConditions { - needsRetry = condition(resp, err1) - if needsRetry { - break - } - } - - if !needsRetry { - return err - } - - if opts.resetReaders { - if err := resetFileReaders(resp.Request.multipartFields...); err != nil { - return err - } - } - - for _, hook := range opts.retryHooks { - hook(resp, err) - } + return &backoffWithJitter{ + lock: new(sync.Mutex), + rnd: rand.New(rand.NewSource(time.Now().UnixNano())), + min: min, + max: max, + } +} - // Don't need to wait when no retries left. - // Still run retry hooks even on last retry to keep compatibility. - if attempt == opts.maxRetries { - return err - } +type backoffWithJitter struct { + lock *sync.Mutex + rnd *rand.Rand + min time.Duration + max time.Duration +} - waitTime, err2 := sleepDuration(resp, opts.waitTime, opts.maxWaitTime, attempt) - if err2 != nil { - if err == nil { - err = err2 +func (b *backoffWithJitter) NextWaitDuration(c *Client, res *Response, err error, attempt int) (time.Duration, error) { + if res != nil { + if res.StatusCode() == http.StatusTooManyRequests || res.StatusCode() == http.StatusServiceUnavailable { + if delay, ok := parseRetryAfterHeader(res.Header().Get(hdrRetryAfterKey)); ok { + return delay, nil } - return err - } - - select { - case <-time.After(waitTime): - case <-ctx.Done(): - return ctx.Err() } } - return err -} - -func sleepDuration(resp *Response, min, max time.Duration, attempt int) (time.Duration, error) { const maxInt = 1<<31 - 1 // max int for arch 386 - if max < 0 { - max = maxInt - } - if resp == nil { - return jitterBackoff(min, max, attempt), nil + if b.max < 0 { + b.max = maxInt } - retryAfterFunc := resp.Request.client.RetryAfter() - - // Check for custom callback - if retryAfterFunc == nil { - return jitterBackoff(min, max, attempt), nil + retryStrategyFunc := c.RetryStrategy() + if res == nil || retryStrategyFunc == nil { + return b.balanceMinMax(b.defaultStrategy(attempt)), nil } - result, err := retryAfterFunc(resp.Request.client, resp) - if err != nil { - return 0, err // i.e. 'API quota exceeded' - } - if result == 0 { - return jitterBackoff(min, max, attempt), nil + delay, rsErr := retryStrategyFunc(res, err) + if rsErr != nil { + return 0, rsErr } - if result < 0 || max < result { - result = max - } - if result < min { - result = min - } - return result, nil + return b.balanceMinMax(delay), nil } // Return capped exponential backoff with jitter // https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ -func jitterBackoff(min, max time.Duration, attempt int) time.Duration { - base := float64(min) - capLevel := float64(max) - - temp := math.Min(capLevel, base*math.Exp2(float64(attempt))) +func (b *backoffWithJitter) defaultStrategy(attempt int) time.Duration { + temp := math.Min(float64(b.max), float64(b.min)*math.Exp2(float64(attempt))) ri := time.Duration(temp / 2) - if ri == 0 { + if ri <= 0 { ri = time.Nanosecond } - result := randDuration(ri) - - if result < min { - result = min - } - - return result + return b.randDuration(ri) } -var rnd = newRnd() -var rndMu sync.Mutex - -func randDuration(center time.Duration) time.Duration { - rndMu.Lock() - defer rndMu.Unlock() +func (b *backoffWithJitter) randDuration(center time.Duration) time.Duration { + b.lock.Lock() + defer b.lock.Unlock() var ri = int64(center) - var jitter = rnd.Int63n(ri) + var jitter = b.rnd.Int63n(ri) return time.Duration(math.Abs(float64(ri + jitter))) } -func newRnd() *rand.Rand { - var seed = time.Now().UnixNano() - var src = rand.NewSource(seed) - return rand.New(src) +func (b *backoffWithJitter) balanceMinMax(delay time.Duration) time.Duration { + if delay <= 0 || b.max < delay { + return b.max + } + if delay < b.min { + return b.min + } + return delay } -func resetFileReaders(fields ...*MultipartField) error { - for _, f := range fields { - if err := f.resetReader(); err != nil { - return err +var timeNow = time.Now + +// parseRetryAfterHeader parses the Retry-After header and returns the +// delay duration according to the spec: https://httpwg.org/specs/rfc7231.html#header.retry-after +// The bool returned will be true if the header was successfully parsed. +// Otherwise, the header was either not present, or was not parseable according to the spec. +// +// Retry-After headers come in two flavors: Seconds or HTTP-Date +// +// Examples: +// - Retry-After: Fri, 31 Dec 1999 23:59:59 GMT +// - Retry-After: 120 +func parseRetryAfterHeader(v string) (time.Duration, bool) { + if isStringEmpty(v) { + return 0, false + } + + // Retry-After: 120 + if delay, err := strconv.ParseInt(v, 10, 64); err == nil { + if delay < 0 { // a negative delay doesn't make sense + return 0, false } + return time.Second * time.Duration(delay), true + } + + // Retry-After: Fri, 31 Dec 1999 23:59:59 GMT + retryTime, err := time.Parse(time.RFC1123, v) + if err != nil { + return 0, false + } + if until := retryTime.Sub(timeNow()); until > 0 { + return until, true } - return nil + // date is in the past + return 0, true } diff --git a/retry_test.go b/retry_test.go index 62600b6a..ac206873 100644 --- a/retry_test.go +++ b/retry_test.go @@ -1,12 +1,14 @@ -// Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved. +// Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. +// SPDX-License-Identifier: MIT package resty import ( "bytes" "context" + "crypto/tls" "encoding/json" "errors" "fmt" @@ -19,119 +21,8 @@ import ( "time" ) -func TestBackoffSuccess(t *testing.T) { - attempts := 3 - externalCounter := 0 - retryErr := backoff(func() (*Response, error) { - externalCounter++ - if externalCounter < attempts { - return nil, errors.New("not yet got the number we're after") - } - - return nil, nil - }) - - assertError(t, retryErr) - assertEqual(t, externalCounter, attempts) -} - -func TestBackoffNoWaitForLastRetry(t *testing.T) { - attempts := 1 - externalCounter := 0 - numRetries := 1 - - canceledCtx, cancel := context.WithCancel(context.Background()) - defer cancel() - - client := New() - client.SetRetryAfter(func(*Client, *Response) (time.Duration, error) { - return 6, nil - }) - resp := &Response{ - Request: &Request{ - ctx: canceledCtx, - client: client, - }, - } - - retryErr := backoff(func() (*Response, error) { - externalCounter++ - return resp, nil - }, RetryConditions([]RetryConditionFunc{func(response *Response, err error) bool { - if externalCounter == attempts+numRetries { - // Backoff returns context canceled if goes to sleep after last retry. - cancel() - } - return true - }}), Retries(numRetries)) - - assertNil(t, retryErr) -} - -func TestBackoffTenAttemptsSuccess(t *testing.T) { - attempts := 10 - externalCounter := 0 - retryErr := backoff(func() (*Response, error) { - externalCounter++ - if externalCounter < attempts { - return nil, errors.New("not yet got the number we're after") - } - return nil, nil - }, Retries(attempts), WaitTime(5), MaxWaitTime(500)) - - assertError(t, retryErr) - assertEqual(t, externalCounter, attempts) -} - -// Check to make sure the conditional of the retry condition is being used -func TestConditionalBackoffCondition(t *testing.T) { - attempts := 3 - counter := 0 - check := RetryConditionFunc(func(*Response, error) bool { - return attempts != counter - }) - retryErr := backoff(func() (*Response, error) { - counter++ - return nil, nil - }, RetryConditions([]RetryConditionFunc{check})) - - assertError(t, retryErr) - assertEqual(t, counter, attempts) -} - -// Check to make sure that if the conditional is false we don't retry -func TestConditionalBackoffConditionNonExecution(t *testing.T) { - attempts := 3 - counter := 0 - - retryErr := backoff(func() (*Response, error) { - counter++ - return nil, nil - }, RetryConditions([]RetryConditionFunc{filler})) - - assertError(t, retryErr) - assertNotEqual(t, counter, attempts) -} - -// Check to make sure that RetryHooks are executed -func TestOnRetryBackoff(t *testing.T) { - attempts := 3 - counter := 0 - - hook := func(r *Response, err error) { - counter++ - } - - retryErr := backoff(func() (*Response, error) { - return nil, nil - }, RetryHooks([]OnRetryFunc{hook})) - - assertError(t, retryErr) - assertNotEqual(t, counter, attempts) -} - // Check to make sure the functions added to add conditionals work -func TestConditionalGet(t *testing.T) { +func TestRetryConditionalGet(t *testing.T) { ts := createGetServer(t) defer ts.Close() attemptCount := 1 @@ -143,8 +34,10 @@ func TestConditionalGet(t *testing.T) { return attemptCount != externalCounter }) - client := dcnl().AddRetryCondition(check).SetRetryCount(1) + client := dcnl() resp, err := client.R(). + AddRetryCondition(check). + SetRetryCount(2). SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)). Get(ts.URL + "/") @@ -160,23 +53,21 @@ func TestConditionalGet(t *testing.T) { func TestConditionalGetRequestLevel(t *testing.T) { ts := createGetServer(t) defer ts.Close() - attemptCount := 1 - externalCounter := 0 + externalCounter := 0 // This check should pass on first run, and let the response through - check := RetryConditionFunc(func(*Response, error) bool { + check := RetryConditionFunc(func(r *Response, _ error) bool { externalCounter++ - return attemptCount != externalCounter + return false }) // Clear the default client. - client := dcnl() - resp, err := client.R(). + c := dcnl() + resp, err := c.R(). AddRetryCondition(check). SetRetryCount(1). - SetRetryWaitTime(time.Duration(50)*time.Millisecond). - SetRetryMaxWaitTime(time.Duration(1)*time.Second). - SetRetryResetReaders(true). + SetRetryWaitTime(50*time.Millisecond). + SetRetryMaxWaitTime(1*time.Second). SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)). Get(ts.URL + "/") @@ -184,7 +75,8 @@ func TestConditionalGetRequestLevel(t *testing.T) { assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, "TestGet: text response", resp.String()) - assertEqual(t, externalCounter, attemptCount) + assertEqual(t, 1, resp.Request.Attempt) + assertEqual(t, 1, externalCounter) logResponse(t, resp) } @@ -194,7 +86,7 @@ func TestClientRetryGet(t *testing.T) { defer ts.Close() c := dcnl(). - SetTimeout(time.Second * 3). + SetTimeout(time.Millisecond * 50). SetRetryCount(3) resp, err := c.R().Get(ts.URL + "/set-retrycount-test") @@ -208,18 +100,16 @@ func TestClientRetryGet(t *testing.T) { strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/set-retrycount-test\"")) } -func TestClientRetryWait(t *testing.T) { +func TestClientRetryWithMinAndMaxWaitTime(t *testing.T) { ts := createGetServer(t) defer ts.Close() - attempt := 0 - retryCount := 5 retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones - retryWaitTime := time.Duration(3) * time.Second - retryMaxWaitTime := time.Duration(9) * time.Second + retryWaitTime := 10 * time.Millisecond + retryMaxWaitTime := 100 * time.Millisecond c := dcnl(). SetRetryCount(retryCount). @@ -227,16 +117,16 @@ func TestClientRetryWait(t *testing.T) { SetRetryMaxWaitTime(retryMaxWaitTime). AddRetryCondition( func(r *Response, _ error) bool { - timeSlept, _ := strconv.ParseUint(string(r.bodyBytes), 10, 64) - retryIntervals[attempt] = timeSlept - attempt++ + retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true }, ) - _, _ = c.R().Get(ts.URL + "/set-retrywaittime-test") + res, _ := c.R().Get(ts.URL + "/set-retrywaittime-test") + + retryIntervals[res.Request.Attempt-1] = parseTimeSleptFromResponse(res.String()) - // 6 attempts were made - assertEqual(t, attempt, 6) + // retryCount+1 == attempts were made + assertEqual(t, retryCount+1, res.Request.Attempt) // Initial attempt has 0 time slept since last request assertEqual(t, retryIntervals[0], uint64(0)) @@ -245,8 +135,11 @@ func TestClientRetryWait(t *testing.T) { slept := time.Duration(retryIntervals[i]) // Ensure that client has slept some duration between // waitTime and maxWaitTime for consequent requests - if slept < retryWaitTime || slept > retryMaxWaitTime { - t.Errorf("Client has slept %f seconds before retry %d", slept.Seconds(), i) + if slept < retryWaitTime-5*time.Millisecond { + t.Logf("Client has slept %f seconds which is s < min (%f) before retry %d", slept.Seconds(), retryWaitTime.Seconds(), i) + } + if slept > retryMaxWaitTime+5*time.Millisecond { + t.Logf("Client has slept %f seconds which is s > max (%f) before retry %d", slept.Seconds(), retryMaxWaitTime.Seconds(), i) } } } @@ -255,13 +148,11 @@ func TestClientRetryWaitMaxInfinite(t *testing.T) { ts := createGetServer(t) defer ts.Close() - attempt := 0 - retryCount := 5 retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones - retryWaitTime := time.Duration(3) * time.Second + retryWaitTime := time.Duration(10) * time.Millisecond retryMaxWaitTime := time.Duration(-1.0) // negative value c := dcnl(). @@ -270,16 +161,16 @@ func TestClientRetryWaitMaxInfinite(t *testing.T) { SetRetryMaxWaitTime(retryMaxWaitTime). AddRetryCondition( func(r *Response, _ error) bool { - timeSlept, _ := strconv.ParseUint(string(r.bodyBytes), 10, 64) - retryIntervals[attempt] = timeSlept - attempt++ + retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true }, ) - _, _ = c.R().Get(ts.URL + "/set-retrywaittime-test") + res, _ := c.R().Get(ts.URL + "/set-retrywaittime-test") + + retryIntervals[res.Request.Attempt-1] = parseTimeSleptFromResponse(res.String()) - // 6 attempts were made - assertEqual(t, attempt, 6) + // retryCount+1 == attempts were made + assertEqual(t, retryCount+1, res.Request.Attempt) // Initial attempt has 0 time slept since last request assertEqual(t, retryIntervals[0], uint64(0)) @@ -288,8 +179,8 @@ func TestClientRetryWaitMaxInfinite(t *testing.T) { slept := time.Duration(retryIntervals[i]) // Ensure that client has slept some duration between // waitTime and maxWaitTime for consequent requests - if slept < retryWaitTime { - t.Errorf("Client has slept %f seconds before retry %d", slept.Seconds(), i) + if slept < retryWaitTime-5*time.Millisecond { + t.Logf("Client has slept %f seconds which is s < min (%f) before retry %d", slept.Seconds(), retryWaitTime.Seconds(), i) } } } @@ -308,20 +199,19 @@ func TestClientRetryWaitMaxMinimum(t *testing.T) { assertError(t, err) } -func TestClientRetryWaitCallbackError(t *testing.T) { +func TestClientRetryStrategyFuncError(t *testing.T) { ts := createGetServer(t) defer ts.Close() attempt := 0 - retryCount := 5 retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones - retryWaitTime := 3 * time.Second - retryMaxWaitTime := 9 * time.Second + retryWaitTime := 50 * time.Millisecond + retryMaxWaitTime := 150 * time.Millisecond - retryAfter := func(client *Client, resp *Response) (time.Duration, error) { + retryStrategyFunc := func(res *Response, err error) (time.Duration, error) { return 0, errors.New("quota exceeded") } @@ -329,11 +219,10 @@ func TestClientRetryWaitCallbackError(t *testing.T) { SetRetryCount(retryCount). SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). - SetRetryAfter(retryAfter). + SetRetryStrategy(retryStrategyFunc). AddRetryCondition( func(r *Response, _ error) bool { - timeSlept, _ := strconv.ParseUint(string(r.bodyBytes), 10, 64) - retryIntervals[attempt] = timeSlept + retryIntervals[attempt] = parseTimeSleptFromResponse(r.String()) attempt++ return true }, @@ -342,46 +231,45 @@ func TestClientRetryWaitCallbackError(t *testing.T) { _, err := c.R().Get(ts.URL + "/set-retrywaittime-test") // 1 attempts were made - assertEqual(t, attempt, 1) + assertEqual(t, 1, attempt) // non-nil error was returned - assertNotEqual(t, nil, err) + assertNotNil(t, err) } -func TestClientRetryWaitCallback(t *testing.T) { +func TestClientRetryStrategyFunc(t *testing.T) { ts := createGetServer(t) defer ts.Close() - attempt := 0 - - retryCount := 5 + retryCount := 10 retryIntervals := make([]uint64, retryCount+1) - // Set retry wait times that do not intersect with default ones - retryWaitTime := 3 * time.Second - retryMaxWaitTime := 9 * time.Second + // Set retry wait times to constant delay + retryWaitTime := 50 * time.Millisecond + retryMaxWaitTime := 50 * time.Millisecond - retryAfter := func(client *Client, resp *Response) (time.Duration, error) { - return 5 * time.Second, nil + // custom strategy func with constant delay + retryStrategyFunc := func(res *Response, err error) (time.Duration, error) { + return 50 * time.Millisecond, nil } c := dcnl(). SetRetryCount(retryCount). SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). - SetRetryAfter(retryAfter). + SetRetryStrategy(retryStrategyFunc). AddRetryCondition( func(r *Response, _ error) bool { - timeSlept, _ := strconv.ParseUint(string(r.bodyBytes), 10, 64) - retryIntervals[attempt] = timeSlept - attempt++ + retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true }, ) - _, _ = c.R().Get(ts.URL + "/set-retrywaittime-test") + res, _ := c.R().Get(ts.URL + "/set-retrywaittime-test") + + retryIntervals[res.Request.Attempt-1] = parseTimeSleptFromResponse(res.String()) - // 6 attempts were made - assertEqual(t, attempt, 6) + // retryCount+1 == attempts were made + assertEqual(t, retryCount+1, res.Request.Attempt) // Initial attempt has 0 time slept since last request assertEqual(t, retryIntervals[0], uint64(0)) @@ -390,46 +278,50 @@ func TestClientRetryWaitCallback(t *testing.T) { slept := time.Duration(retryIntervals[i]) // Ensure that client has slept some duration between // waitTime and maxWaitTime for consequent requests - if slept < 5*time.Second-5*time.Millisecond || 5*time.Second+5*time.Millisecond < slept { - t.Logf("Client has slept %f seconds before retry %d", slept.Seconds(), i) + if slept < retryWaitTime-5*time.Millisecond { + t.Logf("Client has slept %f seconds which is s < min (%f) before retry %d", slept.Seconds(), retryWaitTime.Seconds(), i) + } + if retryMaxWaitTime+5*time.Millisecond < slept { + t.Logf("Client has slept %f seconds which is max < s (%f) before retry %d", slept.Seconds(), retryMaxWaitTime.Seconds(), i) } } } -func TestClientRetryWaitCallbackTooShort(t *testing.T) { +func TestRequestRetryStrategyFunc(t *testing.T) { ts := createGetServer(t) defer ts.Close() - attempt := 0 - - retryCount := 5 + retryCount := 10 retryIntervals := make([]uint64, retryCount+1) - // Set retry wait times that do not intersect with default ones - retryWaitTime := 3 * time.Second - retryMaxWaitTime := 9 * time.Second + // Set retry wait times to constant delay + retryWaitTime := 50 * time.Millisecond + retryMaxWaitTime := 50 * time.Millisecond - retryAfter := func(client *Client, resp *Response) (time.Duration, error) { - return 2 * time.Second, nil // too short duration + // custom strategy func with constant delay + retryStrategyFunc := func(res *Response, err error) (time.Duration, error) { + return 50 * time.Millisecond, nil } - c := dcnl(). + c := dcnl() + + res, _ := c.R(). SetRetryCount(retryCount). SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). - SetRetryAfter(retryAfter). + SetRetryStrategy(retryStrategyFunc). AddRetryCondition( func(r *Response, _ error) bool { - timeSlept, _ := strconv.ParseUint(string(r.bodyBytes), 10, 64) - retryIntervals[attempt] = timeSlept - attempt++ + retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true }, - ) - _, _ = c.R().Get(ts.URL + "/set-retrywaittime-test") + ). + Get(ts.URL + "/set-retrywaittime-test") - // 6 attempts were made - assertEqual(t, attempt, 6) + retryIntervals[res.Request.Attempt-1] = parseTimeSleptFromResponse(res.String()) + + // retryCount+1 == attempts were made + assertEqual(t, retryCount+1, res.Request.Attempt) // Initial attempt has 0 time slept since last request assertEqual(t, retryIntervals[0], uint64(0)) @@ -438,46 +330,47 @@ func TestClientRetryWaitCallbackTooShort(t *testing.T) { slept := time.Duration(retryIntervals[i]) // Ensure that client has slept some duration between // waitTime and maxWaitTime for consequent requests - if slept < retryWaitTime-5*time.Millisecond || retryWaitTime+5*time.Millisecond < slept { - t.Logf("Client has slept %f seconds before retry %d", slept.Seconds(), i) + if slept < retryWaitTime-5*time.Millisecond { + t.Logf("Client has slept %f seconds which is s < min (%f) before retry %d", slept.Seconds(), retryWaitTime.Seconds(), i) + } + if retryMaxWaitTime+5*time.Millisecond < slept { + t.Logf("Client has slept %f seconds which is max < s (%f) before retry %d", slept.Seconds(), retryMaxWaitTime.Seconds(), i) } } } -func TestClientRetryWaitCallbackTooLong(t *testing.T) { +func TestClientRetryStrategyWaitTooShort(t *testing.T) { ts := createGetServer(t) defer ts.Close() - attempt := 0 - retryCount := 5 retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones - retryWaitTime := 1 * time.Second - retryMaxWaitTime := 3 * time.Second + retryWaitTime := 50 * time.Millisecond + retryMaxWaitTime := 150 * time.Millisecond - retryAfter := func(client *Client, resp *Response) (time.Duration, error) { - return 4 * time.Second, nil // too long duration + retryStrategyFunc := func(res *Response, err error) (time.Duration, error) { + return 10 * time.Millisecond, nil } c := dcnl(). SetRetryCount(retryCount). SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). - SetRetryAfter(retryAfter). + SetRetryStrategy(retryStrategyFunc). AddRetryCondition( func(r *Response, _ error) bool { - timeSlept, _ := strconv.ParseUint(string(r.bodyBytes), 10, 64) - retryIntervals[attempt] = timeSlept - attempt++ + retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true }, ) - _, _ = c.R().Get(ts.URL + "/set-retrywaittime-test") + res, _ := c.R().Get(ts.URL + "/set-retrywaittime-test") - // 6 attempts were made - assertEqual(t, attempt, 6) + retryIntervals[res.Request.Attempt-1] = parseTimeSleptFromResponse(res.String()) + + // retryCount+1 == attempts were made + assertEqual(t, retryCount+1, res.Request.Attempt) // Initial attempt has 0 time slept since last request assertEqual(t, retryIntervals[0], uint64(0)) @@ -486,64 +379,60 @@ func TestClientRetryWaitCallbackTooLong(t *testing.T) { slept := time.Duration(retryIntervals[i]) // Ensure that client has slept some duration between // waitTime and maxWaitTime for consequent requests - if slept < retryMaxWaitTime-5*time.Millisecond || retryMaxWaitTime+5*time.Millisecond < slept { - t.Logf("Client has slept %f seconds before retry %d", slept.Seconds(), i) + if slept < retryWaitTime-5*time.Millisecond { + t.Logf("Client has slept %f seconds which is s < min (%f) before retry %d", slept.Seconds(), retryWaitTime.Seconds(), i) + } + if retryWaitTime+5*time.Millisecond < slept { + t.Logf("Client has slept %f seconds which is min < s (%f) before retry %d", slept.Seconds(), retryWaitTime.Seconds(), i) } } } -func TestClientRetryWaitCallbackSwitchToDefault(t *testing.T) { +func TestClientRetryStrategyWaitTooLong(t *testing.T) { ts := createGetServer(t) defer ts.Close() - attempt := 0 - retryCount := 5 retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones - retryWaitTime := 1 * time.Second - retryMaxWaitTime := 3 * time.Second + retryWaitTime := 10 * time.Millisecond + retryMaxWaitTime := 50 * time.Millisecond - retryAfter := func(client *Client, resp *Response) (time.Duration, error) { - return 0, nil // use default algorithm to determine retry-after time + retryStrategyFunc := func(res *Response, err error) (time.Duration, error) { + return 1 * time.Second, nil } c := dcnl(). - EnableTrace(). SetRetryCount(retryCount). SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). - SetRetryAfter(retryAfter). + SetRetryStrategy(retryStrategyFunc). AddRetryCondition( func(r *Response, _ error) bool { - timeSlept, _ := strconv.ParseUint(string(r.bodyBytes), 10, 64) - retryIntervals[attempt] = timeSlept - attempt++ + retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true }, ) - resp, _ := c.R().Get(ts.URL + "/set-retrywaittime-test") + res, _ := c.R().Get(ts.URL + "/set-retrywaittime-test") - // 6 attempts were made - assertEqual(t, attempt, 6) - assertEqual(t, resp.Request.Attempt, 6) - assertEqual(t, resp.Request.TraceInfo().RequestAttempt, 6) + retryIntervals[res.Request.Attempt-1] = parseTimeSleptFromResponse(res.String()) + + // retryCount+1 == attempt attempts were made + assertEqual(t, retryCount+1, res.Request.Attempt) // Initial attempt has 0 time slept since last request assertEqual(t, retryIntervals[0], uint64(0)) for i := 1; i < len(retryIntervals); i++ { slept := time.Duration(retryIntervals[i]) - expected := (1 << (uint(i - 1))) * time.Second - if expected > retryMaxWaitTime { - expected = retryMaxWaitTime - } - // Ensure that client has slept some duration between // waitTime and maxWaitTime for consequent requests - if slept < expected/2-5*time.Millisecond || expected+5*time.Millisecond < slept { - t.Errorf("Client has slept %f seconds before retry %d", slept.Seconds(), i) + if slept < retryMaxWaitTime-5*time.Millisecond { + t.Logf("Client has slept %f seconds which is s < max (%f) before retry %d", slept.Seconds(), retryMaxWaitTime.Seconds(), i) + } + if retryMaxWaitTime+5*time.Millisecond < slept { + t.Logf("Client has slept %f seconds which is max < s (%f) before retry %d", slept.Seconds(), retryMaxWaitTime.Seconds(), i) } } } @@ -552,14 +441,12 @@ func TestClientRetryCancel(t *testing.T) { ts := createGetServer(t) defer ts.Close() - attempt := 0 - retryCount := 5 retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones - retryWaitTime := time.Duration(10) * time.Second - retryMaxWaitTime := time.Duration(20) * time.Second + retryWaitTime := 100 * time.Millisecond + retryMaxWaitTime := 200 * time.Millisecond c := dcnl(). SetRetryCount(retryCount). @@ -567,20 +454,19 @@ func TestClientRetryCancel(t *testing.T) { SetRetryMaxWaitTime(retryMaxWaitTime). AddRetryCondition( func(r *Response, _ error) bool { - timeSlept, _ := strconv.ParseUint(string(r.bodyBytes), 10, 64) - retryIntervals[attempt] = timeSlept - attempt++ + retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true }, ) - timeout := 2 * time.Second + timeout := 100 * time.Millisecond ctx, cancelFunc := context.WithTimeout(context.Background(), timeout) - _, _ = c.R().SetContext(ctx).Get(ts.URL + "/set-retrywaittime-test") + req := c.R().SetContext(ctx) + _, _ = req.Get(ts.URL + "/set-retrywaittime-test") // 1 attempts were made - assertEqual(t, attempt, 1) + assertEqual(t, 1, req.Attempt) // Initial attempt has 0 time slept since last request assertEqual(t, retryIntervals[0], uint64(0)) @@ -661,14 +547,14 @@ func TestClientRetryErrorRecover(t *testing.T) { assertNil(t, resp.Error()) } -func TestClientRetryCount(t *testing.T) { +func TestClientRetryCountWithTimeout(t *testing.T) { ts := createGetServer(t) defer ts.Close() attempt := 0 c := dcnl(). - SetTimeout(time.Second * 3). + SetTimeout(time.Millisecond * 50). SetRetryCount(1). AddRetryCondition( func(r *Response, _ error) bool { @@ -685,20 +571,19 @@ func TestClientRetryCount(t *testing.T) { assertEqual(t, 0, len(resp.Header())) // 2 attempts were made - assertEqual(t, attempt, 2) + assertEqual(t, 2, resp.Request.Attempt) assertEqual(t, true, strings.HasPrefix(err.Error(), "Get "+ts.URL+"/set-retrycount-test") || strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/set-retrycount-test\"")) } -func TestClientErrorRetry(t *testing.T) { +func TestClientRetryTooManyRequestsAndRecover(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl(). - SetTimeout(time.Second * 3). - SetRetryCount(1). - AddRetryAfterErrorCondition() + SetTimeout(time.Second * 1). + SetRetryCount(2) resp, err := c.R(). SetHeader(hdrContentTypeKey, "application/json; charset=utf-8"). @@ -720,15 +605,17 @@ func TestClientRetryHook(t *testing.T) { ts := createGetServer(t) defer ts.Close() - attempt := 0 + hookCalledCount := 0 retryHook := func(r *Response, _ error) { - attempt++ + hookCalledCount++ } + retryCount := 3 + c := dcnl(). - SetRetryCount(2). - SetTimeout(time.Second * 3). + SetRetryCount(retryCount). + SetTimeout(50 * time.Millisecond). AddRetryHook(retryHook) // Since reflect.DeepEqual can not compare two functions @@ -745,16 +632,13 @@ func TestClientRetryHook(t *testing.T) { assertEqual(t, 0, len(resp.Cookies())) assertEqual(t, 0, len(resp.Header())) - assertEqual(t, 3, attempt) + assertEqual(t, retryCount+1, resp.Request.Attempt) + assertEqual(t, 3, hookCalledCount) assertEqual(t, true, strings.HasPrefix(err.Error(), "Get "+ts.URL+"/set-retrycount-test") || strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/set-retrycount-test\"")) } -func filler(*Response, error) bool { - return false -} - var errSeekFailure = fmt.Errorf("failing seek test") type failingSeeker struct { @@ -783,11 +667,7 @@ func TestResetMultipartReaderSeekStartError(t *testing.T) { c := dcnl(). SetRetryCount(2). - SetTimeout(time.Second * 3). - SetRetryResetReaders(true). - AddRetryAfterErrorCondition() - - assertEqual(t, true, c.RetryResetReaders()) + SetTimeout(200 * time.Millisecond) resp, err := c.R(). SetFileReader("name", "filename", testSeeker). @@ -810,8 +690,6 @@ func TestClientResetMultipartReaders(t *testing.T) { c := dcnl(). SetRetryCount(2). SetTimeout(time.Second * 3). - SetRetryResetReaders(true). - AddRetryAfterErrorCondition(). AddRetryHook( func(response *Response, _ error) { read, err := bufReader.Read(bufCpy) @@ -822,8 +700,6 @@ func TestClientResetMultipartReaders(t *testing.T) { }, ) - assertEqual(t, true, c.RetryResetReaders()) - resp, err := c.R(). SetFileReader("name", "filename", bufReader). Post(ts.URL + "/set-reset-multipart-readers-test") @@ -844,7 +720,6 @@ func TestRequestResetMultipartReaders(t *testing.T) { c := dcnl(). SetTimeout(time.Second * 3). - AddRetryAfterErrorCondition(). AddRetryHook( func(response *Response, _ error) { read, err := bufReader.Read(bufCpy) @@ -855,15 +730,142 @@ func TestRequestResetMultipartReaders(t *testing.T) { }, ) - assertEqual(t, false, c.RetryResetReaders()) - req := c.R(). SetRetryCount(2). - SetRetryResetReaders(true). SetFileReader("name", "filename", bufReader) resp, err := req.Post(ts.URL + "/set-reset-multipart-readers-test") - assertEqual(t, true, req.RetryResetReaders) assertEqual(t, 500, resp.StatusCode()) assertNil(t, err) } + +func TestParseRetryAfterHeader(t *testing.T) { + testStaticTime(t) + + tests := []struct { + name string + header string + sleep time.Duration + ok bool + }{ + {"seconds", "2", time.Second * 2, true}, + {"date", "Fri, 31 Dec 1999 23:59:59 GMT", time.Second * 2, true}, + {"past-date", "Fri, 31 Dec 1999 23:59:00 GMT", 0, true}, + {"two-headers", "3", time.Second * 3, true}, + {"empty", "", 0, false}, + {"negative", "-2", 0, false}, + {"bad-date", "Fri, 32 Dec 1999 23:59:59 GMT", 0, false}, + {"bad-date-format", "badbadbad", 0, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + sleep, ok := parseRetryAfterHeader(test.header) + if ok != test.ok { + t.Errorf("expected ok=%t, got ok=%t", test.ok, ok) + } + if sleep != test.sleep { + t.Errorf("expected sleep=%v, got sleep=%v", test.sleep, sleep) + } + }) + } +} + +func TestRetryTooManyRequestsHeaderRetryAfter(t *testing.T) { + ts := createGetServer(t) + defer ts.Close() + + c := dcnl() + + resp, err := c.R(). + SetRetryCount(2). + SetHeader(hdrContentTypeKey, "application/json; charset=utf-8"). + SetResult(AuthSuccess{}). + Get(ts.URL + "/retry-after-delay") + + assertError(t, err) + + authSuccess := resp.Result().(*AuthSuccess) + + assertEqual(t, http.StatusOK, resp.StatusCode()) + assertEqual(t, "hello", authSuccess.Message) + + assertNil(t, resp.Error()) +} + +func TestRetryDefaultConditions(t *testing.T) { + t.Run("redirect error", func(t *testing.T) { + ts := createRedirectServer(t) + defer ts.Close() + + _, err := dcnl().R(). + SetRetryCount(2). + Get(ts.URL + "/redirect-1") + + assertNotNil(t, err) + assertEqual(t, true, (err.Error() == `Get "/redirect-11": stopped after 10 redirects`)) + }) + + t.Run("invalid scheme error", func(t *testing.T) { + ts := createGetServer(t) + defer ts.Close() + + c := dcnl().SetBaseURL(strings.Replace(ts.URL, "http", "ftp", 1)) + + _, err := c.R(). + SetRetryCount(2). + Get("/") + assertNotNil(t, err) + assertEqual(t, true, strings.Contains(err.Error(), `unsupported protocol scheme "ftp"`)) + }) + + t.Run("invalid header error", func(t *testing.T) { + ts := createGetServer(t) + defer ts.Close() + + _, err := dcnl().R(). + SetRetryCount(2). + SetHeader("Header-Name", "bad header value \033"). + Get(ts.URL + "/") + assertNotNil(t, err) + assertEqual(t, true, strings.Contains(err.Error(), "net/http: invalid header field value")) + + _, err = dcnl().R(). + SetRetryCount(2). + SetHeader("Header-Name\033", "bad header value"). + Get(ts.URL + "/") + assertNotNil(t, err) + assertEqual(t, true, strings.Contains(err.Error(), "net/http: invalid header field name")) + }) +} + +func TestRetryCoverage(t *testing.T) { + t.Run("apply retry default min and max value", func(t *testing.T) { + backoff := newBackoffWithJitter(0, 0) + assertEqual(t, defaultWaitTime, backoff.min) + assertEqual(t, defaultMaxWaitTime, backoff.max) + }) + + t.Run("mock tls cert error", func(t *testing.T) { + certError := tls.CertificateVerificationError{} + result1 := applyRetryDefaultConditions(nil, &certError) + assertEqual(t, false, result1) + }) +} + +func parseTimeSleptFromResponse(v string) uint64 { + timeSlept, _ := strconv.ParseUint(v, 10, 64) + return timeSlept +} + +func testStaticTime(t *testing.T) { + timeNow = func() time.Time { + now, err := time.Parse(time.RFC1123, "Fri, 31 Dec 1999 23:59:57 GMT") + if err != nil { + panic(err) + } + return now + } + t.Cleanup(func() { + timeNow = time.Now + }) +} diff --git a/util.go b/util.go index 595a5cf4..29518c8a 100644 --- a/util.go +++ b/util.go @@ -1,12 +1,12 @@ -// Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved. +// Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. +// SPDX-License-Identifier: MIT package resty import ( "bytes" - "errors" "fmt" "io" "log" @@ -17,7 +17,6 @@ import ( "runtime" "sort" "strings" - "sync" ) //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ @@ -63,16 +62,6 @@ func (l *logger) output(format string, v ...any) { l.l.Printf(format, v...) } -//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Rate Limiter interface -//_______________________________________________________________________ - -var ErrRateLimitExceeded = errors.New("resty: rate limit exceeded") - -type RateLimiter interface { - Allow() bool -} - //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Package Helper methods //_______________________________________________________________________ @@ -200,35 +189,6 @@ func releaseBuffer(buf *bytes.Buffer) { } } -func wrapRequestBufferReleaser(r *Request) io.ReadCloser { - if r.bodyBuf == nil { - return r.RawRequest.Body - } - return &requestBufferReleaser{ - reqBuf: r.bodyBuf, - ReadCloser: r.RawRequest.Body, - } -} - -var _ io.ReadCloser = (*requestBufferReleaser)(nil) - -// requestBufferReleaser wraps request body and implements custom Close for it. -// The Close method closes original body and releases request body back to sync.Pool. -type requestBufferReleaser struct { - releaseOnce sync.Once - reqBuf *bytes.Buffer - io.ReadCloser -} - -func (rr *requestBufferReleaser) Close() error { - err := rr.ReadCloser.Close() - rr.releaseOnce.Do(func() { - releaseBuffer(rr.reqBuf) - }) - - return err -} - func closeq(v any) { if c, ok := v.(io.Closer); ok { silently(c.Close()) @@ -283,28 +243,6 @@ func (e *restyError) Unwrap() error { return e.inner } -type noRetryErr struct { - err error -} - -func (e *noRetryErr) Error() string { - return e.err.Error() -} - -func wrapNoRetryErr(err error) error { - if err != nil { - err = &noRetryErr{err: err} - } - return err -} - -func unwrapNoRetryErr(err error) error { - if e, ok := err.(*noRetryErr); ok { - err = e.err - } - return err -} - // cloneURLValues is a helper function to deep copy url.Values. func cloneURLValues(v url.Values) url.Values { if v == nil { @@ -329,3 +267,26 @@ func cloneCookie(c *http.Cookie) *http.Cookie { Unparsed: c.Unparsed, } } + +var mimeInvalidBoundaryErrStr = "mime: invalid boundary character" + +func isInvalidRequestError(err error) bool { + if u, ok := err.(*url.Error); ok { + if u.Op == "parse" { + return true + } + } + if err.Error() == mimeInvalidBoundaryErrStr || + err == ErrNoActiveHost || + err == ErrUnsupportedRequestBodyKind { + return true + } + return false +} + +func drainBody(res *Response) { + if res != nil && res.Body != nil { + defer closeq(res.Body) + _, _ = io.Copy(io.Discard, res.Body) + } +} diff --git a/util_test.go b/util_test.go index 4250da15..eb681e23 100644 --- a/util_test.go +++ b/util_test.go @@ -122,7 +122,4 @@ func TestUtilMiscTestCoverage(t *testing.T) { }{} err := decodeJSON(bytes.NewReader([]byte(`{\" \": \"some value\"}`)), &v) assertEqual(t, "invalid character '\\\\' looking for beginning of object key string", err.Error()) - - err = &noRetryErr{err: errors.New("hey error")} - assertEqual(t, "hey error", err.Error()) }