Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v17] Handle retryable errors in postgres e2e tests #50895

Merged
merged 2 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 112 additions & 34 deletions e2e/aws/databases_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,22 @@ import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
"os"
"strconv"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/secretsmanager"
mysqlclient "github.com/go-mysql-org/go-mysql/client"
"github.com/gravitational/trace"
"github.com/jackc/pgconn"
"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -41,6 +46,7 @@ import (
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/integration/helpers"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/cryptosuites"
Expand All @@ -50,6 +56,7 @@ import (
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/postgres"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
)

func TestDatabases(t *testing.T) {
Expand Down Expand Up @@ -140,29 +147,14 @@ func postgresConnTest(t *testing.T, cluster *helpers.TeleInstance, user string,
assert.NotNil(t, pgConn)
}, waitForConnTimeout, connRetryTick, "connecting to postgres")

// dont wait forever on the exec or close.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

// Execute a query.
results, err := pgConn.Exec(ctx, query).ReadAll()
require.NoError(t, err)
for i, r := range results {
require.NoError(t, r.Err, "error in result %v", i)
}

// Disconnect.
err = pgConn.Close(ctx)
require.NoError(t, err)
execPGTestQuery(t, pgConn, query)
}

// postgresLocalProxyConnTest tests connection to a postgres database via
// local proxy tunnel.
func postgresLocalProxyConnTest(t *testing.T, cluster *helpers.TeleInstance, user string, route tlsca.RouteToDatabase, query string) {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 2*waitForConnTimeout)
defer cancel()
lp := startLocalALPNProxy(t, ctx, user, cluster, route)
lp := startLocalALPNProxy(t, user, cluster, route)

pgconnConfig, err := pgconn.ParseConfig(fmt.Sprintf("postgres://%v/", lp.GetAddr()))
require.NoError(t, err)
Expand All @@ -180,30 +172,36 @@ func postgresLocalProxyConnTest(t *testing.T, cluster *helpers.TeleInstance, use
assert.NotNil(t, pgConn)
}, waitForConnTimeout, connRetryTick, "connecting to postgres")

// dont wait forever on the exec or close.
ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
execPGTestQuery(t, pgConn, query)
}

func execPGTestQuery(t *testing.T, conn *pgconn.PgConn, query string) {
t.Helper()
defer func() {
// dont wait forever to gracefully terminate.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Disconnect.
require.NoError(t, conn.Close(ctx))
}()

// dont wait forever on the exec.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

// Execute a query.
results, err := pgConn.Exec(ctx, query).ReadAll()
results, err := conn.Exec(ctx, query).ReadAll()
require.NoError(t, err)
for i, r := range results {
require.NoError(t, r.Err, "error in result %v", i)
}

// Disconnect.
err = pgConn.Close(ctx)
require.NoError(t, err)
}

// mysqlLocalProxyConnTest tests connection to a MySQL database via
// local proxy tunnel.
func mysqlLocalProxyConnTest(t *testing.T, cluster *helpers.TeleInstance, user string, route tlsca.RouteToDatabase, query string) {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 2*waitForConnTimeout)
defer cancel()

lp := startLocalALPNProxy(t, ctx, user, cluster, route)
lp := startLocalALPNProxy(t, user, cluster, route)

var conn *mysqlclient.Conn
// retry for a while, the database service might need time to give
Expand All @@ -223,19 +221,22 @@ func mysqlLocalProxyConnTest(t *testing.T, cluster *helpers.TeleInstance, user s
assert.NoError(t, err)
assert.NotNil(t, conn)
}, waitForConnTimeout, connRetryTick, "connecting to mysql")
defer func() {
// Disconnect.
require.NoError(t, conn.Close())
}()

// Execute a query.
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
_, err := conn.Execute(query)
require.NoError(t, err)

// Disconnect.
require.NoError(t, conn.Close())
}

