From a5ab1e43ec65966b533f9247e3811b2b69e3945b Mon Sep 17 00:00:00 2001 From: motoki317 Date: Sun, 27 Aug 2023 18:10:20 +0900 Subject: [PATCH] Force auth overlay on /_oauth/login on soft-auth --- internal/auth.go | 18 ++++++++++++++++++ internal/auth_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ internal/server.go | 12 ++++++++---- 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/internal/auth.go b/internal/auth.go index ab50259e..121be8a8 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -146,6 +146,24 @@ func ValidateDomains(user string, domains CommaSeparatedList) bool { return false } +func GetRedirectURI(r *http.Request) string { + redirect := r.URL.Query().Get("redirect") + if redirect != "" { + return redirect + } + forwardedURI := r.Header.Get("X-Forwarded-Uri") + if forwardedURI != "" { + u, err := url.ParseRequestURI(forwardedURI) + if err == nil { + redirect = u.Query().Get("redirect") + if redirect != "" { + return redirect + } + } + } + return "/" +} + func ValidateLoginRedirect(r *http.Request, redirect string) (*url.URL, error) { u, err := url.ParseRequestURI(redirect) if err != nil { diff --git a/internal/auth_test.go b/internal/auth_test.go index 597fa3f3..416a41ec 100644 --- a/internal/auth_test.go +++ b/internal/auth_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/traPtitech/traefik-forward-auth/internal/provider" ) @@ -254,6 +255,45 @@ func TestAuthValidateUser(t *testing.T) { assert.True(v, "should allow user in whitelist") } +func TestGetRedirectURI(t *testing.T) { + cases := []struct { + name string + path string + headers map[string]string + want string + }{ + { + name: "no redirect param", + path: "/", + want: "/", + }, + { + name: "has redirect param", + path: "/?redirect=/foo", + want: "/foo", + }, + { + name: "has redirect param from forwarded uri header", + path: "/", + headers: map[string]string{ + "X-Forwarded-Uri": "/?redirect=/bar", + }, + want: "/bar", + }, + } + for _, cc := range cases { + t.Run(cc.name, func(t *testing.T) { + req, err := http.NewRequest("GET", cc.path, nil) + require.NoError(t, err) + for k, v := range cc.headers { + req.Header.Add(k, v) + } + got := GetRedirectURI(req) + assert.Equal(t, cc.want, got) + }) + } +} + func TestAuthValidateRedirect(t *testing.T) { assert := assert.New(t) config, _ = NewConfig([]string{}) diff --git a/internal/server.go b/internal/server.go index 4fe65feb..f8cb8355 100644 --- a/internal/server.go +++ b/internal/server.go @@ -143,6 +143,8 @@ func (s *Server) authHandler(providerName, rule string, soft bool) http.HandlerF } } + forceLogin := s.LoginHandler(providerName) + return func(w http.ResponseWriter, r *http.Request) { // Logging setup logger := s.logger(r, "Auth", rule, "Authenticating request") @@ -170,6 +172,11 @@ func (s *Server) authHandler(providerName, rule string, soft bool) http.HandlerF } if user == nil { if soft { + isForceLogin := strings.HasPrefix(r.Header.Get("X-Forwarded-Uri"), config.Path+"/login") + if isForceLogin { + forceLogin(w, r) + return + } unauthorized(w) return } else { @@ -311,10 +318,7 @@ func (s *Server) LoginHandler(providerName string) http.HandlerFunc { logger.Info("Explicit user login") // Calculate and validate redirect - redirect := r.URL.Query().Get("redirect") - if redirect == "" { - redirect = "/" - } + redirect := GetRedirectURI(r) redirectURL, err := ValidateLoginRedirect(r, redirect) if err != nil { logger.WithFields(logrus.Fields{