Skip to content

Commit

Permalink
Improve error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
UgnineSirdis committed Apr 11, 2024
1 parent fe4baec commit 8371f8d
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 52 deletions.
29 changes: 18 additions & 11 deletions internal/credentials/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,15 @@ const (

var (
errEmptyTokenEndpointError = errors.New("OAuth2 token exchange: empty token endpoint")
errCouldNotParseResponse = errors.New("OAuth2 token exchange: could not parse response")
errCouldNotExchangeToken = errors.New("OAuth2 token exchange: could not exchange token")
errUnsupportedTokenType = errors.New("OAuth2 token exchange: unsupported token type")
errIncorrectExpirationTime = errors.New("OAuth2 token exchange: incorrect expiration time")
errDifferentScope = errors.New("OAuth2 token exchange: got different scope")
errCouldNotMakeHTTPRequest = errors.New("OAuth2 token exchange: could not make http request")
errNoSigningMethodError = errors.New("JWT token source: no signing method")
errNoPrivateKeyError = errors.New("JWT token source: no private key")
errCouldNotSignJWTToken = errors.New("JWT token source: could not sign jwt token")
)

type Oauth2TokenExchangeCredentialsOption interface {
Expand Down Expand Up @@ -267,7 +274,7 @@ func (provider *oauth2TokenExchange) processTokenExchangeResponse(result *http.R
}

if result.StatusCode != http.StatusOK {
description := "OAuth2 token exchange: could not exchange token: " + result.Status
description := result.Status

//nolint:tagliatelle
type errorResponse struct {
Expand All @@ -279,7 +286,7 @@ func (provider *oauth2TokenExchange) processTokenExchangeResponse(result *http.R
if err := json.Unmarshal(data, &parsedErrorResponse); err != nil {
description += ", could not parse response: " + err.Error()

return xerrors.WithStackTrace(errors.New(description))
return xerrors.WithStackTrace(fmt.Errorf("%w: %s", errCouldNotExchangeToken, description))
}

if parsedErrorResponse.ErrorName != "" {
Expand All @@ -294,7 +301,7 @@ func (provider *oauth2TokenExchange) processTokenExchangeResponse(result *http.R
description += ", error_uri: " + parsedErrorResponse.ErrorURI
}

return xerrors.WithStackTrace(errors.New(description))
return xerrors.WithStackTrace(fmt.Errorf("%w: %s", errCouldNotExchangeToken, description))
}

//nolint:tagliatelle
Expand All @@ -306,24 +313,24 @@ func (provider *oauth2TokenExchange) processTokenExchangeResponse(result *http.R
}
var parsedResponse response
if err := json.Unmarshal(data, &parsedResponse); err != nil {
return xerrors.WithStackTrace(fmt.Errorf("OAuth2 token exchange: could not parse response: %w", err))
return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotParseResponse, err))
}

if !strings.EqualFold(parsedResponse.TokenType, "bearer") {
return xerrors.WithStackTrace(
fmt.Errorf("OAuth2 token exchange: unsupported token type: %q", parsedResponse.TokenType))
fmt.Errorf("%w: %q", errUnsupportedTokenType, parsedResponse.TokenType))
}

if parsedResponse.ExpiresIn <= 0 {
return xerrors.WithStackTrace(
fmt.Errorf("OAuth2 token exchange: incorrect expiration time: %d", parsedResponse.ExpiresIn))
fmt.Errorf("%w: %d", errIncorrectExpirationTime, parsedResponse.ExpiresIn))
}

if parsedResponse.Scope != "" {
scope := provider.getScopeParam()
if parsedResponse.Scope != scope {
return xerrors.WithStackTrace(
fmt.Errorf("OAuth2 token exchange: got different scope. Expected %q, but got %q", scope, parsedResponse.Scope))
fmt.Errorf("%w. Expected %q, but got %q", errDifferentScope, scope, parsedResponse.Scope))
}
}

Expand All @@ -344,12 +351,12 @@ func (provider *oauth2TokenExchange) processTokenExchangeResponse(result *http.R
func (provider *oauth2TokenExchange) exchangeToken(ctx context.Context, now time.Time) error {
body, err := provider.getRequestParams()
if err != nil {
return xerrors.WithStackTrace(fmt.Errorf("OAuth2 token exchange: could not make http request: %w", err))
return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotMakeHTTPRequest, err))
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.tokenEndpoint, strings.NewReader(body))
if err != nil {
return xerrors.WithStackTrace(fmt.Errorf("OAuth2 token exchange: could not make http request: %w", err))
return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotMakeHTTPRequest, err))
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Content-Length", strconv.Itoa(len(body)))
Expand All @@ -362,7 +369,7 @@ func (provider *oauth2TokenExchange) exchangeToken(ctx context.Context, now time

result, err := client.Do(req)
if err != nil {
return xerrors.WithStackTrace(fmt.Errorf("OAuth2 token exchange: could not exchange token: %w", err))
return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotExchangeToken, err))
}

