Skip to content

Commit

Permalink
implement PKCE for AuthorizationCode grant
Browse files Browse the repository at this point in the history
  • Loading branch information
thegrumpylion committed Nov 15, 2020
1 parent b46cf9f commit a32a6ed
Show file tree
Hide file tree
Showing 12 changed files with 375 additions and 79 deletions.
36 changes: 36 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package oauth2

import (
"crypto/sha256"
"encoding/base64"
)

// ResponseType the type of authorization request
type ResponseType string

Expand Down Expand Up @@ -34,3 +39,34 @@ func (gt GrantType) String() string {
}
return ""
}

// CodeChallengeMethod PCKE method
type CodeChallengeMethod string

const (
// CodeChallengePlain PCKE Method
CodeChallengePlain CodeChallengeMethod = "plain"
// CodeChallengeS256 PCKE Method
CodeChallengeS256 CodeChallengeMethod = "S256"
)

func (ccm CodeChallengeMethod) String() string {
if ccm == CodeChallengePlain ||
ccm == CodeChallengeS256 {
return string(ccm)
}
return ""
}

// Validate code challenge
func (ccm CodeChallengeMethod) Validate(cc, ver string) bool {
switch ccm {
case CodeChallengePlain:
return cc == ver
case CodeChallengeS256:
s256 := sha256.Sum256([]byte(ver))
return base64.URLEncoding.EncodeToString(s256[:]) == cc
default:
return false
}
}
21 changes: 21 additions & 0 deletions const_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package oauth2_test

import (
"testing"

"github.com/go-oauth2/oauth2/v4"
)

func TestValidatePlain(t *testing.T) {
cc := oauth2.CodeChallengePlain
if !cc.Validate("plaintest", "plaintest") {
t.Fatal("not valid")
}
}

func TestValidateS256(t *testing.T) {
cc := oauth2.CodeChallengeS256
if !cc.Validate("W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o=", "s256test") {
t.Fatal("not valid")
}
}
3 changes: 3 additions & 0 deletions errors/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,7 @@ var (
ErrInvalidRefreshToken = errors.New("invalid refresh token")
ErrExpiredAccessToken = errors.New("expired access token")
ErrExpiredRefreshToken = errors.New("expired refresh token")
ErrMissingCodeVerifier = errors.New("missing code verifier")
ErrMissingCodeChallenge = errors.New("missing code challenge")
ErrInvalidCodeChallenge = errors.New("invalid code challenge")
)
69 changes: 39 additions & 30 deletions errors/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,42 +34,51 @@ func (r *Response) SetHeader(key, value string) {

// https://tools.ietf.org/html/rfc6749#section-5.2
var (
ErrInvalidRequest = errors.New("invalid_request")
ErrUnauthorizedClient = errors.New("unauthorized_client")
ErrAccessDenied = errors.New("access_denied")
ErrUnsupportedResponseType = errors.New("unsupported_response_type")
ErrInvalidScope = errors.New("invalid_scope")
ErrServerError = errors.New("server_error")
ErrTemporarilyUnavailable = errors.New("temporarily_unavailable")
ErrInvalidClient = errors.New("invalid_client")
ErrInvalidGrant = errors.New("invalid_grant")
ErrUnsupportedGrantType = errors.New("unsupported_grant_type")
ErrInvalidRequest = errors.New("invalid_request")
ErrUnauthorizedClient = errors.New("unauthorized_client")
ErrAccessDenied = errors.New("access_denied")
ErrUnsupportedResponseType = errors.New("unsupported_response_type")
ErrInvalidScope = errors.New("invalid_scope")
ErrServerError = errors.New("server_error")
ErrTemporarilyUnavailable = errors.New("temporarily_unavailable")
ErrInvalidClient = errors.New("invalid_client")
ErrInvalidGrant = errors.New("invalid_grant")
ErrUnsupportedGrantType = errors.New("unsupported_grant_type")
ErrCodeChallengeRquired = errors.New("invalid_request")
ErrUnsupportedCodeChallengeMethod = errors.New("invalid_request")
ErrInvalidCodeChallengeLen = errors.New("invalid_request")
)

// Descriptions error description
var Descriptions = map[error]string{
ErrInvalidRequest: "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed",
ErrUnauthorizedClient: "The client is not authorized to request an authorization code using this method",
ErrAccessDenied: "The resource owner or authorization server denied the request",
ErrUnsupportedResponseType: "The authorization server does not support obtaining an authorization code using this method",
ErrInvalidScope: "The requested scope is invalid, unknown, or malformed",
ErrServerError: "The authorization server encountered an unexpected condition that prevented it from fulfilling the request",
ErrTemporarilyUnavailable: "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server",
ErrInvalidClient: "Client authentication failed",
ErrInvalidGrant: "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client",
ErrUnsupportedGrantType: "The authorization grant type is not supported by the authorization server",
ErrInvalidRequest: "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed",
ErrUnauthorizedClient: "The client is not authorized to request an authorization code using this method",
ErrAccessDenied: "The resource owner or authorization server denied the request",
ErrUnsupportedResponseType: "The authorization server does not support obtaining an authorization code using this method",
ErrInvalidScope: "The requested scope is invalid, unknown, or malformed",
ErrServerError: "The authorization server encountered an unexpected condition that prevented it from fulfilling the request",
ErrTemporarilyUnavailable: "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server",
ErrInvalidClient: "Client authentication failed",
ErrInvalidGrant: "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client",
ErrUnsupportedGrantType: "The authorization grant type is not supported by the authorization server",
ErrCodeChallengeRquired: "PKCE is required. code_challenge is missing",
ErrUnsupportedCodeChallengeMethod: "Selected code_challenge_method not supported",
ErrInvalidCodeChallengeLen: "Code challenge length must be between 43 and 128 charachters long",
}

// StatusCodes response error HTTP status code
var StatusCodes = map[error]int{
ErrInvalidRequest: 400,
ErrUnauthorizedClient: 401,
ErrAccessDenied: 403,
ErrUnsupportedResponseType: 401,
ErrInvalidScope: 400,
ErrServerError: 500,
ErrTemporarilyUnavailable: 503,
ErrInvalidClient: 401,
ErrInvalidGrant: 401,
ErrUnsupportedGrantType: 401,
ErrInvalidRequest: 400,
ErrUnauthorizedClient: 401,
ErrAccessDenied: 403,
ErrUnsupportedResponseType: 401,
ErrInvalidScope: 400,
ErrServerError: 500,
ErrTemporarilyUnavailable: 503,
ErrInvalidClient: 401,
ErrInvalidGrant: 401,
ErrUnsupportedGrantType: 401,
ErrCodeChallengeRquired: 400,
ErrUnsupportedCodeChallengeMethod: 400,
ErrInvalidCodeChallengeLen: 400,
}
13 changes: 11 additions & 2 deletions example/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package main

import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -33,7 +35,9 @@ var (

func main() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
u := config.AuthCodeURL("xyz")
u := config.AuthCodeURL("xyz",
oauth2.SetAuthURLParam("code_challenge", genCodeChallengeS256("s256example")),
oauth2.SetAuthURLParam("code_challenge_method", "S256"))
http.Redirect(w, r, u, http.StatusFound)
})

