Skip to content

Commit

Permalink
Handle retryable errors in postgres e2e tests (#50895)
Browse files Browse the repository at this point in the history
This wraps the test pgx.Conn in a helper struct that adds
retries for retryable failures for all calls to Exec.
  • Loading branch information
GavinFrazar authored Jan 10, 2025
1 parent 585c40f commit 457dc0a
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 49 deletions.
146 changes: 112 additions & 34 deletions e2e/aws/databases_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,22 @@ import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
"os"
"strconv"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/secretsmanager"
mysqlclient "github.com/go-mysql-org/go-mysql/client"
"github.com/gravitational/trace"
"github.com/jackc/pgconn"
"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -41,6 +46,7 @@ import (
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/integration/helpers"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/cryptosuites"
Expand All @@ -50,6 +56,7 @@ import (
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/postgres"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
)

func TestDatabases(t *testing.T) {
Expand Down Expand Up @@ -140,29 +147,14 @@ func postgresConnTest(t *testing.T, cluster *helpers.TeleInstance, user string,
assert.NotNil(t, pgConn)
}, waitForConnTimeout, connRetryTick, "connecting to postgres")

// dont wait forever on the exec or close.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

// Execute a query.
results, err := pgConn.Exec(ctx, query).ReadAll()
require.NoError(t, err)
for i, r := range results {
require.NoError(t, r.Err, "error in result %v", i)
}

// Disconnect.
err = pgConn.Close(ctx)
require.NoError(t, err)
execPGTestQuery(t, pgConn, query)
}

// postgresLocalProxyConnTest tests connection to a postgres database via
// local proxy tunnel.
func postgresLocalProxyConnTest(t *testing.T, cluster *helpers.TeleInstance, user string, route tlsca.RouteToDatabase, query string) {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 2*waitForConnTimeout)
defer cancel()
lp := startLocalALPNProxy(t, ctx, user, cluster, route)
lp := startLocalALPNProxy(t, user, cluster, route)

pgconnConfig, err := pgconn.ParseConfig(fmt.Sprintf("postgres://%v/", lp.GetAddr()))
require.NoError(t, err)
Expand All @@ -180,30 +172,36 @@ func postgresLocalProxyConnTest(t *testing.T, cluster *helpers.TeleInstance, use
assert.NotNil(t, pgConn)
}, waitForConnTimeout, connRetryTick, "connecting to postgres")

// dont wait forever on the exec or close.
ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
execPGTestQuery(t, pgConn, query)
}

func execPGTestQuery(t *testing.T, conn *pgconn.PgConn, query string) {
t.Helper()
defer func() {
// dont wait forever to gracefully terminate.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Disconnect.
require.NoError(t, conn.Close(ctx))
}()

// dont wait forever on the exec.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

// Execute a query.
results, err := pgConn.Exec(ctx, query).ReadAll()
results, err := conn.Exec(ctx, query).ReadAll()
require.NoError(t, err)
for i, r := range results {
require.NoError(t, r.Err, "error in result %v", i)
}

// Disconnect.
err = pgConn.Close(ctx)
require.NoError(t, err)
}

// mysqlLocalProxyConnTest tests connection to a MySQL database via
// local proxy tunnel.
func mysqlLocalProxyConnTest(t *testing.T, cluster *helpers.TeleInstance, user string, route tlsca.RouteToDatabase, query string) {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 2*waitForConnTimeout)
defer cancel()

lp := startLocalALPNProxy(t, ctx, user, cluster, route)
lp := startLocalALPNProxy(t, user, cluster, route)

var conn *mysqlclient.Conn
// retry for a while, the database service might need time to give
Expand All @@ -223,19 +221,22 @@ func mysqlLocalProxyConnTest(t *testing.T, cluster *helpers.TeleInstance, user s
assert.NoError(t, err)
assert.NotNil(t, conn)
}, waitForConnTimeout, connRetryTick, "connecting to mysql")
defer func() {
// Disconnect.
require.NoError(t, conn.Close())
}()

// Execute a query.
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
_, err := conn.Execute(query)
require.NoError(t, err)

// Disconnect.
require.NoError(t, conn.Close())
}

