diff --git a/client/client.go b/client/client.go index 02b6a2e..8681870 100644 --- a/client/client.go +++ b/client/client.go @@ -97,29 +97,30 @@ var DefaultErrorHandler = func(res *http.Response, uri string) error { // TimeoutOption is an option to set timeout for the http client calls // value unit is nanoseconds -// -// Deprecated: Internal http.Client shouldn't be modified after construction. Use WithHttpClient instead func TimeoutOption(timeout time.Duration) Option { return func(request *Request) error { - request.SetTimeout(timeout) + httpClient, ok := request.HttpClient.(*http.Client) + if !ok { + return errors.New("unable to set timeout: httpclient is not *http.Client") + } + httpClient.Timeout = timeout return nil } } -// Deprecated: Internal http.Client shouldn't be modified after construction. Use WithHttpClient instead func ProxyOption(proxyURL string) Option { return func(request *Request) error { if proxyURL == "" { return nil } - err := request.SetProxy(proxyURL) - if err != nil { - return err + httpClient, ok := request.HttpClient.(*http.Client) + if !ok { + return errors.New("unable to set proxy: httpclient is not *http.Client") } - return nil + return setHttpClientTransportProxy(httpClient, proxyURL) } } @@ -130,6 +131,13 @@ func WithHttpClient(httpClient HTTPClient) Option { } } +func WithExtraHeader(key, value string) Option { + return func(request *Request) error { + request.Headers[key] = value + return nil + } +} + func WithExtraHeaders(headers map[string]string) Option { return func(request *Request) error { for k, v := range headers { @@ -169,3 +177,25 @@ func (r *Request) SetProxy(proxyUrl string) error { func (r *Request) AddHeader(key, value string) { r.Headers[key] = value } + +func setHttpClientTransportProxy(client *http.Client, proxyUrl string) error { + if proxyUrl == "" { + return errors.New("empty proxy url") + } + url, err := url.Parse(proxyUrl) + if err != nil { + return err + } + + if client.Transport == nil { + client.Transport = &http.Transport{Proxy: http.ProxyURL(url)} + return nil + } + + transport, ok := client.Transport.(*http.Transport) + if !ok { + return errors.New("http client transport is not *http.Transport") + } + transport.Proxy = http.ProxyURL(url) + return nil +} diff --git a/client/client_test.go b/client/client_test.go index d4d7892..5be62b7 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestRequest_GetBase(t *testing.T) { @@ -199,3 +200,71 @@ func (c *mockJSONClient) Do(_ *http.Request) (*http.Response, error) { Body: io.NopCloser(bytes.NewBuffer(c.body)), }, nil } + +func TestInitClientOptions(t *testing.T) { + const aBaseURL = "http://www.example.com/" + tests := []struct { + name string + options []Option + assertion func(t *testing.T, client *Request) + }{ + { + name: "set proxy", + options: []Option{ProxyOption("http://www.example.com")}, + assertion: func(t *testing.T, client *Request) { + // checks that the proxy is set + url, err := client.HttpClient.(*http.Client).Transport.(*http.Transport).Proxy(&http.Request{}) + require.NoError(t, err) + require.Equal(t, "http://www.example.com", url.String()) + }, + }, + { + name: "set timeout", + options: []Option{TimeoutOption(3 * time.Second)}, + assertion: func(t *testing.T, client *Request) { + // checks that the timeout is set + require.Equal(t, 3*time.Second, client.HttpClient.(*http.Client).Timeout) + }, + }, + { + name: "set proxy&timeout", + options: []Option{ProxyOption("http://www.example.com"), TimeoutOption(3 * time.Second)}, + assertion: func(t *testing.T, client *Request) { + // checks that the proxy is set + url, err := client.HttpClient.(*http.Client).Transport.(*http.Transport).Proxy(&http.Request{}) + require.NoError(t, err) + require.Equal(t, "http://www.example.com", url.String()) + + // checks that the timeout is set + require.Equal(t, 3*time.Second, client.HttpClient.(*http.Client).Timeout) + }, + }, + { + name: "WithExtraHeader multiple times", + options: []Option{ + WithExtraHeader("Content-Type", "application/json"), + WithExtraHeader("Accept", "application/json"), + WithExtraHeader("Server", "Apache"), + WithExtraHeaders(map[string]string{ + "Authorization": "Basic ", + "Connection": "Keep-Alive", + "Server": "nginx", + }), + }, + assertion: func(t *testing.T, client *Request) { + require.Equal(t, "application/json", client.Headers["Content-Type"]) + require.Equal(t, "application/json", client.Headers["Accept"]) + require.Equal(t, "Basic ", client.Headers["Authorization"]) + require.Equal(t, "Keep-Alive", client.Headers["Connection"]) + require.Equal(t, "nginx", client.Headers["Server"]) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := InitClient(aBaseURL, nil, test.options...) + test.assertion(t, &c) + }) + } +}