Expand All @@ -49,7 +53,7 @@ func main() {
http.Error(w, "Code not found", http.StatusBadRequest)
return
}
token, err := config.Exchange(context.Background(), code)
token, err := config.Exchange(context.Background(), code, oauth2.SetAuthURLParam("code_verifier", "s256example"))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand Down Expand Up @@ -130,3 +134,8 @@ func main() {
log.Println("Client is running at 9094 port.Please open http://localhost:9094")
log.Fatal(http.ListenAndServe(":9094", nil))
}

func genCodeChallengeS256(s string) string {
s256 := sha256.Sum256([]byte(s))
return base64.URLEncoding.EncodeToString(s256[:])
}
21 changes: 12 additions & 9 deletions manage.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@ import (

// TokenGenerateRequest provide to generate the token request parameters
type TokenGenerateRequest struct {
ClientID string
ClientSecret string
UserID string
RedirectURI string
Scope string
Code string
Refresh string
AccessTokenExp time.Duration
Request *http.Request
ClientID string
ClientSecret string
UserID string
RedirectURI string
Scope string
Code string
CodeChallenge string
CodeChallengeMethod CodeChallengeMethod
Refresh string
CodeVerifier string
AccessTokenExp time.Duration
Request *http.Request
}

// Manager authorization management interface
Expand Down
29 changes: 29 additions & 0 deletions manage/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType,
if exp := tgr.AccessTokenExp; exp > 0 {
ti.SetAccessExpiresIn(exp)
}
if tgr.CodeChallenge != "" {
ti.SetCodeChallenge(tgr.CodeChallenge)
ti.SetCodeChallengeMethod(tgr.CodeChallengeMethod)
}

tv, err := m.authorizeGenerate.Token(ctx, td)
if err != nil {
Expand Down Expand Up @@ -251,6 +255,28 @@ func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.To
return ti, nil
}

func (m *Manager) validateCodeChallenge(ti oauth2.TokenInfo, ver string) error {
cc := ti.GetCodeChallenge()
// early return
if cc == "" && ver == "" {
return nil
}
if cc == "" {
return errors.ErrMissingCodeVerifier
}
if ver == "" {
return errors.New("missing code verifier")
}
ccm := ti.GetCodeChallengeMethod()
if ccm.String() == "" {
ccm = oauth2.CodeChallengePlain
}
if !ccm.Validate(cc, ver) {
return errors.ErrInvalidCodeChallenge
}
return nil
}

// GenerateAccessToken generate the access token
func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
cli, err := m.GetClient(ctx, tgr.ClientID)
Expand All @@ -275,6 +301,9 @@ func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType,
if err != nil {
return nil, err
}
if err := m.validateCodeChallenge(ti, tgr.CodeVerifier); err != nil {
return nil, err
}
tgr.UserID = ti.GetUserID()
tgr.Scope = ti.GetScope()
if exp := ti.GetAccessExpiresIn(); exp > 0 {
Expand Down
4 changes: 4 additions & 0 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ type (
SetCodeCreateAt(time.Time)
GetCodeExpiresIn() time.Duration
SetCodeExpiresIn(time.Duration)
GetCodeChallenge() string
SetCodeChallenge(string)
GetCodeChallengeMethod() CodeChallengeMethod
SetCodeChallengeMethod(CodeChallengeMethod)

GetAccess() string
SetAccess(string)
Expand Down
48 changes: 35 additions & 13 deletions models/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,21 @@ func NewToken() *Token {

// Token token model
type Token struct {
ClientID string `bson:"ClientID"`
UserID string `bson:"UserID"`
RedirectURI string `bson:"RedirectURI"`
Scope string `bson:"Scope"`
Code string `bson:"Code"`
CodeCreateAt time.Time `bson:"CodeCreateAt"`
CodeExpiresIn time.Duration `bson:"CodeExpiresIn"`
Access string `bson:"Access"`
AccessCreateAt time.Time `bson:"AccessCreateAt"`
AccessExpiresIn time.Duration `bson:"AccessExpiresIn"`
Refresh string `bson:"Refresh"`
RefreshCreateAt time.Time `bson:"RefreshCreateAt"`
RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"`
ClientID string `bson:"ClientID"`
UserID string `bson:"UserID"`
RedirectURI string `bson:"RedirectURI"`
Scope string `bson:"Scope"`
Code string `bson:"Code"`
CodeChallenge string `bson:"CodeChallenge"`
CodeChallengeMethod string `bson:"CodeChallengeMethod"`
CodeCreateAt time.Time `bson:"CodeCreateAt"`
CodeExpiresIn time.Duration `bson:"CodeExpiresIn"`
Access string `bson:"Access"`
AccessCreateAt time.Time `bson:"AccessCreateAt"`
AccessExpiresIn time.Duration `bson:"AccessExpiresIn"`
Refresh string `bson:"Refresh"`
RefreshCreateAt time.Time `bson:"RefreshCreateAt"`
RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"`
}

// New create to token model instance
Expand Down Expand Up @@ -103,6 +105,26 @@ func (t *Token) SetCodeExpiresIn(exp time.Duration) {
t.CodeExpiresIn = exp
}

// GetCodeChallenge challenge code
func (t *Token) GetCodeChallenge() string {
return t.CodeChallenge
}

// SetCodeChallenge challenge code
func (t *Token) SetCodeChallenge(code string) {
t.CodeChallenge = code
}

// GetCodeChallengeMethod challenge method
func (t *Token) GetCodeChallengeMethod() oauth2.CodeChallengeMethod {
return oauth2.CodeChallengeMethod(t.CodeChallengeMethod)
}

// SetCodeChallengeMethod challenge method
func (t *Token) SetCodeChallengeMethod(method oauth2.CodeChallengeMethod) {
t.CodeChallengeMethod = string(method)
}

// GetAccess access Token
func (t *Token) GetAccess() string {
return t.Access
Expand Down
32 changes: 20 additions & 12 deletions server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ import (

// Config configuration parameters
type Config struct {
TokenType string // token type
AllowGetAccessRequest bool // to allow GET requests for the token
AllowedResponseTypes []oauth2.ResponseType // allow the authorization type
AllowedGrantTypes []oauth2.GrantType // allow the grant type
TokenType string // token type
AllowGetAccessRequest bool // to allow GET requests for the token
AllowedResponseTypes []oauth2.ResponseType // allow the authorization type
AllowedGrantTypes []oauth2.GrantType // allow the grant type
AllowedCodeChallengeMethods []oauth2.CodeChallengeMethod
ForcePKCE bool
}

// NewConfig create to configuration instance
Expand All @@ -26,17 +28,23 @@ func NewConfig() *Config {
oauth2.ClientCredentials,
oauth2.Refreshing,
},
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{
oauth2.CodeChallengePlain,
oauth2.CodeChallengeS256,
},
}
}

// AuthorizeRequest authorization request
type AuthorizeRequest struct {
ResponseType oauth2.ResponseType
ClientID string
Scope string
RedirectURI string
State string
UserID string
AccessTokenExp time.Duration
Request *http.Request
ResponseType oauth2.ResponseType
ClientID string
Scope string
RedirectURI string
State string
UserID string
CodeChallenge string
CodeChallengeMethod oauth2.CodeChallengeMethod
AccessTokenExp time.Duration
Request *http.Request
}
Loading

0 comments on commit a32a6ed

Please sign in to comment.