Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use client authn on device auth request #757

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 25 additions & 35 deletions deviceauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
Expand Down Expand Up @@ -93,47 +90,40 @@ func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*Devic
return retrieveDeviceAuth(ctx, c, v)
}

func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
if c.Endpoint.DeviceAuthURL == "" {
return nil, errors.New("endpoint missing DeviceAuthURL")
// deviceAuthFromInternal maps an *internal.DeviceAuthResponse struct into
// a *DeviceAuthResponse struct.
func deviceAuthFromInternal(da *internal.DeviceAuthResponse) *DeviceAuthResponse {
if da == nil {
return nil
}

req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, err
return &DeviceAuthResponse{
DeviceCode: da.DeviceCode,
UserCode: da.UserCode,
VerificationURI: da.VerificationURI,
VerificationURIComplete: da.VerificationURIComplete,
Expiry: time.Now().UTC().Add(time.Second * time.Duration(da.Expiry)),
Interval: da.Interval,
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
}

t := time.Now()
r, err := internal.ContextClient(ctx).Do(req)
if err != nil {
return nil, err
// retrieveDeviceAuth takes a *Config and uses that to retrieve an *internal.DeviceAuthResponse.
// This response is then mapped from *internal.DeviceAuthResponse into an *oauth2.DeviceAuthResponse which is returned along
// with an error.
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
if c.Endpoint.DeviceAuthURL == "" {
return nil, errors.New("endpoint missing DeviceAuthURL")
}

body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
da, err := internal.RetrieveDeviceAuth(ctx, c.ClientID, c.ClientSecret, c.Endpoint.DeviceAuthURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
if err != nil {
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
}
if code := r.StatusCode; code < 200 || code > 299 {
return nil, &RetrieveError{
Response: r,
Body: body,
if rErr, ok := err.(*internal.RetrieveError); ok {
return nil, (*RetrieveError)(rErr)
}
return nil, err
}
dar := deviceAuthFromInternal(da)

da := &DeviceAuthResponse{}
err = json.Unmarshal(body, &da)
if err != nil {
return nil, fmt.Errorf("unmarshal %s", err)
}

if !da.Expiry.IsZero() {
// Make a small adjustment to account for time taken by the request
da.Expiry = da.Expiry.Add(-time.Since(t))
}

return da, nil
return dar, err
}

// DeviceAccessToken polls the server to exchange a device code for a token.
Expand Down
97 changes: 97 additions & 0 deletions internal/deviceauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package internal

import (
"context"
"encoding/json"
"fmt"
"io"
"net/url"
"time"
)

// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
//
// This type is a mirror of oauth2.DeviceAuthResponse, with the only difference
// being that in this struct `expires_in` isn't mapped to a timestamp. It solely
// exists to break an otherwise-circular dependency. Other internal packages should
// convert this DeviceAuthResponse into an oauth2.DeviceAuthResponse before use.
type DeviceAuthResponse struct {
// DeviceCode
DeviceCode string `json:"device_code"`
// UserCode is the code the user should enter at the verification uri
UserCode string `json:"user_code"`
// VerificationURI is where user should enter the user code
VerificationURI string `json:"verification_uri"`
// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
// Expiry is when the device code and user code expire
Expiry int64 `json:"expires_in,omitempty"`
// Interval is the duration in seconds that Poll should wait between requests
Interval int64 `json:"interval,omitempty"`
}

func RetrieveDeviceAuth(ctx context.Context, clientID, clientSecret, deviceAuthURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*DeviceAuthResponse, error) {
needsAuthStyleProbe := authStyle == AuthStyleUnknown
if needsAuthStyleProbe {
if style, ok := styleCache.lookupAuthStyle(deviceAuthURL); ok {
authStyle = style
needsAuthStyleProbe = false
} else {
authStyle = AuthStyleInHeader // the first way we'll try
}
}

req, err := NewRequestWithClientAuthn("POST", deviceAuthURL, clientID, clientSecret, v, authStyle)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")

t := time.Now()
r, err := ContextClient(ctx).Do(req)

if err != nil && needsAuthStyleProbe {
// If we get an error, assume the server wants the
// clientID & clientSecret in a different form.
authStyle = AuthStyleInParams // the second way we'll try
req, _ := NewRequestWithClientAuthn("POST", deviceAuthURL, clientID, clientSecret, v, authStyle)
r, err = ContextClient(ctx).Do(req)
}
if needsAuthStyleProbe && err == nil {
styleCache.setAuthStyle(deviceAuthURL, authStyle)
}

if err != nil {
return nil, err
}
defer r.Body.Close()

body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
}
if code := r.StatusCode; code < 200 || code > 299 {
return nil, &RetrieveError{
Response: r,
Body: body,
}
}

da := &DeviceAuthResponse{}
err = json.Unmarshal(body, &da)
if err != nil {
return nil, fmt.Errorf("unmarshal %s", err)
}

if da.Expiry != 0 {
// Make a small adjustment to account for time taken by the request
da.Expiry = da.Expiry + int64(t.Nanosecond())
}
return da, nil
}
88 changes: 88 additions & 0 deletions internal/deviceauth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package internal

import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)

func TestDeviceAuth_ClientAuthnInParams(t *testing.T) {
styleCache := new(AuthStyleCache)
const clientID = "client-id"
const clientSecret = "client-secret"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.FormValue("client_id"), clientID; got != want {
t.Errorf("client_id = %q; want %q", got, want)
}
if got, want := r.FormValue("client_secret"), clientSecret; got != want {
t.Errorf("client_secret = %q; want %q", got, want)
}
io.WriteString(w, `{"device_code":"code","user_code":"user_code","verification_uri":"http://example.device.com","expires_in":300,"interval":5}`)
}))
defer ts.Close()
_, err := RetrieveDeviceAuth(context.Background(), clientID, clientSecret, ts.URL, url.Values{}, AuthStyleInParams, styleCache)
if err != nil {
t.Errorf("RetrieveDeviceAuth = %v; want no error", err)
}
}