// startLocalALPNProxy starts local ALPN proxy for the specified database.
func startLocalALPNProxy(t *testing.T, ctx context.Context, user string, cluster *helpers.TeleInstance, route tlsca.RouteToDatabase) *alpnproxy.LocalProxy {
func startLocalALPNProxy(t *testing.T, user string, cluster *helpers.TeleInstance, route tlsca.RouteToDatabase) *alpnproxy.LocalProxy {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
proto, err := alpncommon.ToALPNProtocol(route.Protocol)
require.NoError(t, err)

Expand Down Expand Up @@ -337,7 +338,7 @@ type dbUserLogin struct {
port int
}

func connectPostgres(t *testing.T, ctx context.Context, info dbUserLogin, dbName string) *pgx.Conn {
func connectPostgres(t *testing.T, ctx context.Context, info dbUserLogin, dbName string) *pgConn {
pgCfg, err := pgx.ParseConfig(fmt.Sprintf("postgres://%s:%d/?sslmode=verify-full", info.address, info.port))
require.NoError(t, err)
pgCfg.User = info.username
Expand All @@ -353,7 +354,10 @@ func connectPostgres(t *testing.T, ctx context.Context, info dbUserLogin, dbName
t.Cleanup(func() {
_ = conn.Close(ctx)
})
return conn
return &pgConn{
logger: utils.NewSlogLoggerForTests(),
Conn: conn,
}
}

// secretPassword is used to unmarshal an AWS Secrets Manager
Expand Down Expand Up @@ -395,3 +399,77 @@ func getSecretValue(t *testing.T, ctx context.Context, secretID string) secretsm
require.NotNil(t, secretVal)
return *secretVal
}

// pgConn wraps a [pgx.Conn] and adds retries to all Exec calls.
type pgConn struct {
logger *slog.Logger
*pgx.Conn
}

func (c *pgConn) Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) {
var out pgconn.CommandTag
err := withRetry(ctx, c.logger, func() error {
var err error
out, err = c.Conn.Exec(ctx, sql, args...)
return trace.Wrap(err)
})
return out, trace.Wrap(err)
}

// withRetry runs a given func a finite number of times until it returns nil
// error or the given context is done.
func withRetry(ctx context.Context, log *slog.Logger, f func() error) error {
linear, err := retryutils.NewLinear(retryutils.LinearConfig{
First: 0,
Step: 500 * time.Millisecond,
Max: 5 * time.Second,
Jitter: retryutils.HalfJitter,
})
if err != nil {
return trace.Wrap(err)
}

// retry a finite number of times before giving up.
const retries = 10
for i := 0; i < retries; i++ {
err := f()
if err == nil {
return nil
}

if isRetryable(err) {
log.DebugContext(ctx, "operation failed, retrying", "error", err)
} else {
return trace.Wrap(err)
}

linear.Inc()
select {
case <-linear.After():
case <-ctx.Done():
return trace.Wrap(ctx.Err())
}
}
return trace.Wrap(err, "too many retries")
}

// isRetryable returns true if an error can be retried.
func isRetryable(err error) bool {
var pgErr *pgconn.PgError
err = trace.Unwrap(err)
if errors.As(err, &pgErr) {
// https://www.postgresql.org/docs/current/mvcc-serialization-failure-handling.html
switch pgErr.Code {
case pgerrcode.DeadlockDetected, pgerrcode.SerializationFailure,
pgerrcode.UniqueViolation, pgerrcode.ExclusionViolation:
return true
}
}
// Redshift reports this with a vague SQLSTATE XX000, which is the internal
// error code, but this is a serialization error that rolls back the
// transaction, so it should be retried.
if strings.Contains(err.Error(), "conflict with concurrent transaction") {
return true
}
return pgconn.SafeToRetry(err)
}
4 changes: 0 additions & 4 deletions e2e/aws/fixtures_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,6 @@ func withDiscoveryService(t *testing.T, discoveryGroup string, awsMatchers ...ty
options.serviceConfigFuncs = append(options.serviceConfigFuncs, func(cfg *servicecfg.Config) {
cfg.Discovery.Enabled = true
cfg.Discovery.DiscoveryGroup = discoveryGroup
// Reduce the polling interval to speed up the test execution
// in the case of a failure of the first attempt.
// The default polling interval is 5 minutes.
cfg.Discovery.PollInterval = 1 * time.Minute
cfg.Discovery.AWSMatchers = append(cfg.Discovery.AWSMatchers, awsMatchers...)
})
}
Expand Down
9 changes: 4 additions & 5 deletions e2e/aws/rds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"github.com/aws/aws-sdk-go-v2/service/rds"
mysqlclient "github.com/go-mysql-org/go-mysql/client"
"github.com/go-mysql-org/go-mysql/mysql"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -440,7 +439,7 @@ func testRDS(t *testing.T) {
})
}

