Skip to content

Commit

Permalink
feat(rest): add global default client for easy access
Browse files Browse the repository at this point in the history
  • Loading branch information
lvlcn-t committed Sep 6, 2024
1 parent edecf18 commit 49baa5a
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 29 deletions.
104 changes: 79 additions & 25 deletions rest/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,47 @@ import (
"golang.org/x/time/rate"
)

var _ Client = (*client)(nil)
var (
_ Client = (*client)(nil)
// DefaultClient is the default rest client used for making requests.
DefaultClient Client = newDefaultClient()
)

// Do makes a request to the given endpoint with the given payload and response objects.
// It applies the given options and returns an error if the request fails.
//
// Example:
//
// // Define the request endpoint
// ctx := context.Background()
// endpoint := rest.Get("https://api.example.com/resource")
//
// // Define the response type
// type response struct {
// ID int `json:"id"`
// Name string `json:"name"`
// }
//
// // Make the request
// resp, status, err := rest.Do[response](ctx, endpoint, nil)
// if err != nil {
// // Handle error
// }
//
// The request will be made to "https://api.example.com/resource" with the payload marshaled to JSON
// and the response unmarshaled into a response object with the given type.
func Do[T any](ctx context.Context, endpoint *Endpoint, payload any, opts ...RequestOption) (resp T, code int, err error) {
status, err := DefaultClient.Do(ctx, endpoint, payload, &resp, opts...)
return resp, status, err
}

// Close closes the default rest client and gracefully awaits all pending requests to finish.
// If the context is canceled, it will close the idle connections immediately.
func Close(ctx context.Context) {
DefaultClient.Close(ctx)
}

