Skip to content

Commit

Permalink
Merge pull request #15676 from ahrtr/jwt_panic_3.5_20230410
Browse files Browse the repository at this point in the history
[3.5] etcdserver: verify field 'username' and 'revision' present when decoding a JWT token
  • Loading branch information
serathius authored Apr 11, 2023
2 parents 3cd07fe + 643e6e1 commit 9d2cda4
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 4 deletions.
17 changes: 13 additions & 4 deletions server/auth/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (t *tokenJWT) info(ctx context.Context, token string, rev uint64) (*AuthInf
// rev isn't used in JWT, it is only used in simple token
var (
username string
revision uint64
revision float64
)

parsed, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
Expand Down Expand Up @@ -74,10 +74,19 @@ func (t *tokenJWT) info(ctx context.Context, token string, rev uint64) (*AuthInf
return nil, false
}

username = claims["username"].(string)
revision = uint64(claims["revision"].(float64))
username, ok = claims["username"].(string)
if !ok {
t.lg.Warn("failed to obtain user claims from jwt token")
return nil, false
}

revision, ok = claims["revision"].(float64)
if !ok {
t.lg.Warn("failed to obtain revision claims from jwt token")
return nil, false
}

return &AuthInfo{Username: username, Revision: revision}, true
return &AuthInfo{Username: username, Revision: uint64(revision)}, true
}

func (t *tokenJWT) assign(ctx context.Context, username string, revision uint64) (string, error) {
Expand Down
75 changes: 75 additions & 0 deletions server/auth/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ import (
"context"
"fmt"
"testing"
"time"

"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -202,3 +205,75 @@ func TestJWTBad(t *testing.T) {
func testJWTOpts() string {
return fmt.Sprintf("%s,pub-key=%s,priv-key=%s,sign-method=RS256", tokenTypeJWT, jwtRSAPubKey, jwtRSAPrivKey)
}

func TestJWTTokenWithMissingFields(t *testing.T) {
testCases := []struct {
name string
username string // An empty string means not present
revision uint64 // 0 means not present
expectValid bool
}{
{
name: "valid token",
username: "hello",
revision: 100,
expectValid: true,
},
{
name: "no username",
username: "",
revision: 100,
expectValid: false,
},
{
name: "no revision",
username: "hello",
revision: 0,
expectValid: false,
},
}

for _, tc := range testCases {
tc := tc
optsMap := map[string]string{
"priv-key": jwtRSAPrivKey,
"sign-method": "RS256",
"ttl": "1h",
}

t.Run(tc.name, func(t *testing.T) {
// prepare claims
claims := jwt.MapClaims{
"exp": time.Now().Add(time.Hour).Unix(),
}
if tc.username != "" {
claims["username"] = tc.username
}
if tc.revision != 0 {
claims["revision"] = tc.revision
}

// generate a JWT token with the given claims
var opts jwtOptions
err := opts.ParseWithDefaults(optsMap)
require.NoError(t, err)
key, err := opts.Key()
require.NoError(t, err)

tk := jwt.NewWithClaims(opts.SignMethod, claims)
token, err := tk.SignedString(key)
require.NoError(t, err)

// verify the token
jwtProvider, err := newTokenProviderJWT(zap.NewNop(), optsMap)
require.NoError(t, err)
ai, ok := jwtProvider.info(context.TODO(), token, 123)

require.Equal(t, tc.expectValid, ok)
if ok {
require.Equal(t, tc.username, ai.Username)
require.Equal(t, tc.revision, ai.Revision)
}
})
}
}

0 comments on commit 9d2cda4

Please sign in to comment.