Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update jwt.go #761

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 120 additions & 98 deletions jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@

// Package jwt implements the OAuth 2.0 JSON Web Token flow, commonly
// known as "two-legged OAuth 2.0".
//
// See: https://tools.ietf.org/html/draft-ietf-oauth-jwt-bearer-12
package jwt

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
Expand All @@ -29,157 +28,180 @@ var (
defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"}
)

// Config is the configuration for using JWT to fetch tokens,
// commonly known as "two-legged OAuth 2.0".
// Config holds the configuration for using JWT to fetch tokens.
type Config struct {
// Email is the OAuth client identifier used when communicating with
// the configured OAuth provider.
Email string

// PrivateKey contains the contents of an RSA private key or the
// contents of a PEM file that contains a private key. The provided
// private key is used to sign JWT payloads.
// PEM containers with a passphrase are not supported.
// Use the following command to convert a PKCS 12 file into a PEM.
//
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
//
PrivateKey []byte

// PrivateKeyID contains an optional hint indicating which key is being
// used.
PrivateKeyID string

// Subject is the optional user to impersonate.
Subject string

// Scopes optionally specifies a list of requested permission scopes.
Scopes []string

// TokenURL is the endpoint required to complete the 2-legged JWT flow.
TokenURL string

// Expires optionally specifies how long the token is valid for.
Expires time.Duration

// Audience optionally specifies the intended audience of the
// request. If empty, the value of TokenURL is used as the
// intended audience.
Audience string

// PrivateClaims optionally specifies custom private claims in the JWT.
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
PrivateClaims map[string]interface{}

// UseIDToken optionally specifies whether ID token should be used instead
// of access token when the server returns both.
UseIDToken bool
Email string
PrivateKey []byte
PrivateKeyID string
Subject string
Scopes []string
TokenURL string
Expires time.Duration
Audience string
PrivateClaims map[string]interface{}
UseIDToken bool
}

// TokenSource returns a JWT TokenSource using the configuration
// in c and the HTTP client from the provided context.
// TokenSource returns a JWT TokenSource using the configuration in c.
func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource {
return oauth2.ReuseTokenSource(nil, jwtSource{ctx, c})
return oauth2.ReuseTokenSource(nil, jwtSource{ctx: ctx, conf: c})
}

// Client returns an HTTP client wrapping the context's
// HTTP transport and adding Authorization headers with tokens
// obtained from c.
//
// The returned client and its Transport should not be modified.
// Client returns an HTTP client that adds Authorization headers with tokens obtained from c.
func (c *Config) Client(ctx context.Context) *http.Client {
return oauth2.NewClient(ctx, c.TokenSource(ctx))
}

// jwtSource is a source that always does a signed JWT request for a token.
// It should typically be wrapped with a reuseTokenSource.
type jwtSource struct {
ctx context.Context
conf *Config
}

func (js jwtSource) Token() (*oauth2.Token, error) {
// Validate config
if err := js.validateConfig(); err != nil {
return nil, err
}

// Parse private key
pk, err := internal.ParseKey(js.conf.PrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %v", err)
}

// Generate JWT payload
claimSet, err := js.generateClaimSet()
if err != nil {
return nil, err
}
hc := oauth2.NewClient(js.ctx, nil)

h := *defaultHeader
h.KeyID = js.conf.PrivateKeyID
payload, err := jws.Encode(&h, claimSet, pk)
if err != nil {
return nil, fmt.Errorf("failed to encode JWT: %v", err)
}

// Request token
return js.requestToken(payload)
}

func (js jwtSource) validateConfig() error {
if js.conf.Email == "" {
return errors.New("email is required")
}
if len(js.conf.PrivateKey) == 0 {
return errors.New("private key is required")
}
if js.conf.TokenURL == "" {
return errors.New("token URL is required")
}
return nil
}