defer result.Body.Close()
Expand Down Expand Up @@ -618,7 +625,7 @@ func (s *jwtTokenSource) Token() (Token, error) {
var token Token
token.Token, err = t.SignedString(s.privateKey)
if err != nil {
return token, xerrors.WithStackTrace(fmt.Errorf("JWT token source: could not sign jwt token: %w", err))
return token, xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotSignJWTToken, err))
}
token.TokenType = "urn:ietf:params:oauth:token-type:jwt"

Expand Down
93 changes: 52 additions & 41 deletions internal/credentials/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ type Oauth2TokenExchangeTestParams struct {
Response string
Status int
ExpectedToken string
ExpectedError error
ExpectedErrorPart string
}

Expand All @@ -107,70 +108,73 @@ func TestOauth2TokenExchange(t *testing.T) {

testsParams := []Oauth2TokenExchangeTestParams{
{
Response: `{"access_token":"test_token","token_type":"BEARER","expires_in":42,"some_other_field":"x"}`,
Status: http.StatusOK,
ExpectedToken: "Bearer test_token",
ExpectedErrorPart: "",
Response: `{"access_token":"test_token","token_type":"BEARER","expires_in":42,"some_other_field":"x"}`,
Status: http.StatusOK,
ExpectedToken: "Bearer test_token",
},
{
Response: `aaa`,
Status: http.StatusOK,
ExpectedToken: "",
ExpectedErrorPart: "OAuth2 token exchange: could not parse response:",
Response: `aaa`,
Status: http.StatusOK,
ExpectedToken: "",
ExpectedError: errCouldNotParseResponse,
},
{
Response: `{}`,
Status: http.StatusBadRequest,
ExpectedToken: "",
ExpectedErrorPart: "OAuth2 token exchange: could not exchange token: 400 Bad Request",
Response: `{}`,
Status: http.StatusBadRequest,
ExpectedToken: "",
ExpectedError: errCouldNotExchangeToken,
},
{
Response: `not json`,
Status: http.StatusNotFound,
ExpectedToken: "",
ExpectedErrorPart: "OAuth2 token exchange: could not exchange token: 404 Not Found",
Response: `not json`,
Status: http.StatusNotFound,
ExpectedToken: "",
ExpectedError: errCouldNotExchangeToken,
},
{
Response: `{"error": "invalid_request"}`,
Status: http.StatusBadRequest,
ExpectedToken: "",
ExpectedErrorPart: "OAuth2 token exchange: could not exchange token: 400 Bad Request, error: invalid_request",
ExpectedError: errCouldNotExchangeToken,
ExpectedErrorPart: "400 Bad Request, error: invalid_request",
},
{
Response: `{"error":"unauthorized_client","error_description":"something went bad"}`,
Status: http.StatusInternalServerError,
ExpectedToken: "",
ExpectedErrorPart: "OAuth2 token exchange: could not exchange token: 500 Internal Server Error, error: unauthorized_client, description: \"something went bad\"", //nolint:lll
ExpectedError: errCouldNotExchangeToken,
ExpectedErrorPart: "500 Internal Server Error, error: unauthorized_client, description: \"something went bad\"", //nolint:lll
},
{
Response: `{"error_description":"something went bad","error_uri":"my_error_uri"}`,
Status: http.StatusForbidden,
ExpectedToken: "",
ExpectedErrorPart: "OAuth2 token exchange: could not exchange token: 403 Forbidden, description: \"something went bad\", error_uri: my_error_uri", //nolint:lll
ExpectedError: errCouldNotExchangeToken,
ExpectedErrorPart: "403 Forbidden, description: \"something went bad\", error_uri: my_error_uri",
},
{
Response: `{"access_token":"test_token","token_type":"","expires_in":42,"some_other_field":"x"}`,
Status: http.StatusOK,
ExpectedToken: "",
ExpectedErrorPart: "OAuth2 token exchange: unsupported token type: \"\"",
Response: `{"access_token":"test_token","token_type":"","expires_in":42,"some_other_field":"x"}`,
Status: http.StatusOK,
ExpectedToken: "",
ExpectedError: errUnsupportedTokenType,
},
{
Response: `{"access_token":"test_token","token_type":"basic","expires_in":42,"some_other_field":"x"}`,
Status: http.StatusOK,
ExpectedToken: "",
ExpectedErrorPart: "OAuth2 token exchange: unsupported token type: \"basic\"",
Response: `{"access_token":"test_token","token_type":"basic","expires_in":42,"some_other_field":"x"}`,
Status: http.StatusOK,
ExpectedToken: "",
ExpectedError: errUnsupportedTokenType,
},
{
Response: `{"access_token":"test_token","token_type":"Bearer","expires_in":-42,"some_other_field":"x"}`,
Status: http.StatusOK,
ExpectedToken: "",
ExpectedErrorPart: "OAuth2 token exchange: incorrect expiration time: -42",
Response: `{"access_token":"test_token","token_type":"Bearer","expires_in":-42,"some_other_field":"x"}`,
Status: http.StatusOK,
ExpectedToken: "",
ExpectedError: errIncorrectExpirationTime,
},
{
Response: `{"access_token":"test_token","token_type":"Bearer","expires_in":42,"scope":"s"}`,
Status: http.StatusOK,
ExpectedToken: "",
ExpectedErrorPart: "OAuth2 token exchange: got different scope. Expected \"test_scope1 test_scope2\", but got \"s\"",
ExpectedError: errDifferentScope,
ExpectedErrorPart: "Expected \"test_scope1 test_scope2\", but got \"s\"",
},
}

Expand All @@ -186,10 +190,15 @@ func TestOauth2TokenExchange(t *testing.T) {
require.NoError(t, err)

token, err := client.Token(ctx)
if params.ExpectedErrorPart == "" {
if params.ExpectedErrorPart == "" && params.ExpectedError == nil {
require.NoError(t, err)
} else {
require.ErrorContains(t, err, params.ExpectedErrorPart)
if params.ExpectedErrorPart != "" {
require.ErrorContains(t, err, params.ExpectedErrorPart)
}
if params.ExpectedError != nil {
require.ErrorIs(t, err, params.ExpectedError)
}
}
require.Equal(t, params.ExpectedToken, token)
}
Expand Down Expand Up @@ -270,13 +279,15 @@ func TestWrongParameters(t *testing.T) {
WithSubjectToken(NewFixedTokenSource("test_source_token", "urn:ietf:params:oauth:token-type:test_jwt")),
WithRequestedTokenType("access_token"),
)
require.Error(t, err)
require.ErrorIs(t, err, errEmptyTokenEndpointError)
}

type errorTokenSource struct{}

var errTokenSource = errors.New("test error")

func (s *errorTokenSource) Token() (Token, error) {
return Token{"", ""}, errors.New("Test error")
return Token{"", ""}, errTokenSource
}

func TestErrorInSourceToken(t *testing.T) {
Expand All @@ -290,7 +301,7 @@ func TestErrorInSourceToken(t *testing.T) {
require.NoError(t, err)

token, err := client.Token(context.Background())
require.Error(t, err)
require.ErrorIs(t, err, errTokenSource)
require.Equal(t, "", token)

client, err = NewOauth2TokenExchangeCredentials(
Expand All @@ -303,7 +314,7 @@ func TestErrorInSourceToken(t *testing.T) {
require.NoError(t, err)

token, err = client.Token(context.Background())
require.Error(t, err)
require.ErrorIs(t, err, errTokenSource)
require.Equal(t, "", token)
}

Expand All @@ -315,7 +326,7 @@ func TestErrorInHTTPRequest(t *testing.T) {
require.NoError(t, err)

token, err := client.Token(context.Background())
require.Error(t, err)
require.ErrorIs(t, err, errCouldNotExchangeToken)
require.Equal(t, "", token)
}

Expand Down Expand Up @@ -367,7 +378,7 @@ func TestJWTTokenBadParams(t *testing.T) {
WithAudience("test_audience"),
WithID("id"),
)
require.Error(t, err)
require.ErrorIs(t, err, errNoPrivateKeyError)

_, err = NewJWTTokenSource(
WithPrivateKey(privateKey),
Expand All @@ -377,5 +388,5 @@ func TestJWTTokenBadParams(t *testing.T) {
WithTokenTTL(time.Minute),
WithAudience("test_audience"),
)
require.Error(t, err)
require.ErrorIs(t, err, errNoSigningMethodError)
}

0 comments on commit 8371f8d

Please sign in to comment.