Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for multiple keysets configured #277

Merged
merged 7 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/hashicorp/cap/jwt"
"github.com/hashicorp/cap/oidc"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"github.com/patrickmn/go-cache"
Expand Down Expand Up @@ -145,15 +146,29 @@ func (b *jwtAuthBackend) jwtValidator(config *jwtConfig) (*jwt.Validator, error)

var err error
var keySet jwt.KeySet
var keySets []jwt.KeySet

// Configure the key set for the validator
switch config.authType() {
case JWKS:
keySet, err = jwt.NewJSONWebKeySet(b.providerCtx, config.JWKSURL, config.JWKSCAPEM)
keySets = []jwt.KeySet{keySet}
case MultiJWKS:
pairs, err := NewJWKSPairsConfig(config)

for _, p := range pairs {
keySet, keySetErr := jwt.NewJSONWebKeySet(b.providerCtx, p.JWKSUrl, p.JWKSCAPEM)
if keySetErr != nil {
err = multierror.Append(err, keySetErr)
}
keySets = append(keySets, keySet)
}
case StaticKeys:
keySet, err = jwt.NewStaticKeySet(config.ParsedJWTPubKeys)
keySets = []jwt.KeySet{keySet}
case OIDCDiscovery:
keySet, err = jwt.NewOIDCDiscoveryKeySet(b.providerCtx, config.OIDCDiscoveryURL, config.OIDCDiscoveryCAPEM)
keySets = []jwt.KeySet{keySet}
default:
return nil, errors.New("unsupported config type")
}
Expand All @@ -162,7 +177,7 @@ func (b *jwtAuthBackend) jwtValidator(config *jwtConfig) (*jwt.Validator, error)
return nil, fmt.Errorf("keyset configuration error: %w", err)
}