// startLocalALPNProxy starts local ALPN proxy for the specified database.
func startLocalALPNProxy(t *testing.T, ctx context.Context, user string, cluster *helpers.TeleInstance, route tlsca.RouteToDatabase) *alpnproxy.LocalProxy {
func startLocalALPNProxy(t *testing.T, user string, cluster *helpers.TeleInstance, route tlsca.RouteToDatabase) *alpnproxy.LocalProxy {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
proto, err := alpncommon.ToALPNProtocol(route.Protocol)
require.NoError(t, err)

Expand Down Expand Up @@ -337,7 +338,7 @@ type dbUserLogin struct {
port int
}

func connectPostgres(t *testing.T, ctx context.Context, info dbUserLogin, dbName string) *pgx.Conn {
func connectPostgres(t *testing.T, ctx context.Context, info dbUserLogin, dbName string) *pgConn {
pgCfg, err := pgx.ParseConfig(fmt.Sprintf("postgres://%s:%d/?sslmode=verify-full", info.address, info.port))
require.NoError(t, err)
pgCfg.User = info.username
Expand All @@ -353,7 +354,10 @@ func connectPostgres(t *testing.T, ctx context.Context, info dbUserLogin, dbName
t.Cleanup(func() {
_ = conn.Close(ctx)
})
return conn
return &pgConn{
logger: utils.NewSlogLoggerForTests(),
Conn: conn,
}
}

// secretPassword is used to unmarshal an AWS Secrets Manager
Expand Down Expand Up @@ -395,3 +399,77 @@ func getSecretValue(t *testing.T, ctx context.Context, secretID string) secretsm
require.NotNil(t, secretVal)
return *secretVal
}

// pgConn wraps a [pgx.Conn] and adds retries to all Exec calls.
type pgConn struct {
logger *slog.Logger
*pgx.Conn
}

func (c *pgConn) Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) {
var out pgconn.CommandTag
err := withRetry(ctx, c.logger, func() error {
var err error
out, err = c.Conn.Exec(ctx, sql, args...)
return trace.Wrap(err)
})
return out, trace.Wrap(err)
}

// withRetry runs a given func a finite number of times until it returns nil
// error or the given context is done.
func withRetry(ctx context.Context, log *slog.Logger, f func() error) error {
linear, err := retryutils.NewLinear(retryutils.LinearConfig{
First: 0,
Step: 500 * time.Millisecond,
Max: 5 * time.Second,
Jitter: retryutils.HalfJitter,
})
if err != nil {
return trace.Wrap(err)
}

// retry a finite number of times before giving up.
const retries = 10
for i := 0; i < retries; i++ {
err := f()
if err == nil {
return nil
}

if isRetryable(err) {
log.DebugContext(ctx, "operation failed, retrying", "error", err)
} else {
return trace.Wrap(err)
}

linear.Inc()
select {
case <-linear.After():
case <-ctx.Done():
return trace.Wrap(ctx.Err())
}
}
return trace.Wrap(err, "too many retries")
}

// isRetryable returns true if an error can be retried.
func isRetryable(err error) bool {
var pgErr *pgconn.PgError
err = trace.Unwrap(err)
if errors.As(err, &pgErr) {
// https://www.postgresql.org/docs/current/mvcc-serialization-failure-handling.html
switch pgErr.Code {
case pgerrcode.DeadlockDetected, pgerrcode.SerializationFailure,
pgerrcode.UniqueViolation, pgerrcode.ExclusionViolation:
return true
}
}
// Redshift reports this with a vague SQLSTATE XX000, which is the internal
// error code, but this is a serialization error that rolls back the
// transaction, so it should be retried.
if strings.Contains(err.Error(), "conflict with concurrent transaction") {
return true
}
return pgconn.SafeToRetry(err)
}
4 changes: 0 additions & 4 deletions e2e/aws/fixtures_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,6 @@ func withDiscoveryService(t *testing.T, discoveryGroup string, awsMatchers ...ty
options.serviceConfigFuncs = append(options.serviceConfigFuncs, func(cfg *servicecfg.Config) {
cfg.Discovery.Enabled = true
cfg.Discovery.DiscoveryGroup = discoveryGroup
// Reduce the polling interval to speed up the test execution
// in the case of a failure of the first attempt.
// The default polling interval is 5 minutes.
cfg.Discovery.PollInterval = 1 * time.Minute
cfg.Discovery.AWSMatchers = append(cfg.Discovery.AWSMatchers, awsMatchers...)
})
}
Expand Down
9 changes: 4 additions & 5 deletions e2e/aws/rds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"github.com/aws/aws-sdk-go-v2/service/rds"
mysqlclient "github.com/go-mysql-org/go-mysql/client"
"github.com/go-mysql-org/go-mysql/mysql"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -440,7 +439,7 @@ func testRDS(t *testing.T) {
})
}

