Skip to content

Commit

Permalink
Make Retry abstract
Browse files Browse the repository at this point in the history
  • Loading branch information
mostafa committed Nov 19, 2023
1 parent 48ed58c commit 17660a4
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 51 deletions.
66 changes: 48 additions & 18 deletions network/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,31 @@ func NewClient(
}
}

var (
conn net.Conn
origErr error
)
var origErr error
// Create a new connection and retry a few times if needed.
conn, origErr = client.retry.DialTimeout(client.Network, client.Address, client.DialTimeout)
//nolint:wrapcheck
if conn, err := client.retry.Retry(func() (any, error) {
if client.DialTimeout > 0 {
return net.DialTimeout(client.Network, client.Address, client.DialTimeout)
} else {
return net.Dial(client.Network, client.Address)
}
}); err != nil {
origErr = err
} else {
if netConn, ok := conn.(net.Conn); ok {
client.conn = netConn
} else {
origErr = fmt.Errorf("unexpected connection type: %T", conn)
}
}
if origErr != nil {
err := gerr.ErrClientConnectionFailed.Wrap(origErr)
logger.Error().Err(err).Msg("Failed to create a new connection")
span.RecordError(err)
return nil
}

client.conn = conn
client.connected.Store(true)

// Set the TCP keep alive.
Expand Down Expand Up @@ -151,7 +162,11 @@ func NewClient(

logger.Trace().Str("address", client.Address).Msg("New client created")
client.ID = GetID(
conn.LocalAddr().Network(), conn.LocalAddr().String(), config.DefaultSeed, logger)
client.conn.LocalAddr().Network(),
client.conn.LocalAddr().String(),
config.DefaultSeed,
logger,
)

metrics.ServerConnections.Inc()

