diff --git a/CHANGELOG.md b/CHANGELOG.md index 8046cf4fc..aca53cdd0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,9 @@ +* Supported OAuth 2.0 Token Exchange credentials provider + ## v3.64.0 * Supported `table.Session.RenameTables` method -* Fixed out of range panic if next query result set part is empty -* Updated the indirect dependencies `golang.org/x/net` to `v0.17.0` and `golang.org/x/sys` to `v0.13.0` due to vulnerability issue +* Fixed out of range panic if next query result set part is empty +* Updated the indirect dependencies `golang.org/x/net` to `v0.17.0` and `golang.org/x/sys` to `v0.13.0` due to vulnerability issue ## v3.63.0 * Added versioning policy @@ -14,7 +16,7 @@ * Added `go` with anonymous function case in `gstack` ## v3.61.2 -* Changed default transaction control to `NoTx` for execute query through query service client +* Changed default transaction control to `NoTx` for execute query through query service client ## v3.61.1 * Renamed `db.Coordination().CreateSession()` to `db.Coordination().Session()` for compatibility with protos @@ -97,7 +99,7 @@ * Fixed sometime panic on topic writer closing * Added experimental query parameters builder `ydb.ParamsBuilder()` * Changed types of `table/table.{QueryParameters,ParameterOption}` to aliases on `internal/params.{Parameters,NamedValue}` -* Fixed bug with optional decimal serialization +* Fixed bug with optional decimal serialization ## v3.56.2 * Fixed return private error for commit to stopped partition in topic reader. diff --git a/credentials/credentials.go b/credentials/credentials.go index 8829eb973..b11106414 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -34,3 +34,21 @@ func NewStaticCredentials( ) *credentials.Static { return credentials.NewStaticCredentials(user, password, authEndpoint, opts...) } + +// NewOauth2TokenExchangeCredentials makes OAuth 2.0 token exchange protocol credentials object +// https://www.rfc-editor.org/rfc/rfc8693 +func NewOauth2TokenExchangeCredentials( + opts ...credentials.Oauth2TokenExchangeCredentialsOption, +) (Credentials, error) { + return credentials.NewOauth2TokenExchangeCredentials(opts...) +} + +// NewJWTTokenSource makes JWT token source for OAuth 2.0 token exchange credentials +func NewJWTTokenSource(opts ...credentials.JWTTokenSourceOption) (credentials.TokenSource, error) { + return credentials.NewJWTTokenSource(opts...) +} + +// NewFixedTokenSource makes fixed token source for OAuth 2.0 token exchange credentials +func NewFixedTokenSource(token, tokenType string) credentials.TokenSource { + return credentials.NewFixedTokenSource(token, tokenType) +} diff --git a/credentials/options.go b/credentials/options.go index 9da142092..5d944dae6 100644 --- a/credentials/options.go +++ b/credentials/options.go @@ -1,11 +1,20 @@ package credentials import ( + "time" + + "github.com/golang-jwt/jwt/v4" "google.golang.org/grpc" "github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials" ) +type Oauth2TokenExchangeCredentialsOption = credentials.Oauth2TokenExchangeCredentialsOption + +type TokenSource = credentials.TokenSource + +type Token = credentials.Token + // WithSourceInfo option append to credentials object the source info for reporting source info details on error case func WithSourceInfo(sourceInfo string) credentials.SourceInfoOption { return credentials.WithSourceInfo(sourceInfo) @@ -15,3 +24,118 @@ func WithSourceInfo(sourceInfo string) credentials.SourceInfoOption { func WithGrpcDialOptions(opts ...grpc.DialOption) credentials.StaticCredentialsOption { return credentials.WithGrpcDialOptions(opts...) } + +// TokenEndpoint +func WithTokenEndpoint(endpoint string) Oauth2TokenExchangeCredentialsOption { + return credentials.WithTokenEndpoint(endpoint) +} + +// GrantType +func WithGrantType(grantType string) Oauth2TokenExchangeCredentialsOption { + return credentials.WithGrantType(grantType) +} + +// Resource +func WithResource(resource string) Oauth2TokenExchangeCredentialsOption { + return credentials.WithResource(resource) +} + +// RequestedTokenType +func WithRequestedTokenType(requestedTokenType string) Oauth2TokenExchangeCredentialsOption { + return credentials.WithRequestedTokenType(requestedTokenType) +} + +// Scope +func WithScope(scope ...string) Oauth2TokenExchangeCredentialsOption { + return credentials.WithScope(scope...) +} + +// RequestTimeout +func WithRequestTimeout(timeout time.Duration) Oauth2TokenExchangeCredentialsOption { + return credentials.WithRequestTimeout(timeout) +} + +// SubjectTokenSource +func WithSubjectToken(subjectToken credentials.TokenSource) Oauth2TokenExchangeCredentialsOption { + return credentials.WithSubjectToken(subjectToken) +} + +// SubjectTokenSource +func WithFixedSubjectToken(token, tokenType string) Oauth2TokenExchangeCredentialsOption { + return credentials.WithFixedSubjectToken(token, tokenType) +} + +// SubjectTokenSource +func WithJWTSubjectToken(opts ...credentials.JWTTokenSourceOption) Oauth2TokenExchangeCredentialsOption { + return credentials.WithJWTSubjectToken(opts...) +} + +// ActorTokenSource +func WithActorToken(actorToken credentials.TokenSource) Oauth2TokenExchangeCredentialsOption { + return credentials.WithActorToken(actorToken) +} + +// ActorTokenSource +func WithFixedActorToken(token, tokenType string) Oauth2TokenExchangeCredentialsOption { + return credentials.WithFixedActorToken(token, tokenType) +} + +// ActorTokenSource +func WithJWTActorToken(opts ...credentials.JWTTokenSourceOption) Oauth2TokenExchangeCredentialsOption { + return credentials.WithJWTActorToken(opts...) +} + +// Audience +type oauthCredentialsAndJWTCredentialsOption interface { + credentials.Oauth2TokenExchangeCredentialsOption + credentials.JWTTokenSourceOption +} + +func WithAudience(audience ...string) oauthCredentialsAndJWTCredentialsOption { + return credentials.WithAudience(audience...) +} + +// Issuer +func WithIssuer(issuer string) credentials.JWTTokenSourceOption { + return credentials.WithIssuer(issuer) +} + +// Subject +func WithSubject(subject string) credentials.JWTTokenSourceOption { + return credentials.WithSubject(subject) +} + +// ID +func WithID(id string) credentials.JWTTokenSourceOption { + return credentials.WithID(id) +} + +// TokenTTL +func WithTokenTTL(ttl time.Duration) credentials.JWTTokenSourceOption { + return credentials.WithTokenTTL(ttl) +} + +// SigningMethod +func WithSigningMethod(method jwt.SigningMethod) credentials.JWTTokenSourceOption { + return credentials.WithSigningMethod(method) +} + +// KeyID +func WithKeyID(id string) credentials.JWTTokenSourceOption { + return credentials.WithKeyID(id) +} + +// PrivateKey +func WithPrivateKey(key interface{}) credentials.JWTTokenSourceOption { + return credentials.WithPrivateKey(key) +} + +// PrivateKey +func WithRSAPrivateKeyPEMContent(key []byte) credentials.JWTTokenSourceOption { + return credentials.WithRSAPrivateKeyPEMContent(key) +} + +// PrivateKey +func WithRSAPrivateKeyPEMFile(path string) credentials.JWTTokenSourceOption { + return credentials.WithRSAPrivateKeyPEMFile(path) +} diff --git a/examples/auth/README.md b/examples/auth/README.md index 273113372..e2b384488 100644 --- a/examples/auth/README.md +++ b/examples/auth/README.md @@ -4,6 +4,7 @@ Auth examples helps to understand YDB authentication: * `access_token_credentials` - example of use access token credentials * `anonymous_credentials` - example of use anonymous credentials * `metadata_credentials` - example of use metadata credentials +* `oauth2_token_exchange_credentials` - example of use oauth 2.0 token exchange credentials * `service_account_credentials` - example of use service account key file credentials * `static_credentials` - example of use static credentials * `environ` - example of use environment variables to configure YDB authenticate diff --git a/examples/auth/oauth2_token_exchange_credentials/README.md b/examples/auth/oauth2_token_exchange_credentials/README.md new file mode 100644 index 000000000..a13adf8d8 --- /dev/null +++ b/examples/auth/oauth2_token_exchange_credentials/README.md @@ -0,0 +1,8 @@ +# Authenticate with oauth 2.0 token exchange credentials + +`oauth2_token_exchange_credentials` example provides code snippet for authentication to YDB with oauth 2.0 token exchange credentials + +## Runing code snippet +```bash +oauth2_token_exchange_credentials -ydb="grpcs://endpoint/?database=database" -token-endpoint="https://exchange.token.endpoint/oauth2/token/exchange" -key-id="123" -private-key-file="path/to/key/file" -audience="test-aud" -issuer="test-issuer" -subject="test-subject" +``` diff --git a/examples/auth/oauth2_token_exchange_credentials/main.go b/examples/auth/oauth2_token_exchange_credentials/main.go new file mode 100644 index 000000000..e20e8a882 --- /dev/null +++ b/examples/auth/oauth2_token_exchange_credentials/main.go @@ -0,0 +1,107 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + + "github.com/golang-jwt/jwt/v4" + ydb "github.com/ydb-platform/ydb-go-sdk/v3" + "github.com/ydb-platform/ydb-go-sdk/v3/credentials" +) + +var ( + dsn string + tokenEndpoint string + keyID string + privateKeyFile string + audience string + issuer string + subject string +) + +func init() { //nolint:gochecknoinits + required := []string{"ydb", "private-key-file", "key-id", "token-endpoint"} + flagSet := flag.NewFlagSet(os.Args[0], flag.ExitOnError) + flagSet.Usage = func() { + out := flagSet.Output() + _, _ = fmt.Fprintf(out, "Usage:\n%s [options]\n", os.Args[0]) + _, _ = fmt.Fprintf(out, "\nOptions:\n") + flagSet.PrintDefaults() + } + flagSet.StringVar(&dsn, + "ydb", "", + "YDB connection string", + ) + flagSet.StringVar(&tokenEndpoint, + "token-endpoint", "", + "oauth 2.0 token exchange endpoint", + ) + flagSet.StringVar(&keyID, + "key-id", "", + "key id for jwt token", + ) + flagSet.StringVar(&privateKeyFile, + "private-key-file", "", + "RSA private key file for jwt token in pem format", + ) + flagSet.StringVar(&audience, + "audience", "", + "audience", + ) + flagSet.StringVar(&issuer, + "issuer", "", + "jwt token issuer", + ) + flagSet.StringVar(&subject, + "subject", "", + "jwt token subject", + ) + if err := flagSet.Parse(os.Args[1:]); err != nil { + flagSet.Usage() + os.Exit(1) + } + flagSet.Visit(func(f *flag.Flag) { + for i, arg := range required { + if arg == f.Name { + required = append(required[:i], required[i+1:]...) + } + } + }) + if len(required) > 0 { + fmt.Printf("\nSome required options not defined: %v\n\n", required) + flagSet.Usage() + os.Exit(1) + } +} + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db, err := ydb.Open(ctx, dsn, + ydb.WithOauth2TokenExchangeCredentials( + credentials.WithTokenEndpoint(tokenEndpoint), + credentials.WithAudience(audience), + credentials.WithJWTSubjectToken( + credentials.WithSigningMethod(jwt.SigningMethodRS256), + credentials.WithKeyID(keyID), + credentials.WithRSAPrivateKeyPEMFile(privateKeyFile), + credentials.WithIssuer(issuer), + credentials.WithSubject(subject), + credentials.WithAudience(audience), + ), + ), + ) + if err != nil { + panic(err) + } + defer func() { _ = db.Close(ctx) }() + + whoAmI, err := db.Discovery().WhoAmI(ctx) + if err != nil { + panic(err) + } + + fmt.Println(whoAmI.String()) +} diff --git a/internal/credentials/oauth2.go b/internal/credentials/oauth2.go new file mode 100644 index 000000000..f49826127 --- /dev/null +++ b/internal/credentials/oauth2.go @@ -0,0 +1,846 @@ +package credentials + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "os/user" + "path/filepath" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/golang-jwt/jwt/v4" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/secret" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring" +) + +const ( + defaultRequestTimeout = time.Second * 10 + defaultJWTTokenTTL = 3600 * time.Second + updateTimeDivider = 2 +) + +var ( + errEmptyTokenEndpointError = errors.New("OAuth2 token exchange: empty token endpoint") + errCouldNotParseResponse = errors.New("OAuth2 token exchange: could not parse response") + errCouldNotExchangeToken = errors.New("OAuth2 token exchange: could not exchange token") + errUnsupportedTokenType = errors.New("OAuth2 token exchange: unsupported token type") + errIncorrectExpirationTime = errors.New("OAuth2 token exchange: incorrect expiration time") + errDifferentScope = errors.New("OAuth2 token exchange: got different scope") + errCouldNotMakeHTTPRequest = errors.New("OAuth2 token exchange: could not make http request") + errCouldNotApplyOption = errors.New("OAuth2 token exchange: could not apply option") + errCouldNotCreateTokenSource = errors.New("OAuth2 token exchange: could not createTokenSource") + errNoSigningMethodError = errors.New("JWT token source: no signing method") + errNoPrivateKeyError = errors.New("JWT token source: no private key") + errCouldNotSignJWTToken = errors.New("JWT token source: could not sign jwt token") + errCouldNotApplyJWTOption = errors.New("JWT token source: could not apply option") + errCouldNotparseRSAPrivateKey = errors.New("JWT token source: could not parse RSA private key from PEM") + errCouldNotParseHomeDir = errors.New("JWT token source: could not parse home dir for private key") + errCouldNotReadPrivateKeyFile = errors.New("JWT token source: could not read from private key file") +) + +type Oauth2TokenExchangeCredentialsOption interface { + ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error +} + +// TokenEndpoint +type tokenEndpointOption string + +func (endpoint tokenEndpointOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error { + c.tokenEndpoint = string(endpoint) + + return nil +} + +func WithTokenEndpoint(endpoint string) tokenEndpointOption { + return tokenEndpointOption(endpoint) +} + +// GrantType +type grantTypeOption string + +func (grantType grantTypeOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error { + c.grantType = string(grantType) + + return nil +} + +func WithGrantType(grantType string) grantTypeOption { + return grantTypeOption(grantType) +} + +// Resource +type resourceOption string + +func (resource resourceOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error { + c.resource = string(resource) + + return nil +} + +func WithResource(resource string) resourceOption { + return resourceOption(resource) +} + +// RequestedTokenType +type requestedTokenTypeOption string + +func (requestedTokenType requestedTokenTypeOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error { + c.requestedTokenType = string(requestedTokenType) + + return nil +} + +func WithRequestedTokenType(requestedTokenType string) requestedTokenTypeOption { + return requestedTokenTypeOption(requestedTokenType) +} + +// Audience +type audienceOption []string + +func (audience audienceOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error { + c.audience = audience + + return nil +} + +func WithAudience(audience ...string) audienceOption { + return audience +} + +// Scope +type scopeOption []string + +func (scope scopeOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error { + c.scope = scope + + return nil +} + +func WithScope(scope ...string) scopeOption { + return scope +} + +// RequestTimeout +type requestTimeoutOption time.Duration + +func (timeout requestTimeoutOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error { + c.requestTimeout = time.Duration(timeout) + + return nil +} + +func WithRequestTimeout(timeout time.Duration) requestTimeoutOption { + return requestTimeoutOption(timeout) +} + +const ( + SubjectTokenSourceType = 1 + ActorTokenSourceType = 2 +) + +// SubjectTokenSource/ActorTokenSource +type tokenSourceOption struct { + source TokenSource + createFunc func() (TokenSource, error) + tokenSourceType int +} + +func (tokenSource *tokenSourceOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error { + src := tokenSource.source + var err error + if src == nil { + src, err = tokenSource.createFunc() + if err != nil { + return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotCreateTokenSource, err)) + } + } + switch tokenSource.tokenSourceType { + case SubjectTokenSourceType: + c.subjectTokenSource = src + case ActorTokenSourceType: + c.actorTokenSource = src + } + + return nil +} + +func WithSubjectToken(subjectToken TokenSource) *tokenSourceOption { + return &tokenSourceOption{ + source: subjectToken, + tokenSourceType: SubjectTokenSourceType, + } +} + +func WithFixedSubjectToken(token, tokenType string) *tokenSourceOption { + return &tokenSourceOption{ + createFunc: func() (TokenSource, error) { + return NewFixedTokenSource(token, tokenType), nil + }, + tokenSourceType: SubjectTokenSourceType, + } +} + +func WithJWTSubjectToken(opts ...JWTTokenSourceOption) *tokenSourceOption { + return &tokenSourceOption{ + createFunc: func() (TokenSource, error) { + return NewJWTTokenSource(opts...) + }, + tokenSourceType: SubjectTokenSourceType, + } +} + +// ActorTokenSource +func WithActorToken(actorToken TokenSource) *tokenSourceOption { + return &tokenSourceOption{ + source: actorToken, + tokenSourceType: ActorTokenSourceType, + } +} + +func WithFixedActorToken(token, tokenType string) *tokenSourceOption { + return &tokenSourceOption{ + createFunc: func() (TokenSource, error) { + return NewFixedTokenSource(token, tokenType), nil + }, + tokenSourceType: ActorTokenSourceType, + } +} + +func WithJWTActorToken(opts ...JWTTokenSourceOption) *tokenSourceOption { + return &tokenSourceOption{ + createFunc: func() (TokenSource, error) { + return NewJWTTokenSource(opts...) + }, + tokenSourceType: ActorTokenSourceType, + } +} + +type oauth2TokenExchange struct { + tokenEndpoint string + + // grant_type parameter + // urn:ietf:params:oauth:grant-type:token-exchange by default + grantType string + + resource string + audience []string + scope []string + + // requested_token_type parameter + // urn:ietf:params:oauth:token-type:access_token by default + requestedTokenType string + + subjectTokenSource TokenSource + + actorTokenSource TokenSource + + // Http request timeout + // 10 by default + requestTimeout time.Duration + + // Received data + receivedToken string + updateTokenTime time.Time + receivedTokenExpireTime time.Time + + mutex sync.RWMutex + updating atomic.Bool // true if separate goroutine is run and updates token in background + + sourceInfo string +} + +func NewOauth2TokenExchangeCredentials( + opts ...Oauth2TokenExchangeCredentialsOption, +) (*oauth2TokenExchange, error) { + c := &oauth2TokenExchange{ + grantType: "urn:ietf:params:oauth:grant-type:token-exchange", + requestedTokenType: "urn:ietf:params:oauth:token-type:access_token", + requestTimeout: defaultRequestTimeout, + sourceInfo: stack.Record(1), + } + + var err error + for _, opt := range opts { + if opt != nil { + err = opt.ApplyOauth2CredentialsOption(c) + if err != nil { + return nil, xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotApplyOption, err)) + } + } + } + + if c.tokenEndpoint == "" { + return nil, xerrors.WithStackTrace(errEmptyTokenEndpointError) + } + + return c, nil +} + +func (provider *oauth2TokenExchange) getScopeParam() string { + var scope string + if len(provider.scope) != 0 { + for _, s := range provider.scope { + if s != "" { + if scope != "" { + scope += " " + } + scope += s + } + } + } + + return scope +} + +func (provider *oauth2TokenExchange) addTokenSrc(params *url.Values, src TokenSource, tName, tTypeName string) error { + if src != nil { + token, err := src.Token() + if err != nil { + return xerrors.WithStackTrace(err) + } + params.Set(tName, token.Token) + params.Set(tTypeName, token.TokenType) + } + + return nil +} + +func (provider *oauth2TokenExchange) getRequestParams() (string, error) { + params := url.Values{} + params.Set("grant_type", provider.grantType) + if provider.resource != "" { + params.Set("resource", provider.resource) + } + for _, aud := range provider.audience { + if aud != "" { + params.Add("audience", aud) + } + } + scope := provider.getScopeParam() + if scope != "" { + params.Set("scope", scope) + } + + params.Set("requested_token_type", provider.requestedTokenType) + + err := provider.addTokenSrc(¶ms, provider.subjectTokenSource, "subject_token", "subject_token_type") + if err != nil { + return "", xerrors.WithStackTrace(err) + } + + err = provider.addTokenSrc(¶ms, provider.actorTokenSource, "actor_token", "actor_token_type") + if err != nil { + return "", xerrors.WithStackTrace(err) + } + + return params.Encode(), nil +} + +func (provider *oauth2TokenExchange) processTokenExchangeResponse(result *http.Response, now time.Time) error { + var ( + data []byte + err error + ) + if result.Body != nil { + data, err = io.ReadAll(result.Body) + if err != nil { + return xerrors.WithStackTrace(err) + } + } else { + data = make([]byte, 0) + } + + if result.StatusCode != http.StatusOK { + description := result.Status + + //nolint:tagliatelle + type errorResponse struct { + ErrorName string `json:"error"` + ErrorDescription string `json:"error_description"` + ErrorURI string `json:"error_uri"` + } + var parsedErrorResponse errorResponse + if err := json.Unmarshal(data, &parsedErrorResponse); err != nil { + description += ", could not parse response: " + err.Error() + + return xerrors.WithStackTrace(fmt.Errorf("%w: %s", errCouldNotExchangeToken, description)) + } + + if parsedErrorResponse.ErrorName != "" { + description += ", error: " + parsedErrorResponse.ErrorName + } + + if parsedErrorResponse.ErrorDescription != "" { + description += fmt.Sprintf(", description: %q", parsedErrorResponse.ErrorDescription) + } + + if parsedErrorResponse.ErrorURI != "" { + description += ", error_uri: " + parsedErrorResponse.ErrorURI + } + + return xerrors.WithStackTrace(fmt.Errorf("%w: %s", errCouldNotExchangeToken, description)) + } + + //nolint:tagliatelle + type response struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + Scope string `json:"scope"` + } + var parsedResponse response + if err := json.Unmarshal(data, &parsedResponse); err != nil { + return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotParseResponse, err)) + } + + if !strings.EqualFold(parsedResponse.TokenType, "bearer") { + return xerrors.WithStackTrace( + fmt.Errorf("%w: %q", errUnsupportedTokenType, parsedResponse.TokenType)) + } + + if parsedResponse.ExpiresIn <= 0 { + return xerrors.WithStackTrace( + fmt.Errorf("%w: %d", errIncorrectExpirationTime, parsedResponse.ExpiresIn)) + } + + if parsedResponse.Scope != "" { + scope := provider.getScopeParam() + if parsedResponse.Scope != scope { + return xerrors.WithStackTrace( + fmt.Errorf("%w. Expected %q, but got %q", errDifferentScope, scope, parsedResponse.Scope)) + } + } + + provider.receivedToken = "Bearer " + parsedResponse.AccessToken + + // Expire time + expireDelta := time.Duration(parsedResponse.ExpiresIn) + expireDelta *= time.Second + provider.receivedTokenExpireTime = now.Add(expireDelta) + + updateDelta := time.Duration(parsedResponse.ExpiresIn / updateTimeDivider) + updateDelta *= time.Second + provider.updateTokenTime = now.Add(updateDelta) + + return nil +} + +func (provider *oauth2TokenExchange) exchangeToken(ctx context.Context, now time.Time) error { + body, err := provider.getRequestParams() + if err != nil { + return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotMakeHTTPRequest, err)) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.tokenEndpoint, strings.NewReader(body)) + if err != nil { + return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotMakeHTTPRequest, err)) + } + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(body))) + req.Close = true + + client := http.Client{ + Transport: http.DefaultTransport, + Timeout: provider.requestTimeout, + } + + result, err := client.Do(req) + if err != nil { + return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotExchangeToken, err)) + } + + defer result.Body.Close() + + return provider.processTokenExchangeResponse(result, now) +} + +func (provider *oauth2TokenExchange) exchangeTokenInBackground() { + provider.mutex.Lock() + defer provider.mutex.Unlock() + + now := time.Now() + if !provider.needUpdate(now) { + return + } + + ctx := context.Background() + _ = provider.exchangeToken(ctx, now) + + provider.updating.Store(false) +} + +func (provider *oauth2TokenExchange) checkBackgroundUpdate(now time.Time) { + if provider.needUpdate(now) && !provider.updating.Load() { + if provider.updating.CompareAndSwap(false, true) { + go provider.exchangeTokenInBackground() + } + } +} + +func (provider *oauth2TokenExchange) expired(now time.Time) bool { + return now.Compare(provider.receivedTokenExpireTime) > 0 +} + +func (provider *oauth2TokenExchange) needUpdate(now time.Time) bool { + return now.Compare(provider.updateTokenTime) > 0 +} + +func (provider *oauth2TokenExchange) fastCheck(now time.Time) string { + provider.mutex.RLock() + defer provider.mutex.RUnlock() + + if !provider.expired(now) { + provider.checkBackgroundUpdate(now) + + return provider.receivedToken + } + + return "" +} + +func (provider *oauth2TokenExchange) Token(ctx context.Context) (string, error) { + now := time.Now() + + token := provider.fastCheck(now) + if token != "" { + return token, nil + } + + provider.mutex.Lock() + defer provider.mutex.Unlock() + + if !provider.expired(now) { + return provider.receivedToken, nil + } + + if err := provider.exchangeToken(ctx, now); err != nil { + return "", err + } + + return provider.receivedToken, nil +} + +func (provider *oauth2TokenExchange) String() string { + buffer := xstring.Buffer() + defer buffer.Free() + fmt.Fprintf( + buffer, + "OAuth2TokenExchange{Endpoint:%q,GrantType:%s,Resource:%s,Audience:%v,Scope:%v,RequestedTokenType:%s", + provider.tokenEndpoint, + provider.grantType, + provider.resource, + provider.audience, + provider.scope, + provider.requestedTokenType, + ) + if provider.subjectTokenSource != nil { + fmt.Fprintf(buffer, ",SubjectToken:%s", provider.subjectTokenSource) + } + if provider.actorTokenSource != nil { + fmt.Fprintf(buffer, ",ActorToken:%s", provider.actorTokenSource) + } + if provider.sourceInfo != "" { + fmt.Fprintf(buffer, ",From:%q", provider.sourceInfo) + } + buffer.WriteByte('}') + + return buffer.String() +} + +type Token struct { + Token string + + // token type according to OAuth 2.0 token exchange protocol + // https://www.rfc-editor.org/rfc/rfc8693#TokenTypeIdentifiers + // for example urn:ietf:params:oauth:token-type:jwt + TokenType string +} + +type TokenSource interface { + Token() (Token, error) +} + +type fixedTokenSource struct { + fixedToken Token +} + +func (s *fixedTokenSource) Token() (Token, error) { + return s.fixedToken, nil +} + +func (s *fixedTokenSource) String() string { + buffer := xstring.Buffer() + defer buffer.Free() + fmt.Fprintf( + buffer, + "FixedTokenSource{Token:%q,Type:%s}", + secret.Token(s.fixedToken.Token), + s.fixedToken.TokenType, + ) + + return buffer.String() +} + +func NewFixedTokenSource(token, tokenType string) *fixedTokenSource { + return &fixedTokenSource{ + fixedToken: Token{ + Token: token, + TokenType: tokenType, + }, + } +} + +type JWTTokenSourceOption interface { + ApplyJWTTokenSourceOption(s *jwtTokenSource) error +} + +// Issuer +type issuerOption string + +func (issuer issuerOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { + s.issuer = string(issuer) + + return nil +} + +func WithIssuer(issuer string) issuerOption { + return issuerOption(issuer) +} + +// Subject +type subjectOption string + +func (subject subjectOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { + s.subject = string(subject) + + return nil +} + +func WithSubject(subject string) subjectOption { + return subjectOption(subject) +} + +// Audience +func (audience audienceOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { + s.audience = audience + + return nil +} + +// ID +type idOption string + +func (id idOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { + s.id = string(id) + + return nil +} + +func WithID(id string) idOption { + return idOption(id) +} + +// TokenTTL +type tokenTTLOption time.Duration + +func (ttl tokenTTLOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { + s.tokenTTL = time.Duration(ttl) + + return nil +} + +func WithTokenTTL(ttl time.Duration) tokenTTLOption { + return tokenTTLOption(ttl) +} + +// SigningMethod +type signingMethodOption struct { + method jwt.SigningMethod +} + +func (method *signingMethodOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { + s.signingMethod = method.method + + return nil +} + +func WithSigningMethod(method jwt.SigningMethod) *signingMethodOption { + return &signingMethodOption{method} +} + +// KeyID +type keyIDOption string + +func (id keyIDOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { + s.keyID = string(id) + + return nil +} + +func WithKeyID(id string) keyIDOption { + return keyIDOption(id) +} + +// PrivateKey +type privateKeyOption struct { + key interface{} +} + +func (key *privateKeyOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { + s.privateKey = key.key + + return nil +} + +func WithPrivateKey(key interface{}) *privateKeyOption { + return &privateKeyOption{key} +} + +// PrivateKey +type rsaPrivateKeyPemContentOption struct { + keyContent []byte +} + +func (key *rsaPrivateKeyPemContentOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { + privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(key.keyContent) + if err != nil { + return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotparseRSAPrivateKey, err)) + } + s.privateKey = privateKey + + return nil +} + +func WithRSAPrivateKeyPEMContent(key []byte) *rsaPrivateKeyPemContentOption { + return &rsaPrivateKeyPemContentOption{key} +} + +// PrivateKey +type rsaPrivateKeyPemFileOption struct { + path string +} + +func (key *rsaPrivateKeyPemFileOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { + if len(key.path) > 0 && key.path[0] == '~' { + usr, err := user.Current() + if err != nil { + return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotParseHomeDir, err)) + } + key.path = filepath.Join(usr.HomeDir, key.path[1:]) + } + bytes, err := os.ReadFile(key.path) + if err != nil { + return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotReadPrivateKeyFile, err)) + } + + o := rsaPrivateKeyPemContentOption{bytes} + + return o.ApplyJWTTokenSourceOption(s) +} + +func WithRSAPrivateKeyPEMFile(path string) *rsaPrivateKeyPemFileOption { + return &rsaPrivateKeyPemFileOption{path} +} + +func NewJWTTokenSource(opts ...JWTTokenSourceOption) (*jwtTokenSource, error) { + s := &jwtTokenSource{ + tokenTTL: defaultJWTTokenTTL, + } + + var err error + for _, opt := range opts { + if opt != nil { + err = opt.ApplyJWTTokenSourceOption(s) + if err != nil { + return nil, xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotApplyJWTOption, err)) + } + } + } + + if s.signingMethod == nil { + return nil, xerrors.WithStackTrace(errNoSigningMethodError) + } + + if s.privateKey == nil { + return nil, xerrors.WithStackTrace(errNoPrivateKeyError) + } + + return s, nil +} + +type jwtTokenSource struct { + signingMethod jwt.SigningMethod + keyID string + privateKey interface{} // symmetric key in case of symmetric algorithm + + // JWT claims + issuer string + subject string + audience []string + id string + tokenTTL time.Duration +} + +func (s *jwtTokenSource) Token() (Token, error) { + var ( + now = time.Now() + issued = jwt.NewNumericDate(now.UTC()) + expire = jwt.NewNumericDate(now.Add(s.tokenTTL).UTC()) + err error + ) + t := jwt.Token{ + Header: map[string]interface{}{ + "typ": "JWT", + "alg": s.signingMethod.Alg(), + "kid": s.keyID, + }, + Claims: jwt.RegisteredClaims{ + Issuer: s.issuer, + Subject: s.subject, + IssuedAt: issued, + Audience: s.audience, + ExpiresAt: expire, + ID: s.id, + }, + Method: s.signingMethod, + } + + var token Token + token.Token, err = t.SignedString(s.privateKey) + if err != nil { + return token, xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotSignJWTToken, err)) + } + token.TokenType = "urn:ietf:params:oauth:token-type:jwt" + + return token, nil +} + +func (s *jwtTokenSource) String() string { + buffer := xstring.Buffer() + defer buffer.Free() + fmt.Fprintf( + buffer, + "JWTTokenSource{Method:%s,KeyID:%s,Issuer:%q,Subject:%q,Audience:%v,ID:%s,TokenTTL:%s}", + s.signingMethod.Alg(), + s.keyID, + s.issuer, + s.subject, + s.audience, + s.id, + s.tokenTTL, + ) + + return buffer.String() +} diff --git a/internal/credentials/oauth2_test.go b/internal/credentials/oauth2_test.go new file mode 100644 index 000000000..3ce391961 --- /dev/null +++ b/internal/credentials/oauth2_test.go @@ -0,0 +1,470 @@ +package credentials + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "os/user" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/require" +) + +var ( + testPrivateKeyContent = "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC75/JS3rMcLJxv\nFgpOzF5+2gH+Yig3RE2MTl9uwC0BZKAv6foYr7xywQyWIK+W1cBhz8R4LfFmZo2j\nM0aCvdRmNBdW0EDSTnHLxCsFhoQWLVq+bI5f5jzkcoiioUtaEpADPqwgVULVtN/n\nnPJiZ6/dU30C3jmR6+LUgEntUtWt3eq3xQIn5lG3zC1klBY/HxtfH5Hu8xBvwRQT\nJnh3UpPLj8XwSmriDgdrhR7o6umWyVuGrMKlLHmeivlfzjYtfzO1MOIMG8t2/zxG\nR+xb4Vwks73sH1KruH/0/JMXU97npwpe+Um+uXhpldPygGErEia7abyZB2gMpXqr\nWYKMo02NAgMBAAECggEAO0BpC5OYw/4XN/optu4/r91bupTGHKNHlsIR2rDzoBhU\nYLd1evpTQJY6O07EP5pYZx9mUwUdtU4KRJeDGO/1/WJYp7HUdtxwirHpZP0lQn77\nuccuX/QQaHLrPekBgz4ONk+5ZBqukAfQgM7fKYOLk41jgpeDbM2Ggb6QUSsJISEp\nzrwpI/nNT/wn+Hvx4DxrzWU6wF+P8kl77UwPYlTA7GsT+T7eKGVH8xsxmK8pt6lg\nsvlBA5XosWBWUCGLgcBkAY5e4ZWbkdd183o+oMo78id6C+PQPE66PLDtHWfpRRmN\nm6XC03x6NVhnfvfozoWnmS4+e4qj4F/emCHvn0GMywKBgQDLXlj7YPFVXxZpUvg/\nrheVcCTGbNmQJ+4cZXx87huqwqKgkmtOyeWsRc7zYInYgraDrtCuDBCfP//ZzOh0\nLxepYLTPk5eNn/GT+VVrqsy35Ccr60g7Lp/bzb1WxyhcLbo0KX7/6jl0lP+VKtdv\nmto+4mbSBXSM1Y5BVVoVgJ3T/wKBgQDsiSvPRzVi5TTj13x67PFymTMx3HCe2WzH\nJUyepCmVhTm482zW95pv6raDr5CTO6OYpHtc5sTTRhVYEZoEYFTM9Vw8faBtluWG\nBjkRh4cIpoIARMn74YZKj0C/0vdX7SHdyBOU3bgRPHg08Hwu3xReqT1kEPSI/B2V\n4pe5fVrucwKBgQCNFgUxUA3dJjyMES18MDDYUZaRug4tfiYouRdmLGIxUxozv6CG\nZnbZzwxFt+GpvPUV4f+P33rgoCvFU+yoPctyjE6j+0aW0DFucPmb2kBwCu5J/856\nkFwCx3blbwFHAco+SdN7g2kcwgmV2MTg/lMOcU7XwUUcN0Obe7UlWbckzQKBgQDQ\nnXaXHL24GGFaZe4y2JFmujmNy1dEsoye44W9ERpf9h1fwsoGmmCKPp90az5+rIXw\nFXl8CUgk8lXW08db/r4r+ma8Lyx0GzcZyplAnaB5/6j+pazjSxfO4KOBy4Y89Tb+\nTP0AOcCi6ws13bgY+sUTa/5qKA4UVw+c5zlb7nRpgwKBgGXAXhenFw1666482iiN\ncHSgwc4ZHa1oL6aNJR1XWH+aboBSwR+feKHUPeT4jHgzRGo/aCNHD2FE5I8eBv33\nof1kWYjAO0YdzeKrW0rTwfvt9gGg+CS397aWu4cy+mTI+MNfBgeDAIVBeJOJXLlX\nhL8bFAuNNVrCOp79TNnNIsh7\n-----END PRIVATE KEY-----\n" //nolint:lll + testPublicKeyContent = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAu+fyUt6zHCycbxYKTsxe\nftoB/mIoN0RNjE5fbsAtAWSgL+n6GK+8csEMliCvltXAYc/EeC3xZmaNozNGgr3U\nZjQXVtBA0k5xy8QrBYaEFi1avmyOX+Y85HKIoqFLWhKQAz6sIFVC1bTf55zyYmev\n3VN9At45kevi1IBJ7VLVrd3qt8UCJ+ZRt8wtZJQWPx8bXx+R7vMQb8EUEyZ4d1KT\ny4/F8Epq4g4Ha4Ue6OrplslbhqzCpSx5nor5X842LX8ztTDiDBvLdv88RkfsW+Fc\nJLO97B9Sq7h/9PyTF1Pe56cKXvlJvrl4aZXT8oBhKxImu2m8mQdoDKV6q1mCjKNN\njQIDAQAB\n-----END PUBLIC KEY-----\n" //nolint:lll +) + +type httpServerKey int + +const ( + keyServerAddr httpServerKey = 42 +) + +func WriteErr(w http.ResponseWriter, err error) { + WriteResponse(w, http.StatusInternalServerError, err.Error(), "text/html") +} + +func WriteResponse(w http.ResponseWriter, code int, body string, bodyType string) { + w.Header().Add("Content-Type", bodyType) + w.Header().Add("Content-Length", strconv.Itoa(len(body))) + w.WriteHeader(code) + _, _ = w.Write([]byte(body)) +} + +func runTokenExchangeServer( + ctx context.Context, + cancel context.CancelFunc, + port int, + currentTestParams *Oauth2TokenExchangeTestParams, +) { + defer cancel() + mux := http.NewServeMux() + mux.HandleFunc("/exchange", func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + WriteErr(w, err) + } + + fmt.Printf("got token exchange request: %s\n", body) + + params, err := url.ParseQuery(string(body)) + if err != nil { + WriteErr(w, err) + } + expectedParams := url.Values{} + expectedParams.Set("scope", "test_scope1 test_scope2") + expectedParams.Set("audience", "test_audience") + expectedParams.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") + expectedParams.Set("requested_token_type", "urn:ietf:params:oauth:token-type:access_token") + expectedParams.Set("subject_token", "test_source_token") + expectedParams.Set("subject_token_type", "urn:ietf:params:oauth:token-type:test_jwt") + + if !reflect.DeepEqual(expectedParams, params) { + WriteResponse(w, 555, fmt.Sprintf("Params are not as expected: \"%s\" != \"%s\"", + expectedParams.Encode(), body), "text/html") // error will be checked in test thread + } else { + WriteResponse(w, currentTestParams.Status, currentTestParams.Response, "application/json") + } + }) + server := http.Server{ + Addr: fmt.Sprintf(":%d", port), + Handler: mux, + BaseContext: func(l net.Listener) context.Context { + ctx = context.WithValue(ctx, keyServerAddr, l.Addr().String()) + + return ctx + }, + ReadHeaderTimeout: 10 * time.Second, + } + err := server.ListenAndServe() + if err != nil { + fmt.Printf("Failed to run http server: %s", err.Error()) + } +} + +type Oauth2TokenExchangeTestParams struct { + Response string + Status int + ExpectedToken string + ExpectedError error + ExpectedErrorPart string +} + +func TestOauth2TokenExchange(t *testing.T) { + var currentTestParams Oauth2TokenExchangeTestParams + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + runCtx, runCancel := context.WithCancel(ctx) + go runTokenExchangeServer(runCtx, runCancel, 14321, ¤tTestParams) + + testsParams := []Oauth2TokenExchangeTestParams{ + { + Response: `{"access_token":"test_token","token_type":"BEARER","expires_in":42,"some_other_field":"x"}`, + Status: http.StatusOK, + ExpectedToken: "Bearer test_token", + }, + { + Response: `aaa`, + Status: http.StatusOK, + ExpectedToken: "", + ExpectedError: errCouldNotParseResponse, + }, + { + Response: `{}`, + Status: http.StatusBadRequest, + ExpectedToken: "", + ExpectedError: errCouldNotExchangeToken, + }, + { + Response: `not json`, + Status: http.StatusNotFound, + ExpectedToken: "", + ExpectedError: errCouldNotExchangeToken, + }, + { + Response: `{"error": "invalid_request"}`, + Status: http.StatusBadRequest, + ExpectedToken: "", + ExpectedError: errCouldNotExchangeToken, + ExpectedErrorPart: "400 Bad Request, error: invalid_request", + }, + { + Response: `{"error":"unauthorized_client","error_description":"something went bad"}`, + Status: http.StatusInternalServerError, + ExpectedToken: "", + ExpectedError: errCouldNotExchangeToken, + ExpectedErrorPart: "500 Internal Server Error, error: unauthorized_client, description: \"something went bad\"", //nolint:lll + }, + { + Response: `{"error_description":"something went bad","error_uri":"my_error_uri"}`, + Status: http.StatusForbidden, + ExpectedToken: "", + ExpectedError: errCouldNotExchangeToken, + ExpectedErrorPart: "403 Forbidden, description: \"something went bad\", error_uri: my_error_uri", + }, + { + Response: `{"access_token":"test_token","token_type":"","expires_in":42,"some_other_field":"x"}`, + Status: http.StatusOK, + ExpectedToken: "", + ExpectedError: errUnsupportedTokenType, + }, + { + Response: `{"access_token":"test_token","token_type":"basic","expires_in":42,"some_other_field":"x"}`, + Status: http.StatusOK, + ExpectedToken: "", + ExpectedError: errUnsupportedTokenType, + }, + { + Response: `{"access_token":"test_token","token_type":"Bearer","expires_in":-42,"some_other_field":"x"}`, + Status: http.StatusOK, + ExpectedToken: "", + ExpectedError: errIncorrectExpirationTime, + }, + { + Response: `{"access_token":"test_token","token_type":"Bearer","expires_in":42,"scope":"s"}`, + Status: http.StatusOK, + ExpectedToken: "", + ExpectedError: errDifferentScope, + ExpectedErrorPart: "Expected \"test_scope1 test_scope2\", but got \"s\"", + }, + } + + for _, params := range testsParams { + currentTestParams = params + + client, err := NewOauth2TokenExchangeCredentials( + WithTokenEndpoint("http://localhost:14321/exchange"), + WithAudience("test_audience"), + WithScope("test_scope1", "test_scope2"), + WithSubjectToken(NewFixedTokenSource("test_source_token", "urn:ietf:params:oauth:token-type:test_jwt")), + ) + require.NoError(t, err) + + token, err := client.Token(ctx) + if params.ExpectedErrorPart == "" && params.ExpectedError == nil { + require.NoError(t, err) + } else { + if params.ExpectedErrorPart != "" { + require.ErrorContains(t, err, params.ExpectedErrorPart) + } + if params.ExpectedError != nil { + require.ErrorIs(t, err, params.ExpectedError) + } + } + require.Equal(t, params.ExpectedToken, token) + } +} + +func TestOauth2TokenUpdate(t *testing.T) { + var currentTestParams Oauth2TokenExchangeTestParams + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + runCtx, runCancel := context.WithCancel(ctx) + go runTokenExchangeServer(runCtx, runCancel, 14322, ¤tTestParams) + + // First exchange + currentTestParams = Oauth2TokenExchangeTestParams{ + Response: `{"access_token":"test_token_1", "token_type":"Bearer","expires_in":2}`, + Status: http.StatusOK, + } + + client, err := NewOauth2TokenExchangeCredentials( + WithTokenEndpoint("http://localhost:14322/exchange"), + WithAudience("test_audience"), + WithScope("test_scope1", "test_scope2"), + WithFixedSubjectToken("test_source_token", "urn:ietf:params:oauth:token-type:test_jwt"), + ) + require.NoError(t, err) + + token, err := client.Token(ctx) + t1 := time.Now() + require.NoError(t, err) + require.Equal(t, "Bearer test_token_1", token) + + // Second exchange + currentTestParams = Oauth2TokenExchangeTestParams{ + Response: `{"access_token":"test_token_2", "token_type":"Bearer","expires_in":10000}`, + Status: http.StatusOK, + } + + token, err = client.Token(ctx) + t2 := time.Now() + require.NoError(t, err) + if t2.Sub(t1) <= time.Second { // half expire period => no attempts to update + require.Equal(t, "Bearer test_token_1", token) + } + + time.Sleep(time.Second) // wait half expire period + for i := 1; i <= 100; i++ { + t3 := time.Now() + token, err = client.Token(ctx) + require.NoError(t, err) + if t3.Sub(t1) >= 2*time.Second { + require.Equal(t, "Bearer test_token_2", token) // Must update at least sync + } + if token == "Bearer test_token_2" { // already updated + break + } + require.Equal(t, "Bearer test_token_1", token) + + time.Sleep(10 * time.Millisecond) + } + + // Third exchange (never got, because token will be expired later) + currentTestParams = Oauth2TokenExchangeTestParams{ + Response: `{}`, + Status: http.StatusInternalServerError, + } + + for i := 1; i <= 5; i++ { + token, err = client.Token(ctx) + require.NoError(t, err) + require.Equal(t, "Bearer test_token_2", token) + } +} + +func TestWrongParameters(t *testing.T) { + _, err := NewOauth2TokenExchangeCredentials( + // No endpoint + WithFixedActorToken("test_source_token", "urn:ietf:params:oauth:token-type:test_jwt"), + WithRequestedTokenType("access_token"), + ) + require.ErrorIs(t, err, errEmptyTokenEndpointError) +} + +type errorTokenSource struct{} + +var errTokenSource = errors.New("test error") + +func (s *errorTokenSource) Token() (Token, error) { + return Token{"", ""}, errTokenSource +} + +func TestErrorInSourceToken(t *testing.T) { + // Create + _, err := NewOauth2TokenExchangeCredentials( + WithTokenEndpoint("http:trololo"), + WithJWTSubjectToken( + WithRSAPrivateKeyPEMContent([]byte("invalid")), + WithKeyID("key_id"), + WithSigningMethod(jwt.SigningMethodRS256), + WithIssuer("test_issuer"), + WithAudience("test_audience"), + ), + ) + require.ErrorIs(t, err, errCouldNotCreateTokenSource) + + // Use + client, err := NewOauth2TokenExchangeCredentials( + WithTokenEndpoint("http:trololo"), + WithGrantType("grant_type"), + WithRequestTimeout(time.Second), + WithResource("res"), + WithFixedSubjectToken("t", "tt"), + WithActorToken(&errorTokenSource{}), + WithSourceInfo("TestErrorInSourceToken"), + ) + require.NoError(t, err) + + // Check that token prints well + formatted := fmt.Sprint(client) + require.Equal(t, `OAuth2TokenExchange{Endpoint:"http:trololo",GrantType:grant_type,Resource:res,Audience:[],Scope:[],RequestedTokenType:urn:ietf:params:oauth:token-type:access_token,SubjectToken:FixedTokenSource{Token:"****(CRC-32c: 856A5AA8)",Type:tt},ActorToken:&{},From:"TestErrorInSourceToken"}`, formatted) //nolint:lll + + token, err := client.Token(context.Background()) + require.ErrorIs(t, err, errTokenSource) + require.Equal(t, "", token) + + client, err = NewOauth2TokenExchangeCredentials( + WithTokenEndpoint("http:trololo"), + WithGrantType("grant_type"), + WithRequestTimeout(time.Second), + WithResource("res"), + WithSubjectToken(&errorTokenSource{}), + ) + require.NoError(t, err) + + token, err = client.Token(context.Background()) + require.ErrorIs(t, err, errTokenSource) + require.Equal(t, "", token) +} + +func TestErrorInHTTPRequest(t *testing.T) { + client, err := NewOauth2TokenExchangeCredentials( + WithTokenEndpoint("http://invalid_host:42/exchange"), + WithJWTSubjectToken( + WithRSAPrivateKeyPEMContent([]byte(testPrivateKeyContent)), + WithKeyID("key_id"), + WithSigningMethod(jwt.SigningMethodRS256), + WithIssuer("test_issuer"), + WithAudience("test_audience"), + ), + WithJWTActorToken( + WithRSAPrivateKeyPEMContent([]byte(testPrivateKeyContent)), + WithKeyID("key_id"), + WithSigningMethod(jwt.SigningMethodRS256), + WithIssuer("test_issuer"), + ), + WithScope("1", "2", "3"), + WithSourceInfo("TestErrorInHTTPRequest"), + ) + require.NoError(t, err) + + token, err := client.Token(context.Background()) + require.ErrorIs(t, err, errCouldNotExchangeToken) + require.Equal(t, "", token) + + // check format: + formatted := fmt.Sprint(client) + require.Equal(t, `OAuth2TokenExchange{Endpoint:"http://invalid_host:42/exchange",GrantType:urn:ietf:params:oauth:grant-type:token-exchange,Resource:,Audience:[],Scope:[1 2 3],RequestedTokenType:urn:ietf:params:oauth:token-type:access_token,SubjectToken:JWTTokenSource{Method:RS256,KeyID:key_id,Issuer:"test_issuer",Subject:"",Audience:[test_audience],ID:,TokenTTL:1h0m0s},ActorToken:JWTTokenSource{Method:RS256,KeyID:key_id,Issuer:"test_issuer",Subject:"",Audience:[],ID:,TokenTTL:1h0m0s},From:"TestErrorInHTTPRequest"}`, formatted) //nolint:lll +} + +func TestJWTTokenSource(t *testing.T) { + publicKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(testPublicKeyContent)) + require.NoError(t, err) + getPublicKey := func(*jwt.Token) (interface{}, error) { + return publicKey, nil + } + + var src TokenSource + src, err = NewJWTTokenSource( + WithRSAPrivateKeyPEMContent([]byte(testPrivateKeyContent)), + WithKeyID("key_id"), + WithSigningMethod(jwt.SigningMethodRS256), + WithIssuer("test_issuer"), + WithAudience("test_audience"), + ) + require.NoError(t, err) + + token, err := src.Token() + require.NoError(t, err) + require.Equal(t, "urn:ietf:params:oauth:token-type:jwt", token.TokenType) + + claims := jwt.RegisteredClaims{} + parsedToken, err := jwt.ParseWithClaims(token.Token, &claims, getPublicKey) + require.NoError(t, err) + + require.True(t, parsedToken.Valid) + require.NoError(t, parsedToken.Claims.Valid()) + require.Equal(t, "test_issuer", claims.Issuer) + require.Equal(t, "test_audience", claims.Audience[0]) + require.Equal(t, "key_id", parsedToken.Header["kid"].(string)) + require.Equal(t, "RS256", parsedToken.Header["alg"].(string)) +} + +func TestJWTTokenBadParams(t *testing.T) { + privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(testPrivateKeyContent)) + require.NoError(t, err) + + _, err = NewJWTTokenSource( + // no private key + WithKeyID("key_id"), + WithSigningMethod(jwt.SigningMethodRS256), + WithIssuer("test_issuer"), + WithAudience("test_audience"), + WithID("id"), + ) + require.ErrorIs(t, err, errNoPrivateKeyError) + + _, err = NewJWTTokenSource( + WithPrivateKey(privateKey), + WithKeyID("key_id"), + // no signing method + WithSubject("s"), + WithTokenTTL(time.Minute), + WithAudience("test_audience"), + ) + require.ErrorIs(t, err, errNoSigningMethodError) +} + +func TestJWTTokenSourceReadPrivateKeyFromFile(t *testing.T) { + const perm = 0o600 + usr, err := user.Current() + require.NoError(t, err) + fileName := strconv.Itoa(time.Now().Second()) + filePath := filepath.Join(usr.HomeDir, fileName) + beautifulFilePath := filepath.Join("~", fileName) + err = os.WriteFile( + filePath, + []byte(testPrivateKeyContent), + perm, + ) + require.NoError(t, err) + defer os.Remove(filePath) + + var src TokenSource + src, err = NewJWTTokenSource( + WithRSAPrivateKeyPEMFile(beautifulFilePath), + WithKeyID("key_id"), + WithSigningMethod(jwt.SigningMethodRS256), + WithIssuer("test_issuer"), + WithAudience("test_audience"), + ) + require.NoError(t, err) + + token, err := src.Token() + require.NoError(t, err) + + // parse token + publicKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(testPublicKeyContent)) + require.NoError(t, err) + getPublicKey := func(*jwt.Token) (interface{}, error) { + return publicKey, nil + } + + claims := jwt.RegisteredClaims{} + _, err = jwt.ParseWithClaims(token.Token, &claims, getPublicKey) + require.NoError(t, err) +} diff --git a/internal/credentials/source_info.go b/internal/credentials/source_info.go index 017a49a2e..53f7c5065 100644 --- a/internal/credentials/source_info.go +++ b/internal/credentials/source_info.go @@ -14,6 +14,12 @@ func (sourceInfo SourceInfoOption) ApplyAccessTokenCredentialsOption(h *AccessTo h.sourceInfo = string(sourceInfo) } +func (sourceInfo SourceInfoOption) ApplyOauth2CredentialsOption(h *oauth2TokenExchange) error { + h.sourceInfo = string(sourceInfo) + + return nil +} + // WithSourceInfo option append to credentials object the source info for reporting source info details on error case func WithSourceInfo(sourceInfo string) SourceInfoOption { return SourceInfoOption(sourceInfo) diff --git a/options.go b/options.go index c4ed15129..305f700fc 100644 --- a/options.go +++ b/options.go @@ -54,6 +54,19 @@ func WithAccessTokenCredentials(accessToken string) Option { ) } +// WithOauth2TokenExchangeCredentials adds credentials that exchange token using +// OAuth 2.0 token exchange protocol: +// https://www.rfc-editor.org/rfc/rfc8693 +func WithOauth2TokenExchangeCredentials( + opts ...credentials.Oauth2TokenExchangeCredentialsOption, +) Option { + opts = append(opts, credentials.WithSourceInfo("ydb.WithOauth2TokenExchangeCredentials(opts)")) + + return WithCreateCredentialsFunc(func(context.Context) (credentials.Credentials, error) { + return credentials.NewOauth2TokenExchangeCredentials(opts...) + }) +} + // WithApplicationName add provided application name to all api requests func WithApplicationName(applicationName string) Option { return func(ctx context.Context, c *Driver) error {