func connectAsRDSPostgresAdmin(t *testing.T, ctx context.Context, instanceID string) *pgx.Conn {
func connectAsRDSPostgresAdmin(t *testing.T, ctx context.Context, instanceID string) *pgConn {
t.Helper()
info := getRDSAdminInfo(t, ctx, instanceID)
const dbName = "postgres"
Expand Down Expand Up @@ -509,7 +508,7 @@ func getRDSAdminInfo(t *testing.T, ctx context.Context, instanceID string) dbUse

// provisionRDSPostgresAutoUsersAdmin provisions an admin user suitable for auto-user
// provisioning.
func provisionRDSPostgresAutoUsersAdmin(t *testing.T, ctx context.Context, conn *pgx.Conn, adminUser string) {
func provisionRDSPostgresAutoUsersAdmin(t *testing.T, ctx context.Context, conn *pgConn, adminUser string) {
t.Helper()
// Create the admin user and grant rds_iam so Teleport can auth
// with IAM as an existing user.
Expand Down Expand Up @@ -600,7 +599,7 @@ const (
autoUserWaitStep = 10 * time.Second
)

func waitForPostgresAutoUserDeactivate(t *testing.T, ctx context.Context, conn *pgx.Conn, user string) {
func waitForPostgresAutoUserDeactivate(t *testing.T, ctx context.Context, conn *pgConn, user string) {
t.Helper()
require.EventuallyWithT(t, func(c *assert.CollectT) {
// `Query` documents that it is always safe to attempt to read from the
Expand Down Expand Up @@ -641,7 +640,7 @@ func waitForPostgresAutoUserDeactivate(t *testing.T, ctx context.Context, conn *
}, autoUserWaitDur, autoUserWaitStep, "waiting for auto user %q to be deactivated", user)
}

func waitForPostgresAutoUserDrop(t *testing.T, ctx context.Context, conn *pgx.Conn, user string) {
func waitForPostgresAutoUserDrop(t *testing.T, ctx context.Context, conn *pgConn, user string) {
t.Helper()
require.EventuallyWithT(t, func(c *assert.CollectT) {
// `Query` documents that it is always safe to attempt to read from the
Expand Down
11 changes: 5 additions & 6 deletions e2e/aws/redshift_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/redshift"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -96,7 +95,7 @@ func testRedshiftCluster(t *testing.T) {
// eachother.
labels := db.GetStaticLabels()
labels[types.DatabaseAdminLabel] = "test_admin_" + randASCII(t, 6)
cluster.Process.GetAuthServer().UpdateDatabase(ctx, db)
err = cluster.Process.GetAuthServer().UpdateDatabase(ctx, db)
require.NoError(t, err)
adminUser := mustGetDBAdmin(t, db)

Expand Down Expand Up @@ -213,7 +212,7 @@ func testRedshiftCluster(t *testing.T) {
}
}

func connectAsRedshiftClusterAdmin(t *testing.T, ctx context.Context, clusterID string) *pgx.Conn {
func connectAsRedshiftClusterAdmin(t *testing.T, ctx context.Context, clusterID string) *pgConn {
t.Helper()
info := getRedshiftAdminInfo(t, ctx, clusterID)
const dbName = "dev"
Expand Down Expand Up @@ -247,7 +246,7 @@ func getRedshiftAdminInfo(t *testing.T, ctx context.Context, clusterID string) d

// provisionRedshiftAutoUsersAdmin provisions an admin user suitable for auto-user
// provisioning.
func provisionRedshiftAutoUsersAdmin(t *testing.T, ctx context.Context, conn *pgx.Conn, adminUser string) {
func provisionRedshiftAutoUsersAdmin(t *testing.T, ctx context.Context, conn *pgConn, adminUser string) {
t.Helper()
// Don't cleanup the db admin after, because test runs would interfere
// with each other.
Expand All @@ -261,7 +260,7 @@ func provisionRedshiftAutoUsersAdmin(t *testing.T, ctx context.Context, conn *pg
}
}

func waitForRedshiftAutoUserDeactivate(t *testing.T, ctx context.Context, conn *pgx.Conn, user string) {
func waitForRedshiftAutoUserDeactivate(t *testing.T, ctx context.Context, conn *pgConn, user string) {
t.Helper()
require.EventuallyWithT(t, func(c *assert.CollectT) {
// `Query` documents that it is always safe to attempt to read from the
Expand Down Expand Up @@ -300,7 +299,7 @@ func waitForRedshiftAutoUserDeactivate(t *testing.T, ctx context.Context, conn *
}, autoUserWaitDur, autoUserWaitStep, "waiting for auto user %q to be deactivated", user)
}

func waitForRedshiftAutoUserDrop(t *testing.T, ctx context.Context, conn *pgx.Conn, user string) {
func waitForRedshiftAutoUserDrop(t *testing.T, ctx context.Context, conn *pgConn, user string) {
t.Helper()
require.EventuallyWithT(t, func(c *assert.CollectT) {
// `Query` documents that it is always safe to attempt to read from the
Expand Down

0 comments on commit 457dc0a

Please sign in to comment.