diff --git a/connector/connector.go b/connector/connector.go index d812390f0c..ca3fb08dac 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -103,3 +103,7 @@ type RefreshConnector interface { type TokenIdentityConnector interface { TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (Identity, error) } + +type PayloadExtender interface { + ExtendPayload(scopes []string, payload []byte, connectorData []byte) ([]byte, error) +} diff --git a/server/handlers.go b/server/handlers.go index 5954820caa..e2233354e4 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -721,14 +721,14 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe implicitOrHybrid = true var err error - accessToken, _, err = s.newAccessToken(r.Context(), authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID) + accessToken, _, err = s.newAccessToken(r.Context(), authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID, authReq.ConnectorData) if err != nil { s.logger.ErrorContext(r.Context(), "failed to create new access token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - idToken, idTokenExpiry, err = s.newIDToken(r.Context(), authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, accessToken, code.ID, authReq.ConnectorID) + idToken, idTokenExpiry, err = s.newIDToken(r.Context(), authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, accessToken, code.ID, authReq.ConnectorID, authReq.ConnectorData) if err != nil { s.logger.ErrorContext(r.Context(), "failed to create ID token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -936,14 +936,14 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s } func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) { - accessToken, _, err := s.newAccessToken(ctx, client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID) + accessToken, _, err := s.newAccessToken(ctx, client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID, authCode.ConnectorData) if err != nil { s.logger.ErrorContext(ctx, "failed to create new access token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return nil, err } - idToken, expiry, err := s.newIDToken(ctx, client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ID, authCode.ConnectorID) + idToken, expiry, err := s.newIDToken(ctx, client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ID, authCode.ConnectorID, authCode.ConnectorData) if err != nil { s.logger.ErrorContext(ctx, "failed to create ID token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -1201,14 +1201,14 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli Groups: identity.Groups, } - accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, scopes, nonce, connID) + accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, scopes, nonce, connID, identity.ConnectorData) if err != nil { s.logger.ErrorContext(r.Context(), "password grant failed to create new access token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - idToken, expiry, err := s.newIDToken(r.Context(), client.ID, claims, scopes, nonce, accessToken, "", connID) + idToken, expiry, err := s.newIDToken(r.Context(), client.ID, claims, scopes, nonce, accessToken, "", connID, identity.ConnectorData) if err != nil { s.logger.ErrorContext(r.Context(), "password grant failed to create new ID token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -1405,9 +1405,9 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli var expiry time.Time switch requestedTokenType { case tokenTypeID: - resp.AccessToken, expiry, err = s.newIDToken(r.Context(), client.ID, claims, scopes, "", "", "", connID) + resp.AccessToken, expiry, err = s.newIDToken(r.Context(), client.ID, claims, scopes, "", "", "", connID, identity.ConnectorData) case tokenTypeAccess: - resp.AccessToken, expiry, err = s.newAccessToken(r.Context(), client.ID, claims, scopes, "", connID) + resp.AccessToken, expiry, err = s.newAccessToken(r.Context(), client.ID, claims, scopes, "", connID, identity.ConnectorData) default: s.tokenErrHelper(w, errRequestNotSupported, "Invalid requested_token_type.", http.StatusBadRequest) return diff --git a/server/introspectionhandler_test.go b/server/introspectionhandler_test.go index 695bbad8e6..fee6d51b6a 100644 --- a/server/introspectionhandler_test.go +++ b/server/introspectionhandler_test.go @@ -265,7 +265,7 @@ func TestHandleIntrospect(t *testing.T) { Email: "jane.doe@example.com", EmailVerified: true, Groups: []string{"a", "b"}, - }, []string{"openid", "email", "profile", "groups"}, "foo", "", "", "test") + }, []string{"openid", "email", "profile", "groups"}, "foo", "", "", "test", nil) require.NoError(t, err) activeRefreshToken, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "bar"}) diff --git a/server/oauth2.go b/server/oauth2.go index cc81a8a52d..37fb0fc270 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -303,8 +303,8 @@ type federatedIDClaims struct { UserID string `json:"user_id,omitempty"` } -func (s *Server) newAccessToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, expiry time.Time, err error) { - return s.newIDToken(ctx, clientID, claims, scopes, nonce, storage.NewID(), "", connID) +func (s *Server) newAccessToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, connID string, connectorData []byte) (accessToken string, expiry time.Time, err error) { + return s.newIDToken(ctx, clientID, claims, scopes, nonce, storage.NewID(), "", connID, connectorData) } func getClientID(aud audience, azp string) (string, error) { @@ -350,13 +350,20 @@ func genSubject(userID string, connID string) (string, error) { return internal.Marshal(sub) } -func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string) (idToken string, expiry time.Time, err error) { + +func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string, connectorData []byte) (idToken string, expiry time.Time, err error) { keys, err := s.storage.GetKeys() if err != nil { s.logger.ErrorContext(ctx, "failed to get keys", "err", err) return "", expiry, err } + conn, err := s.getConnector(connID) + if err != nil { + s.logger.ErrorContext(ctx, "failed to get connector", "connector", connID, "err", err) + return "", expiry, err + } + signingKey := keys.SigningKey if signingKey == nil { return "", expiry, fmt.Errorf("no key to sign payload with") @@ -445,6 +452,17 @@ func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage return "", expiry, fmt.Errorf("could not serialize claims: %v", err) } + switch c := conn.Connector.(type) { + case connector.PayloadExtender: + extendedPayload, err := c.ExtendPayload(scopes, payload, connectorData) + if err != nil { + s.logger.WarnContext(ctx, "failed to enhance payload", "err", err) + break + } + payload = extendedPayload + default: + } + if idToken, err = signPayload(signingKey, signingAlg, payload); err != nil { return "", expiry, fmt.Errorf("failed to sign payload: %v", err) } diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 391d552251..a5fdc93db3 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -364,14 +364,15 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie Groups: ident.Groups, } - accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID) + accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID, rCtx.connectorData) if err != nil { s.logger.ErrorContext(r.Context(), "failed to create new access token", "err", err) s.refreshTokenErrHelper(w, newInternalServerError()) return } - idToken, expiry, err := s.newIDToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID) + + idToken, expiry, err := s.newIDToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID, rCtx.connectorData) if err != nil { s.logger.ErrorContext(r.Context(), "failed to create ID token", "err", err) s.refreshTokenErrHelper(w, newInternalServerError())