From 5ff5176b88f50f9d61268694a284f7080bc36d90 Mon Sep 17 00:00:00 2001 From: lvlcn-t <75443136+lvlcn-t@users.noreply.github.com> Date: Sun, 8 Sep 2024 20:06:00 +0200 Subject: [PATCH] feat(rest): add NewWithClient constructor to create a new client with a custom http.Client --- rest/client.go | 56 ++++++++++++++++++++++++++++----------------- rest/client_test.go | 2 +- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/rest/client.go b/rest/client.go index 3780f2e..d64a5d6 100644 --- a/rest/client.go +++ b/rest/client.go @@ -17,7 +17,10 @@ import ( ) // DefaultClient is the default rest client used for making requests. -var DefaultClient = newDefaultClient() +var DefaultClient = defaultClient() + +// DefaultTransport is the default transport used for the default client. +var DefaultTransport = defaultTransport() // Do makes a request to the given endpoint with the given payload and response type. // It applies the given options and returns an error if the request fails. @@ -101,14 +104,14 @@ type RequestOption func(*Request) var _ Client = (*restClient)(nil) const ( + // DefaultTimeout is the default timeout for requests. + DefaultTimeout = 60 * time.Second // maxIdleConns controls the maximum number of idle (keep-alive) connections across all hosts. maxIdleConns = 100 // maxIdleConnsPerHost controls the maximum number of idle (keep-alive) connections to keep per-host. maxIdleConnsPerHost = 100 // idleConnTimeout controls the maximum amount of time an idle (keep-alive) connection will remain idle before closing itself. idleConnTimeout = 90 * time.Second - // defaultTimeout is the default timeout for requests. - defaultTimeout = 60 * time.Second // maxRequestRate is the maximum number of requests that can be made in a single second. maxRequestRate rate.Limit = 10 // maxRequestBurst is the maximum number of requests that can be made in a single moment. @@ -128,29 +131,30 @@ type restClient struct { wg sync.WaitGroup } -// NewClient creates a new rest client with the given base URL. -// You can optionally provide a timeout for requests. If no timeout is provided, the default timeout is used. -func NewClient(baseURL string, timeout ...time.Duration) (Client, error) { +// New creates a new rest client with the given base URL. +// You can optionally provide a timeout for requests. If no timeout is provided, the [DefaultTimeout] will be used. +func New(baseURL string, timeout ...time.Duration) (Client, error) { + if len(timeout) == 0 { + return NewWithClient(baseURL, nil) + } + + return NewWithClient(baseURL, &http.Client{Transport: defaultTransport(), Timeout: timeout[0]}) +} + +// NewWithClient creates a new rest client with the given base URL and [http.Client]. +// If the client is nil, it will create a new client with the [DefaultTransport] and [DefaultTimeout]. +func NewWithClient(baseURL string, client *http.Client) (Client, error) { if _, err := url.Parse(baseURL); err != nil { return nil, fmt.Errorf("invalid base URL: %w", err) } - t := defaultTimeout - if len(timeout) > 0 { - t = timeout[0] + if client == nil { + client = &http.Client{Transport: defaultTransport(), Timeout: DefaultTimeout} } - tp := http.DefaultTransport.(*http.Transport).Clone() - tp.MaxIdleConns = maxIdleConns - tp.MaxIdleConnsPerHost = maxIdleConnsPerHost - tp.IdleConnTimeout = idleConnTimeout - return &restClient{ baseURL: baseURL, - client: &http.Client{ - Timeout: t, - Transport: tp, - }, + client: client, limiter: rate.NewLimiter(maxRequestRate, maxRequestBurst), }, nil } @@ -310,11 +314,21 @@ func WithTracer(c *httptrace.ClientTrace) RequestOption { } } -// newDefaultClient creates a new rest client without a base URL. -func newDefaultClient() Client { - c, err := NewClient("") +// defaultClient creates a new [restClient] without a base URL. +// Panics if the client cannot be created. +func defaultClient() Client { + c, err := New("") if err != nil { panic(fmt.Sprintf("failed to create default client: %v", err)) } return c } + +// defaultTransport creates a new [http.Transport] with custom settings. +func defaultTransport() *http.Transport { + tp := http.DefaultTransport.(*http.Transport).Clone() + tp.MaxIdleConns = maxIdleConns + tp.MaxIdleConnsPerHost = maxIdleConnsPerHost + tp.IdleConnTimeout = idleConnTimeout + return tp +} diff --git a/rest/client_test.go b/rest/client_test.go index fa96c78..54723bb 100644 --- a/rest/client_test.go +++ b/rest/client_test.go @@ -69,7 +69,7 @@ func TestNewClient(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := NewClient(tt.baseURL, 5*time.Second) + _, err := New(tt.baseURL, 5*time.Second) if (err != nil) != tt.wantErr { t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr) }