diff --git a/client.go b/client.go index 013c947..85fb149 100644 --- a/client.go +++ b/client.go @@ -646,6 +646,8 @@ func (c *Client) R() *Request { debugLogCurlCmd: c.debugLogCurlCmd, unescapeQueryParams: c.unescapeQueryParams, credentials: c.credentials, + retryConditions: slices.Clone(c.retryConditions), + retryHooks: slices.Clone(c.retryHooks), } if c.ctx != nil { @@ -1362,16 +1364,18 @@ func (c *Client) RetryConditions() []RetryConditionFunc { return c.retryConditions } -// AddRetryCondition method adds a retry condition function to an array of functions -// that are checked to determine if the request is retried. The request will -// retry if any functions return true and the error is nil. +// AddRetryConditions method adds one or more retry condition functions into the request. +// These retry conditions are executed to determine if the request can be retried. +// The request will retry if any functions return `true`, otherwise return `false`. // -// NOTE: These retry conditions are applied on all requests made using this Client. -// For [Request] specific retry conditions, check [Request.AddRetryCondition] -func (c *Client) AddRetryCondition(condition RetryConditionFunc) *Client { +// NOTE: +// - The client-level retry conditions are applied to all requests. +// - The request-level retry conditions are executed first before client-level +// retry conditions. See [Request.AddRetryConditions], [Request.SetRetryConditions] +func (c *Client) AddRetryConditions(conditions ...RetryConditionFunc) *Client { c.lock.Lock() defer c.lock.Unlock() - c.retryConditions = append(c.retryConditions, condition) + c.retryConditions = append(c.retryConditions, conditions...) return c } @@ -1382,12 +1386,16 @@ func (c *Client) RetryHooks() []RetryHookFunc { return c.retryHooks } -// AddRetryHook adds a side-effecting retry hook to an array of hooks -// that will be executed on each retry. -func (c *Client) AddRetryHook(hook RetryHookFunc) *Client { +// AddRetryHooks method adds one or more side-effecting retry hooks to an array +// of hooks that will be executed on each retry. +// +// NOTE: +// - All the retry hooks are executed on request retry. +// - The request-level retry hooks are executed first before client-level hooks. +func (c *Client) AddRetryHooks(hooks ...RetryHookFunc) *Client { c.lock.Lock() defer c.lock.Unlock() - c.retryHooks = append(c.retryHooks, hook) + c.retryHooks = append(c.retryHooks, hooks...) return c } diff --git a/client_test.go b/client_test.go index 82aa4b8..1fe38c9 100644 --- a/client_test.go +++ b/client_test.go @@ -518,6 +518,7 @@ func TestClientSettingsCoverage(t *testing.T) { assertEqual(t, time.Millisecond*100, c.RetryWaitTime()) assertEqual(t, time.Second*2, c.RetryMaxWaitTime()) assertEqual(t, false, c.IsTrace()) + assertEqual(t, 0, len(c.RetryConditions())) authToken := "sample auth token value" c.SetAuthToken(authToken) @@ -1144,7 +1145,7 @@ func TestClientOnResponseError(t *testing.T) { SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF"). SetRetryCount(0). SetRetryMaxWaitTime(time.Microsecond). - AddRetryCondition(func(response *Response, err error) bool { + AddRetryConditions(func(response *Response, err error) bool { if err != nil { return true } diff --git a/request.go b/request.go index 78713bd..6336d12 100644 --- a/request.go +++ b/request.go @@ -94,6 +94,7 @@ type Request struct { multipartBoundary string multipartFields []*MultipartField retryConditions []RetryConditionFunc + retryHooks []RetryHookFunc resultCurlCmd string generateCurlCmd bool debugLogCurlCmd bool @@ -945,14 +946,43 @@ func (r *Request) SetDebug(d bool) *Request { return r } -// AddRetryCondition method adds a retry condition function to the request's -// array of functions is checked to determine if the request can be retried. -// The request will retry if any functions return true and the error is nil. +// AddRetryConditions method adds one or more retry condition functions into the request. +// These retry conditions are executed to determine if the request can be retried. +// The request will retry if any functions return `true`, otherwise return `false`. // -// NOTE: The request level retry conditions are checked before all retry -// conditions from the client instance. -func (r *Request) AddRetryCondition(condition RetryConditionFunc) *Request { - r.retryConditions = append(r.retryConditions, condition) +// NOTE: +// - The client-level retry conditions are applied to all requests. +// - The request-level retry conditions are executed first before client-level +// retry conditions. See [Request.SetRetryConditions] +func (r *Request) AddRetryConditions(conditions ...RetryConditionFunc) *Request { + r.retryConditions = append(r.retryConditions, conditions...) + return r +} + +// SetRetryConditions method overwrites the retry conditions in the request. +// These retry conditions are executed to determine if the request can be retried. +// The request will retry if any function returns `true`, otherwise return `false`. +func (r *Request) SetRetryConditions(conditions ...RetryConditionFunc) *Request { + r.retryConditions = conditions + return r +} + +// AddRetryHooks method adds one or more side-effecting retry hooks in the request. +// +// NOTE: +// - All the retry hooks are executed on each request retry. +// - The request-level retry hooks are executed first before client-level hooks. +func (r *Request) AddRetryHooks(hooks ...RetryHookFunc) *Request { + r.retryHooks = append(r.retryHooks, hooks...) + return r +} + +// SetRetryHooks method overwrites side-effecting retry hooks in the request. +// +// NOTE: +// - All the retry hooks are executed on each request retry. +func (r *Request) SetRetryHooks(hooks ...RetryHookFunc) *Request { + r.retryHooks = hooks return r } @@ -1355,8 +1385,7 @@ func (r *Request) Execute(method, url string) (res *Response, err error) { // is still false if !needsRetry && res != nil { // user defined retry conditions - retryConditions := append(r.retryConditions, r.client.RetryConditions()...) - for _, retryCondition := range retryConditions { + for _, retryCondition := range r.retryConditions { if needsRetry = retryCondition(res, err); needsRetry { break } @@ -1375,7 +1404,7 @@ func (r *Request) Execute(method, url string) (res *Response, err error) { } // run user-defined retry hooks - for _, retryHookFunc := range r.client.RetryHooks() { + for _, retryHookFunc := range r.retryHooks { retryHookFunc(res, err) } @@ -1393,11 +1422,11 @@ func (r *Request) Execute(method, url string) (res *Response, err error) { select { case <-r.Context().Done(): isCtxDone = true - timer.Stop() err = wrapErrors(r.Context().Err(), err) break case <-timer.C: } + timer.Stop() if isCtxDone { break } diff --git a/request_test.go b/request_test.go index f046466..a66038a 100644 --- a/request_test.go +++ b/request_test.go @@ -2137,7 +2137,7 @@ func TestRequestNoRetryOnNonIdempotentMethod(t *testing.T) { c := dcnl(). SetTimeout(time.Second * 3). - AddRetryHook( + AddRetryHooks( func(response *Response, _ error) { read, err := bufReader.Read(bufCpy) diff --git a/retry_test.go b/retry_test.go index 0691fdd..42fd753 100644 --- a/retry_test.go +++ b/retry_test.go @@ -37,7 +37,7 @@ func TestRetryConditionalGet(t *testing.T) { client := dcnl() resp, err := client.R(). - AddRetryCondition(check). + AddRetryConditions(check). SetRetryCount(2). SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)). Get(ts.URL + "/") @@ -51,7 +51,7 @@ func TestRetryConditionalGet(t *testing.T) { logResponse(t, resp) } -func TestConditionalGetRequestLevel(t *testing.T) { +func TestRequestConditionalGet(t *testing.T) { ts := createGetServer(t) defer ts.Close() @@ -67,7 +67,7 @@ func TestConditionalGetRequestLevel(t *testing.T) { resp, err := c.R(). EnableDebug(). - AddRetryCondition(check). + AddRetryConditions(check). SetRetryCount(1). SetRetryWaitTime(50*time.Millisecond). SetRetryMaxWaitTime(1*time.Second). @@ -118,7 +118,7 @@ func TestClientRetryWithMinAndMaxWaitTime(t *testing.T) { c.SetRetryCount(retryCount). SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). - AddRetryCondition( + AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true @@ -164,7 +164,7 @@ func TestClientRetryWaitMaxInfinite(t *testing.T) { SetRetryCount(retryCount). SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). - AddRetryCondition( + AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true @@ -199,7 +199,7 @@ func TestClientRetryWaitMaxMinimum(t *testing.T) { c := dcnl(). SetRetryCount(1). SetRetryMaxWaitTime(retryMaxWaitTime). - AddRetryCondition(func(*Response, error) bool { return true }) + AddRetryConditions(func(*Response, error) bool { return true }) _, err := c.R().Get(ts.URL + "/set-retrywaittime-test") assertError(t, err) } @@ -225,7 +225,7 @@ func TestClientRetryStrategyFuncError(t *testing.T) { SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). SetRetryStrategy(retryStrategyFunc). - AddRetryCondition( + AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[attempt] = parseTimeSleptFromResponse(r.String()) attempt++ @@ -263,7 +263,7 @@ func TestClientRetryStrategyFunc(t *testing.T) { SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). SetRetryStrategy(retryStrategyFunc). - AddRetryCondition( + AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true @@ -315,7 +315,7 @@ func TestRequestRetryStrategyFunc(t *testing.T) { SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). SetRetryStrategy(retryStrategyFunc). - AddRetryCondition( + AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true @@ -364,7 +364,7 @@ func TestClientRetryStrategyWaitTooShort(t *testing.T) { SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). SetRetryStrategy(retryStrategyFunc). - AddRetryCondition( + AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true @@ -413,7 +413,7 @@ func TestClientRetryStrategyWaitTooLong(t *testing.T) { SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). SetRetryStrategy(retryStrategyFunc). - AddRetryCondition( + AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true @@ -457,7 +457,7 @@ func TestClientRetryCancel(t *testing.T) { SetRetryCount(retryCount). SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). - AddRetryCondition( + AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true @@ -496,7 +496,7 @@ func TestClientRetryPost(t *testing.T) { c := dcnl() c.SetRetryCount(3) - c.AddRetryCondition(RetryConditionFunc(func(r *Response, _ error) bool { + c.AddRetryConditions(RetryConditionFunc(func(r *Response, _ error) bool { return r.StatusCode() >= http.StatusInternalServerError })) @@ -528,7 +528,7 @@ func TestClientRetryErrorRecover(t *testing.T) { c := dcnl(). SetRetryCount(2). SetError(AuthError{}). - AddRetryCondition( + AddRetryConditions( func(r *Response, _ error) bool { err, ok := r.Error().(*AuthError) retry := ok && r.StatusCode() == 429 && err.Message == "too many" @@ -561,7 +561,7 @@ func TestClientRetryCountWithTimeout(t *testing.T) { c := dcnl(). SetTimeout(50 * time.Millisecond). SetRetryCount(1). - AddRetryCondition( + AddRetryConditions( func(r *Response, _ error) bool { attempt++ return true @@ -618,7 +618,7 @@ func TestClientRetryHookWithTimeout(t *testing.T) { c := dcnl(). SetRetryCount(retryCount). SetTimeout(50 * time.Millisecond). - AddRetryHook(retryHook) + AddRetryHooks(retryHook) // Since reflect.DeepEqual can not compare two functions // just compare pointers of the two hooks @@ -690,7 +690,7 @@ func TestClientResetMultipartReaders(t *testing.T) { c := dcnl(). SetRetryCount(2). SetTimeout(time.Second * 3). - AddRetryHook( + AddRetryHooks( func(response *Response, _ error) { read, err := bufReader.Read(bufCpy) @@ -720,7 +720,7 @@ func TestRequestResetMultipartReaders(t *testing.T) { c := dcnl(). SetTimeout(time.Second * 3). - AddRetryHook( + AddRetryHooks( func(response *Response, _ error) { read, err := bufReader.Read(bufCpy) @@ -770,7 +770,7 @@ func TestParseRetryAfterHeader(t *testing.T) { } } -func TestRetryTooManyRequestsHeaderRetryAfter(t *testing.T) { +func TestRequestRetryTooManyRequestsHeaderRetryAfter(t *testing.T) { ts := createGetServer(t) defer ts.Close() @@ -843,7 +843,7 @@ func TestRetryDefaultConditions(t *testing.T) { }) } -func TestRetryRequestPutIoReadSeekerForBuffer(t *testing.T) { +func TestRequestRetryPutIoReadSeekerForBuffer(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { b, err := io.ReadAll(r.Body) assertError(t, err) @@ -853,7 +853,7 @@ func TestRetryRequestPutIoReadSeekerForBuffer(t *testing.T) { })) c := dcnl(). - AddRetryCondition( + AddRetryConditions( func(r *Response, err error) bool { return err != nil || r.StatusCode() > 499 }, @@ -875,7 +875,7 @@ func TestRetryRequestPutIoReadSeekerForBuffer(t *testing.T) { assertEqual(t, "", resp.String()) } -func TestRetryRequestPostIoReadSeeker(t *testing.T) { +func TestRequestRetryPostIoReadSeeker(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { b, err := io.ReadAll(r.Body) assertError(t, err) @@ -885,7 +885,7 @@ func TestRetryRequestPostIoReadSeeker(t *testing.T) { })) c := dcnl(). - AddRetryCondition( + AddRetryConditions( func(r *Response, err error) bool { return err != nil || r.StatusCode() > 499 }, @@ -906,7 +906,63 @@ func TestRetryRequestPostIoReadSeeker(t *testing.T) { assertEqual(t, "", resp.String()) } -func TestRetryQueryParamsGH938(t *testing.T) { +func TestRequestRetryHooks(t *testing.T) { + ts := createGetServer(t) + defer ts.Close() + + hookFunc := func(msg string) RetryHookFunc { + return func(res *Response, err error) { + res.Request.log.Debugf(msg) + } + } + + c, lb := dcldb() + c.AddRetryConditions(func(r *Response, err error) bool { + return true + }). + AddRetryHooks( + hookFunc("This is client hook1"), + hookFunc("This is client hook2"), + ) + + _, _ = c.R(). + SetRetryCount(1). + AddRetryHooks(hookFunc("This is request hook1")). + SetRetryHooks(hookFunc("This is request overwrite hook1")). + Get("/set-retrycount-test") + + debugLog := lb.String() + assertEqual(t, false, strings.Contains(debugLog, "This is client hook1")) + assertEqual(t, false, strings.Contains(debugLog, "This is client hook2")) + assertEqual(t, false, strings.Contains(debugLog, "This is request hook1")) + assertEqual(t, true, strings.Contains(debugLog, "This is request overwrite hook1")) +} + +func TestRequestSetRetryConditions(t *testing.T) { + ts := createGetServer(t) + defer ts.Close() + + condFunc := func(fn func() bool) RetryConditionFunc { + return func(r *Response, err error) bool { + return fn() + } + } + + c := dcnl(). + AddRetryConditions( + condFunc(func() bool { return true }), + condFunc(func() bool { return true }), + ) + + res, _ := c.R(). + SetRetryCount(2). + SetRetryConditions(condFunc(func() bool { return false })). // disable retry with overwrite condition + Get("/set-retrycount-test") + + assertEqual(t, 1, res.Request.Attempt) +} + +func TestRequestRetryQueryParamsGH938(t *testing.T) { ts := createGetServer(t) defer ts.Close() @@ -917,7 +973,7 @@ func TestRetryQueryParamsGH938(t *testing.T) { SetRetryCount(5). SetRetryWaitTime(10 * time.Millisecond). SetRetryMaxWaitTime(20 * time.Millisecond). - AddRetryCondition( + AddRetryConditions( func(r *Response, _ error) bool { assertEqual(t, expectedQueryParams, r.Request.RawRequest.URL.RawQuery) return true // always retry