Skip to content

Commit

Permalink
oauth2: use client authn on device auth request
Browse files Browse the repository at this point in the history
According to https://datatracker.ietf.org/doc/html/rfc8628#section-3.1,
the device auth request must include client authentication.

Fixes #685
  • Loading branch information
nsklikas committed Nov 27, 2024
1 parent 22134a4 commit dcc463a
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 49 deletions.
60 changes: 25 additions & 35 deletions deviceauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
Expand Down Expand Up @@ -93,47 +90,40 @@ func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*Devic
return retrieveDeviceAuth(ctx, c, v)
}

func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
if c.Endpoint.DeviceAuthURL == "" {
return nil, errors.New("endpoint missing DeviceAuthURL")
// deviceAuthFromInternal maps an *internal.DeviceAuthResponse struct into
// a *DeviceAuthResponse struct.
func deviceAuthFromInternal(da *internal.DeviceAuthResponse) *DeviceAuthResponse {
if da == nil {
return nil
}

req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, err
return &DeviceAuthResponse{
DeviceCode: da.DeviceCode,
UserCode: da.UserCode,
VerificationURI: da.VerificationURI,
VerificationURIComplete: da.VerificationURIComplete,
Expiry: time.Now().UTC().Add(time.Second * time.Duration(da.Expiry)),
Interval: da.Interval,
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
}

t := time.Now()
r, err := internal.ContextClient(ctx).Do(req)
if err != nil {
return nil, err
// retrieveDeviceAuth takes a *Config and uses that to retrieve an *internal.DeviceAuthResponse.
// This response is then mapped from *internal.DeviceAuthResponse into an *oauth2.DeviceAuthResponse which is returned along
// with an error.
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
if c.Endpoint.DeviceAuthURL == "" {
return nil, errors.New("endpoint missing DeviceAuthURL")
}

body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
da, err := internal.RetrieveDeviceAuth(ctx, c.ClientID, c.ClientSecret, c.Endpoint.DeviceAuthURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
if err != nil {
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
}
if code := r.StatusCode; code < 200 || code > 299 {
return nil, &RetrieveError{
Response: r,
Body: body,
if rErr, ok := err.(*internal.RetrieveError); ok {
return nil, (*RetrieveError)(rErr)
}
return nil, err
}
dar := deviceAuthFromInternal(da)

da := &DeviceAuthResponse{}
err = json.Unmarshal(body, &da)
if err != nil {
return nil, fmt.Errorf("unmarshal %s", err)
}

if !da.Expiry.IsZero() {
// Make a small adjustment to account for time taken by the request
da.Expiry = da.Expiry.Add(-time.Since(t))
}

return da, nil
return dar, err
}

// DeviceAccessToken polls the server to exchange a device code for a token.
Expand Down
97 changes: 97 additions & 0 deletions internal/deviceauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package internal

import (
"context"
"encoding/json"
"fmt"
"io"
"net/url"
"time"
)

// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
//
// This type is a mirror of oauth2.DeviceAuthResponse, with the only difference
// being that in this struct `expires_in` isn't mapped to a timestamp. It solely
// exists to break an otherwise-circular dependency. Other internal packages should
// convert this DeviceAuthResponse into an oauth2.DeviceAuthResponse before use.
type DeviceAuthResponse struct {
// DeviceCode
DeviceCode string `json:"device_code"`
// UserCode is the code the user should enter at the verification uri
UserCode string `json:"user_code"`
// VerificationURI is where user should enter the user code
VerificationURI string `json:"verification_uri"`
// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
// Expiry is when the device code and user code expire
Expiry int64 `json:"expires_in,omitempty"`
// Interval is the duration in seconds that Poll should wait between requests
Interval int64 `json:"interval,omitempty"`
}

func RetrieveDeviceAuth(ctx context.Context, clientID, clientSecret, deviceAuthURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*DeviceAuthResponse, error) {
needsAuthStyleProbe := authStyle == AuthStyleUnknown
if needsAuthStyleProbe {
if style, ok := styleCache.lookupAuthStyle(deviceAuthURL); ok {
authStyle = style
needsAuthStyleProbe = false
} else {
authStyle = AuthStyleInHeader // the first way we'll try
}
}

req, err := NewRequestWithClientAuthn("POST", deviceAuthURL, clientID, clientSecret, v, authStyle)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")

t := time.Now()
r, err := ContextClient(ctx).Do(req)

if err != nil && needsAuthStyleProbe {
// If we get an error, assume the server wants the
// clientID & clientSecret in a different form.
authStyle = AuthStyleInParams // the second way we'll try
req, _ := NewRequestWithClientAuthn("POST", deviceAuthURL, clientID, clientSecret, v, authStyle)
r, err = ContextClient(ctx).Do(req)
}
if needsAuthStyleProbe && err == nil {
styleCache.setAuthStyle(deviceAuthURL, authStyle)
}

if err != nil {
return nil, err
}
defer r.Body.Close()

body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
}
if code := r.StatusCode; code < 200 || code > 299 {
return nil, &RetrieveError{
Response: r,
Body: body,
}
}

da := &DeviceAuthResponse{}
err = json.Unmarshal(body, &da)
if err != nil {
return nil, fmt.Errorf("unmarshal %s", err)
}

if da.Expiry != 0 {
// Make a small adjustment to account for time taken by the request
da.Expiry = da.Expiry + int64(t.Nanosecond())
}
return da, nil
}
88 changes: 88 additions & 0 deletions internal/deviceauth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package internal

import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)

func TestDeviceAuth_ClientAuthnInParams(t *testing.T) {
styleCache := new(AuthStyleCache)
const clientID = "client-id"
const clientSecret = "client-secret"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.FormValue("client_id"), clientID; got != want {
t.Errorf("client_id = %q; want %q", got, want)
}
if got, want := r.FormValue("client_secret"), clientSecret; got != want {
t.Errorf("client_secret = %q; want %q", got, want)
}
io.WriteString(w, `{"device_code":"code","user_code":"user_code","verification_uri":"http://example.device.com","expires_in":300,"interval":5}`)
}))
defer ts.Close()
_, err := RetrieveDeviceAuth(context.Background(), clientID, clientSecret, ts.URL, url.Values{}, AuthStyleInParams, styleCache)
if err != nil {
t.Errorf("RetrieveDeviceAuth = %v; want no error", err)
}
}

