From 8371f8d909ae2dc03eec58d1e8a7ae2e1e3729ac Mon Sep 17 00:00:00 2001 From: Vasily Gerasimov Date: Thu, 11 Apr 2024 16:58:30 +0000 Subject: [PATCH] Improve error handling --- internal/credentials/oauth2.go | 29 +++++---- internal/credentials/oauth2_test.go | 93 ++++++++++++++++------------- 2 files changed, 70 insertions(+), 52 deletions(-) diff --git a/internal/credentials/oauth2.go b/internal/credentials/oauth2.go index 9efa8b504..2df73e952 100644 --- a/internal/credentials/oauth2.go +++ b/internal/credentials/oauth2.go @@ -27,8 +27,15 @@ const ( var ( errEmptyTokenEndpointError = errors.New("OAuth2 token exchange: empty token endpoint") + errCouldNotParseResponse = errors.New("OAuth2 token exchange: could not parse response") + errCouldNotExchangeToken = errors.New("OAuth2 token exchange: could not exchange token") + errUnsupportedTokenType = errors.New("OAuth2 token exchange: unsupported token type") + errIncorrectExpirationTime = errors.New("OAuth2 token exchange: incorrect expiration time") + errDifferentScope = errors.New("OAuth2 token exchange: got different scope") + errCouldNotMakeHTTPRequest = errors.New("OAuth2 token exchange: could not make http request") errNoSigningMethodError = errors.New("JWT token source: no signing method") errNoPrivateKeyError = errors.New("JWT token source: no private key") + errCouldNotSignJWTToken = errors.New("JWT token source: could not sign jwt token") ) type Oauth2TokenExchangeCredentialsOption interface { @@ -267,7 +274,7 @@ func (provider *oauth2TokenExchange) processTokenExchangeResponse(result *http.R } if result.StatusCode != http.StatusOK { - description := "OAuth2 token exchange: could not exchange token: " + result.Status + description := result.Status //nolint:tagliatelle type errorResponse struct { @@ -279,7 +286,7 @@ func (provider *oauth2TokenExchange) processTokenExchangeResponse(result *http.R if err := json.Unmarshal(data, &parsedErrorResponse); err != nil { description += ", could not parse response: " + err.Error() - return xerrors.WithStackTrace(errors.New(description)) + return xerrors.WithStackTrace(fmt.Errorf("%w: %s", errCouldNotExchangeToken, description)) } if parsedErrorResponse.ErrorName != "" { @@ -294,7 +301,7 @@ func (provider *oauth2TokenExchange) processTokenExchangeResponse(result *http.R description += ", error_uri: " + parsedErrorResponse.ErrorURI } - return xerrors.WithStackTrace(errors.New(description)) + return xerrors.WithStackTrace(fmt.Errorf("%w: %s", errCouldNotExchangeToken, description)) } //nolint:tagliatelle @@ -306,24 +313,24 @@ func (provider *oauth2TokenExchange) processTokenExchangeResponse(result *http.R } var parsedResponse response if err := json.Unmarshal(data, &parsedResponse); err != nil { - return xerrors.WithStackTrace(fmt.Errorf("OAuth2 token exchange: could not parse response: %w", err)) + return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotParseResponse, err)) } if !strings.EqualFold(parsedResponse.TokenType, "bearer") { return xerrors.WithStackTrace( - fmt.Errorf("OAuth2 token exchange: unsupported token type: %q", parsedResponse.TokenType)) + fmt.Errorf("%w: %q", errUnsupportedTokenType, parsedResponse.TokenType)) } if parsedResponse.ExpiresIn <= 0 { return xerrors.WithStackTrace( - fmt.Errorf("OAuth2 token exchange: incorrect expiration time: %d", parsedResponse.ExpiresIn)) + fmt.Errorf("%w: %d", errIncorrectExpirationTime, parsedResponse.ExpiresIn)) } if parsedResponse.Scope != "" { scope := provider.getScopeParam() if parsedResponse.Scope != scope { return xerrors.WithStackTrace( - fmt.Errorf("OAuth2 token exchange: got different scope. Expected %q, but got %q", scope, parsedResponse.Scope)) + fmt.Errorf("%w. Expected %q, but got %q", errDifferentScope, scope, parsedResponse.Scope)) } } @@ -344,12 +351,12 @@ func (provider *oauth2TokenExchange) processTokenExchangeResponse(result *http.R func (provider *oauth2TokenExchange) exchangeToken(ctx context.Context, now time.Time) error { body, err := provider.getRequestParams() if err != nil { - return xerrors.WithStackTrace(fmt.Errorf("OAuth2 token exchange: could not make http request: %w", err)) + return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotMakeHTTPRequest, err)) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.tokenEndpoint, strings.NewReader(body)) if err != nil { - return xerrors.WithStackTrace(fmt.Errorf("OAuth2 token exchange: could not make http request: %w", err)) + return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotMakeHTTPRequest, err)) } req.Header.Add("Content-Type", "application/x-www-form-urlencoded") req.Header.Add("Content-Length", strconv.Itoa(len(body))) @@ -362,7 +369,7 @@ func (provider *oauth2TokenExchange) exchangeToken(ctx context.Context, now time result, err := client.Do(req) if err != nil { - return xerrors.WithStackTrace(fmt.Errorf("OAuth2 token exchange: could not exchange token: %w", err)) + return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotExchangeToken, err)) } defer result.Body.Close() @@ -618,7 +625,7 @@ func (s *jwtTokenSource) Token() (Token, error) { var token Token token.Token, err = t.SignedString(s.privateKey) if err != nil { - return token, xerrors.WithStackTrace(fmt.Errorf("JWT token source: could not sign jwt token: %w", err)) + return token, xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotSignJWTToken, err)) } token.TokenType = "urn:ietf:params:oauth:token-type:jwt" diff --git a/internal/credentials/oauth2_test.go b/internal/credentials/oauth2_test.go index 1be5999a1..2c5ff753b 100644 --- a/internal/credentials/oauth2_test.go +++ b/internal/credentials/oauth2_test.go @@ -94,6 +94,7 @@ type Oauth2TokenExchangeTestParams struct { Response string Status int ExpectedToken string + ExpectedError error ExpectedErrorPart string } @@ -107,70 +108,73 @@ func TestOauth2TokenExchange(t *testing.T) { testsParams := []Oauth2TokenExchangeTestParams{ { - Response: `{"access_token":"test_token","token_type":"BEARER","expires_in":42,"some_other_field":"x"}`, - Status: http.StatusOK, - ExpectedToken: "Bearer test_token", - ExpectedErrorPart: "", + Response: `{"access_token":"test_token","token_type":"BEARER","expires_in":42,"some_other_field":"x"}`, + Status: http.StatusOK, + ExpectedToken: "Bearer test_token", }, { - Response: `aaa`, - Status: http.StatusOK, - ExpectedToken: "", - ExpectedErrorPart: "OAuth2 token exchange: could not parse response:", + Response: `aaa`, + Status: http.StatusOK, + ExpectedToken: "", + ExpectedError: errCouldNotParseResponse, }, { - Response: `{}`, - Status: http.StatusBadRequest, - ExpectedToken: "", - ExpectedErrorPart: "OAuth2 token exchange: could not exchange token: 400 Bad Request", + Response: `{}`, + Status: http.StatusBadRequest, + ExpectedToken: "", + ExpectedError: errCouldNotExchangeToken, }, { - Response: `not json`, - Status: http.StatusNotFound, - ExpectedToken: "", - ExpectedErrorPart: "OAuth2 token exchange: could not exchange token: 404 Not Found", + Response: `not json`, + Status: http.StatusNotFound, + ExpectedToken: "", + ExpectedError: errCouldNotExchangeToken, }, { Response: `{"error": "invalid_request"}`, Status: http.StatusBadRequest, ExpectedToken: "", - ExpectedErrorPart: "OAuth2 token exchange: could not exchange token: 400 Bad Request, error: invalid_request", + ExpectedError: errCouldNotExchangeToken, + ExpectedErrorPart: "400 Bad Request, error: invalid_request", }, { Response: `{"error":"unauthorized_client","error_description":"something went bad"}`, Status: http.StatusInternalServerError, ExpectedToken: "", - ExpectedErrorPart: "OAuth2 token exchange: could not exchange token: 500 Internal Server Error, error: unauthorized_client, description: \"something went bad\"", //nolint:lll + ExpectedError: errCouldNotExchangeToken, + ExpectedErrorPart: "500 Internal Server Error, error: unauthorized_client, description: \"something went bad\"", //nolint:lll }, { Response: `{"error_description":"something went bad","error_uri":"my_error_uri"}`, Status: http.StatusForbidden, ExpectedToken: "", - ExpectedErrorPart: "OAuth2 token exchange: could not exchange token: 403 Forbidden, description: \"something went bad\", error_uri: my_error_uri", //nolint:lll + ExpectedError: errCouldNotExchangeToken, + ExpectedErrorPart: "403 Forbidden, description: \"something went bad\", error_uri: my_error_uri", }, { - Response: `{"access_token":"test_token","token_type":"","expires_in":42,"some_other_field":"x"}`, - Status: http.StatusOK, - ExpectedToken: "", - ExpectedErrorPart: "OAuth2 token exchange: unsupported token type: \"\"", + Response: `{"access_token":"test_token","token_type":"","expires_in":42,"some_other_field":"x"}`, + Status: http.StatusOK, + ExpectedToken: "", + ExpectedError: errUnsupportedTokenType, }, { - Response: `{"access_token":"test_token","token_type":"basic","expires_in":42,"some_other_field":"x"}`, - Status: http.StatusOK, - ExpectedToken: "", - ExpectedErrorPart: "OAuth2 token exchange: unsupported token type: \"basic\"", + Response: `{"access_token":"test_token","token_type":"basic","expires_in":42,"some_other_field":"x"}`, + Status: http.StatusOK, + ExpectedToken: "", + ExpectedError: errUnsupportedTokenType, }, { - Response: `{"access_token":"test_token","token_type":"Bearer","expires_in":-42,"some_other_field":"x"}`, - Status: http.StatusOK, - ExpectedToken: "", - ExpectedErrorPart: "OAuth2 token exchange: incorrect expiration time: -42", + Response: `{"access_token":"test_token","token_type":"Bearer","expires_in":-42,"some_other_field":"x"}`, + Status: http.StatusOK, + ExpectedToken: "", + ExpectedError: errIncorrectExpirationTime, }, { Response: `{"access_token":"test_token","token_type":"Bearer","expires_in":42,"scope":"s"}`, Status: http.StatusOK, ExpectedToken: "", - ExpectedErrorPart: "OAuth2 token exchange: got different scope. Expected \"test_scope1 test_scope2\", but got \"s\"", + ExpectedError: errDifferentScope, + ExpectedErrorPart: "Expected \"test_scope1 test_scope2\", but got \"s\"", }, } @@ -186,10 +190,15 @@ func TestOauth2TokenExchange(t *testing.T) { require.NoError(t, err) token, err := client.Token(ctx) - if params.ExpectedErrorPart == "" { + if params.ExpectedErrorPart == "" && params.ExpectedError == nil { require.NoError(t, err) } else { - require.ErrorContains(t, err, params.ExpectedErrorPart) + if params.ExpectedErrorPart != "" { + require.ErrorContains(t, err, params.ExpectedErrorPart) + } + if params.ExpectedError != nil { + require.ErrorIs(t, err, params.ExpectedError) + } } require.Equal(t, params.ExpectedToken, token) } @@ -270,13 +279,15 @@ func TestWrongParameters(t *testing.T) { WithSubjectToken(NewFixedTokenSource("test_source_token", "urn:ietf:params:oauth:token-type:test_jwt")), WithRequestedTokenType("access_token"), ) - require.Error(t, err) + require.ErrorIs(t, err, errEmptyTokenEndpointError) } type errorTokenSource struct{} +var errTokenSource = errors.New("test error") + func (s *errorTokenSource) Token() (Token, error) { - return Token{"", ""}, errors.New("Test error") + return Token{"", ""}, errTokenSource } func TestErrorInSourceToken(t *testing.T) { @@ -290,7 +301,7 @@ func TestErrorInSourceToken(t *testing.T) { require.NoError(t, err) token, err := client.Token(context.Background()) - require.Error(t, err) + require.ErrorIs(t, err, errTokenSource) require.Equal(t, "", token) client, err = NewOauth2TokenExchangeCredentials( @@ -303,7 +314,7 @@ func TestErrorInSourceToken(t *testing.T) { require.NoError(t, err) token, err = client.Token(context.Background()) - require.Error(t, err) + require.ErrorIs(t, err, errTokenSource) require.Equal(t, "", token) } @@ -315,7 +326,7 @@ func TestErrorInHTTPRequest(t *testing.T) { require.NoError(t, err) token, err := client.Token(context.Background()) - require.Error(t, err) + require.ErrorIs(t, err, errCouldNotExchangeToken) require.Equal(t, "", token) } @@ -367,7 +378,7 @@ func TestJWTTokenBadParams(t *testing.T) { WithAudience("test_audience"), WithID("id"), ) - require.Error(t, err) + require.ErrorIs(t, err, errNoPrivateKeyError) _, err = NewJWTTokenSource( WithPrivateKey(privateKey), @@ -377,5 +388,5 @@ func TestJWTTokenBadParams(t *testing.T) { WithTokenTTL(time.Minute), WithAudience("test_audience"), ) - require.Error(t, err) + require.ErrorIs(t, err, errNoSigningMethodError) }