Skip to content

Commit

Permalink
idtoken: validate if token is expired (#492)
Browse files Browse the repository at this point in the history
Updates: #484
  • Loading branch information
codyoss authored May 20, 2020
1 parent 05ec534 commit b65b9f3
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 9 deletions.
11 changes: 10 additions & 1 deletion idtoken/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"math/big"
"net/http"
"strings"
"time"

htransport "google.golang.org/api/transport/http"
)
Expand All @@ -27,7 +28,11 @@ const (
googleSACertsURL string = "https://www.googleapis.com/oauth2/v3/certs"
)

var defaultValidator = &Validator{client: newCachingClient(http.DefaultClient)}
var (
defaultValidator = &Validator{client: newCachingClient(http.DefaultClient)}
// now aliases time.Now for testing.
now = time.Now
)

// Payload represents a decoded payload of an ID Token.
type Payload struct {
Expand Down Expand Up @@ -129,6 +134,10 @@ func (v *Validator) validate(ctx context.Context, idToken string, audience strin
return nil, fmt.Errorf("idtoken: audience provided does not match aud claim in the JWT")
}

if now().Unix() > payload.Expires {
return nil, fmt.Errorf("idtoken: token expired")
}

switch header.Algorithm {
case "RS256":
if err := v.validateRS256(ctx, header.KeyID, jwt.hashedContent(), sig); err != nil {
Expand Down
90 changes: 82 additions & 8 deletions idtoken/validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@ import (
"math/big"
"net/http"
"testing"
"time"

"google.golang.org/api/option"
)

const (
keyID = "1234"
testAudience = "test-audience"
keyID = "1234"
testAudience = "test-audience"
expiry int64 = 233431200
)

var (
beforeExp = func() time.Time { return time.Unix(expiry-1, 0) }
afterExp = func() time.Time { return time.Unix(expiry+1, 0) }
)

func TestValidateRS256(t *testing.T) {
Expand All @@ -34,11 +41,41 @@ func TestValidateRS256(t *testing.T) {
keyID string
n *big.Int
e int
nowFunc func() time.Time
wantErr bool
}{
{name: "works", keyID: keyID, n: pk.N, e: pk.E, wantErr: false},
{name: "no matching key", keyID: "5678", n: pk.N, e: pk.E, wantErr: true},
{name: "sig does not match", keyID: keyID, n: new(big.Int).SetBytes([]byte("42")), e: 42, wantErr: true},
{
name: "works",
keyID: keyID,
n: pk.N,
e: pk.E,
nowFunc: beforeExp,
wantErr: false,
},
{
name: "no matching key",
keyID: "5678",
n: pk.N,
e: pk.E,
nowFunc: beforeExp,
wantErr: true,
},
{
name: "sig does not match",
keyID: keyID,
n: new(big.Int).SetBytes([]byte("42")),
e: 42,
nowFunc: beforeExp,
wantErr: true,
},
{
name: "token expired",
keyID: keyID,
n: pk.N,
e: pk.E,
nowFunc: afterExp,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -64,6 +101,9 @@ func TestValidateRS256(t *testing.T) {
}
}),
}
oldNow := now
defer func() { now = oldNow }()
now = tt.nowFunc

v, err := NewValidator(context.Background(), option.WithHTTPClient(client))
if err != nil {
Expand All @@ -87,11 +127,41 @@ func TestValidateES256(t *testing.T) {
keyID string
x *big.Int
y *big.Int
nowFunc func() time.Time
wantErr bool
}{
{name: "works", keyID: keyID, x: pk.X, y: pk.Y, wantErr: false},
{name: "no matching key", keyID: "5678", x: pk.X, y: pk.Y, wantErr: true},
{name: "sig does not match", keyID: keyID, x: new(big.Int), y: new(big.Int), wantErr: true},
{
name: "works",
keyID: keyID,
x: pk.X,
y: pk.Y,
nowFunc: beforeExp,
wantErr: false,
},
{
name: "no matching key",
keyID: "5678",
x: pk.X,
y: pk.Y,
nowFunc: beforeExp,
wantErr: true,
},
{
name: "sig does not match",
keyID: keyID,
x: new(big.Int),
y: new(big.Int),
nowFunc: beforeExp,
wantErr: true,
},
{
name: "token expired",
keyID: keyID,
x: pk.X,
y: pk.Y,
nowFunc: afterExp,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -117,6 +187,9 @@ func TestValidateES256(t *testing.T) {
}
}),
}
oldNow := now
defer func() { now = oldNow }()
now = tt.nowFunc

v, err := NewValidator(context.Background(), option.WithHTTPClient(client))
if err != nil {
Expand Down Expand Up @@ -176,6 +249,7 @@ func commonToken(t *testing.T, alg string) *jwt {
payload := Payload{
Issuer: "example.com",
Audience: testAudience,
Expires: expiry,
}

hb, err := json.Marshal(&header)
Expand Down

0 comments on commit b65b9f3

Please sign in to comment.