diff --git a/pkg/webmiddleware/redirect.go b/pkg/webmiddleware/redirect.go index 85e110a74b..52b4982737 100644 --- a/pkg/webmiddleware/redirect.go +++ b/pkg/webmiddleware/redirect.go @@ -54,6 +54,7 @@ func (c RedirectConfiguration) build(url *url.URL) (*url.URL, bool) { if c.HostName != nil { hostname = c.HostName(hostname) } + originalPortStr := portStr if c.Port != nil { port, _ := strconv.ParseUint(portStr, 10, 0) port = uint64(c.Port(uint(port))) @@ -62,12 +63,15 @@ func (c RedirectConfiguration) build(url *url.URL) (*url.URL, bool) { host := hostname if portStr != "" { switch { - case portStr == "0": + case portStr == "0" && originalPortStr == "": // Just use the hostname. + case portStr == "0" && originalPortStr != "": + // Maintain the original port, in order to avoid loops. + host = net.JoinHostPort(host, originalPortStr) case target.Scheme == "http" && portStr == "80": - // This is the default. Just use the hostame. + // This is the default. Just use the hostname. case target.Scheme == "https" && portStr == "443": - // This is the default. Just use the hostame. + // This is the default. Just use the hostname. default: host = net.JoinHostPort(host, portStr) } diff --git a/pkg/webmiddleware/redirect_test.go b/pkg/webmiddleware/redirect_test.go index d7ab61c620..60e14f2b11 100644 --- a/pkg/webmiddleware/redirect_test.go +++ b/pkg/webmiddleware/redirect_test.go @@ -27,6 +27,8 @@ import ( ) func TestRedirect(t *testing.T) { + t.Parallel() + m := Redirect(RedirectConfiguration{ Scheme: func(s string) string { return SchemeHTTPS }, HostName: func(h string) string { @@ -47,6 +49,7 @@ func TestRedirect(t *testing.T) { }) t.Run("None", func(t *testing.T) { + t.Parallel() a := assertions.New(t) r := httptest.NewRequest(http.MethodGet, "https://dev.example.com/path?query=true", nil) rec := httptest.NewRecorder() @@ -86,7 +89,9 @@ func TestRedirect(t *testing.T) { Redirect: "https://dev.example.com/path", }, } { + tc := tc t.Run(tc.Name, func(t *testing.T) { + t.Parallel() a := assertions.New(t) r := httptest.NewRequest(http.MethodGet, tc.URL, nil) rec := httptest.NewRecorder() @@ -100,4 +105,19 @@ func TestRedirect(t *testing.T) { a.So(res.Header.Get("Location"), should.Equal, tc.Redirect) }) } + + t.Run("ExplicitPort", func(t *testing.T) { + t.Parallel() + a := assertions.New(t) + r := httptest.NewRequest(http.MethodGet, "https://dev.example.com:8885/", nil) + rec := httptest.NewRecorder() + m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })).ServeHTTP(rec, r) + res := rec.Result() + a.So(res.StatusCode, should.Equal, http.StatusOK) + body, _ := io.ReadAll(res.Body) + a.So(string(body), should.Equal, "OK") + }) }