Expand Down Expand Up @@ -260,21 +275,36 @@ func (c *Client) Reconnect() error {
c.Address = address
c.Network = network

var (
conn net.Conn
err error
)
var origErr error
// Create a new connection and retry a few times if needed.
conn, err = c.retry.DialTimeout(c.Network, c.Address, c.DialTimeout)
if err != nil {
c.logger.Error().Err(err).Msg("Failed to reconnect")
span.RecordError(err)
return gerr.ErrClientConnectionFailed.Wrap(err)
//nolint:wrapcheck
if conn, err := c.retry.Retry(func() (any, error) {
if c.DialTimeout > 0 {
return net.DialTimeout(c.Network, c.Address, c.DialTimeout)
} else {
return net.Dial(c.Network, c.Address)
}
}); err != nil {
origErr = err
} else {
if netConn, ok := conn.(net.Conn); ok {
c.conn = netConn
} else {
origErr = fmt.Errorf("unexpected connection type: %T", conn)
}
}
if origErr != nil {
c.logger.Error().Err(origErr).Msg("Failed to reconnect")
span.RecordError(origErr)
return gerr.ErrClientConnectionFailed.Wrap(origErr)
}

c.conn = conn
c.ID = GetID(
conn.LocalAddr().Network(), conn.LocalAddr().String(), config.DefaultSeed, c.logger)
c.conn.LocalAddr().Network(),
c.conn.LocalAddr().String(),
config.DefaultSeed,
c.logger,
)
c.connected.Store(true)
c.logger.Debug().Str("address", c.Address).Msg("Reconnected to server")
metrics.ServerConnections.Inc()
Expand Down
48 changes: 21 additions & 27 deletions network/retry.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package network

import (
"errors"
"math"
"net"
"time"

"github.com/rs/zerolog"
Expand All @@ -13,8 +13,10 @@ const (
BackoffDurationCap = time.Minute
)

type RetryCallback func() (any, error)

type IRetry interface {
DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
Retry(_ RetryCallback) (any, error)
}

type Retry struct {
Expand All @@ -27,22 +29,21 @@ type Retry struct {

var _ IRetry = (*Retry)(nil)

// DialTimeout dials a connection with a timeout, retrying if it fails.
// Retry runs the callback function and retries it if it fails.
// It'll wait for the duration of the backoff between retries.
func (r *Retry) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
func (r *Retry) Retry(callback RetryCallback) (any, error) {
var (
conn net.Conn
err error
retry int
object any
err error
retry int
)

if r == nil {
// Just dial the connection once.
if timeout == 0 {
return net.Dial(network, address) //nolint: wrapcheck
}
if callback == nil {
return nil, errors.New("callback is nil")
}

return net.DialTimeout(network, address, timeout) //nolint: wrapcheck
if r == nil && callback != nil {
return callback()
}

// The first attempt counts as a retry.
Expand Down Expand Up @@ -80,30 +81,23 @@ func (r *Retry) DialTimeout(network, address string, timeout time.Duration) (net
"retry": retry,
"delay": backoffDuration.String(),
},
).Msg("Trying to connect")
).Msg("Trying to run callback again")
} else {
r.logger.Trace().Msg("Trying to connect for the first time")
r.logger.Trace().Msg("First attempt to run callback")
}

// Dial the connection with a timeout if one is provided, otherwise dial the
// connection without a timeout. Dialing without a timeout will block
// indefinitely.
if timeout > 0 {
conn, err = net.DialTimeout(network, address, timeout)
} else {
conn, err = net.Dial(network, address)
}
// If the connection was successful, return it.
// Try and retry the callback.
object, err = callback()
if err == nil {
return conn, nil
return object, nil
}

time.Sleep(backoffDuration)
}

r.logger.Error().Err(err).Msgf("Failed to connect after %d retries", retry)
r.logger.Error().Err(err).Msgf("Failed to run callback after %d retries", retry)

return nil, err //nolint: wrapcheck
return nil, err
}

func NewRetry(
Expand Down
27 changes: 21 additions & 6 deletions network/retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package network

import (
"context"
"net"
"testing"
"time"

Expand All @@ -24,9 +25,9 @@ func TestRetry(t *testing.T) {
t.Run("nil", func(t *testing.T) {
// Nil retry should just dial the connection once.
var retry *Retry
_, err := retry.DialTimeout("", "", 0)
_, err := retry.Retry(nil)
assert.Error(t, err)
assert.ErrorContains(t, err, "dial: unknown network ")
assert.ErrorContains(t, err, "callback is nil")
})
t.Run("retry without timeout", func(t *testing.T) {
retry := NewRetry(0, 0, 0, false, logger)
Expand All @@ -35,10 +36,17 @@ func TestRetry(t *testing.T) {
assert.Equal(t, float64(0), retry.BackoffMultiplier)
assert.False(t, retry.DisableBackoffCaps)

conn, err := retry.DialTimeout("tcp", "localhost:5432", 0)
conn, err := retry.Retry(func() (any, error) {
return net.Dial("tcp", "localhost:5432") //nolint: wrapcheck
})
assert.NoError(t, err)
assert.NotNil(t, conn)
conn.Close()
assert.IsType(t, &net.TCPConn{}, conn)
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.Close()
} else {
t.Errorf("Unexpected connection type: %T", conn)
}
})
t.Run("retry with timeout", func(t *testing.T) {
retry := NewRetry(
Expand All @@ -53,10 +61,17 @@ func TestRetry(t *testing.T) {
assert.Equal(t, config.DefaultBackoffMultiplier, retry.BackoffMultiplier)
assert.False(t, retry.DisableBackoffCaps)

conn, err := retry.DialTimeout("tcp", "localhost:5432", time.Second)
conn, err := retry.Retry(func() (any, error) {
return net.DialTimeout("tcp", "localhost:5432", config.DefaultDialTimeout) //nolint: wrapcheck
})
assert.NoError(t, err)
assert.NotNil(t, conn)
conn.Close()
assert.IsType(t, &net.TCPConn{}, conn)
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.Close()
} else {
t.Errorf("Unexpected connection type: %T", conn)
}
})
})
}

0 comments on commit 17660a4

Please sign in to comment.