validator, err := jwt.NewValidator(keySet)
validator, err := jwt.NewValidator(keySets...)
if err != nil {
return nil, fmt.Errorf("JWT validator configuration error: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ go 1.20

require (
github.com/go-test/deep v1.1.0
github.com/hashicorp/cap v0.4.1
github.com/hashicorp/cap v0.5.0
github.com/hashicorp/errwrap v1.1.0
github.com/hashicorp/go-cleanhttp v0.5.2
github.com/hashicorp/go-hclog v1.6.2
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2
github.com/hashicorp/go-sockaddr v1.0.6
Expand Down Expand Up @@ -54,7 +55,6 @@ require (
github.com/hashicorp/go-immutable-radix v1.3.1 // indirect
github.com/hashicorp/go-kms-wrapping/entropy/v2 v2.0.0 // indirect
github.com/hashicorp/go-kms-wrapping/v2 v2.0.8 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/hashicorp/go-plugin v1.5.2 // indirect
github.com/hashicorp/go-retryablehttp v0.7.1 // indirect
github.com/hashicorp/go-rootcerts v1.0.2 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfF
github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56etFpas=
github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qKpsEkdD5+I6QGU=
github.com/hashicorp/cap v0.4.1 h1:LVYrTLbPV8W6DPwIm/zC/fbc4UTpCQ7nJhCAPshLuG4=
github.com/hashicorp/cap v0.4.1/go.mod h1:oOoohCNd2JAgfvLz2NpFJTZiZ6CqH9dW8dZ2js52lGA=
github.com/hashicorp/cap v0.5.0 h1:YIlAYxdXXtx2IL1JDvP2OyEl55Ooi0yl573kSB9Orlw=
github.com/hashicorp/cap v0.5.0/go.mod h1:IAy00Er+ZFpMo+5x6B4bkO2HgpzgrkfsuDWMmHAuKUE=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
Expand Down
36 changes: 36 additions & 0 deletions jwks_pairs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package jwtauth

import (
"github.com/mitchellh/mapstructure"
)

type JWKSPair struct {
JWKSUrl string `mapstructure:"jwks_url"`
JWKSCAPEM string `mapstructure:"jwks_ca_pem"`
}

func NewJWKSPairsConfig(jc *jwtConfig) ([]*JWKSPair, error) {
if len(jc.JWKSPairs) <= 0 {
return nil, nil
}

pairs := make([]*JWKSPair, 0, len(jc.JWKSPairs))
for i := 0; i < len(jc.JWKSPairs); i++ {
jp, err := Initialize(jc.JWKSPairs[i].(map[string]interface{}))
if err != nil {
return nil, err
}
pairs = append(pairs, jp)
}

return pairs, nil
}

func Initialize(jp map[string]interface{}) (*JWKSPair, error) {
var newJp JWKSPair
if err := mapstructure.Decode(jp, &newJp); err != nil {
return nil, err
}

return &newJp, nil
}
67 changes: 52 additions & 15 deletions path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net/http"
"strings"

Expand Down Expand Up @@ -72,6 +73,10 @@ func pathConfig(b *jwtAuthBackend) *framework.Path {
Type: framework.TypeString,
Description: "The CA certificate or chain of certificates, in PEM format, to use to validate connections to the JWKS URL. If not set, system certificates are used.",
},
"jwks_pairs": {
Type: framework.TypeSlice,
Description: `Set of JWKS Url and CA certificate (or chain of certificates) pairs. CA certificates must be in PEM format. Cannot be used with "jwks_url" or "jwks_ca_pem".`,
},
"default_role": {
Type: framework.TypeLowerCaseString,
Description: "The default role to use if none is provided during login. If not set, a role is required during login.",
Expand Down Expand Up @@ -200,6 +205,7 @@ func (b *jwtAuthBackend) pathConfigRead(ctx context.Context, req *logical.Reques
"jwt_validation_pubkeys": config.JWTValidationPubKeys,
"jwt_supported_algs": config.JWTSupportedAlgs,
"jwks_url": config.JWKSURL,
"jwks_pairs": config.JWKSPairs,
"jwks_ca_pem": config.JWKSCAPEM,
"bound_issuer": config.BoundIssuer,
"provider_config": providerConfig,
Expand All @@ -219,6 +225,7 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
OIDCResponseMode: d.Get("oidc_response_mode").(string),
OIDCResponseTypes: d.Get("oidc_response_types").([]string),
JWKSURL: d.Get("jwks_url").(string),
JWKSPairs: d.Get("jwks_pairs").([]interface{}),
JWKSCAPEM: d.Get("jwks_ca_pem").(string),
DefaultRole: d.Get("default_role").(string),
JWTValidationPubKeys: d.Get("jwt_validation_pubkeys").([]string),
Expand Down Expand Up @@ -255,10 +262,14 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
if config.JWKSURL != "" {
methodCount++
}
if len(config.JWKSPairs) > 0 {
methodCount++
}

var jwksPairs []*JWKSPair
switch {
case methodCount != 1:
return logical.ErrorResponse("exactly one of 'jwt_validation_pubkeys', 'jwks_url' or 'oidc_discovery_url' must be set"), nil
return logical.ErrorResponse("exactly one of 'jwt_validation_pubkeys', 'jwks_url', 'jwks_pairs' or 'oidc_discovery_url' must be set"), nil

case config.OIDCClientID != "" && config.OIDCClientSecret == "",
config.OIDCClientID == "" && config.OIDCClientSecret != "":
Expand All @@ -279,30 +290,29 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
case config.OIDCClientID != "" && config.OIDCDiscoveryURL == "":
return logical.ErrorResponse("'oidc_discovery_url' must be set for OIDC"), nil

case config.JWKSCAPEM != "" && len(config.JWKSPairs) > 0:
return logical.ErrorResponse("CA PEMs must be provided as part of the 'jwks_pairs' when using multiple JWKS URLs"), nil

case config.JWKSURL != "":
keyset, err := jwt.NewJSONWebKeySet(ctx, config.JWKSURL, config.JWKSCAPEM)
if err != nil {
b.Logger().Error("error checking jwks_ca_pem", "error", err)
return logical.ErrorResponse("error checking jwks_ca_pem"), nil
if r := b.validateJWKSURL(ctx, config.JWKSURL, config.JWKSCAPEM); r != nil {
return r, nil
}

// Try to verify a correctly formatted JWT. The signature will fail to match, but other
// errors with fetching the remote keyset should be reported.
testJWT := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Hf3E3iCHzqC5QIQ0nCqS1kw78IiQTRVzsLTuKoDIpdk"
_, err = keyset.VerifySignature(ctx, testJWT)
if err == nil {
err = errors.New("unexpected verification of JWT")
case len(config.JWKSPairs) > 0:
if jwksPairs, err = NewJWKSPairsConfig(config); err != nil {
return logical.ErrorResponse("invalid jwks_pairs: %s", err), nil
}

if !strings.Contains(err.Error(), "failed to verify id token signature") {
b.Logger().Error("error checking jwks URL", "error", err)
return logical.ErrorResponse("error checking jwks URL"), nil
for _, p := range jwksPairs {
if r := b.validateJWKSURL(ctx, p.JWKSUrl, p.JWKSCAPEM); r != nil {
return r, nil
}
}

case len(config.JWTValidationPubKeys) != 0:
for _, v := range config.JWTValidationPubKeys {
if _, err := certutil.ParsePublicKeyPEM([]byte(v)); err != nil {
return logical.ErrorResponse(errwrap.Wrapf("error parsing public key: {{err}}", err).Error()), nil
return logical.ErrorResponse(fmt.Errorf("error parsing public key: %w", err).Error()), nil
}
}

Expand Down Expand Up @@ -403,6 +413,29 @@ func (b *jwtAuthBackend) createCAContext(ctx context.Context, caPEM string) (con
return caCtx, nil
}

func (b *jwtAuthBackend) validateJWKSURL(ctx context.Context, JWKSURL, JWKSCAPEM string) *logical.Response {
keyset, err := jwt.NewJSONWebKeySet(ctx, JWKSURL, JWKSCAPEM)
if err != nil {
b.Logger().Error("error checking jwks_ca_pem", "error", err)
return logical.ErrorResponse("error checking jwks_ca_pem")
}

// Try to verify a correctly formatted JWT. The signature will fail to match, but other
// errors with fetching the remote keyset should be reported.
testJWT := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Hf3E3iCHzqC5QIQ0nCqS1kw78IiQTRVzsLTuKoDIpdk"
_, err = keyset.VerifySignature(ctx, testJWT)
if err == nil {
err = errors.New("unexpected verification of JWT")
}

if !strings.Contains(err.Error(), "failed to verify id token signature") {
b.Logger().Error("error checking jwks URL", "url", JWKSURL, "error", err)
return logical.ErrorResponse("error checking jwks URL %s", JWKSURL)
}

return nil
}

type jwtConfig struct {
OIDCDiscoveryURL string `json:"oidc_discovery_url"`
OIDCDiscoveryCAPEM string `json:"oidc_discovery_ca_pem"`
Expand All @@ -412,6 +445,7 @@ type jwtConfig struct {
OIDCResponseTypes []string `json:"oidc_response_types"`
JWKSURL string `json:"jwks_url"`
JWKSCAPEM string `json:"jwks_ca_pem"`
JWKSPairs []interface{} `json:"jwks_pairs"`
JWTValidationPubKeys []string `json:"jwt_validation_pubkeys"`
JWTSupportedAlgs []string `json:"jwt_supported_algs"`
BoundIssuer string `json:"bound_issuer"`
Expand All @@ -425,6 +459,7 @@ type jwtConfig struct {
const (
StaticKeys = iota
JWKS
MultiJWKS
OIDCDiscovery
OIDCFlow
unconfigured
Expand All @@ -437,6 +472,8 @@ func (c jwtConfig) authType() int {
return StaticKeys
case c.JWKSURL != "":
return JWKS
case len(c.JWKSPairs) > 0:
return MultiJWKS
case c.OIDCDiscoveryURL != "":
if c.OIDCClientID != "" && c.OIDCClientSecret != "" {
return OIDCFlow
Expand Down
Loading