diff --git a/network/client.go b/network/client.go index 4d1b91ab..3f07b10b 100644 --- a/network/client.go +++ b/network/client.go @@ -89,12 +89,24 @@ func NewClient( } } - var ( - conn net.Conn - origErr error - ) + var origErr error // Create a new connection and retry a few times if needed. - conn, origErr = client.retry.DialTimeout(client.Network, client.Address, client.DialTimeout) + //nolint:wrapcheck + if conn, err := client.retry.Retry(func() (any, error) { + if client.DialTimeout > 0 { + return net.DialTimeout(client.Network, client.Address, client.DialTimeout) + } else { + return net.Dial(client.Network, client.Address) + } + }); err != nil { + origErr = err + } else { + if netConn, ok := conn.(net.Conn); ok { + client.conn = netConn + } else { + origErr = fmt.Errorf("unexpected connection type: %T", conn) + } + } if origErr != nil { err := gerr.ErrClientConnectionFailed.Wrap(origErr) logger.Error().Err(err).Msg("Failed to create a new connection") @@ -102,7 +114,6 @@ func NewClient( return nil } - client.conn = conn client.connected.Store(true) // Set the TCP keep alive. @@ -151,7 +162,11 @@ func NewClient( logger.Trace().Str("address", client.Address).Msg("New client created") client.ID = GetID( - conn.LocalAddr().Network(), conn.LocalAddr().String(), config.DefaultSeed, logger) + client.conn.LocalAddr().Network(), + client.conn.LocalAddr().String(), + config.DefaultSeed, + logger, + ) metrics.ServerConnections.Inc() @@ -260,21 +275,36 @@ func (c *Client) Reconnect() error { c.Address = address c.Network = network - var ( - conn net.Conn - err error - ) + var origErr error // Create a new connection and retry a few times if needed. - conn, err = c.retry.DialTimeout(c.Network, c.Address, c.DialTimeout) - if err != nil { - c.logger.Error().Err(err).Msg("Failed to reconnect") - span.RecordError(err) - return gerr.ErrClientConnectionFailed.Wrap(err) + //nolint:wrapcheck + if conn, err := c.retry.Retry(func() (any, error) { + if c.DialTimeout > 0 { + return net.DialTimeout(c.Network, c.Address, c.DialTimeout) + } else { + return net.Dial(c.Network, c.Address) + } + }); err != nil { + origErr = err + } else { + if netConn, ok := conn.(net.Conn); ok { + c.conn = netConn + } else { + origErr = fmt.Errorf("unexpected connection type: %T", conn) + } + } + if origErr != nil { + c.logger.Error().Err(origErr).Msg("Failed to reconnect") + span.RecordError(origErr) + return gerr.ErrClientConnectionFailed.Wrap(origErr) } - c.conn = conn c.ID = GetID( - conn.LocalAddr().Network(), conn.LocalAddr().String(), config.DefaultSeed, c.logger) + c.conn.LocalAddr().Network(), + c.conn.LocalAddr().String(), + config.DefaultSeed, + c.logger, + ) c.connected.Store(true) c.logger.Debug().Str("address", c.Address).Msg("Reconnected to server") metrics.ServerConnections.Inc() diff --git a/network/retry.go b/network/retry.go index f83b3210..3f720df3 100644 --- a/network/retry.go +++ b/network/retry.go @@ -1,8 +1,8 @@ package network import ( + "errors" "math" - "net" "time" "github.com/rs/zerolog" @@ -13,8 +13,10 @@ const ( BackoffDurationCap = time.Minute ) +type RetryCallback func() (any, error) + type IRetry interface { - DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) + Retry(_ RetryCallback) (any, error) } type Retry struct { @@ -27,22 +29,21 @@ type Retry struct { var _ IRetry = (*Retry)(nil) -// DialTimeout dials a connection with a timeout, retrying if it fails. +// Retry runs the callback function and retries it if it fails. // It'll wait for the duration of the backoff between retries. -func (r *Retry) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { +func (r *Retry) Retry(callback RetryCallback) (any, error) { var ( - conn net.Conn - err error - retry int + object any + err error + retry int ) - if r == nil { - // Just dial the connection once. - if timeout == 0 { - return net.Dial(network, address) //nolint: wrapcheck - } + if callback == nil { + return nil, errors.New("callback is nil") + } - return net.DialTimeout(network, address, timeout) //nolint: wrapcheck + if r == nil && callback != nil { + return callback() } // The first attempt counts as a retry. @@ -80,30 +81,23 @@ func (r *Retry) DialTimeout(network, address string, timeout time.Duration) (net "retry": retry, "delay": backoffDuration.String(), }, - ).Msg("Trying to connect") + ).Msg("Trying to run callback again") } else { - r.logger.Trace().Msg("Trying to connect for the first time") + r.logger.Trace().Msg("First attempt to run callback") } - // Dial the connection with a timeout if one is provided, otherwise dial the - // connection without a timeout. Dialing without a timeout will block - // indefinitely. - if timeout > 0 { - conn, err = net.DialTimeout(network, address, timeout) - } else { - conn, err = net.Dial(network, address) - } - // If the connection was successful, return it. + // Try and retry the callback. + object, err = callback() if err == nil { - return conn, nil + return object, nil } time.Sleep(backoffDuration) } - r.logger.Error().Err(err).Msgf("Failed to connect after %d retries", retry) + r.logger.Error().Err(err).Msgf("Failed to run callback after %d retries", retry) - return nil, err //nolint: wrapcheck + return nil, err } func NewRetry( diff --git a/network/retry_test.go b/network/retry_test.go index dfebbf59..a3889042 100644 --- a/network/retry_test.go +++ b/network/retry_test.go @@ -2,6 +2,7 @@ package network import ( "context" + "net" "testing" "time" @@ -24,9 +25,9 @@ func TestRetry(t *testing.T) { t.Run("nil", func(t *testing.T) { // Nil retry should just dial the connection once. var retry *Retry - _, err := retry.DialTimeout("", "", 0) + _, err := retry.Retry(nil) assert.Error(t, err) - assert.ErrorContains(t, err, "dial: unknown network ") + assert.ErrorContains(t, err, "callback is nil") }) t.Run("retry without timeout", func(t *testing.T) { retry := NewRetry(0, 0, 0, false, logger) @@ -35,10 +36,17 @@ func TestRetry(t *testing.T) { assert.Equal(t, float64(0), retry.BackoffMultiplier) assert.False(t, retry.DisableBackoffCaps) - conn, err := retry.DialTimeout("tcp", "localhost:5432", 0) + conn, err := retry.Retry(func() (any, error) { + return net.Dial("tcp", "localhost:5432") //nolint: wrapcheck + }) assert.NoError(t, err) assert.NotNil(t, conn) - conn.Close() + assert.IsType(t, &net.TCPConn{}, conn) + if tcpConn, ok := conn.(*net.TCPConn); ok { + tcpConn.Close() + } else { + t.Errorf("Unexpected connection type: %T", conn) + } }) t.Run("retry with timeout", func(t *testing.T) { retry := NewRetry( @@ -53,10 +61,17 @@ func TestRetry(t *testing.T) { assert.Equal(t, config.DefaultBackoffMultiplier, retry.BackoffMultiplier) assert.False(t, retry.DisableBackoffCaps) - conn, err := retry.DialTimeout("tcp", "localhost:5432", time.Second) + conn, err := retry.Retry(func() (any, error) { + return net.DialTimeout("tcp", "localhost:5432", config.DefaultDialTimeout) //nolint: wrapcheck + }) assert.NoError(t, err) assert.NotNil(t, conn) - conn.Close() + assert.IsType(t, &net.TCPConn{}, conn) + if tcpConn, ok := conn.(*net.TCPConn); ok { + tcpConn.Close() + } else { + t.Errorf("Unexpected connection type: %T", conn) + } }) }) }