func connectAsRDSPostgresAdmin(t *testing.T, ctx context.Context, instanceID string) *pgx.Conn {
func connectAsRDSPostgresAdmin(t *testing.T, ctx context.Context, instanceID string) *pgConn {
t.Helper()
info := getRDSAdminInfo(t, ctx, instanceID)
const dbName = "postgres"
Expand Down Expand Up @@ -509,7 +508,7 @@ func getRDSAdminInfo(t *testing.T, ctx context.Context, instanceID string) dbUse

// provisionRDSPostgresAutoUsersAdmin provisions an admin user suitable for auto-user
// provisioning.
func provisionRDSPostgresAutoUsersAdmin(t *testing.T, ctx context.Context, conn *pgx.Conn, adminUser string) {
func provisionRDSPostgresAutoUsersAdmin(t *testing.T, ctx context.Context, conn *pgConn, adminUser string) {
t.Helper()
// Create the admin user and grant rds_iam so Teleport can auth
// with IAM as an existing user.
Expand Down Expand Up @@ -600,7 +599,7 @@ const (
autoUserWaitStep = 10 * time.Second
)

func waitForPostgresAutoUserDeactivate(t *testing.T, ctx context.Context, conn *pgx.Conn, user string) {
func waitForPostgresAutoUserDeactivate(t *testing.T, ctx context.Context, conn *pgConn, user string) {
t.Helper()
require.EventuallyWithT(t, func(c *assert.CollectT) {
// `Query` documents that it is always safe to attempt to read from the
Expand Down Expand Up @@ -641,7 +640,7 @@ func waitForPostgresAutoUserDeactivate(t *testing.T, ctx context.Context, conn *
}, autoUserWaitDur, autoUserWaitStep, "waiting for auto user %q to be deactivated", user)
}

func waitForPostgresAutoUserDrop(t *testing.T, ctx context.Context, conn *pgx.Conn, user string) {
func waitForPostgresAutoUserDrop(t *testing.T, ctx context.Context, conn *pgConn, user string) {
t.Helper()
require.EventuallyWithT(t, func(c *assert.CollectT) {
// `Query` documents that it is always safe to attempt to read from the
Expand Down
11 changes: 5 additions & 6 deletions e2e/aws/redshift_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/redshift"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -96,7 +95,7 @@ func testRedshiftCluster(t *testing.T) {
// eachother.
labels := db.GetStaticLabels()
labels[types.DatabaseAdminLabel] = "test_admin_" + randASCII(t, 6)
cluster.Process.GetAuthServer().UpdateDatabase(ctx, db)
err = cluster.Process.GetAuthServer().UpdateDatabase(ctx, db)
require.NoError(t, err)
adminUser := mustGetDBAdmin(t, db)

Expand Down Expand Up @@ -213,7 +212,7 @@ func testRedshiftCluster(t *testing.T) {
}
}

func connectAsRedshiftClusterAdmin(t *testing.T, ctx context.Context, clusterID string) *pgx.Conn {
func connectAsRedshiftClusterAdmin(t *testing.T, ctx context.Context, clusterID string) *pgConn {
t.Helper()
info := getRedshiftAdminInfo(t, ctx, clusterID)
const dbName = "dev"
Expand Down Expand Up @@ -247,7 +246,7 @@ func getRedshiftAdminInfo(t *testing.T, ctx context.Context, clusterID string) d

// provisionRedshiftAutoUsersAdmin provisions an admin user suitable for auto-user
// provisioning.
func provisionRedshiftAutoUsersAdmin(t *testing.T, ctx context.Context, conn *pgx.Conn, adminUser string) {
func provisionRedshiftAutoUsersAdmin(t *testing.T, ctx context.Context, conn *pgConn, adminUser string) {
t.Helper()
// Don't cleanup the db admin after, because test runs would interfere
// with each other.
Expand All @@ -261,7 +260,7 @@ func provisionRedshiftAutoUsersAdmin(t *testing.T, ctx context.Context, conn *pg
}
}

func waitForRedshiftAutoUserDeactivate(t *testing.T, ctx context.Context, conn *pgx.Conn, user string) {
func waitForRedshiftAutoUserDeactivate(t *testing.T, ctx context.Context, conn *pgConn, user string) {
t.Helper()
require.EventuallyWithT(t, func(c *assert.CollectT) {
// `Query` documents that it is always safe to attempt to read from the
Expand Down Expand Up @@ -300,7 +299,7 @@ func waitForRedshiftAutoUserDeactivate(t *testing.T, ctx context.Context, conn *
}, autoUserWaitDur, autoUserWaitStep, "waiting for auto user %q to be deactivated", user)
}

func waitForRedshiftAutoUserDrop(t *testing.T, ctx context.Context, conn *pgx.Conn, user string) {
func waitForRedshiftAutoUserDrop(t *testing.T, ctx context.Context, conn *pgConn, user string) {
t.Helper()
require.EventuallyWithT(t, func(c *assert.CollectT) {
// `Query` documents that it is always safe to attempt to read from the
Expand Down
Loading