From d29805a6e6d062e9276f8921fb56b459e6a4862c Mon Sep 17 00:00:00 2001 From: alishakawaguchi Date: Tue, 21 Jan 2025 08:03:18 -0800 Subject: [PATCH] Postgres init schema - improved handling of creating extensions (#3162) --- .../gen/go/db/dbschemas/postgresql/querier.go | 1 + .../go/db/dbschemas/postgresql/system.sql.go | 43 +++++++++++++++++ .../sql/postgresql/queries/system.sql | 14 ++++++ .../sqlmanager/postgres/postgres-manager.go | 46 ++++++++++++++++++- backend/pkg/sqlmanager/shared/types.go | 5 ++ .../init-statement-builder.go | 2 +- 6 files changed, 109 insertions(+), 2 deletions(-) diff --git a/backend/gen/go/db/dbschemas/postgresql/querier.go b/backend/gen/go/db/dbschemas/postgresql/querier.go index 644083d615..4417bf0e1e 100644 --- a/backend/gen/go/db/dbschemas/postgresql/querier.go +++ b/backend/gen/go/db/dbschemas/postgresql/querier.go @@ -15,6 +15,7 @@ type Querier interface { GetDataTypesBySchemaAndTables(ctx context.Context, db DBTX, arg *GetDataTypesBySchemaAndTablesParams) ([]*GetDataTypesBySchemaAndTablesRow, error) GetDatabaseSchema(ctx context.Context, db DBTX) ([]*GetDatabaseSchemaRow, error) GetDatabaseTableSchemasBySchemasAndTables(ctx context.Context, db DBTX, schematables []string) ([]*GetDatabaseTableSchemasBySchemasAndTablesRow, error) + GetExtensions(ctx context.Context, db DBTX) ([]*GetExtensionsRow, error) GetIndicesBySchemasAndTables(ctx context.Context, db DBTX, schematables []string) ([]*GetIndicesBySchemasAndTablesRow, error) GetPostgresRolePermissions(ctx context.Context, db DBTX) ([]*GetPostgresRolePermissionsRow, error) GetTableConstraints(ctx context.Context, db DBTX, arg *GetTableConstraintsParams) ([]*GetTableConstraintsRow, error) diff --git a/backend/gen/go/db/dbschemas/postgresql/system.sql.go b/backend/gen/go/db/dbschemas/postgresql/system.sql.go index 8944430143..991360e122 100644 --- a/backend/gen/go/db/dbschemas/postgresql/system.sql.go +++ b/backend/gen/go/db/dbschemas/postgresql/system.sql.go @@ -45,6 +45,7 @@ column_default_functions AS ( WHERE ad.adrelid IN (SELECT oid FROM relevant_schemas_tables) AND d.refclassid = 'pg_proc'::regclass AND d.classid = 'pg_attrdef'::regclass + AND p.oid NOT IN(SELECT objid FROM pg_catalog.pg_depend WHERE deptype = 'e') -- excludes extensions ) SELECT schema_name, @@ -826,6 +827,48 @@ func (q *Queries) GetDatabaseTableSchemasBySchemasAndTables(ctx context.Context, return items, nil } +const getExtensions = `-- name: GetExtensions :many +SELECT + e.extname AS extension_name, + e.extversion AS installed_version, + n.nspname as schema_name +FROM + pg_catalog.pg_extension e +LEFT JOIN pg_catalog.pg_namespace n ON e.extnamespace = n.oid +WHERE extname != 'plpgsql' +ORDER BY + extname +` + +type GetExtensionsRow struct { + ExtensionName string + InstalledVersion string + SchemaName sql.NullString +} + +func (q *Queries) GetExtensions(ctx context.Context, db DBTX) ([]*GetExtensionsRow, error) { + rows, err := db.QueryContext(ctx, getExtensions) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*GetExtensionsRow + for rows.Next() { + var i GetExtensionsRow + if err := rows.Scan(&i.ExtensionName, &i.InstalledVersion, &i.SchemaName); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getIndicesBySchemasAndTables = `-- name: GetIndicesBySchemasAndTables :many SELECT ns.nspname AS schema_name, diff --git a/backend/pkg/dbschemas/sql/postgresql/queries/system.sql b/backend/pkg/dbschemas/sql/postgresql/queries/system.sql index bf1b0ea0b5..c3b5148202 100644 --- a/backend/pkg/dbschemas/sql/postgresql/queries/system.sql +++ b/backend/pkg/dbschemas/sql/postgresql/queries/system.sql @@ -583,6 +583,7 @@ column_default_functions AS ( WHERE ad.adrelid IN (SELECT oid FROM relevant_schemas_tables) AND d.refclassid = 'pg_proc'::regclass AND d.classid = 'pg_attrdef'::regclass + AND p.oid NOT IN(SELECT objid FROM pg_catalog.pg_depend WHERE deptype = 'e') -- excludes extensions ) SELECT schema_name, @@ -604,6 +605,19 @@ ORDER BY function_name; +-- name: GetExtensions :many +SELECT + e.extname AS extension_name, + e.extversion AS installed_version, + n.nspname as schema_name +FROM + pg_catalog.pg_extension e +LEFT JOIN pg_catalog.pg_namespace n ON e.extnamespace = n.oid +WHERE extname != 'plpgsql' +ORDER BY + extname; + + -- name: GetCustomTriggersBySchemaAndTables :many SELECT n.nspname AS schema_name, diff --git a/backend/pkg/sqlmanager/postgres/postgres-manager.go b/backend/pkg/sqlmanager/postgres/postgres-manager.go index 1daa77b3ee..0b6b7f91a8 100644 --- a/backend/pkg/sqlmanager/postgres/postgres-manager.go +++ b/backend/pkg/sqlmanager/postgres/postgres-manager.go @@ -2,6 +2,7 @@ package sqlmanager_postgres import ( "context" + "database/sql" "errors" "fmt" "strings" @@ -16,7 +17,8 @@ import ( ) const ( - SchemasLabel = "schemas" + SchemasLabel = "schemas" + ExtensionsLabel = "extensions" ) type PostgresManager struct { @@ -332,6 +334,35 @@ func (p *PostgresManager) GetSequencesByTables(ctx context.Context, schema strin return output, nil } +func (p *PostgresManager) getExtensions(ctx context.Context) ([]*sqlmanager_shared.ExtensionDataType, error) { + rows, err := p.querier.GetExtensions(ctx, p.db) + if err != nil && !neosyncdb.IsNoRows(err) { + return nil, err + } else if err != nil && neosyncdb.IsNoRows(err) { + return []*sqlmanager_shared.ExtensionDataType{}, nil + } + + output := make([]*sqlmanager_shared.ExtensionDataType, 0, len(rows)) + for _, row := range rows { + output = append(output, &sqlmanager_shared.ExtensionDataType{ + Name: row.ExtensionName, + Definition: wrapPgIdempotentExtension(row.SchemaName, row.ExtensionName, row.InstalledVersion), + }) + } + return output, nil +} + +func wrapPgIdempotentExtension( + schema sql.NullString, + extensionName, + version string, +) string { + if schema.Valid && strings.EqualFold(schema.String, "public") { + return fmt.Sprintf(`CREATE EXTENSION IF NOT EXISTS %q VERSION %q;`, extensionName, version) + } + return fmt.Sprintf(`CREATE EXTENSION IF NOT EXISTS %q VERSION %q SCHEMA %q;`, extensionName, version, schema.String) +} + func (p *PostgresManager) getFunctionsByTables(ctx context.Context, schema string, tables []string) ([]*sqlmanager_shared.DataType, error) { rows, err := p.querier.GetCustomFunctionsBySchemaAndTables(ctx, p.db, &pg_queries.GetCustomFunctionsBySchemaAndTablesParams{ Schema: schema, @@ -553,6 +584,18 @@ func (p *PostgresManager) GetSchemaInitStatements( return nil }) + extensionStmts := []string{} + errgrp.Go(func() error { + extensions, err := p.getExtensions(errctx) + if err != nil { + return fmt.Errorf("unable to get postgres extensions: %w", err) + } + for _, extension := range extensions { + extensionStmts = append(extensionStmts, extension.Definition) + } + return nil + }) + createTables := []string{} nonFkAlterStmts := []string{} fkAlterStmts := []string{} @@ -585,6 +628,7 @@ func (p *PostgresManager) GetSchemaInitStatements( return []*sqlmanager_shared.InitSchemaStatements{ {Label: SchemasLabel, Statements: schemaStmts}, + {Label: ExtensionsLabel, Statements: extensionStmts}, {Label: "data types", Statements: dataTypeStmts}, {Label: "create table", Statements: createTables}, {Label: "non-fk alter table", Statements: nonFkAlterStmts}, diff --git a/backend/pkg/sqlmanager/shared/types.go b/backend/pkg/sqlmanager/shared/types.go index 9827a5d217..9f574bc2af 100644 --- a/backend/pkg/sqlmanager/shared/types.go +++ b/backend/pkg/sqlmanager/shared/types.go @@ -133,6 +133,11 @@ type DataType struct { Definition string } +type ExtensionDataType struct { + Name string + Definition string +} + // These are all items that live at the schema level, but are used by tables type SchemaTableDataTypeResponse struct { // Custom Sequences not tied to the SERIAL data type diff --git a/worker/pkg/workflows/datasync/activities/run-sql-init-table-stmts/init-statement-builder.go b/worker/pkg/workflows/datasync/activities/run-sql-init-table-stmts/init-statement-builder.go index ee1cbcbef6..fca6f980b2 100644 --- a/worker/pkg/workflows/datasync/activities/run-sql-init-table-stmts/init-statement-builder.go +++ b/worker/pkg/workflows/datasync/activities/run-sql-init-table-stmts/init-statement-builder.go @@ -174,7 +174,7 @@ func (b *initStatementBuilder) RunSqlInitTableStatements( err = destdb.Db().BatchExec(ctx, batchSizeConst, block.Statements, &sqlmanager_shared.BatchExecOpts{}) if err != nil { slogger.Error(fmt.Sprintf("unable to exec pg %s statements: %s", block.Label, err.Error())) - if block.Label != sqlmanager_postgres.SchemasLabel { + if block.Label != sqlmanager_postgres.SchemasLabel && block.Label != sqlmanager_postgres.ExtensionsLabel { return nil, fmt.Errorf("unable to exec pg %s statements: %w", block.Label, err) } initErrors = append(initErrors, &InitSchemaError{