Skip to content

Commit

Permalink
Remove batchStatementBuilder
Browse files Browse the repository at this point in the history
Use the template instead.
  • Loading branch information
andrew-farries committed Jan 28, 2025
1 parent 25e2983 commit 6fcc687
Showing 1 changed file with 22 additions and 93 deletions.
115 changes: 22 additions & 93 deletions pkg/backfill/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@ import (
"database/sql"
"errors"
"fmt"
"strings"
"time"

"github.com/lib/pq"

"github.com/xataio/pgroll/pkg/backfill/templates"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
Expand Down Expand Up @@ -59,7 +57,13 @@ func (bf *Backfill) Start(ctx context.Context, table *schema.Table) error {
}

// Create a batcher for the table.
b := newBatcher(table, bf.batchSize)
b := batcher{
BatchConfig: templates.BatchConfig{
TableName: table.Name,
PrimaryKey: identityColumns,
BatchSize: bf.batchSize,
},
}

// Update each batch of rows, invoking callbacks for each one.
for batch := 0; ; batch++ {
Expand Down Expand Up @@ -158,109 +162,34 @@ func getIdentityColumns(table *schema.Table) []string {
return nil
}

// A batcher is responsible for updating a batch of rows in a table.
// It holds the state necessary to update the next batch of rows.
type batcher struct {
statementBuilder *batchStatementBuilder
lastValues []string
}

func newBatcher(table *schema.Table, batchSize int) *batcher {
return &batcher{
statementBuilder: newBatchStatementBuilder(table.Name, getIdentityColumns(table), batchSize),
lastValues: make([]string, len(getIdentityColumns(table))),
}
templates.BatchConfig
}

func (b *batcher) updateBatch(ctx context.Context, conn db.DB) error {
return conn.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
// Build the query to update the next batch of rows
query := b.statementBuilder.buildQuery(b.lastValues)
sql, err := templates.BuildSQL(b.BatchConfig)
if err != nil {
return err
}

// Execute the query to update the next batch of rows and update the last PK
// value for the next batch
wrapper := make([]any, len(b.lastValues))
for i := range b.lastValues {
wrapper[i] = &b.lastValues[i]
if b.LastValue == nil {
b.LastValue = make([]string, len(b.PrimaryKey))
}
wrapper := make([]any, len(b.LastValue))
for i := range b.LastValue {
wrapper[i] = &b.LastValue[i]
}
err := tx.QueryRowContext(ctx, query).Scan(wrapper...)
err = tx.QueryRowContext(ctx, sql).Scan(wrapper...)
if err != nil {
return err
}

return nil
})
}

type batchStatementBuilder struct {
tableName string
identityColumns []string
batchSize int
}

func newBatchStatementBuilder(tableName string, identityColumnNames []string, batchSize int) *batchStatementBuilder {
quotedCols := make([]string, len(identityColumnNames))
for i, col := range identityColumnNames {
quotedCols[i] = pq.QuoteIdentifier(col)
}
return &batchStatementBuilder{
tableName: pq.QuoteIdentifier(tableName),
identityColumns: quotedCols,
batchSize: batchSize,
}
}

// buildQuery builds the query used to update the next batch of rows.
func (sb *batchStatementBuilder) buildQuery(lastValues []string) string {
return fmt.Sprintf("WITH batch AS (%[1]s), update AS (%[2]s) %[3]s",
sb.buildBatchSubQuery(lastValues),
sb.buildUpdateBatchSubQuery(),
sb.buildLastValueQuery())
}

// fetch the next batch of PK of rows to update
func (sb *batchStatementBuilder) buildBatchSubQuery(lastValues []string) string {
whereClause := ""
if len(lastValues) != 0 && lastValues[0] != "" {
whereClause = fmt.Sprintf("WHERE (%s) > (%s)",
strings.Join(sb.identityColumns, ", "), strings.Join(quoteLiteralList(lastValues), ", "))
}

return fmt.Sprintf("SELECT %[1]s FROM %[2]s %[3]s ORDER BY %[1]s LIMIT %[4]d FOR NO KEY UPDATE",
strings.Join(sb.identityColumns, ", "), sb.tableName, whereClause, sb.batchSize)
}

func quoteLiteralList(l []string) []string {
quoted := make([]string, len(l))
for i, v := range l {
quoted[i] = pq.QuoteLiteral(v)
}
return quoted
}

// update the rows in the batch
func (sb *batchStatementBuilder) buildUpdateBatchSubQuery() string {
conditions := make([]string, len(sb.identityColumns))
for i, col := range sb.identityColumns {
conditions[i] = fmt.Sprintf("%[1]s.%[2]s = batch.%[2]s", sb.tableName, col)
}
updateWhereClause := "WHERE " + strings.Join(conditions, " AND ")

setStmt := fmt.Sprintf("%[1]s = %[2]s.%[1]s", sb.identityColumns[0], sb.tableName)
for i := 1; i < len(sb.identityColumns); i++ {
setStmt += fmt.Sprintf(", %[1]s = %[2]s.%[1]s", sb.identityColumns[i], sb.tableName)
}
updateReturning := sb.tableName + "." + sb.identityColumns[0]
for i := 1; i < len(sb.identityColumns); i++ {
updateReturning += ", " + sb.tableName + "." + sb.identityColumns[i]
}
return fmt.Sprintf("UPDATE %[1]s SET %[2]s FROM batch %[3]s RETURNING %[4]s",
sb.tableName, setStmt, updateWhereClause, updateReturning)
}

// fetch the last values of the PK column
func (sb *batchStatementBuilder) buildLastValueQuery() string {
lastValues := make([]string, len(sb.identityColumns))
for i, col := range sb.identityColumns {
lastValues[i] = "LAST_VALUE(" + col + ") OVER()"
}
return fmt.Sprintf("SELECT %[1]s FROM update", strings.Join(lastValues, ", "))
}

0 comments on commit 6fcc687

Please sign in to comment.