Skip to content

Commit

Permalink
Allow for multiple keysets configured
Browse files Browse the repository at this point in the history
  • Loading branch information
johnlanda committed Feb 2, 2024
1 parent cca903f commit 1c4c2b4
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 22 deletions.
26 changes: 21 additions & 5 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/hashicorp/cap/jwt"

Check failure on line 12 in backend.go

View workflow job for this annotation

GitHub Actions / run-tests / Run Tests

github.com/hashicorp/[email protected]: replacement directory /Users/john.landa/Dev/cap does not exist
"github.com/hashicorp/cap/oidc"

Check failure on line 13 in backend.go

View workflow job for this annotation

GitHub Actions / run-tests / Run Tests

github.com/hashicorp/[email protected]: replacement directory /Users/john.landa/Dev/cap does not exist
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"github.com/patrickmn/go-cache"
Expand Down Expand Up @@ -144,16 +145,31 @@ func (b *jwtAuthBackend) jwtValidator(config *jwtConfig) (*jwt.Validator, error)
}

var err error
var keySet jwt.KeySet
var keySets []jwt.KeySet

// Configure the key set for the validator
switch config.authType() {
case JWKS:
keySet, err = jwt.NewJSONWebKeySet(b.providerCtx, config.JWKSURL, config.JWKSCAPEM)
keySet, err2 := jwt.NewJSONWebKeySet(b.providerCtx, config.JWKSURL, config.JWKSCAPEM)
keySets = []jwt.KeySet{keySet}
err = err2
case MultiJWKS:
for k, v := range config.JWKSPairs {
// TODO (@johnlanda): read this in as map[string]string in config so we don't need type conversion.
keySet, err2 := jwt.NewJSONWebKeySet(b.providerCtx, k, v.(string))
if err2 != nil {
err = multierror.Append(err, fmt.Errorf("keyset configuration error: %w", err2))
}
keySets = append(keySets, keySet)
}
case StaticKeys:
keySet, err = jwt.NewStaticKeySet(config.ParsedJWTPubKeys)
keySet, err2 := jwt.NewStaticKeySet(config.ParsedJWTPubKeys)
keySets = []jwt.KeySet{keySet}
err = err2
case OIDCDiscovery:
keySet, err = jwt.NewOIDCDiscoveryKeySet(b.providerCtx, config.OIDCDiscoveryURL, config.OIDCDiscoveryCAPEM)
keySet, err2 := jwt.NewOIDCDiscoveryKeySet(b.providerCtx, config.OIDCDiscoveryURL, config.OIDCDiscoveryCAPEM)
keySets = []jwt.KeySet{keySet}
err = err2
default:
return nil, errors.New("unsupported config type")
}
Expand All @@ -162,7 +178,7 @@ func (b *jwtAuthBackend) jwtValidator(config *jwtConfig) (*jwt.Validator, error)
return nil, fmt.Errorf("keyset configuration error: %w", err)
}

