Skip to content

Commit

Permalink
Force auth overlay on /_oauth/login on soft-auth
Browse files Browse the repository at this point in the history
  • Loading branch information
motoki317 committed Aug 27, 2023
1 parent ca3733c commit a5ab1e4
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 4 deletions.
18 changes: 18 additions & 0 deletions internal/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,24 @@ func ValidateDomains(user string, domains CommaSeparatedList) bool {
return false
}

func GetRedirectURI(r *http.Request) string {
redirect := r.URL.Query().Get("redirect")
if redirect != "" {
return redirect
}
forwardedURI := r.Header.Get("X-Forwarded-Uri")
if forwardedURI != "" {
u, err := url.ParseRequestURI(forwardedURI)
if err == nil {
redirect = u.Query().Get("redirect")
if redirect != "" {
return redirect
}
}
}
return "/"
}

func ValidateLoginRedirect(r *http.Request, redirect string) (*url.URL, error) {
u, err := url.ParseRequestURI(redirect)
if err != nil {
Expand Down
40 changes: 40 additions & 0 deletions internal/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/traPtitech/traefik-forward-auth/internal/provider"
)
Expand Down Expand Up @@ -254,6 +255,45 @@ func TestAuthValidateUser(t *testing.T) {
assert.True(v, "should allow user in whitelist")
}

func TestGetRedirectURI(t *testing.T) {
cases := []struct {
name string
path string
headers map[string]string
want string
}{
{
name: "no redirect param",
path: "/",
want: "/",
},
{
name: "has redirect param",
path: "/?redirect=/foo",
want: "/foo",
},
{
name: "has redirect param from forwarded uri header",
path: "/",
headers: map[string]string{
"X-Forwarded-Uri": "/?redirect=/bar",
},
want: "/bar",
},
}
for _, cc := range cases {
t.Run(cc.name, func(t *testing.T) {
req, err := http.NewRequest("GET", cc.path, nil)
require.NoError(t, err)
for k, v := range cc.headers {
req.Header.Add(k, v)
}
got := GetRedirectURI(req)
assert.Equal(t, cc.want, got)
})
}
}

func TestAuthValidateRedirect(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
Expand Down
12 changes: 8 additions & 4 deletions internal/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ func (s *Server) authHandler(providerName, rule string, soft bool) http.HandlerF
}
}

forceLogin := s.LoginHandler(providerName)

return func(w http.ResponseWriter, r *http.Request) {
// Logging setup
logger := s.logger(r, "Auth", rule, "Authenticating request")
Expand Down Expand Up @@ -170,6 +172,11 @@ func (s *Server) authHandler(providerName, rule string, soft bool) http.HandlerF
}
if user == nil {
if soft {
isForceLogin := strings.HasPrefix(r.Header.Get("X-Forwarded-Uri"), config.Path+"/login")
if isForceLogin {
forceLogin(w, r)
return
}
unauthorized(w)
return
} else {
Expand Down Expand Up @@ -311,10 +318,7 @@ func (s *Server) LoginHandler(providerName string) http.HandlerFunc {
logger.Info("Explicit user login")

// Calculate and validate redirect
redirect := r.URL.Query().Get("redirect")
if redirect == "" {
redirect = "/"
}
redirect := GetRedirectURI(r)
redirectURL, err := ValidateLoginRedirect(r, redirect)
if err != nil {
logger.WithFields(logrus.Fields{
Expand Down

0 comments on commit a5ab1e4

Please sign in to comment.