From 44a1e507eb8bc97ee5f557ee8cc6fd06a0182964 Mon Sep 17 00:00:00 2001
From: Mostafa Moradian <mostafa@gatewayd.io>
Date: Sun, 19 Nov 2023 02:02:27 +0100
Subject: [PATCH] Add tests for retry

---
 config/constants.go   |  2 +-
 network/retry.go      | 13 +++------
 network/retry_test.go | 62 +++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 67 insertions(+), 10 deletions(-)
 create mode 100644 network/retry_test.go

diff --git a/config/constants.go b/config/constants.go
index 64919f30..224885b9 100644
--- a/config/constants.go
+++ b/config/constants.go
@@ -106,7 +106,7 @@ const (
 	DefaultDialTimeout        = 60 * time.Second
 	DefaultRetries            = 3
 	DefaultBackoff            = 1 * time.Second
-	DefaultBackoffMultiplier  = 2
+	DefaultBackoffMultiplier  = 2.0
 	DefaultDisableBackoffCaps = false
 
 	// Pool constants.
diff --git a/network/retry.go b/network/retry.go
index 0bfc83ef..d7926e06 100644
--- a/network/retry.go
+++ b/network/retry.go
@@ -11,7 +11,6 @@ import (
 const (
 	BackoffMultiplierCap = 10
 	BackoffDurationCap   = time.Minute
-	DefaultBackoff       = 1 * time.Second
 )
 
 type IRetry interface {
@@ -40,10 +39,10 @@ func (r *Retry) DialTimeout(network, address string, timeout time.Duration) (net
 	if r == nil {
 		// Just dial the connection once.
 		if timeout == 0 {
-			return net.Dial(network, address)
-		} else {
-			return net.DialTimeout(network, address, timeout)
+			return net.Dial(network, address) //nolint: wrapcheck
 		}
+
+		return net.DialTimeout(network, address, timeout) //nolint: wrapcheck
 	}
 
 	for ; retry < r.Retries; retry++ {
@@ -103,7 +102,7 @@ func (r *Retry) DialTimeout(network, address string, timeout time.Duration) (net
 
 	r.logger.Error().Err(err).Msgf("Failed to connect after %d retries", retry)
 
-	return nil, err
+	return nil, err //nolint: wrapcheck
 }
 
 func NewRetry(
@@ -121,10 +120,6 @@ func NewRetry(
 		logger:             logger,
 	}
 
-	if retry.Backoff == 0 {
-		retry.Backoff = DefaultBackoff
-	}
-
 	if retry.Retries == 0 {
 		retry.Retries = 1
 	}
diff --git a/network/retry_test.go b/network/retry_test.go
new file mode 100644
index 00000000..dfebbf59
--- /dev/null
+++ b/network/retry_test.go
@@ -0,0 +1,62 @@
+package network
+
+import (
+	"context"
+	"testing"
+	"time"
+
+	"github.com/gatewayd-io/gatewayd/config"
+	"github.com/gatewayd-io/gatewayd/logging"
+	"github.com/rs/zerolog"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestRetry(t *testing.T) {
+	logger := logging.NewLogger(context.Background(), logging.LoggerConfig{
+		Output:            []config.LogOutput{config.Console},
+		TimeFormat:        zerolog.TimeFormatUnix,
+		ConsoleTimeFormat: time.RFC3339,
+		Level:             zerolog.DebugLevel,
+		NoColor:           true,
+	})
+
+	t.Run("DialTimeout", func(t *testing.T) {
+		t.Run("nil", func(t *testing.T) {
+			// Nil retry should just dial the connection once.
+			var retry *Retry
+			_, err := retry.DialTimeout("", "", 0)
+			assert.Error(t, err)
+			assert.ErrorContains(t, err, "dial: unknown network ")
+		})
+		t.Run("retry without timeout", func(t *testing.T) {
+			retry := NewRetry(0, 0, 0, false, logger)
+			assert.Equal(t, 1, retry.Retries)
+			assert.Equal(t, time.Duration(0), retry.Backoff)
+			assert.Equal(t, float64(0), retry.BackoffMultiplier)
+			assert.False(t, retry.DisableBackoffCaps)
+
+			conn, err := retry.DialTimeout("tcp", "localhost:5432", 0)
+			assert.NoError(t, err)
+			assert.NotNil(t, conn)
+			conn.Close()
+		})
+		t.Run("retry with timeout", func(t *testing.T) {
+			retry := NewRetry(
+				config.DefaultRetries,
+				config.DefaultBackoff,
+				config.DefaultBackoffMultiplier,
+				config.DefaultDisableBackoffCaps,
+				logger,
+			)
+			assert.Equal(t, config.DefaultRetries, retry.Retries)
+			assert.Equal(t, config.DefaultBackoff, retry.Backoff)
+			assert.Equal(t, config.DefaultBackoffMultiplier, retry.BackoffMultiplier)
+			assert.False(t, retry.DisableBackoffCaps)
+
+			conn, err := retry.DialTimeout("tcp", "localhost:5432", time.Second)
+			assert.NoError(t, err)
+			assert.NotNil(t, conn)
+			conn.Close()
+		})
+	})
+}