diff --git a/pkg/backfill/backfill_test.go b/pkg/backfill/backfill_test.go deleted file mode 100644 index 2fd30457..00000000 --- a/pkg/backfill/backfill_test.go +++ /dev/null @@ -1,54 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package backfill - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestBatchStatementBuilder(t *testing.T) { - tests := map[string]struct { - tableName string - identityColumns []string - batchSize int - lasValues []string - expected string - }{ - "single identity column no last value": { - tableName: "table_name", - identityColumns: []string{"id"}, - batchSize: 10, - expected: `WITH batch AS (SELECT "id" FROM "table_name" ORDER BY "id" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id" FROM batch WHERE "table_name"."id" = batch."id" RETURNING "table_name"."id") SELECT LAST_VALUE("id") OVER() FROM update`, - }, - "multiple identity columns no last value": { - tableName: "table_name", - identityColumns: []string{"id", "zip"}, - batchSize: 10, - expected: `WITH batch AS (SELECT "id", "zip" FROM "table_name" ORDER BY "id", "zip" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id", "zip" = "table_name"."zip" FROM batch WHERE "table_name"."id" = batch."id" AND "table_name"."zip" = batch."zip" RETURNING "table_name"."id", "table_name"."zip") SELECT LAST_VALUE("id") OVER(), LAST_VALUE("zip") OVER() FROM update`, - }, - "single identity column with last value": { - tableName: "table_name", - identityColumns: []string{"id"}, - batchSize: 10, - lasValues: []string{"1"}, - expected: `WITH batch AS (SELECT "id" FROM "table_name" WHERE ("id") > ('1') ORDER BY "id" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id" FROM batch WHERE "table_name"."id" = batch."id" RETURNING "table_name"."id") SELECT LAST_VALUE("id") OVER() FROM update`, - }, - "multiple identity columns with last value": { - tableName: "table_name", - identityColumns: []string{"id", "zip"}, - batchSize: 10, - lasValues: []string{"1", "1234"}, - expected: `WITH batch AS (SELECT "id", "zip" FROM "table_name" WHERE ("id", "zip") > ('1', '1234') ORDER BY "id", "zip" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id", "zip" = "table_name"."zip" FROM batch WHERE "table_name"."id" = batch."id" AND "table_name"."zip" = batch."zip" RETURNING "table_name"."id", "table_name"."zip") SELECT LAST_VALUE("id") OVER(), LAST_VALUE("zip") OVER() FROM update`, - }, - } - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - builder := newBatchStatementBuilder(test.tableName, test.identityColumns, test.batchSize) - actual := builder.buildQuery(test.lasValues) - assert.Equal(t, test.expected, actual) - }) - } -} diff --git a/pkg/backfill/templates/build.go b/pkg/backfill/templates/build.go new file mode 100644 index 00000000..9e6fe706 --- /dev/null +++ b/pkg/backfill/templates/build.go @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 + +package templates + +import ( + "bytes" + "strings" + "text/template" + + "github.com/lib/pq" +) + +type BatchConfig struct { + TableName string + PrimaryKey []string + LastValue []string + BatchSize int +} + +func BuildSQL(cfg BatchConfig) (string, error) { + return executeTemplate("sql", SQL, cfg) +} + +func executeTemplate(name, content string, cfg BatchConfig) (string, error) { + ql := pq.QuoteLiteral + qi := pq.QuoteIdentifier + + tmpl := template.Must(template.New(name). + Funcs(template.FuncMap{ + "ql": ql, + "qi": qi, + "commaSeparate": func(slice []string) string { + return strings.Join(slice, ", ") + }, + "quoteIdentifiers": func(slice []string) []string { + quoted := make([]string, len(slice)) + for i, s := range slice { + quoted[i] = qi(s) + } + return quoted + }, + "quoteLiterals": func(slice []string) []string { + quoted := make([]string, len(slice)) + for i, s := range slice { + quoted[i] = ql(s) + } + return quoted + }, + "updateSetClause": func(tableName string, columns []string) string { + quoted := make([]string, len(columns)) + for i, c := range columns { + quoted[i] = qi(c) + " = " + qi(tableName) + "." + qi(c) + } + return strings.Join(quoted, ", ") + }, + "updateWhereClause": func(tableName string, columns []string) string { + quoted := make([]string, len(columns)) + for i, c := range columns { + quoted[i] = qi(tableName) + "." + qi(c) + " = batch." + qi(c) + } + return strings.Join(quoted, " AND ") + }, + "updateReturnClause": func(tableName string, columns []string) string { + quoted := make([]string, len(columns)) + for i, c := range columns { + quoted[i] = qi(tableName) + "." + qi(c) + } + return strings.Join(quoted, ", ") + }, + "selectLastValue": func(columns []string) string { + quoted := make([]string, len(columns)) + for i, c := range columns { + quoted[i] = "LAST_VALUE(" + qi(c) + ") OVER()" + } + return strings.Join(quoted, ", ") + }, + }). + Parse(content)) + + buf := bytes.Buffer{} + if err := tmpl.Execute(&buf, cfg); err != nil { + return "", err + } + + return buf.String(), nil +} diff --git a/pkg/backfill/templates/build_test.go b/pkg/backfill/templates/build_test.go new file mode 100644 index 00000000..4e9128c1 --- /dev/null +++ b/pkg/backfill/templates/build_test.go @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: Apache-2.0 + +package templates + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBatchStatementBuilder(t *testing.T) { + tests := map[string]struct { + config BatchConfig + expected string + }{ + "single identity column no last value": { + config: BatchConfig{ + TableName: "table_name", + PrimaryKey: []string{"id"}, + BatchSize: 10, + }, + expected: expectSingleIDColumnNoLastValue, + }, + "multiple identity columns no last value": { + config: BatchConfig{ + TableName: "table_name", + PrimaryKey: []string{"id", "zip"}, + BatchSize: 10, + }, + expected: multipleIDColumnsNoLastValue, + }, + "single identity column with last value": { + config: BatchConfig{ + TableName: "table_name", + PrimaryKey: []string{"id"}, + LastValue: []string{"1"}, + BatchSize: 10, + }, + expected: singleIDColumnWithLastValue, + }, + "multiple identity columns with last value": { + config: BatchConfig{ + TableName: "table_name", + PrimaryKey: []string{"id", "zip"}, + LastValue: []string{"1", "1234"}, + BatchSize: 10, + }, + expected: multipleIDColumnsWithLastValue, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + actual, err := BuildSQL(test.config) + assert.NoError(t, err) + + assert.Equal(t, test.expected, actual) + }) + } +} + +const expectSingleIDColumnNoLastValue = `WITH batch AS +( + SELECT "id" + FROM "table_name" + ORDER BY "id" + LIMIT 10 + FOR NO KEY UPDATE +), +update AS +( + UPDATE "table_name" + SET "id" = "table_name"."id" + FROM batch + WHERE "table_name"."id" = batch."id" + RETURNING "table_name"."id" +) +SELECT LAST_VALUE("id") OVER() +FROM update +` + +const multipleIDColumnsNoLastValue = `WITH batch AS +( + SELECT "id", "zip" + FROM "table_name" + ORDER BY "id", "zip" + LIMIT 10 + FOR NO KEY UPDATE +), +update AS +( + UPDATE "table_name" + SET "id" = "table_name"."id", "zip" = "table_name"."zip" + FROM batch + WHERE "table_name"."id" = batch."id" AND "table_name"."zip" = batch."zip" + RETURNING "table_name"."id", "table_name"."zip" +) +SELECT LAST_VALUE("id") OVER(), LAST_VALUE("zip") OVER() +FROM update +` + +const singleIDColumnWithLastValue = `WITH batch AS +( + SELECT "id" + FROM "table_name" + WHERE ("id") > ('1') + ORDER BY "id" + LIMIT 10 + FOR NO KEY UPDATE +), +update AS +( + UPDATE "table_name" + SET "id" = "table_name"."id" + FROM batch + WHERE "table_name"."id" = batch."id" + RETURNING "table_name"."id" +) +SELECT LAST_VALUE("id") OVER() +FROM update +` + +const multipleIDColumnsWithLastValue = `WITH batch AS +( + SELECT "id", "zip" + FROM "table_name" + WHERE ("id", "zip") > ('1', '1234') + ORDER BY "id", "zip" + LIMIT 10 + FOR NO KEY UPDATE +), +update AS +( + UPDATE "table_name" + SET "id" = "table_name"."id", "zip" = "table_name"."zip" + FROM batch + WHERE "table_name"."id" = batch."id" AND "table_name"."zip" = batch."zip" + RETURNING "table_name"."id", "table_name"."zip" +) +SELECT LAST_VALUE("id") OVER(), LAST_VALUE("zip") OVER() +FROM update +` diff --git a/pkg/backfill/templates/sql.go b/pkg/backfill/templates/sql.go new file mode 100644 index 00000000..9a11f9dc --- /dev/null +++ b/pkg/backfill/templates/sql.go @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: Apache-2.0 + +package templates + +const SQL = `WITH batch AS +( + SELECT {{ commaSeparate (quoteIdentifiers .PrimaryKey) }} + FROM {{ .TableName | qi}} + {{ if .LastValue -}} + WHERE ({{ commaSeparate (quoteIdentifiers .PrimaryKey) }}) > ({{ commaSeparate (quoteLiterals .LastValue) }}) + {{ end -}} + ORDER BY {{ commaSeparate (quoteIdentifiers .PrimaryKey) }} + LIMIT {{ .BatchSize }} + FOR NO KEY UPDATE +), +update AS +( + UPDATE {{ .TableName | qi }} + SET {{ updateSetClause .TableName .PrimaryKey }} + FROM batch + WHERE {{ updateWhereClause .TableName .PrimaryKey }} + RETURNING {{ updateReturnClause .TableName .PrimaryKey }} +) +SELECT {{ selectLastValue .PrimaryKey }} +FROM update +`