diff --git a/authorize_helper.go b/authorize_helper.go index bffe10c4..3c3eafde 100644 --- a/authorize_helper.go +++ b/authorize_helper.go @@ -7,8 +7,8 @@ import ( "context" "html/template" "io" + "net" "net/url" - "regexp" "strings" "authelia.com/provider/oauth2/internal/consts" @@ -129,7 +129,7 @@ func isMatchingAsLoopback(requested *url.URL, registeredURI string) bool { // // Source: https://datatracker.ietf.org/doc/html/rfc8252#section-7.3 if requested.Scheme == "http" && - isLoopbackAddress(requested.Host) && + isLoopbackAddress(requested) && registered.Hostname() == requested.Hostname() && // The port is skipped here - see codedoc above! registered.Path == requested.Path && @@ -140,14 +140,14 @@ func isMatchingAsLoopback(requested *url.URL, registeredURI string) bool { return false } -var ( - regexLoopbackAddress = regexp.MustCompile(`^(127\.0\.0\.1|\[::1])(:\d+)?$`) -) - // Check if address is either an IPv4 loopback or an IPv6 loopback- // An optional port is ignored -func isLoopbackAddress(address string) bool { - return regexLoopbackAddress.MatchString(address) +func isLoopbackAddress(uri *url.URL) bool { + if uri == nil { + return false + } + + return net.ParseIP(uri.Hostname()).IsLoopback() } // IsValidRedirectURI validates a redirect_uri as specified in: @@ -185,7 +185,8 @@ func IsRedirectURISecureStrict(redirectURI *url.URL) bool { func IsLocalhost(redirectURI *url.URL) bool { hn := redirectURI.Hostname() - return strings.HasSuffix(hn, ".localhost") || hn == "127.0.0.1" || hn == "::1" || hn == "localhost" + + return strings.HasSuffix(hn, ".localhost") || hn == "localhost" || isLoopbackAddress(redirectURI) } func WriteAuthorizeFormPostResponse(redirectURL string, parameters url.Values, template *template.Template, rw io.Writer) { diff --git a/authorize_helper_whitebox_test.go b/authorize_helper_whitebox_test.go index f2d027a1..d04afab3 100644 --- a/authorize_helper_whitebox_test.go +++ b/authorize_helper_whitebox_test.go @@ -4,6 +4,8 @@ package oauth2 import ( + "github.com/stretchr/testify/require" + "net/url" "testing" "github.com/stretchr/testify/assert" @@ -17,53 +19,47 @@ func TestIsLookbackAddress(t *testing.T) { }{ { "ShouldReturnTrueIPv4Loopback", - "127.0.0.1", - true, - }, - { - "ShouldReturnTrueIPv4LoopbackWithPort", - "127.0.0.1:1230", + "http://127.0.0.1:1235", true, }, { "ShouldReturnTrueIPv6Loopback", - "[::1]", + "http://[::1]:1234", true, }, { - "ShouldReturnTrueIPv6LoopbackWithPort", - "[::1]:1230", - true, - }, { "ShouldReturnFalse12700255", - "127.0.0.255", - false, + "https://127.0.0.255", + true, }, { - "ShouldReturnFalse12700255WithPort", - "127.0.0.255:1230", - false, + "ShouldReturnTrue127.0.0.255", + "https://127.0.0.255", + true, }, { "ShouldReturnFalseInvalidFourthOctet", - "127.0.0.11230", + "https://127.0.0.11230", false, }, { "ShouldReturnFalseInvalidIPv4", - "127x0x0x11230", + "https://127x0x0x11230", false, }, { "ShouldReturnFalseInvalidIPv6", - "[::1]1230", + "https://[::1]1230", false, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.expected, isLoopbackAddress(tc.have)) + have, err := url.Parse(tc.have) + + require.NoError(t, err) + assert.Equal(t, tc.expected, isLoopbackAddress(have)) }) } }