diff --git a/cors.go b/cors.go index 8af9c09..d6ada56 100644 --- a/cors.go +++ b/cors.go @@ -16,7 +16,7 @@ type cors struct { allowedOrigins []string allowedOriginValidator OriginValidator exposedHeaders []string - maxAge int + maxAge *int ignoreOptions bool allowCredentials bool optionStatusCode int @@ -94,8 +94,8 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ",")) } - if ch.maxAge > 0 { - w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge)) + if ch.maxAge != nil { + w.Header().Set(corsMaxAgeHeader, strconv.Itoa(*ch.maxAge)) } if !ch.isMatch(method, defaultCorsMethods) { @@ -295,7 +295,7 @@ func MaxAge(age int) CORSOption { age = 600 } - ch.maxAge = age + ch.maxAge = &age return nil } } diff --git a/cors_test.go b/cors_test.go index 777a420..80d98d4 100644 --- a/cors_test.go +++ b/cors_test.go @@ -418,3 +418,47 @@ func TestCORSAllowStar(t *testing.T) { t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowOriginHeader, want, got) } } + +func TestCORSHandlerZeroMaxAgeForPreflight(t *testing.T) { + r := newRequest(http.MethodOptions, "http://www.example.com") + r.Header.Set("Origin", r.URL.String()) + r.Header.Set(corsRequestMethodHeader, http.MethodPost) + + rr := httptest.NewRecorder() + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + CORS(MaxAge(0))(testHandler).ServeHTTP(rr, r) + resp := rr.Result() + + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Fatalf("bad status: got %v want %v", got, want) + } + + header := resp.Header.Get(corsMaxAgeHeader) + if got, want := header, "0"; got != want { + t.Fatalf("bad header: expected %q to be %q, got %q.", corsMaxAgeHeader, want, got) + } +} + +func TestCORSHandlerNoMaxAgeForPreflight(t *testing.T) { + r := newRequest(http.MethodOptions, "http://www.example.com") + r.Header.Set("Origin", r.URL.String()) + r.Header.Set(corsRequestMethodHeader, http.MethodPost) + + rr := httptest.NewRecorder() + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + CORS()(testHandler).ServeHTTP(rr, r) + resp := rr.Result() + + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Fatalf("bad status: got %v want %v", got, want) + } + + header := resp.Header.Get(corsMaxAgeHeader) + if header != "" { + t.Fatalf("unexpected header %q", corsMaxAgeHeader) + } +}