Skip to content

Commit

Permalink
Allow for multiple keysets configured (#277)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnlanda authored Feb 8, 2024
1 parent cca903f commit 882a9b7
Show file tree
Hide file tree
Showing 7 changed files with 415 additions and 20 deletions.
21 changes: 20 additions & 1 deletion backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/hashicorp/cap/jwt"
"github.com/hashicorp/cap/oidc"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"github.com/patrickmn/go-cache"
Expand Down Expand Up @@ -145,15 +146,33 @@ func (b *jwtAuthBackend) jwtValidator(config *jwtConfig) (*jwt.Validator, error)

var err error
var keySet jwt.KeySet
var keySets []jwt.KeySet

// Configure the key set for the validator
switch config.authType() {
case JWKS:
keySet, err = jwt.NewJSONWebKeySet(b.providerCtx, config.JWKSURL, config.JWKSCAPEM)
keySets = []jwt.KeySet{keySet}
case MultiJWKS:
pairs, pairsErr := NewJWKSPairsConfig(config)
if pairsErr != nil {
return nil, pairsErr
}

for _, p := range pairs {
keySet, keySetErr := jwt.NewJSONWebKeySet(b.providerCtx, p.JWKSUrl, p.JWKSCAPEM)
if keySetErr != nil {
err = multierror.Append(err, keySetErr)
continue
}
keySets = append(keySets, keySet)
}
case StaticKeys:
keySet, err = jwt.NewStaticKeySet(config.ParsedJWTPubKeys)
keySets = []jwt.KeySet{keySet}
case OIDCDiscovery:
keySet, err = jwt.NewOIDCDiscoveryKeySet(b.providerCtx, config.OIDCDiscoveryURL, config.OIDCDiscoveryCAPEM)
keySets = []jwt.KeySet{keySet}
default:
return nil, errors.New("unsupported config type")
}
Expand All @@ -162,7 +181,7 @@ func (b *jwtAuthBackend) jwtValidator(config *jwtConfig) (*jwt.Validator, error)
return nil, fmt.Errorf("keyset configuration error: %w", err)
}

validator, err := jwt.NewValidator(keySet)
validator, err := jwt.NewValidator(keySets...)
if err != nil {
return nil, fmt.Errorf("JWT validator configuration error: %w", err)
}
Expand Down
54 changes: 54 additions & 0 deletions backend_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package jwtauth

import (
"context"
"testing"

"github.com/hashicorp/cap/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_jwtAuthBackend_jwtValidator(t *testing.T) {
type args struct {
config *jwtConfig
}
tests := []struct {
name string
args args
want *jwt.Validator
expectedErr string
}{
{
name: "invalid ca pem",
args: args{
config: &jwtConfig{
JWKSPairs: []interface{}{
map[string]any{
"jwks_url": "https://www.foobar.com/something",
"jwks_ca_pem": "defg",
},
map[string]any{
"jwks_url": "https://www.barbaz.com/something",
"jwks_ca_pem": "",
},
},
},
},
expectedErr: "could not parse CA PEM value successfully",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b := &jwtAuthBackend{}
b.providerCtx = context.TODO()

got, err := b.jwtValidator(tt.args.config)
if tt.expectedErr != "" {
require.ErrorContains(t, err, tt.expectedErr)
return
}
assert.Equalf(t, tt.want, got, "jwtValidator(%v)", tt.args.config)
})
}
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ go 1.20

