Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
asmyasnikov committed Apr 11, 2024
1 parent 5aabdd3 commit 20ac4f4
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 143 deletions.
132 changes: 58 additions & 74 deletions internal/credentials/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"time"

"github.com/golang-jwt/jwt/v4"

"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
)

type Oauth2TokenExchangeCredentialsOption interface {
Expand Down Expand Up @@ -69,22 +71,22 @@ func WithRequestedTokenType(requestedTokenType string) requestedTokenTypeOption
type audienceOption []string

func (audience audienceOption) ApplyOauth2CredentialsOption(c *Oauth2TokenExchange) {
c.audience = []string(audience)
c.audience = audience
}

func WithAudience(audience ...string) audienceOption {
return audienceOption(audience)
return audience
}

// Scope
type scopeOption []string

func (scope scopeOption) ApplyOauth2CredentialsOption(c *Oauth2TokenExchange) {
c.scope = []string(scope)
c.scope = scope
}

func WithScope(scope ...string) scopeOption {
return scopeOption(scope)
return scope
}

// RequestTimeout
Expand Down Expand Up @@ -124,45 +126,30 @@ func WithActorToken(actorToken TokenSource) *actorTokenSourceOption {
return &actorTokenSourceOption{actorToken}
}

const defaultRequestTimeout = time.Second * 10

func NewOauth2TokenExchangeCredentials(
opts ...Oauth2TokenExchangeCredentialsOption,
) (*Oauth2TokenExchange, error) {
c := &Oauth2TokenExchange{}
c := &Oauth2TokenExchange{
grantType: "urn:ietf:params:oauth:grant-type:token-exchange",
requestedTokenType: "urn:ietf:params:oauth:token-type:access_token",
requestTimeout: defaultRequestTimeout,
}

for _, opt := range opts {
if opt != nil {
opt.ApplyOauth2CredentialsOption(c)
}
}

err := c.init()
if err != nil {
return nil, err
if c.tokenEndpoint == "" {
return nil, xerrors.WithStackTrace(errors.New("OAuth2 token exchange: empty token endpoint"))
}

return c, nil
}

func (provider *Oauth2TokenExchange) init() error {
if provider.tokenEndpoint == "" {
return fmt.Errorf("OAuth2 token exchange: empty token endpoint")
}

if provider.grantType == "" {
provider.grantType = "urn:ietf:params:oauth:grant-type:token-exchange"
}

if provider.requestedTokenType == "" {
provider.requestedTokenType = "urn:ietf:params:oauth:token-type:access_token"
}

if provider.requestTimeout == 0 {
provider.requestTimeout = time.Second * 10
}

return nil
}

func (provider *Oauth2TokenExchange) getScopeParam() string {
var scope string
if len(provider.scope) != 0 {
Expand All @@ -175,6 +162,7 @@ func (provider *Oauth2TokenExchange) getScopeParam() string {
}
}
}

return scope
}

Expand All @@ -198,19 +186,20 @@ func (provider *Oauth2TokenExchange) getRequestParams() (string, error) {
if provider.subjectTokenSource != nil {
token, err := provider.subjectTokenSource.Token()
if err != nil {
return "", err
return "", xerrors.WithStackTrace(err)
}
params.Set("subject_token", token.Token)
params.Set("subject_token_type", token.TokenType)
}
if provider.actorTokenSource != nil {
token, err := provider.actorTokenSource.Token()
if err != nil {
return "", err
return "", xerrors.WithStackTrace(err)
}
params.Set("actor_token", token.Token)
params.Set("actor_token_type", token.TokenType)
}

return params.Encode(), nil
}

Expand All @@ -222,7 +211,7 @@ func (provider *Oauth2TokenExchange) processTokenExchangeResponse(result *http.R
if result.Body != nil {
data, err = io.ReadAll(result.Body)
if err != nil {
return err
return xerrors.WithStackTrace(err)
}
} else {
data = make([]byte, 0)
Expand All @@ -231,15 +220,17 @@ func (provider *Oauth2TokenExchange) processTokenExchangeResponse(result *http.R
if result.StatusCode != http.StatusOK {
description := fmt.Sprintf("OAuth2 token exchange: could not exchange token: %s", result.Status)

//nolint:tagliatelle
type errorResponse struct {
Error string `json:"error"`
Description string `json:"error_description"`
ErrorUri string `json:"error_uri"`
ErrorURI string `json:"error_uri"`
}
var parsedErrorResponse errorResponse
if err := json.Unmarshal(data, &parsedErrorResponse); err != nil {
description += fmt.Sprintf(", could not parse response: %s", err.Error())
return errors.New(description)

return xerrors.WithStackTrace(errors.New(description))
}

if parsedErrorResponse.Error != "" {
Expand All @@ -250,13 +241,14 @@ func (provider *Oauth2TokenExchange) processTokenExchangeResponse(result *http.R
description += fmt.Sprintf(", description: \"%s\"", parsedErrorResponse.Description)

Check failure on line 241 in internal/credentials/oauth2.go

View workflow job for this annotation

GitHub Actions / golangci-lint

sprintfQuotedString: use %q instead of "%s" for quoted strings (gocritic)
}

if parsedErrorResponse.ErrorUri != "" {
description += fmt.Sprintf(", error_uri: %s", parsedErrorResponse.ErrorUri)
if parsedErrorResponse.ErrorURI != "" {
description += fmt.Sprintf(", error_uri: %s", parsedErrorResponse.ErrorURI)
}

return errors.New(description)
return xerrors.WithStackTrace(errors.New(description))
}

//nolint:tagliatelle
type response struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
Expand All @@ -265,32 +257,32 @@ func (provider *Oauth2TokenExchange) processTokenExchangeResponse(result *http.R
}
var parsedResponse response
if err := json.Unmarshal(data, &parsedResponse); err != nil {
return fmt.Errorf("OAuth2 token exchange: could not parse response: %w", err)
return xerrors.WithStackTrace(fmt.Errorf("OAuth2 token exchange: could not parse response: %w", err))
}

if !strings.EqualFold(parsedResponse.TokenType, "bearer") {
return fmt.Errorf("OAuth2 token exchange: unsupported token type: \"%s\"", parsedResponse.TokenType)
return xerrors.WithStackTrace(fmt.Errorf("OAuth2 token exchange: unsupported token type: \"%s\"", parsedResponse.TokenType))

Check failure on line 264 in internal/credentials/oauth2.go

View workflow job for this annotation

GitHub Actions / golangci-lint

line is 126 characters (lll)
}

if parsedResponse.ExpiresIn <= 0 {
return fmt.Errorf("OAuth2 token exchange: incorrect expiration time: %d", parsedResponse.ExpiresIn)
return xerrors.WithStackTrace(fmt.Errorf("OAuth2 token exchange: incorrect expiration time: %d", parsedResponse.ExpiresIn))

Check failure on line 268 in internal/credentials/oauth2.go

View workflow job for this annotation

GitHub Actions / golangci-lint

line is 125 characters (lll)
}

if parsedResponse.Scope != "" {
scope := provider.getScopeParam()
if parsedResponse.Scope != scope {
return fmt.Errorf("OAuth2 token exchange: got different scope. Expected \"%s\", but got \"%s\"", scope, parsedResponse.Scope)
return xerrors.WithStackTrace(fmt.Errorf("OAuth2 token exchange: got different scope. Expected \"%s\", but got \"%s\"", scope, parsedResponse.Scope))
}
}

provider.receivedToken = "Bearer " + parsedResponse.AccessToken

// Expire time
var expireDelta time.Duration = time.Duration(parsedResponse.ExpiresIn)
expireDelta := time.Duration(parsedResponse.ExpiresIn)
expireDelta *= time.Second
provider.receivedTokenExpireTime = now.Add(expireDelta)

var updateDelta time.Duration = time.Duration(parsedResponse.ExpiresIn / 2)
updateDelta := time.Duration(parsedResponse.ExpiresIn / 2) //nolint:gomnd
updateDelta *= time.Second
provider.updateTokenTime = now.Add(updateDelta)

Expand All @@ -300,12 +292,12 @@ func (provider *Oauth2TokenExchange) processTokenExchangeResponse(result *http.R
func (provider *Oauth2TokenExchange) exchangeToken(ctx context.Context, now time.Time) error {
body, err := provider.getRequestParams()
if err != nil {
return fmt.Errorf("OAuth2 token exchange: could not make http request: %w", err)
return xerrors.WithStackTrace(fmt.Errorf("OAuth2 token exchange: could not make http request: %w", err))
}

req, err := http.NewRequestWithContext(ctx, "POST", provider.tokenEndpoint, strings.NewReader(body))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.tokenEndpoint, strings.NewReader(body))
if err != nil {
return fmt.Errorf("OAuth2 token exchange: could not make http request: %w", err)
return xerrors.WithStackTrace(fmt.Errorf("OAuth2 token exchange: could not make http request: %w", err))
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Content-Length", strconv.Itoa(len(body)))
Expand All @@ -318,15 +310,15 @@ func (provider *Oauth2TokenExchange) exchangeToken(ctx context.Context, now time

result, err := client.Do(req)
if err != nil {
return fmt.Errorf("iam: could not exchange token: %w", err)
return xerrors.WithStackTrace(fmt.Errorf("iam: could not exchange token: %w", err))
}

defer result.Body.Close()

return provider.processTokenExchangeResponse(result, now)
}

func (provider *Oauth2TokenExchange) exchangeTokenInBackgroud() {
func (provider *Oauth2TokenExchange) exchangeTokenInbackground() {
provider.mutex.Lock()
defer provider.mutex.Unlock()

Expand All @@ -341,10 +333,10 @@ func (provider *Oauth2TokenExchange) exchangeTokenInBackgroud() {
provider.updating.Store(false)
}

func (provider *Oauth2TokenExchange) checkBackgroudUpdate(now time.Time) {
func (provider *Oauth2TokenExchange) checkbackgroundUpdate(now time.Time) {
if provider.needUpdate(now) && !provider.updating.Load() {
if provider.updating.CompareAndSwap(false, true) {
go provider.exchangeTokenInBackgroud()
go provider.exchangeTokenInbackground()
}
}
}
Expand All @@ -362,9 +354,11 @@ func (provider *Oauth2TokenExchange) fastCheck(now time.Time) string {
defer provider.mutex.RUnlock()

if !provider.expired(now) {
provider.checkBackgroudUpdate(now)
provider.checkbackgroundUpdate(now)

return provider.receivedToken
}

return ""
}

Expand All @@ -386,6 +380,7 @@ func (provider *Oauth2TokenExchange) Token(ctx context.Context) (string, error)
if err := provider.exchangeToken(ctx, now); err != nil {
return "", err
}

return provider.receivedToken, nil
}

Expand Down Expand Up @@ -418,7 +413,7 @@ type Oauth2TokenExchange struct {
receivedTokenExpireTime time.Time

mutex sync.RWMutex
updating atomic.Bool // true if separate goroutine is run and updates token in backgroud
updating atomic.Bool // true if separate goroutine is run and updates token in background
}

type Token struct {
Expand All @@ -443,13 +438,12 @@ func (s *FixedTokenSource) Token() (Token, error) {
}

func NewFixedTokenSource(token, tokenType string) *FixedTokenSource {
s := &FixedTokenSource{
return &FixedTokenSource{
fixedToken: Token{
Token: token,
TokenType: tokenType,
},
}
return s
}

type JWTTokenSourceOption interface {
Expand Down Expand Up @@ -543,17 +537,22 @@ func WithPrivateKey(key interface{}) *privateKeyOption {
}

func NewJWTTokenSource(opts ...JWTTokenSourceOption) (*JWTTokenSource, error) {
s := &JWTTokenSource{}
s := &JWTTokenSource{
tokenTTL: time.Hour,
}

for _, opt := range opts {
if opt != nil {
opt.ApplyJWTTokenSourceOption(s)
}
}

err := s.init()
if err != nil {
return nil, err
if s.signingMethod == nil {
return nil, xerrors.WithStackTrace(fmt.Errorf("JWT token source: no signing method"))
}

if s.privateKey == nil {
return nil, xerrors.WithStackTrace(fmt.Errorf("JWT token source: no private key"))
}

return s, nil
Expand All @@ -572,22 +571,6 @@ type JWTTokenSource struct {
tokenTTL time.Duration
}

func (s *JWTTokenSource) init() error {
if s.signingMethod == nil {
return fmt.Errorf("JWT token source: no signing method")
}

if s.privateKey == nil {
return fmt.Errorf("JWT token source: no private key")
}

if s.tokenTTL == 0 {
s.tokenTTL = time.Duration(3600 * time.Second)
}

return nil
}

func (s *JWTTokenSource) Token() (Token, error) {
var (
now = time.Now()
Expand Down Expand Up @@ -615,8 +598,9 @@ func (s *JWTTokenSource) Token() (Token, error) {
var token Token
token.Token, err = t.SignedString(s.privateKey)
if err != nil {
return token, fmt.Errorf("JWTTokenSource: could not sign jwt token: %w", err)
return token, xerrors.WithStackTrace(fmt.Errorf("JWTTokenSource: could not sign jwt token: %w", err))
}
token.TokenType = "urn:ietf:params:oauth:token-type:jwt"

return token, nil
}
Loading

0 comments on commit 20ac4f4

Please sign in to comment.