diff --git a/pkg/backfill/backfill.go b/pkg/backfill/backfill.go index 41e86d2d..ea9923d0 100644 --- a/pkg/backfill/backfill.go +++ b/pkg/backfill/backfill.go @@ -7,11 +7,9 @@ import ( "database/sql" "errors" "fmt" - "strings" "time" - "github.com/lib/pq" - + "github.com/xataio/pgroll/pkg/backfill/templates" "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -59,7 +57,13 @@ func (bf *Backfill) Start(ctx context.Context, table *schema.Table) error { } // Create a batcher for the table. - b := newBatcher(table, bf.batchSize) + b := batcher{ + BatchConfig: templates.BatchConfig{ + TableName: table.Name, + PrimaryKey: identityColumns, + BatchSize: bf.batchSize, + }, + } // Update each batch of rows, invoking callbacks for each one. for batch := 0; ; batch++ { @@ -158,30 +162,30 @@ func getIdentityColumns(table *schema.Table) []string { return nil } +// A batcher is responsible for updating a batch of rows in a table. +// It holds the state necessary to update the next batch of rows. type batcher struct { - statementBuilder *batchStatementBuilder - lastValues []string -} - -func newBatcher(table *schema.Table, batchSize int) *batcher { - return &batcher{ - statementBuilder: newBatchStatementBuilder(table.Name, getIdentityColumns(table), batchSize), - lastValues: make([]string, len(getIdentityColumns(table))), - } + templates.BatchConfig } func (b *batcher) updateBatch(ctx context.Context, conn db.DB) error { return conn.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error { // Build the query to update the next batch of rows - query := b.statementBuilder.buildQuery(b.lastValues) + sql, err := templates.BuildSQL(b.BatchConfig) + if err != nil { + return err + } // Execute the query to update the next batch of rows and update the last PK // value for the next batch - wrapper := make([]any, len(b.lastValues)) - for i := range b.lastValues { - wrapper[i] = &b.lastValues[i] + if b.LastValue == nil { + b.LastValue = make([]string, len(b.PrimaryKey)) + } + wrapper := make([]any, len(b.LastValue)) + for i := range b.LastValue { + wrapper[i] = &b.LastValue[i] } - err := tx.QueryRowContext(ctx, query).Scan(wrapper...) + err = tx.QueryRowContext(ctx, sql).Scan(wrapper...) if err != nil { return err } @@ -189,78 +193,3 @@ func (b *batcher) updateBatch(ctx context.Context, conn db.DB) error { return nil }) } - -type batchStatementBuilder struct { - tableName string - identityColumns []string - batchSize int -} - -func newBatchStatementBuilder(tableName string, identityColumnNames []string, batchSize int) *batchStatementBuilder { - quotedCols := make([]string, len(identityColumnNames)) - for i, col := range identityColumnNames { - quotedCols[i] = pq.QuoteIdentifier(col) - } - return &batchStatementBuilder{ - tableName: pq.QuoteIdentifier(tableName), - identityColumns: quotedCols, - batchSize: batchSize, - } -} - -// buildQuery builds the query used to update the next batch of rows. -func (sb *batchStatementBuilder) buildQuery(lastValues []string) string { - return fmt.Sprintf("WITH batch AS (%[1]s), update AS (%[2]s) %[3]s", - sb.buildBatchSubQuery(lastValues), - sb.buildUpdateBatchSubQuery(), - sb.buildLastValueQuery()) -} - -// fetch the next batch of PK of rows to update -func (sb *batchStatementBuilder) buildBatchSubQuery(lastValues []string) string { - whereClause := "" - if len(lastValues) != 0 && lastValues[0] != "" { - whereClause = fmt.Sprintf("WHERE (%s) > (%s)", - strings.Join(sb.identityColumns, ", "), strings.Join(quoteLiteralList(lastValues), ", ")) - } - - return fmt.Sprintf("SELECT %[1]s FROM %[2]s %[3]s ORDER BY %[1]s LIMIT %[4]d FOR NO KEY UPDATE", - strings.Join(sb.identityColumns, ", "), sb.tableName, whereClause, sb.batchSize) -} - -func quoteLiteralList(l []string) []string { - quoted := make([]string, len(l)) - for i, v := range l { - quoted[i] = pq.QuoteLiteral(v) - } - return quoted -} - -// update the rows in the batch -func (sb *batchStatementBuilder) buildUpdateBatchSubQuery() string { - conditions := make([]string, len(sb.identityColumns)) - for i, col := range sb.identityColumns { - conditions[i] = fmt.Sprintf("%[1]s.%[2]s = batch.%[2]s", sb.tableName, col) - } - updateWhereClause := "WHERE " + strings.Join(conditions, " AND ") - - setStmt := fmt.Sprintf("%[1]s = %[2]s.%[1]s", sb.identityColumns[0], sb.tableName) - for i := 1; i < len(sb.identityColumns); i++ { - setStmt += fmt.Sprintf(", %[1]s = %[2]s.%[1]s", sb.identityColumns[i], sb.tableName) - } - updateReturning := sb.tableName + "." + sb.identityColumns[0] - for i := 1; i < len(sb.identityColumns); i++ { - updateReturning += ", " + sb.tableName + "." + sb.identityColumns[i] - } - return fmt.Sprintf("UPDATE %[1]s SET %[2]s FROM batch %[3]s RETURNING %[4]s", - sb.tableName, setStmt, updateWhereClause, updateReturning) -} - -// fetch the last values of the PK column -func (sb *batchStatementBuilder) buildLastValueQuery() string { - lastValues := make([]string, len(sb.identityColumns)) - for i, col := range sb.identityColumns { - lastValues[i] = "LAST_VALUE(" + col + ") OVER()" - } - return fmt.Sprintf("SELECT %[1]s FROM update", strings.Join(lastValues, ", ")) -}