validator, err := jwt.NewValidator(keySet)
validator, err := jwt.NewValidator(keySets...)
if err != nil {
return nil, fmt.Errorf("JWT validator configuration error: %w", err)
}
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,5 @@ require (
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

replace github.com/hashicorp/cap v0.4.1 => /Users/john.landa/Dev/cap
68 changes: 51 additions & 17 deletions path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net/http"
"strings"

Expand Down Expand Up @@ -72,6 +73,10 @@ func pathConfig(b *jwtAuthBackend) *framework.Path {
Type: framework.TypeString,
Description: "The CA certificate or chain of certificates, in PEM format, to use to validate connections to the JWKS URL. If not set, system certificates are used.",
},
"jwks_pairs": {
Type: framework.TypeMap,
Description: `Set of JWKS Url and CA certificate (or chain of certificates) pairs. CA certificates must be in PEM format. Cannot be used with "jwks_url" or "jwks_ca_pem".`,
},
"default_role": {
Type: framework.TypeLowerCaseString,
Description: "The default role to use if none is provided during login. If not set, a role is required during login.",
Expand Down Expand Up @@ -200,6 +205,7 @@ func (b *jwtAuthBackend) pathConfigRead(ctx context.Context, req *logical.Reques
"jwt_validation_pubkeys": config.JWTValidationPubKeys,
"jwt_supported_algs": config.JWTSupportedAlgs,
"jwks_url": config.JWKSURL,
"jwks_pairs": config.JWKSPairs,
"jwks_ca_pem": config.JWKSCAPEM,
"bound_issuer": config.BoundIssuer,
"provider_config": providerConfig,
Expand All @@ -219,6 +225,7 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
OIDCResponseMode: d.Get("oidc_response_mode").(string),
OIDCResponseTypes: d.Get("oidc_response_types").([]string),
JWKSURL: d.Get("jwks_url").(string),
JWKSPairs: d.Get("jwks_pairs").(map[string]interface{}),
JWKSCAPEM: d.Get("jwks_ca_pem").(string),
DefaultRole: d.Get("default_role").(string),
JWTValidationPubKeys: d.Get("jwt_validation_pubkeys").([]string),
Expand Down Expand Up @@ -255,10 +262,13 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
if config.JWKSURL != "" {
methodCount++
}
if len(config.JWKSPairs) > 0 {
methodCount++
}

switch {
case methodCount != 1:
return logical.ErrorResponse("exactly one of 'jwt_validation_pubkeys', 'jwks_url' or 'oidc_discovery_url' must be set"), nil
return logical.ErrorResponse("exactly one of 'jwt_validation_pubkeys', 'jwks_url', 'jwks_urls' or 'oidc_discovery_url' must be set"), nil

case config.OIDCClientID != "" && config.OIDCClientSecret == "",
config.OIDCClientID == "" && config.OIDCClientSecret != "":
Expand All @@ -279,30 +289,27 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
case config.OIDCClientID != "" && config.OIDCDiscoveryURL == "":
return logical.ErrorResponse("'oidc_discovery_url' must be set for OIDC"), nil

case config.JWKSURL != "":
keyset, err := jwt.NewJSONWebKeySet(ctx, config.JWKSURL, config.JWKSCAPEM)
if err != nil {
b.Logger().Error("error checking jwks_ca_pem", "error", err)
return logical.ErrorResponse("error checking jwks_ca_pem"), nil
}
case config.JWKSCAPEM != "" && len(config.JWKSPairs) > 0:
return logical.ErrorResponse("CA PEMs must be provided as part of the 'jwks_pairs' when using multiple JWKS URLs"), nil

// Try to verify a correctly formatted JWT. The signature will fail to match, but other
// errors with fetching the remote keyset should be reported.
testJWT := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Hf3E3iCHzqC5QIQ0nCqS1kw78IiQTRVzsLTuKoDIpdk"
_, err = keyset.VerifySignature(ctx, testJWT)
if err == nil {
err = errors.New("unexpected verification of JWT")
case config.JWKSURL != "":
if r := b.validateJWKSURL(ctx, config.JWKSURL, config.JWKSCAPEM); r != nil {
return r, nil
}

if !strings.Contains(err.Error(), "failed to verify id token signature") {
b.Logger().Error("error checking jwks URL", "error", err)
return logical.ErrorResponse("error checking jwks URL"), nil
case len(config.JWKSPairs) > 0:
for k, v := range config.JWKSPairs {
// TODO (@johnlanda): we should just read this in as a map[string]string if possible so we don't have to do
// the type conversion.
if r := b.validateJWKSURL(ctx, k, v.(string)); r != nil {
return r, nil
}
}

case len(config.JWTValidationPubKeys) != 0:
for _, v := range config.JWTValidationPubKeys {
if _, err := certutil.ParsePublicKeyPEM([]byte(v)); err != nil {
return logical.ErrorResponse(errwrap.Wrapf("error parsing public key: {{err}}", err).Error()), nil
return logical.ErrorResponse(fmt.Errorf("error parsing public key: %w", err).Error()), nil
}
}

Expand Down Expand Up @@ -403,6 +410,29 @@ func (b *jwtAuthBackend) createCAContext(ctx context.Context, caPEM string) (con
return caCtx, nil
}

func (b *jwtAuthBackend) validateJWKSURL(ctx context.Context, JWKSURL, JWKSCAPEM string) *logical.Response {
keyset, err := jwt.NewJSONWebKeySet(ctx, JWKSURL, JWKSCAPEM)
if err != nil {
b.Logger().Error("error checking jwks_ca_pem", "error", err)
return logical.ErrorResponse("error checking jwks_ca_pem")
}

// Try to verify a correctly formatted JWT. The signature will fail to match, but other
// errors with fetching the remote keyset should be reported.
testJWT := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Hf3E3iCHzqC5QIQ0nCqS1kw78IiQTRVzsLTuKoDIpdk"
_, err = keyset.VerifySignature(ctx, testJWT)
if err == nil {
err = errors.New("unexpected verification of JWT")
}

if !strings.Contains(err.Error(), "failed to verify id token signature") {
b.Logger().Error("error checking jwks URL", "url", JWKSURL, "error", err)
return logical.ErrorResponse("error checking jwks URL %s", JWKSURL)
}

return nil
}

type jwtConfig struct {
OIDCDiscoveryURL string `json:"oidc_discovery_url"`
OIDCDiscoveryCAPEM string `json:"oidc_discovery_ca_pem"`
Expand All @@ -412,6 +442,7 @@ type jwtConfig struct {
OIDCResponseTypes []string `json:"oidc_response_types"`
JWKSURL string `json:"jwks_url"`
JWKSCAPEM string `json:"jwks_ca_pem"`
JWKSPairs map[string]interface{} `json:"jwks_pairs"`
JWTValidationPubKeys []string `json:"jwt_validation_pubkeys"`
JWTSupportedAlgs []string `json:"jwt_supported_algs"`
BoundIssuer string `json:"bound_issuer"`
Expand All @@ -425,6 +456,7 @@ type jwtConfig struct {
const (
StaticKeys = iota
JWKS
MultiJWKS
OIDCDiscovery
OIDCFlow
unconfigured
Expand All @@ -437,6 +469,8 @@ func (c jwtConfig) authType() int {
return StaticKeys
case c.JWKSURL != "":
return JWKS
case len(c.JWKSPairs) > 0:
return MultiJWKS
case c.OIDCDiscoveryURL != "":
if c.OIDCClientID != "" && c.OIDCClientSecret != "" {
return OIDCFlow
Expand Down
13 changes: 13 additions & 0 deletions path_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func TestConfig_JWT_Read(t *testing.T) {
"jwt_supported_algs": []string{},
"jwks_url": "",
"jwks_ca_pem": "",
"jwks_pairs": map[string]interface{}{},
"bound_issuer": "http://vault.example.com/",
"provider_config": map[string]interface{}{},
"namespace_in_state": false,
Expand Down Expand Up @@ -139,6 +140,7 @@ func TestConfig_JWT_Write(t *testing.T) {
JWTValidationPubKeys: []string{testJWTPubKey},
JWTSupportedAlgs: []string{},
OIDCResponseTypes: []string{},
JWKSPairs: map[string]interface{}{},
BoundIssuer: "http://vault.example.com/",
ProviderConfig: map[string]interface{}{},
NamespaceInState: true,
Expand Down Expand Up @@ -168,6 +170,7 @@ func TestConfig_JWKS_Update(t *testing.T) {
data := map[string]interface{}{
"jwks_url": s.server.URL + "/certs",
"jwks_ca_pem": cert,
"jwks_pairs": map[string]interface{}{},
"oidc_discovery_url": "",
"oidc_discovery_ca_pem": "",
"oidc_client_id": "",
Expand Down Expand Up @@ -348,6 +351,7 @@ func TestConfig_OIDC_Write(t *testing.T) {
expected := &jwtConfig{
JWTValidationPubKeys: []string{},
JWTSupportedAlgs: []string{},
JWKSPairs: map[string]interface{}{},
OIDCResponseTypes: []string{},
OIDCDiscoveryURL: "https://team-vault.auth0.com/",
OIDCClientID: "abc",
Expand Down Expand Up @@ -439,6 +443,7 @@ func TestConfig_OIDC_Write_ProviderConfig(t *testing.T) {
expected := &jwtConfig{
JWTValidationPubKeys: []string{},
JWTSupportedAlgs: []string{},
JWKSPairs: map[string]interface{}{},
OIDCResponseTypes: []string{},
OIDCDiscoveryURL: "https://team-vault.auth0.com/",
ProviderConfig: map[string]interface{}{
Expand Down Expand Up @@ -499,6 +504,7 @@ func TestConfig_OIDC_Write_ProviderConfig(t *testing.T) {
expected := &jwtConfig{
JWTValidationPubKeys: []string{},
JWTSupportedAlgs: []string{},
JWKSPairs: map[string]interface{}{},
OIDCResponseTypes: []string{},
OIDCDiscoveryURL: "https://team-vault.auth0.com/",
ProviderConfig: map[string]interface{}{},
Expand Down Expand Up @@ -530,6 +536,7 @@ func TestConfig_OIDC_Create_Namespace(t *testing.T) {
OIDCDiscoveryURL: "https://team-vault.auth0.com/",
NamespaceInState: true,
OIDCResponseTypes: []string{},
JWKSPairs: map[string]interface{}{},
JWTSupportedAlgs: []string{},
JWTValidationPubKeys: []string{},
ProviderConfig: map[string]interface{}{},
Expand All @@ -544,6 +551,7 @@ func TestConfig_OIDC_Create_Namespace(t *testing.T) {
OIDCDiscoveryURL: "https://team-vault.auth0.com/",
NamespaceInState: true,
OIDCResponseTypes: []string{},
JWKSPairs: map[string]interface{}{},
JWTSupportedAlgs: []string{},
JWTValidationPubKeys: []string{},
ProviderConfig: map[string]interface{}{},
Expand All @@ -558,6 +566,7 @@ func TestConfig_OIDC_Create_Namespace(t *testing.T) {
OIDCDiscoveryURL: "https://team-vault.auth0.com/",
NamespaceInState: false,
OIDCResponseTypes: []string{},
JWKSPairs: map[string]interface{}{},
JWTSupportedAlgs: []string{},
JWTValidationPubKeys: []string{},
ProviderConfig: map[string]interface{}{},
Expand Down Expand Up @@ -606,6 +615,7 @@ func TestConfig_OIDC_Update_Namespace(t *testing.T) {
OIDCDiscoveryURL: "https://team-vault.auth0.com/",
NamespaceInState: true,
OIDCResponseTypes: []string{},
JWKSPairs: map[string]interface{}{},
JWTSupportedAlgs: []string{},
JWTValidationPubKeys: []string{},
ProviderConfig: map[string]interface{}{},
Expand All @@ -625,6 +635,7 @@ func TestConfig_OIDC_Update_Namespace(t *testing.T) {
NamespaceInState: false,
DefaultRole: "ui",
OIDCResponseTypes: []string{},
JWKSPairs: map[string]interface{}{},
JWTSupportedAlgs: []string{},
JWTValidationPubKeys: []string{},
ProviderConfig: map[string]interface{}{},
Expand All @@ -643,6 +654,7 @@ func TestConfig_OIDC_Update_Namespace(t *testing.T) {
OIDCDiscoveryURL: "https://team-vault.auth0.com/",
NamespaceInState: false,
OIDCResponseTypes: []string{},
JWKSPairs: map[string]interface{}{},
JWTSupportedAlgs: []string{},
JWTValidationPubKeys: []string{},
ProviderConfig: map[string]interface{}{},
Expand All @@ -662,6 +674,7 @@ func TestConfig_OIDC_Update_Namespace(t *testing.T) {
NamespaceInState: true,
DefaultRole: "ui",
OIDCResponseTypes: []string{},
JWKSPairs: map[string]interface{}{},
JWTSupportedAlgs: []string{},
JWTValidationPubKeys: []string{},
ProviderConfig: map[string]interface{}{},
Expand Down

0 comments on commit 1c4c2b4

Please sign in to comment.