// Client allows doing requests to different endpoints of one host.
// Client allows doing requests to different endpoints.
// It provides a simple way to make requests with rate limiting and request options.
// The client is safe for concurrent use.
//
Expand All @@ -31,7 +69,7 @@ type Client interface {
// client := rest.NewClient("https://api.example.com", 5*time.Second)
// defer client.Close(ctx)
//
// endpoint := rest.Get("/resource")
// endpoint := rest.Post("/resource")
// payload := map[string]string{"key": "value"}
// var response map[string]any
// status, err := client.Do(ctx, endpoint, payload, &response)
Expand Down Expand Up @@ -69,15 +107,16 @@ const (
maxIdleConnsPerHost = 100
// idleConnTimeout controls the maximum amount of time an idle (keep-alive) connection will remain idle before closing itself.
idleConnTimeout = 90 * time.Second
// defaultTimeout is the default timeout for requests.
defaultTimeout = 60 * time.Second
// maxRequestRate is the maximum number of requests that can be made in a single second.
maxRequestRate rate.Limit = 10
// maxRequestBurst is the maximum number of requests that can be made in a single moment.
maxRequestBurst = 10
)

var (
// defaultRateLimiter is the default rate limiter for the rest client.
// It allows 10 requests per second with a burst of 10 (burst is the maximum number of requests that can be made in a single moment).
defaultRateLimiter = rate.NewLimiter(rate.Limit(10), 10) //nolint:mnd // No need for another constant.
// ErrRateLimitExceeded is the error returned when the rate limit is exceeded.
ErrRateLimitExceeded = errors.New("rate limit exceeded")
)
// ErrRateLimitExceeded is the error returned when the rate limit is exceeded.
var ErrRateLimitExceeded = errors.New("rate limit exceeded")

// ErrDecodingResponse is the error returned when the response cannot be unmarshalled into the response object.
type ErrDecodingResponse struct{ err error }
Expand All @@ -99,7 +138,7 @@ func (e *ErrDecodingResponse) Unwrap() error {
}

// client is the default implementation of the Client interface.
// The client is used for making requests to different endpoints of one base URL.
// The client is used for making requests to different endpoints.
type client struct {
// baseURL is the base URL for all requests.
baseURL string
Expand All @@ -111,23 +150,34 @@ type client struct {
wg sync.WaitGroup
}

// NewClient creates a new rest client with the given base URL and timeout.
func NewClient(baseURL string, timeout time.Duration) (Client, error) {
// NewClient creates a new rest client with the given base URL.
// You can optionally provide a timeout for requests. If no timeout is provided, the default timeout is used.
func NewClient(baseURL string, timeout ...time.Duration) (Client, error) {
if _, err := url.Parse(baseURL); err != nil {
return nil, fmt.Errorf("invalid base URL: %w", err)
}

t := defaultTimeout
if len(timeout) > 0 {
t = timeout[0]
}

dt := http.DefaultTransport.(*http.Transport)
return &client{
baseURL: baseURL,
client: &http.Client{
Timeout: timeout,
Timeout: t,
Transport: &http.Transport{
MaxIdleConns: maxIdleConns,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
IdleConnTimeout: idleConnTimeout,
Proxy: dt.Proxy,
DialContext: dt.DialContext,
MaxIdleConns: maxIdleConns,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
IdleConnTimeout: idleConnTimeout,
TLSHandshakeTimeout: dt.TLSHandshakeTimeout,
ExpectContinueTimeout: dt.ExpectContinueTimeout,
},
},
limiter: defaultRateLimiter,
limiter: rate.NewLimiter(maxRequestRate, maxRequestBurst),
}, nil
}

Expand Down Expand Up @@ -214,19 +264,14 @@ func (r *client) Close(ctx context.Context) {
r.wg.Wait()
}()

type closer interface{ CloseIdleConnections() }
select {
case <-ctx.Done():
if transport, ok := r.client.Transport.(closer); ok {
transport.CloseIdleConnections()
}
r.client.CloseIdleConnections()
case <-done:
}

// Ensure all idle connections are closed even if all requests should be done.
if transport, ok := r.client.Transport.(closer); ok {
transport.CloseIdleConnections()
}
r.client.CloseIdleConnections()
}

// WithDelay is a request option that adds a delay before executing the request
Expand Down Expand Up @@ -254,3 +299,12 @@ func WithBasicAuth(username, password string) RequestOption {
r.Request.SetBasicAuth(username, password)
}
}

// newDefaultClient creates a new rest client without a base URL.
func newDefaultClient() Client {
c, err := NewClient("")
if err != nil {
panic(fmt.Sprintf("failed to create default client: %v", err))
}
return c
}
38 changes: 34 additions & 4 deletions rest/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,42 @@ import (
"time"

"github.com/jarcoal/httpmock"
"golang.org/x/time/rate"
)

type response struct {
ID int `json:"id"`
Name string `json:"name"`
}

func TestDefaultClient_Do(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
DefaultClient.(*client).client = http.DefaultClient

httpmock.RegisterResponder(http.MethodGet, "https://example.com/resource", httpmock.NewJsonResponderOrPanic(200, map[string]any{"id": 1, "name": "Resource"}))

type response struct {
ID int `json:"id"`
Name string `json:"name"`
}

endpoint := Get("https://example.com/resource")

resp, status, err := Do[response](context.Background(), endpoint, nil)
if err != nil {
t.Fatalf("Do() error = %v", err)
}

if status != http.StatusOK {
t.Errorf("Do() status = %v, want %v", status, http.StatusOK)
}

if resp.ID != 1 || resp.Name != "Resource" {
t.Errorf("Do() resp = %v, want %v", resp, response{ID: 1, Name: "Resource"})
}
}

func TestNewClient(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -199,7 +228,7 @@ func TestClient_Do(t *testing.T) { //nolint:gocyclo // Either complexity or dupl
c := &client{
baseURL: "https://example.com",
client: http.DefaultClient,
limiter: defaultRateLimiter,
limiter: rate.NewLimiter(maxRequestRate, maxRequestBurst),
}

for _, tt := range tests {
Expand Down Expand Up @@ -366,12 +395,13 @@ func TestClient_Client(t *testing.T) {
}

func TestClient_RateLimiter(t *testing.T) {
limiter := rate.NewLimiter(maxRequestRate, maxRequestBurst)
c := &client{
limiter: defaultRateLimiter,
limiter: limiter,
}

if c.RateLimiter() != defaultRateLimiter {
t.Errorf("RateLimiter() = %v, want %v", c.RateLimiter(), defaultRateLimiter)
if c.RateLimiter() != limiter {
t.Errorf("RateLimiter() = %v, want %v", c.RateLimiter(), limiter)
}
}

Expand Down

0 comments on commit 49baa5a

Please sign in to comment.