func TestDeviceAuth_ClientAuthnInHeader(t *testing.T) {
styleCache := new(AuthStyleCache)
const clientID = "client-id"
const clientSecret = "client-secret"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
u, p, ok := r.BasicAuth()
if !ok {
io.WriteString(w, `{"error":"invalid_client"}`)
w.WriteHeader(http.StatusBadRequest)
}
if got, want := u, clientID; got != want {
io.WriteString(w, `{"error":"invalid_client"}`)
w.WriteHeader(http.StatusBadRequest)
}
if got, want := p, clientSecret; got != want {
io.WriteString(w, `{"error":"invalid_client"}`)
w.WriteHeader(http.StatusBadRequest)
}
io.WriteString(w, `{"device_code":"code","user_code":"user_code","verification_uri":"http://example.device.com","expires_in":300,"interval":5}`)
}))
defer ts.Close()
_, err := RetrieveDeviceAuth(context.Background(), clientID, clientSecret, ts.URL, url.Values{}, AuthStyleInHeader, styleCache)
if err != nil {
t.Errorf("RetrieveDeviceAuth = %v; want no error", err)
}
}

func TestDeviceAuth_ClientAuthnProbe(t *testing.T) {
styleCache := new(AuthStyleCache)
const clientID = "client-id"
const clientSecret = "client-secret"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
u, p, ok := r.BasicAuth()
if !ok {
io.WriteString(w, `{"error":"invalid_client"}`)
w.WriteHeader(http.StatusBadRequest)
}
if got, want := u, clientID; got != want {
io.WriteString(w, `{"error":"invalid_client"}`)
w.WriteHeader(http.StatusBadRequest)
}
if got, want := p, clientSecret; got != want {
io.WriteString(w, `{"error":"invalid_client"}`)
w.WriteHeader(http.StatusBadRequest)
}
io.WriteString(w, `{"device_code":"code","user_code":"user_code","verification_uri":"http://example.device.com","expires_in":300,"interval":5}`)
}))
defer ts.Close()
_, err := RetrieveDeviceAuth(context.Background(), clientID, clientSecret, ts.URL, url.Values{}, AuthStyleUnknown, styleCache)
if err != nil {
t.Errorf("RetrieveDeviceAuth = %v; want no error", err)
}
}
34 changes: 34 additions & 0 deletions internal/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import (
"encoding/pem"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
)

// ParseKey converts the binary contents of a private key file
Expand All @@ -35,3 +38,34 @@ func ParseKey(key []byte) (*rsa.PrivateKey, error) {
}
return parsed, nil
}

// addClientAuthnRequestParams adds client_secret_post client authentication
func addClientAuthnRequestParams(clientID, clientSecret string, v url.Values, authStyle AuthStyle) url.Values {
if authStyle == AuthStyleInParams {
v = cloneURLValues(v)
if clientID != "" {
v.Set("client_id", clientID)
}
if clientSecret != "" {
v.Set("client_secret", clientSecret)
}
}
return v
}

// addClientAuthnRequestHeaders adds client_secret_basic client authentication
func addClientAuthnRequestHeaders(clientID, clientSecret string, req *http.Request, authStyle AuthStyle) {
if authStyle == AuthStyleInHeader {
req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
}
}

func NewRequestWithClientAuthn(httpMethod string, endpointURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) {
v = addClientAuthnRequestParams(clientID, clientSecret, v, authStyle)
req, err := http.NewRequest(httpMethod, endpointURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, err
}
addClientAuthnRequestHeaders(clientID, clientSecret, req, authStyle)
return req, nil
}
15 changes: 1 addition & 14 deletions internal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -181,23 +180,11 @@ func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
// the POST body (along with any values in v); false means to send it
// in the Authorization header.
func newTokenRequest(tokenURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) {
if authStyle == AuthStyleInParams {
v = cloneURLValues(v)
if clientID != "" {
v.Set("client_id", clientID)
}
if clientSecret != "" {
v.Set("client_secret", clientSecret)
}
}
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode()))
req, err := NewRequestWithClientAuthn("POST", tokenURL, clientID, clientSecret, v, authStyle)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if authStyle == AuthStyleInHeader {
req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
}
return req, nil
}

Expand Down