diff --git a/middleware.go b/middleware.go index 162342a5..333be41d 100644 --- a/middleware.go +++ b/middleware.go @@ -116,28 +116,21 @@ func parseRequestURL(c *Client, r *Request) error { } func parseRequestHeader(c *Client, r *Request) error { - hdr := make(http.Header) - for k := range c.Header { - hdr[k] = append(hdr[k], c.Header[k]...) - } - - for k := range r.Header { - hdr.Del(k) - hdr[k] = append(hdr[k], r.Header[k]...) + for k, v := range c.Header { + if _, ok := r.Header[k]; ok { + continue + } + r.Header[k] = v[:] } - if IsStringEmpty(hdr.Get(hdrUserAgentKey)) { - hdr.Set(hdrUserAgentKey, hdrUserAgentValue) + if IsStringEmpty(r.Header.Get(hdrUserAgentKey)) { + r.Header.Set(hdrUserAgentKey, hdrUserAgentValue) } - ct := hdr.Get(hdrContentTypeKey) - if IsStringEmpty(hdr.Get(hdrAcceptKey)) && !IsStringEmpty(ct) && - (IsJSONType(ct) || IsXMLType(ct)) { - hdr.Set(hdrAcceptKey, hdr.Get(hdrContentTypeKey)) + if ct := r.Header.Get(hdrContentTypeKey); IsStringEmpty(r.Header.Get(hdrAcceptKey)) && !IsStringEmpty(ct) && (IsJSONType(ct) || IsXMLType(ct)) { + r.Header.Set(hdrAcceptKey, r.Header.Get(hdrContentTypeKey)) } - r.Header = hdr - return nil } diff --git a/middleware_test.go b/middleware_test.go index fef6f008..a9bac8c7 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -1,7 +1,9 @@ package resty import ( + "net/http" "net/url" + "reflect" "testing" ) @@ -227,3 +229,117 @@ func Test_parseRequestURL(t *testing.T) { }) } } + +func Test_parseRequestHeader(t *testing.T) { + for _, tt := range []struct { + name string + init func(c *Client, r *Request) + expectedHeader http.Header + }{ + { + name: "headers in request", + init: func(c *Client, r *Request) { + r.SetHeaders(map[string]string{ + "foo": "1", + "bar": "2", + }) + }, + expectedHeader: http.Header{ + http.CanonicalHeaderKey("foo"): []string{"1"}, + http.CanonicalHeaderKey("bar"): []string{"2"}, + http.CanonicalHeaderKey(hdrUserAgentKey): []string{hdrUserAgentValue}, + }, + }, + { + name: "headers in client", + init: func(c *Client, r *Request) { + c.SetHeaders(map[string]string{ + "foo": "1", + "bar": "2", + }) + }, + expectedHeader: http.Header{ + http.CanonicalHeaderKey("foo"): []string{"1"}, + http.CanonicalHeaderKey("bar"): []string{"2"}, + http.CanonicalHeaderKey(hdrUserAgentKey): []string{hdrUserAgentValue}, + }, + }, + { + name: "headers in client and request", + init: func(c *Client, r *Request) { + c.SetHeaders(map[string]string{ + "foo": "1", // ignored, because of the same header in the request + "bar": "2", + }) + r.SetHeaders(map[string]string{ + "foo": "3", + "xyz": "4", + }) + }, + expectedHeader: http.Header{ + http.CanonicalHeaderKey("foo"): []string{"3"}, + http.CanonicalHeaderKey("bar"): []string{"2"}, + http.CanonicalHeaderKey("xyz"): []string{"4"}, + http.CanonicalHeaderKey(hdrUserAgentKey): []string{hdrUserAgentValue}, + }, + }, + { + name: "no headers", + init: func(c *Client, r *Request) {}, + expectedHeader: http.Header{ + http.CanonicalHeaderKey(hdrUserAgentKey): []string{hdrUserAgentValue}, + }, + }, + { + name: "user agent", + init: func(c *Client, r *Request) { + c.SetHeader(hdrUserAgentKey, "foo bar") + }, + expectedHeader: http.Header{ + http.CanonicalHeaderKey(hdrUserAgentKey): []string{"foo bar"}, + }, + }, + { + name: "json content type", + init: func(c *Client, r *Request) { + c.SetHeader(hdrContentTypeKey, "application/json") + }, + expectedHeader: http.Header{ + http.CanonicalHeaderKey(hdrContentTypeKey): []string{"application/json"}, + http.CanonicalHeaderKey(hdrAcceptKey): []string{"application/json"}, + http.CanonicalHeaderKey(hdrUserAgentKey): []string{hdrUserAgentValue}, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + c := New() + r := c.R() + tt.init(c, r) + if err := parseRequestHeader(c, r); err != nil { + t.Errorf("parseRequestHeader() error = %v", err) + } + if !reflect.DeepEqual(tt.expectedHeader, r.Header) { + t.Errorf("r.Header = %#+v does not match expected %#+v", r.Header, tt.expectedHeader) + } + }) + } +} + +func Benchmark_parseRequestHeader(b *testing.B) { + c := New() + r := c.R() + c.SetHeaders(map[string]string{ + "foo": "1", // ignored, because of the same header in the request + "bar": "2", + }) + r.SetHeaders(map[string]string{ + "foo": "3", + "xyz": "4", + }) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := parseRequestHeader(c, r); err != nil { + b.Errorf("parseRequestHeader() error = %v", err) + } + } +}