func TestDeviceAuth_ClientAuthnInHeader(t *testing.T) {
styleCache := new(AuthStyleCache)
const clientID = "client-id"
const clientSecret = "client-secret"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
u, p, ok := r.BasicAuth()
if !ok {
io.WriteString(w, `{"error":"invalid_client"}`)
w.WriteHeader(http.StatusBadRequest)
}
if got, want := u, clientID; got != want {
io.WriteString(w, `{"error":"invalid_client"}`)
w.WriteHeader(http.StatusBadRequest)
}
if got, want := p, clientSecret; got != want {
io.WriteString(w, `{"error":"invalid_client"}`)
w.WriteHeader(http.StatusBadRequest)
}
io.WriteString(w, `{"device_code":"code","user_code":"user_code","verification_uri":"http://example.device.com","expires_in":300,"interval":5}`)
}))
defer ts.Close()
_, err := RetrieveDeviceAuth(context.Background(), clientID, clientSecret, ts.URL, url.Values{}, AuthStyleInHeader, styleCache)
if err != nil {
t.Errorf("RetrieveDeviceAuth = %v; want no error", err)
}
}

func TestDeviceAuth_ClientAuthnProbe(t *testing.T) {
styleCache := new(AuthStyleCache)
const clientID = "client-id"
const clientSecret = "client-secret"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
u, p, ok := r.BasicAuth()
if !ok {
io.WriteString(w, `{"error":"invalid_client"}`)
w.WriteHeader(http.StatusBadRequest)
}
if got, want := u, clientID; got != want {
io.WriteString(w, `{"error":"invalid_client"}`)
w.WriteHeader(http.StatusBadRequest)
}
if got, want := p, clientSecret; got != want {
io.WriteString(w, `{"error":"invalid_client"}`)
w.WriteHeader(http.StatusBadRequest)
}
io.WriteString(w, `{"device_code":"code","user_code":"user_code","verification_uri":"http://example.device.com","expires_in":300,"interval":5}`)
}))
defer ts.Close()
_, err := RetrieveDeviceAuth(context.Background(), clientID, clientSecret, ts.URL, url.Values{}, AuthStyleUnknown, styleCache)
if err != nil {
t.Errorf("RetrieveDeviceAuth = %v; want no error", err)
}
}
34 changes: 34 additions & 0 deletions internal/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import (
"encoding/pem"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
)

// ParseKey converts the binary contents of a private key file
Expand All @@ -35,3 +38,34 @@ func ParseKey(key []byte) (*rsa.PrivateKey, error) {
}
return parsed, nil
}

// addClientAuthnRequestParams adds client_secret_post client authentication
func addClientAuthnRequestParams(clientID, clientSecret string, v url.Values, authStyle AuthStyle) url.Values {
if authStyle == AuthStyleInParams {
v = cloneURLValues(v)
if clientID != "" {
v.Set("client_id", clientID)
}
if clientSecret != "" {
v.Set("client_secret", clientSecret)
}
}
return v
}

// addClientAuthnRequestHeaders adds client_secret_basic client authentication
func addClientAuthnRequestHeaders(clientID, clientSecret string, req *http.Request, authStyle AuthStyle) {
if authStyle == AuthStyleInHeader {
req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
}
}

func NewRequestWithClientAuthn(httpMethod string, endpointURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) {
v = addClientAuthnRequestParams(clientID, clientSecret, v, authStyle)
req, err := http.NewRequest(httpMethod, endpointURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, err
}
addClientAuthnRequestHeaders(clientID, clientSecret, req, authStyle)
return req, nil
}
15 changes: 1 addition & 14 deletions internal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -181,23 +180,11 @@ func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
// the POST body (along with any values in v); false means to send it
// in the Authorization header.
func newTokenRequest(tokenURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) {
if authStyle == AuthStyleInParams {
v = cloneURLValues(v)
if clientID != "" {
v.Set("client_id", clientID)
}
if clientSecret != "" {
v.Set("client_secret", clientSecret)
}
}
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode()))
req, err := NewRequestWithClientAuthn("POST", tokenURL, clientID, clientSecret, v, authStyle)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if authStyle == AuthStyleInHeader {
req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
}
return req, nil
}

Expand Down

0 comments on commit dcc463a

Please sign in to comment.