From 9813cfdf62899675e28311b247cc83f81b77dab9 Mon Sep 17 00:00:00 2001 From: Vasily Gerasimov Date: Thu, 11 Apr 2024 21:18:54 +0000 Subject: [PATCH] Additional options --- credentials/options.go | 52 ++++- internal/credentials/oauth2.go | 297 ++++++++++++++++++++++++---- internal/credentials/oauth2_test.go | 92 ++++++++- internal/credentials/source_info.go | 6 + options.go | 10 + 5 files changed, 400 insertions(+), 57 deletions(-) diff --git a/credentials/options.go b/credentials/options.go index 572f89b10..5d944dae6 100644 --- a/credentials/options.go +++ b/credentials/options.go @@ -9,6 +9,12 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials" ) +type Oauth2TokenExchangeCredentialsOption = credentials.Oauth2TokenExchangeCredentialsOption + +type TokenSource = credentials.TokenSource + +type Token = credentials.Token + // WithSourceInfo option append to credentials object the source info for reporting source info details on error case func WithSourceInfo(sourceInfo string) credentials.SourceInfoOption { return credentials.WithSourceInfo(sourceInfo) @@ -20,45 +26,65 @@ func WithGrpcDialOptions(opts ...grpc.DialOption) credentials.StaticCredentialsO } // TokenEndpoint -func WithTokenEndpoint(endpoint string) credentials.Oauth2TokenExchangeCredentialsOption { +func WithTokenEndpoint(endpoint string) Oauth2TokenExchangeCredentialsOption { return credentials.WithTokenEndpoint(endpoint) } // GrantType -func WithGrantType(grantType string) credentials.Oauth2TokenExchangeCredentialsOption { +func WithGrantType(grantType string) Oauth2TokenExchangeCredentialsOption { return credentials.WithGrantType(grantType) } // Resource -func WithResource(resource string) credentials.Oauth2TokenExchangeCredentialsOption { +func WithResource(resource string) Oauth2TokenExchangeCredentialsOption { return credentials.WithResource(resource) } // RequestedTokenType -func WithRequestedTokenType(requestedTokenType string) credentials.Oauth2TokenExchangeCredentialsOption { +func WithRequestedTokenType(requestedTokenType string) Oauth2TokenExchangeCredentialsOption { return credentials.WithRequestedTokenType(requestedTokenType) } // Scope -func WithScope(scope ...string) credentials.Oauth2TokenExchangeCredentialsOption { +func WithScope(scope ...string) Oauth2TokenExchangeCredentialsOption { return credentials.WithScope(scope...) } // RequestTimeout -func WithRequestTimeout(timeout time.Duration) credentials.Oauth2TokenExchangeCredentialsOption { +func WithRequestTimeout(timeout time.Duration) Oauth2TokenExchangeCredentialsOption { return credentials.WithRequestTimeout(timeout) } // SubjectTokenSource -func WithSubjectToken(subjectToken credentials.TokenSource) credentials.Oauth2TokenExchangeCredentialsOption { +func WithSubjectToken(subjectToken credentials.TokenSource) Oauth2TokenExchangeCredentialsOption { return credentials.WithSubjectToken(subjectToken) } +// SubjectTokenSource +func WithFixedSubjectToken(token, tokenType string) Oauth2TokenExchangeCredentialsOption { + return credentials.WithFixedSubjectToken(token, tokenType) +} + +// SubjectTokenSource +func WithJWTSubjectToken(opts ...credentials.JWTTokenSourceOption) Oauth2TokenExchangeCredentialsOption { + return credentials.WithJWTSubjectToken(opts...) +} + // ActorTokenSource -func WithActorToken(actorToken credentials.TokenSource) credentials.Oauth2TokenExchangeCredentialsOption { +func WithActorToken(actorToken credentials.TokenSource) Oauth2TokenExchangeCredentialsOption { return credentials.WithActorToken(actorToken) } +// ActorTokenSource +func WithFixedActorToken(token, tokenType string) Oauth2TokenExchangeCredentialsOption { + return credentials.WithFixedActorToken(token, tokenType) +} + +// ActorTokenSource +func WithJWTActorToken(opts ...credentials.JWTTokenSourceOption) Oauth2TokenExchangeCredentialsOption { + return credentials.WithJWTActorToken(opts...) +} + // Audience type oauthCredentialsAndJWTCredentialsOption interface { credentials.Oauth2TokenExchangeCredentialsOption @@ -103,3 +129,13 @@ func WithKeyID(id string) credentials.JWTTokenSourceOption { func WithPrivateKey(key interface{}) credentials.JWTTokenSourceOption { return credentials.WithPrivateKey(key) } + +// PrivateKey +func WithRSAPrivateKeyPEMContent(key []byte) credentials.JWTTokenSourceOption { + return credentials.WithRSAPrivateKeyPEMContent(key) +} + +// PrivateKey +func WithRSAPrivateKeyPEMFile(path string) credentials.JWTTokenSourceOption { + return credentials.WithRSAPrivateKeyPEMFile(path) +} diff --git a/internal/credentials/oauth2.go b/internal/credentials/oauth2.go index 2df73e952..f49826127 100644 --- a/internal/credentials/oauth2.go +++ b/internal/credentials/oauth2.go @@ -8,6 +8,9 @@ import ( "io" "net/http" "net/url" + "os" + "os/user" + "path/filepath" "strconv" "strings" "sync" @@ -16,7 +19,10 @@ import ( "github.com/golang-jwt/jwt/v4" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/secret" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring" ) const ( @@ -26,27 +32,35 @@ const ( ) var ( - errEmptyTokenEndpointError = errors.New("OAuth2 token exchange: empty token endpoint") - errCouldNotParseResponse = errors.New("OAuth2 token exchange: could not parse response") - errCouldNotExchangeToken = errors.New("OAuth2 token exchange: could not exchange token") - errUnsupportedTokenType = errors.New("OAuth2 token exchange: unsupported token type") - errIncorrectExpirationTime = errors.New("OAuth2 token exchange: incorrect expiration time") - errDifferentScope = errors.New("OAuth2 token exchange: got different scope") - errCouldNotMakeHTTPRequest = errors.New("OAuth2 token exchange: could not make http request") - errNoSigningMethodError = errors.New("JWT token source: no signing method") - errNoPrivateKeyError = errors.New("JWT token source: no private key") - errCouldNotSignJWTToken = errors.New("JWT token source: could not sign jwt token") + errEmptyTokenEndpointError = errors.New("OAuth2 token exchange: empty token endpoint") + errCouldNotParseResponse = errors.New("OAuth2 token exchange: could not parse response") + errCouldNotExchangeToken = errors.New("OAuth2 token exchange: could not exchange token") + errUnsupportedTokenType = errors.New("OAuth2 token exchange: unsupported token type") + errIncorrectExpirationTime = errors.New("OAuth2 token exchange: incorrect expiration time") + errDifferentScope = errors.New("OAuth2 token exchange: got different scope") + errCouldNotMakeHTTPRequest = errors.New("OAuth2 token exchange: could not make http request") + errCouldNotApplyOption = errors.New("OAuth2 token exchange: could not apply option") + errCouldNotCreateTokenSource = errors.New("OAuth2 token exchange: could not createTokenSource") + errNoSigningMethodError = errors.New("JWT token source: no signing method") + errNoPrivateKeyError = errors.New("JWT token source: no private key") + errCouldNotSignJWTToken = errors.New("JWT token source: could not sign jwt token") + errCouldNotApplyJWTOption = errors.New("JWT token source: could not apply option") + errCouldNotparseRSAPrivateKey = errors.New("JWT token source: could not parse RSA private key from PEM") + errCouldNotParseHomeDir = errors.New("JWT token source: could not parse home dir for private key") + errCouldNotReadPrivateKeyFile = errors.New("JWT token source: could not read from private key file") ) type Oauth2TokenExchangeCredentialsOption interface { - ApplyOauth2CredentialsOption(c *oauth2TokenExchange) + ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error } // TokenEndpoint type tokenEndpointOption string -func (endpoint tokenEndpointOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) { +func (endpoint tokenEndpointOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error { c.tokenEndpoint = string(endpoint) + + return nil } func WithTokenEndpoint(endpoint string) tokenEndpointOption { @@ -56,8 +70,10 @@ func WithTokenEndpoint(endpoint string) tokenEndpointOption { // GrantType type grantTypeOption string -func (grantType grantTypeOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) { +func (grantType grantTypeOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error { c.grantType = string(grantType) + + return nil } func WithGrantType(grantType string) grantTypeOption { @@ -67,8 +83,10 @@ func WithGrantType(grantType string) grantTypeOption { // Resource type resourceOption string -func (resource resourceOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) { +func (resource resourceOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error { c.resource = string(resource) + + return nil } func WithResource(resource string) resourceOption { @@ -78,8 +96,10 @@ func WithResource(resource string) resourceOption { // RequestedTokenType type requestedTokenTypeOption string -func (requestedTokenType requestedTokenTypeOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) { +func (requestedTokenType requestedTokenTypeOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error { c.requestedTokenType = string(requestedTokenType) + + return nil } func WithRequestedTokenType(requestedTokenType string) requestedTokenTypeOption { @@ -89,8 +109,10 @@ func WithRequestedTokenType(requestedTokenType string) requestedTokenTypeOption // Audience type audienceOption []string -func (audience audienceOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) { +func (audience audienceOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error { c.audience = audience + + return nil } func WithAudience(audience ...string) audienceOption { @@ -100,8 +122,10 @@ func WithAudience(audience ...string) audienceOption { // Scope type scopeOption []string -func (scope scopeOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) { +func (scope scopeOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error { c.scope = scope + + return nil } func WithScope(scope ...string) scopeOption { @@ -111,38 +135,96 @@ func WithScope(scope ...string) scopeOption { // RequestTimeout type requestTimeoutOption time.Duration -func (timeout requestTimeoutOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) { +func (timeout requestTimeoutOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error { c.requestTimeout = time.Duration(timeout) + + return nil } func WithRequestTimeout(timeout time.Duration) requestTimeoutOption { return requestTimeoutOption(timeout) } -// SubjectTokenSource -type subjectTokenSourceOption struct { - source TokenSource +const ( + SubjectTokenSourceType = 1 + ActorTokenSourceType = 2 +) + +// SubjectTokenSource/ActorTokenSource +type tokenSourceOption struct { + source TokenSource + createFunc func() (TokenSource, error) + tokenSourceType int +} + +func (tokenSource *tokenSourceOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error { + src := tokenSource.source + var err error + if src == nil { + src, err = tokenSource.createFunc() + if err != nil { + return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotCreateTokenSource, err)) + } + } + switch tokenSource.tokenSourceType { + case SubjectTokenSourceType: + c.subjectTokenSource = src + case ActorTokenSourceType: + c.actorTokenSource = src + } + + return nil +} + +func WithSubjectToken(subjectToken TokenSource) *tokenSourceOption { + return &tokenSourceOption{ + source: subjectToken, + tokenSourceType: SubjectTokenSourceType, + } } -func (subjectToken *subjectTokenSourceOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) { - c.subjectTokenSource = subjectToken.source +func WithFixedSubjectToken(token, tokenType string) *tokenSourceOption { + return &tokenSourceOption{ + createFunc: func() (TokenSource, error) { + return NewFixedTokenSource(token, tokenType), nil + }, + tokenSourceType: SubjectTokenSourceType, + } } -func WithSubjectToken(subjectToken TokenSource) *subjectTokenSourceOption { - return &subjectTokenSourceOption{subjectToken} +func WithJWTSubjectToken(opts ...JWTTokenSourceOption) *tokenSourceOption { + return &tokenSourceOption{ + createFunc: func() (TokenSource, error) { + return NewJWTTokenSource(opts...) + }, + tokenSourceType: SubjectTokenSourceType, + } } // ActorTokenSource -type actorTokenSourceOption struct { - source TokenSource +func WithActorToken(actorToken TokenSource) *tokenSourceOption { + return &tokenSourceOption{ + source: actorToken, + tokenSourceType: ActorTokenSourceType, + } } -func (actorToken *actorTokenSourceOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) { - c.actorTokenSource = actorToken.source +func WithFixedActorToken(token, tokenType string) *tokenSourceOption { + return &tokenSourceOption{ + createFunc: func() (TokenSource, error) { + return NewFixedTokenSource(token, tokenType), nil + }, + tokenSourceType: ActorTokenSourceType, + } } -func WithActorToken(actorToken TokenSource) *actorTokenSourceOption { - return &actorTokenSourceOption{actorToken} +func WithJWTActorToken(opts ...JWTTokenSourceOption) *tokenSourceOption { + return &tokenSourceOption{ + createFunc: func() (TokenSource, error) { + return NewJWTTokenSource(opts...) + }, + tokenSourceType: ActorTokenSourceType, + } } type oauth2TokenExchange struct { @@ -175,6 +257,8 @@ type oauth2TokenExchange struct { mutex sync.RWMutex updating atomic.Bool // true if separate goroutine is run and updates token in background + + sourceInfo string } func NewOauth2TokenExchangeCredentials( @@ -184,11 +268,16 @@ func NewOauth2TokenExchangeCredentials( grantType: "urn:ietf:params:oauth:grant-type:token-exchange", requestedTokenType: "urn:ietf:params:oauth:token-type:access_token", requestTimeout: defaultRequestTimeout, + sourceInfo: stack.Record(1), } + var err error for _, opt := range opts { if opt != nil { - opt.ApplyOauth2CredentialsOption(c) + err = opt.ApplyOauth2CredentialsOption(c) + if err != nil { + return nil, xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotApplyOption, err)) + } } } @@ -443,6 +532,33 @@ func (provider *oauth2TokenExchange) Token(ctx context.Context) (string, error) return provider.receivedToken, nil } +func (provider *oauth2TokenExchange) String() string { + buffer := xstring.Buffer() + defer buffer.Free() + fmt.Fprintf( + buffer, + "OAuth2TokenExchange{Endpoint:%q,GrantType:%s,Resource:%s,Audience:%v,Scope:%v,RequestedTokenType:%s", + provider.tokenEndpoint, + provider.grantType, + provider.resource, + provider.audience, + provider.scope, + provider.requestedTokenType, + ) + if provider.subjectTokenSource != nil { + fmt.Fprintf(buffer, ",SubjectToken:%s", provider.subjectTokenSource) + } + if provider.actorTokenSource != nil { + fmt.Fprintf(buffer, ",ActorToken:%s", provider.actorTokenSource) + } + if provider.sourceInfo != "" { + fmt.Fprintf(buffer, ",From:%q", provider.sourceInfo) + } + buffer.WriteByte('}') + + return buffer.String() +} + type Token struct { Token string @@ -464,6 +580,19 @@ func (s *fixedTokenSource) Token() (Token, error) { return s.fixedToken, nil } +func (s *fixedTokenSource) String() string { + buffer := xstring.Buffer() + defer buffer.Free() + fmt.Fprintf( + buffer, + "FixedTokenSource{Token:%q,Type:%s}", + secret.Token(s.fixedToken.Token), + s.fixedToken.TokenType, + ) + + return buffer.String() +} + func NewFixedTokenSource(token, tokenType string) *fixedTokenSource { return &fixedTokenSource{ fixedToken: Token{ @@ -474,14 +603,16 @@ func NewFixedTokenSource(token, tokenType string) *fixedTokenSource { } type JWTTokenSourceOption interface { - ApplyJWTTokenSourceOption(s *jwtTokenSource) + ApplyJWTTokenSourceOption(s *jwtTokenSource) error } // Issuer type issuerOption string -func (issuer issuerOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) { +func (issuer issuerOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { s.issuer = string(issuer) + + return nil } func WithIssuer(issuer string) issuerOption { @@ -491,8 +622,10 @@ func WithIssuer(issuer string) issuerOption { // Subject type subjectOption string -func (subject subjectOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) { +func (subject subjectOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { s.subject = string(subject) + + return nil } func WithSubject(subject string) subjectOption { @@ -500,15 +633,19 @@ func WithSubject(subject string) subjectOption { } // Audience -func (audience audienceOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) { +func (audience audienceOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { s.audience = audience + + return nil } // ID type idOption string -func (id idOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) { +func (id idOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { s.id = string(id) + + return nil } func WithID(id string) idOption { @@ -518,8 +655,10 @@ func WithID(id string) idOption { // TokenTTL type tokenTTLOption time.Duration -func (ttl tokenTTLOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) { +func (ttl tokenTTLOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { s.tokenTTL = time.Duration(ttl) + + return nil } func WithTokenTTL(ttl time.Duration) tokenTTLOption { @@ -531,8 +670,10 @@ type signingMethodOption struct { method jwt.SigningMethod } -func (method *signingMethodOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) { +func (method *signingMethodOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { s.signingMethod = method.method + + return nil } func WithSigningMethod(method jwt.SigningMethod) *signingMethodOption { @@ -542,8 +683,10 @@ func WithSigningMethod(method jwt.SigningMethod) *signingMethodOption { // KeyID type keyIDOption string -func (id keyIDOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) { +func (id keyIDOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { s.keyID = string(id) + + return nil } func WithKeyID(id string) keyIDOption { @@ -555,22 +698,74 @@ type privateKeyOption struct { key interface{} } -func (key *privateKeyOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) { +func (key *privateKeyOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { s.privateKey = key.key + + return nil } func WithPrivateKey(key interface{}) *privateKeyOption { return &privateKeyOption{key} } +// PrivateKey +type rsaPrivateKeyPemContentOption struct { + keyContent []byte +} + +func (key *rsaPrivateKeyPemContentOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { + privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(key.keyContent) + if err != nil { + return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotparseRSAPrivateKey, err)) + } + s.privateKey = privateKey + + return nil +} + +func WithRSAPrivateKeyPEMContent(key []byte) *rsaPrivateKeyPemContentOption { + return &rsaPrivateKeyPemContentOption{key} +} + +// PrivateKey +type rsaPrivateKeyPemFileOption struct { + path string +} + +func (key *rsaPrivateKeyPemFileOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error { + if len(key.path) > 0 && key.path[0] == '~' { + usr, err := user.Current() + if err != nil { + return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotParseHomeDir, err)) + } + key.path = filepath.Join(usr.HomeDir, key.path[1:]) + } + bytes, err := os.ReadFile(key.path) + if err != nil { + return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotReadPrivateKeyFile, err)) + } + + o := rsaPrivateKeyPemContentOption{bytes} + + return o.ApplyJWTTokenSourceOption(s) +} + +func WithRSAPrivateKeyPEMFile(path string) *rsaPrivateKeyPemFileOption { + return &rsaPrivateKeyPemFileOption{path} +} + func NewJWTTokenSource(opts ...JWTTokenSourceOption) (*jwtTokenSource, error) { s := &jwtTokenSource{ tokenTTL: defaultJWTTokenTTL, } + var err error for _, opt := range opts { if opt != nil { - opt.ApplyJWTTokenSourceOption(s) + err = opt.ApplyJWTTokenSourceOption(s) + if err != nil { + return nil, xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotApplyJWTOption, err)) + } } } @@ -631,3 +826,21 @@ func (s *jwtTokenSource) Token() (Token, error) { return token, nil } + +func (s *jwtTokenSource) String() string { + buffer := xstring.Buffer() + defer buffer.Free() + fmt.Fprintf( + buffer, + "JWTTokenSource{Method:%s,KeyID:%s,Issuer:%q,Subject:%q,Audience:%v,ID:%s,TokenTTL:%s}", + s.signingMethod.Alg(), + s.keyID, + s.issuer, + s.subject, + s.audience, + s.id, + s.tokenTTL, + ) + + return buffer.String() +} diff --git a/internal/credentials/oauth2_test.go b/internal/credentials/oauth2_test.go index 2c5ff753b..3ce391961 100644 --- a/internal/credentials/oauth2_test.go +++ b/internal/credentials/oauth2_test.go @@ -8,6 +8,9 @@ import ( "net" "net/http" "net/url" + "os" + "os/user" + "path/filepath" "reflect" "strconv" "testing" @@ -222,7 +225,7 @@ func TestOauth2TokenUpdate(t *testing.T) { WithTokenEndpoint("http://localhost:14322/exchange"), WithAudience("test_audience"), WithScope("test_scope1", "test_scope2"), - WithSubjectToken(NewFixedTokenSource("test_source_token", "urn:ietf:params:oauth:token-type:test_jwt")), + WithFixedSubjectToken("test_source_token", "urn:ietf:params:oauth:token-type:test_jwt"), ) require.NoError(t, err) @@ -276,7 +279,7 @@ func TestOauth2TokenUpdate(t *testing.T) { func TestWrongParameters(t *testing.T) { _, err := NewOauth2TokenExchangeCredentials( // No endpoint - WithSubjectToken(NewFixedTokenSource("test_source_token", "urn:ietf:params:oauth:token-type:test_jwt")), + WithFixedActorToken("test_source_token", "urn:ietf:params:oauth:token-type:test_jwt"), WithRequestedTokenType("access_token"), ) require.ErrorIs(t, err, errEmptyTokenEndpointError) @@ -291,15 +294,35 @@ func (s *errorTokenSource) Token() (Token, error) { } func TestErrorInSourceToken(t *testing.T) { + // Create + _, err := NewOauth2TokenExchangeCredentials( + WithTokenEndpoint("http:trololo"), + WithJWTSubjectToken( + WithRSAPrivateKeyPEMContent([]byte("invalid")), + WithKeyID("key_id"), + WithSigningMethod(jwt.SigningMethodRS256), + WithIssuer("test_issuer"), + WithAudience("test_audience"), + ), + ) + require.ErrorIs(t, err, errCouldNotCreateTokenSource) + + // Use client, err := NewOauth2TokenExchangeCredentials( WithTokenEndpoint("http:trololo"), WithGrantType("grant_type"), WithRequestTimeout(time.Second), WithResource("res"), + WithFixedSubjectToken("t", "tt"), WithActorToken(&errorTokenSource{}), + WithSourceInfo("TestErrorInSourceToken"), ) require.NoError(t, err) + // Check that token prints well + formatted := fmt.Sprint(client) + require.Equal(t, `OAuth2TokenExchange{Endpoint:"http:trololo",GrantType:grant_type,Resource:res,Audience:[],Scope:[],RequestedTokenType:urn:ietf:params:oauth:token-type:access_token,SubjectToken:FixedTokenSource{Token:"****(CRC-32c: 856A5AA8)",Type:tt},ActorToken:&{},From:"TestErrorInSourceToken"}`, formatted) //nolint:lll + token, err := client.Token(context.Background()) require.ErrorIs(t, err, errTokenSource) require.Equal(t, "", token) @@ -321,19 +344,34 @@ func TestErrorInSourceToken(t *testing.T) { func TestErrorInHTTPRequest(t *testing.T) { client, err := NewOauth2TokenExchangeCredentials( WithTokenEndpoint("http://invalid_host:42/exchange"), - WithSubjectToken(NewFixedTokenSource("test_source_token", "test_token_type")), + WithJWTSubjectToken( + WithRSAPrivateKeyPEMContent([]byte(testPrivateKeyContent)), + WithKeyID("key_id"), + WithSigningMethod(jwt.SigningMethodRS256), + WithIssuer("test_issuer"), + WithAudience("test_audience"), + ), + WithJWTActorToken( + WithRSAPrivateKeyPEMContent([]byte(testPrivateKeyContent)), + WithKeyID("key_id"), + WithSigningMethod(jwt.SigningMethodRS256), + WithIssuer("test_issuer"), + ), + WithScope("1", "2", "3"), + WithSourceInfo("TestErrorInHTTPRequest"), ) require.NoError(t, err) token, err := client.Token(context.Background()) require.ErrorIs(t, err, errCouldNotExchangeToken) require.Equal(t, "", token) + + // check format: + formatted := fmt.Sprint(client) + require.Equal(t, `OAuth2TokenExchange{Endpoint:"http://invalid_host:42/exchange",GrantType:urn:ietf:params:oauth:grant-type:token-exchange,Resource:,Audience:[],Scope:[1 2 3],RequestedTokenType:urn:ietf:params:oauth:token-type:access_token,SubjectToken:JWTTokenSource{Method:RS256,KeyID:key_id,Issuer:"test_issuer",Subject:"",Audience:[test_audience],ID:,TokenTTL:1h0m0s},ActorToken:JWTTokenSource{Method:RS256,KeyID:key_id,Issuer:"test_issuer",Subject:"",Audience:[],ID:,TokenTTL:1h0m0s},From:"TestErrorInHTTPRequest"}`, formatted) //nolint:lll } func TestJWTTokenSource(t *testing.T) { - privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(testPrivateKeyContent)) - require.NoError(t, err) - publicKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(testPublicKeyContent)) require.NoError(t, err) getPublicKey := func(*jwt.Token) (interface{}, error) { @@ -342,7 +380,7 @@ func TestJWTTokenSource(t *testing.T) { var src TokenSource src, err = NewJWTTokenSource( - WithPrivateKey(privateKey), + WithRSAPrivateKeyPEMContent([]byte(testPrivateKeyContent)), WithKeyID("key_id"), WithSigningMethod(jwt.SigningMethodRS256), WithIssuer("test_issuer"), @@ -390,3 +428,43 @@ func TestJWTTokenBadParams(t *testing.T) { ) require.ErrorIs(t, err, errNoSigningMethodError) } + +func TestJWTTokenSourceReadPrivateKeyFromFile(t *testing.T) { + const perm = 0o600 + usr, err := user.Current() + require.NoError(t, err) + fileName := strconv.Itoa(time.Now().Second()) + filePath := filepath.Join(usr.HomeDir, fileName) + beautifulFilePath := filepath.Join("~", fileName) + err = os.WriteFile( + filePath, + []byte(testPrivateKeyContent), + perm, + ) + require.NoError(t, err) + defer os.Remove(filePath) + + var src TokenSource + src, err = NewJWTTokenSource( + WithRSAPrivateKeyPEMFile(beautifulFilePath), + WithKeyID("key_id"), + WithSigningMethod(jwt.SigningMethodRS256), + WithIssuer("test_issuer"), + WithAudience("test_audience"), + ) + require.NoError(t, err) + + token, err := src.Token() + require.NoError(t, err) + + // parse token + publicKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(testPublicKeyContent)) + require.NoError(t, err) + getPublicKey := func(*jwt.Token) (interface{}, error) { + return publicKey, nil + } + + claims := jwt.RegisteredClaims{} + _, err = jwt.ParseWithClaims(token.Token, &claims, getPublicKey) + require.NoError(t, err) +} diff --git a/internal/credentials/source_info.go b/internal/credentials/source_info.go index 017a49a2e..53f7c5065 100644 --- a/internal/credentials/source_info.go +++ b/internal/credentials/source_info.go @@ -14,6 +14,12 @@ func (sourceInfo SourceInfoOption) ApplyAccessTokenCredentialsOption(h *AccessTo h.sourceInfo = string(sourceInfo) } +func (sourceInfo SourceInfoOption) ApplyOauth2CredentialsOption(h *oauth2TokenExchange) error { + h.sourceInfo = string(sourceInfo) + + return nil +} + // WithSourceInfo option append to credentials object the source info for reporting source info details on error case func WithSourceInfo(sourceInfo string) SourceInfoOption { return SourceInfoOption(sourceInfo) diff --git a/options.go b/options.go index c4ed15129..3cf9990e4 100644 --- a/options.go +++ b/options.go @@ -54,6 +54,16 @@ func WithAccessTokenCredentials(accessToken string) Option { ) } +func WithOauth2TokenExchangeCredentials( + opts ...credentials.Oauth2TokenExchangeCredentialsOption, +) Option { + opts = append(opts, credentials.WithSourceInfo("ydb.WithOauth2TokenExchangeCredentials(opts)")) + + return WithCreateCredentialsFunc(func(context.Context) (credentials.Credentials, error) { + return credentials.NewOauth2TokenExchangeCredentials(opts...) + }) +} + // WithApplicationName add provided application name to all api requests func WithApplicationName(applicationName string) Option { return func(ctx context.Context, c *Driver) error {