require (
github.com/go-test/deep v1.1.0
github.com/hashicorp/cap v0.4.1
github.com/hashicorp/cap v0.5.0
github.com/hashicorp/errwrap v1.1.0
github.com/hashicorp/go-cleanhttp v0.5.2
github.com/hashicorp/go-hclog v1.6.2
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2
github.com/hashicorp/go-sockaddr v1.0.6
Expand Down Expand Up @@ -54,7 +55,6 @@ require (
github.com/hashicorp/go-immutable-radix v1.3.1 // indirect
github.com/hashicorp/go-kms-wrapping/entropy/v2 v2.0.0 // indirect
github.com/hashicorp/go-kms-wrapping/v2 v2.0.8 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/hashicorp/go-plugin v1.5.2 // indirect
github.com/hashicorp/go-retryablehttp v0.7.1 // indirect
github.com/hashicorp/go-rootcerts v1.0.2 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfF
github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56etFpas=
github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qKpsEkdD5+I6QGU=
github.com/hashicorp/cap v0.4.1 h1:LVYrTLbPV8W6DPwIm/zC/fbc4UTpCQ7nJhCAPshLuG4=
github.com/hashicorp/cap v0.4.1/go.mod h1:oOoohCNd2JAgfvLz2NpFJTZiZ6CqH9dW8dZ2js52lGA=
github.com/hashicorp/cap v0.5.0 h1:YIlAYxdXXtx2IL1JDvP2OyEl55Ooi0yl573kSB9Orlw=
github.com/hashicorp/cap v0.5.0/go.mod h1:IAy00Er+ZFpMo+5x6B4bkO2HgpzgrkfsuDWMmHAuKUE=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
Expand Down
42 changes: 42 additions & 0 deletions jwks_pairs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package jwtauth

import (
"fmt"

"github.com/mitchellh/mapstructure"
)

type JWKSPair struct {
JWKSUrl string `mapstructure:"jwks_url"`
JWKSCAPEM string `mapstructure:"jwks_ca_pem"`
}

func NewJWKSPairsConfig(jc *jwtConfig) ([]*JWKSPair, error) {
if len(jc.JWKSPairs) <= 0 {
return nil, nil
}

pairs := make([]*JWKSPair, 0, len(jc.JWKSPairs))
for i := 0; i < len(jc.JWKSPairs); i++ {
pairsMap, ok := jc.JWKSPairs[i].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("jwks_pairs must be provided as a list of json objects with the fields jwks_url and jwks_ca_pem")
}
jp, err := Initialize(pairsMap)
if err != nil {
return nil, err
}
pairs = append(pairs, jp)
}

return pairs, nil
}

func Initialize(jp map[string]interface{}) (*JWKSPair, error) {
var newJp JWKSPair
if err := mapstructure.Decode(jp, &newJp); err != nil {
return nil, err
}

return &newJp, nil
}
70 changes: 55 additions & 15 deletions path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net/http"
"strings"

Expand Down Expand Up @@ -72,6 +73,10 @@ func pathConfig(b *jwtAuthBackend) *framework.Path {
Type: framework.TypeString,
Description: "The CA certificate or chain of certificates, in PEM format, to use to validate connections to the JWKS URL. If not set, system certificates are used.",
},
"jwks_pairs": {
Type: framework.TypeSlice,
Description: `Set of JWKS Url and CA certificate (or chain of certificates) pairs. CA certificates must be in PEM format. Cannot be used with "jwks_url" or "jwks_ca_pem".`,
},
"default_role": {
Type: framework.TypeLowerCaseString,
Description: "The default role to use if none is provided during login. If not set, a role is required during login.",
Expand Down Expand Up @@ -200,6 +205,7 @@ func (b *jwtAuthBackend) pathConfigRead(ctx context.Context, req *logical.Reques
"jwt_validation_pubkeys": config.JWTValidationPubKeys,
"jwt_supported_algs": config.JWTSupportedAlgs,
"jwks_url": config.JWKSURL,
"jwks_pairs": config.JWKSPairs,
"jwks_ca_pem": config.JWKSCAPEM,
"bound_issuer": config.BoundIssuer,
"provider_config": providerConfig,
Expand All @@ -219,6 +225,7 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
OIDCResponseMode: d.Get("oidc_response_mode").(string),
OIDCResponseTypes: d.Get("oidc_response_types").([]string),
JWKSURL: d.Get("jwks_url").(string),
JWKSPairs: d.Get("jwks_pairs").([]interface{}),
JWKSCAPEM: d.Get("jwks_ca_pem").(string),
DefaultRole: d.Get("default_role").(string),
JWTValidationPubKeys: d.Get("jwt_validation_pubkeys").([]string),
Expand Down Expand Up @@ -255,10 +262,14 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
if config.JWKSURL != "" {
methodCount++
}
if len(config.JWKSPairs) > 0 {
methodCount++
}

var jwksPairs []*JWKSPair
switch {
case methodCount != 1:
return logical.ErrorResponse("exactly one of 'jwt_validation_pubkeys', 'jwks_url' or 'oidc_discovery_url' must be set"), nil
return logical.ErrorResponse("exactly one of 'jwt_validation_pubkeys', 'jwks_url', 'jwks_pairs' or 'oidc_discovery_url' must be set"), nil

case config.OIDCClientID != "" && config.OIDCClientSecret == "",
config.OIDCClientID == "" && config.OIDCClientSecret != "":
Expand All @@ -279,30 +290,32 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
case config.OIDCClientID != "" && config.OIDCDiscoveryURL == "":
return logical.ErrorResponse("'oidc_discovery_url' must be set for OIDC"), nil

case config.JWKSCAPEM != "" && len(config.JWKSPairs) > 0:
return logical.ErrorResponse("CA PEMs must be provided as part of the 'jwks_pairs' when using multiple JWKS URLs"), nil

case len(config.JWKSPairs) > 0 && config.BoundIssuer != "":
return logical.ErrorResponse("Bound issuer is not supported for use with 'jwks_pairs'"), nil

case config.JWKSURL != "":
keyset, err := jwt.NewJSONWebKeySet(ctx, config.JWKSURL, config.JWKSCAPEM)
if err != nil {
b.Logger().Error("error checking jwks_ca_pem", "error", err)
return logical.ErrorResponse("error checking jwks_ca_pem"), nil
if r := b.validateJWKSURL(ctx, config.JWKSURL, config.JWKSCAPEM); r != nil {
return r, nil
}

// Try to verify a correctly formatted JWT. The signature will fail to match, but other
// errors with fetching the remote keyset should be reported.
testJWT := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Hf3E3iCHzqC5QIQ0nCqS1kw78IiQTRVzsLTuKoDIpdk"
_, err = keyset.VerifySignature(ctx, testJWT)
if err == nil {
err = errors.New("unexpected verification of JWT")
case len(config.JWKSPairs) > 0:
if jwksPairs, err = NewJWKSPairsConfig(config); err != nil {
return logical.ErrorResponse("invalid jwks_pairs: %s", err), nil
}

if !strings.Contains(err.Error(), "failed to verify id token signature") {
b.Logger().Error("error checking jwks URL", "error", err)
return logical.ErrorResponse("error checking jwks URL"), nil
for _, p := range jwksPairs {
if r := b.validateJWKSURL(ctx, p.JWKSUrl, p.JWKSCAPEM); r != nil {
return r, nil
}
}

case len(config.JWTValidationPubKeys) != 0:
for _, v := range config.JWTValidationPubKeys {
if _, err := certutil.ParsePublicKeyPEM([]byte(v)); err != nil {
return logical.ErrorResponse(errwrap.Wrapf("error parsing public key: {{err}}", err).Error()), nil
return logical.ErrorResponse(fmt.Errorf("error parsing public key: %w", err).Error()), nil
}
}

Expand Down Expand Up @@ -403,6 +416,29 @@ func (b *jwtAuthBackend) createCAContext(ctx context.Context, caPEM string) (con
return caCtx, nil
}

func (b *jwtAuthBackend) validateJWKSURL(ctx context.Context, JWKSURL, JWKSCAPEM string) *logical.Response {
keyset, err := jwt.NewJSONWebKeySet(ctx, JWKSURL, JWKSCAPEM)
if err != nil {
b.Logger().Error("error checking jwks_ca_pem", "error", err)
return logical.ErrorResponse("error checking jwks_ca_pem")
}

// Try to verify a correctly formatted JWT. The signature will fail to match, but other
// errors with fetching the remote keyset should be reported.
testJWT := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Hf3E3iCHzqC5QIQ0nCqS1kw78IiQTRVzsLTuKoDIpdk"
_, err = keyset.VerifySignature(ctx, testJWT)
if err == nil {
err = errors.New("unexpected verification of JWT")
}

if !strings.Contains(err.Error(), "failed to verify id token signature") {
b.Logger().Error("error checking jwks URL", "url", JWKSURL, "error", err)
return logical.ErrorResponse("error checking jwks URL %s", JWKSURL)
}

return nil
}

type jwtConfig struct {
OIDCDiscoveryURL string `json:"oidc_discovery_url"`
OIDCDiscoveryCAPEM string `json:"oidc_discovery_ca_pem"`
Expand All @@ -412,6 +448,7 @@ type jwtConfig struct {
OIDCResponseTypes []string `json:"oidc_response_types"`
JWKSURL string `json:"jwks_url"`
JWKSCAPEM string `json:"jwks_ca_pem"`
JWKSPairs []interface{} `json:"jwks_pairs"`
JWTValidationPubKeys []string `json:"jwt_validation_pubkeys"`
JWTSupportedAlgs []string `json:"jwt_supported_algs"`
BoundIssuer string `json:"bound_issuer"`
Expand All @@ -425,6 +462,7 @@ type jwtConfig struct {
const (
StaticKeys = iota
JWKS
MultiJWKS
OIDCDiscovery
OIDCFlow
unconfigured
Expand All @@ -437,6 +475,8 @@ func (c jwtConfig) authType() int {
return StaticKeys
case c.JWKSURL != "":
return JWKS
case len(c.JWKSPairs) > 0:
return MultiJWKS
case c.OIDCDiscoveryURL != "":
if c.OIDCClientID != "" && c.OIDCClientSecret != "" {
return OIDCFlow
Expand Down
Loading

0 comments on commit 882a9b7

Please sign in to comment.