func (js jwtSource) generateClaimSet() (*jws.ClaimSet, error) {
claimSet := &jws.ClaimSet{
Iss: js.conf.Email,
Scope: strings.Join(js.conf.Scopes, " "),
Aud: js.conf.TokenURL,
PrivateClaims: js.conf.PrivateClaims,
}
if subject := js.conf.Subject; subject != "" {
claimSet.Sub = subject
// prn is the old name of sub. Keep setting it
// to be compatible with legacy OAuth 2.0 providers.
claimSet.Prn = subject

if js.conf.Subject != "" {
claimSet.Sub = js.conf.Subject
claimSet.Prn = js.conf.Subject
}
if t := js.conf.Expires; t > 0 {
claimSet.Exp = time.Now().Add(t).Unix()

if js.conf.Expires > 0 {
claimSet.Exp = time.Now().Add(js.conf.Expires).Unix()
}
if aud := js.conf.Audience; aud != "" {
claimSet.Aud = aud

if js.conf.Audience != "" {
claimSet.Aud = js.conf.Audience
}
h := *defaultHeader
h.KeyID = js.conf.PrivateKeyID
payload, err := jws.Encode(&h, claimSet, pk)
if err != nil {
return nil, err

return claimSet, nil
}

func (js jwtSource) requestToken(payload string) (*oauth2.Token, error) {
hc := oauth2.NewClient(js.ctx, nil)
v := url.Values{
"grant_type": {defaultGrantType},
"assertion": {payload},
}
v := url.Values{}
v.Set("grant_type", defaultGrantType)
v.Set("assertion", payload)

resp, err := hc.PostForm(js.conf.TokenURL, v)
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
return nil, fmt.Errorf("failed to fetch token: %v", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
if c := resp.StatusCode; c < 200 || c > 299 {

if resp.StatusCode < 200 || resp.StatusCode > 299 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
return nil, &oauth2.RetrieveError{
Response: resp,
Body: body,
}
}
// tokenRes is the JSON response body.

return js.parseTokenResponse(resp)
}

func (js jwtSource) parseTokenResponse(resp *http.Response) (*oauth2.Token, error) {
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("failed to read token response: %v", err)
}

var tokenRes struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
IDToken string `json:"id_token"`
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
ExpiresIn int64 `json:"expires_in"`
}
if err := json.Unmarshal(body, &tokenRes); err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
return nil, fmt.Errorf("failed to parse token response: %v", err)
}

token := &oauth2.Token{
AccessToken: tokenRes.AccessToken,
TokenType: tokenRes.TokenType,
Expiry: time.Now().Add(time.Duration(tokenRes.ExpiresIn) * time.Second),
}
raw := make(map[string]interface{})
json.Unmarshal(body, &raw) // no error checks for optional fields
token = token.WithExtra(raw)

if secs := tokenRes.ExpiresIn; secs > 0 {
token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
}
if v := tokenRes.IDToken; v != "" {
// decode returned id token to get expiry
claimSet, err := jws.Decode(v)
if err != nil {
return nil, fmt.Errorf("oauth2: error decoding JWT token: %v", err)
}
token.Expiry = time.Unix(claimSet.Exp, 0)
}
if js.conf.UseIDToken {
if tokenRes.IDToken == "" {
return nil, fmt.Errorf("oauth2: response doesn't have JWT token")
return nil, errors.New("response missing ID token")
}
token.AccessToken = tokenRes.IDToken
}

return token, nil
}

// Helper functions for better debugging
func debugLog(msg string) {
fmt.Println("DEBUG:", msg)
}

func infoLog(msg string) {
fmt.Println("INFO:", msg)
}

func warnLog(msg string) {
fmt.Println("WARNING:", msg)
}

func errorLog(msg string) {
fmt.Println("ERROR:", msg)
}

// Additional notes to ensure code clarity and maintainability:
// 1. Proper documentation should be added to all exported functions.
// 2. Ensure this code adheres to the latest security practices.
// 3. Add more test cases to cover edge scenarios.
// 4. Future improvements could include support for additional JWT algorithms.

// End of file