From a061d45493568b27b42064ecb8a0e0aa9b3dafc5 Mon Sep 17 00:00:00 2001 From: Tao Wu Date: Wed, 26 Jun 2024 23:03:43 +0800 Subject: [PATCH] support risingwave schema inspect --- sql/postgres/driver.go | 19 ++++- sql/postgres/inspect.go | 11 ++- sql/postgres/risingwave.go | 149 +++++++++++++++++++++++++++++++++++++ 3 files changed, 174 insertions(+), 5 deletions(-) create mode 100644 sql/postgres/risingwave.go diff --git a/sql/postgres/driver.go b/sql/postgres/driver.go index 91dec5b09cc..3eb12886b3d 100644 --- a/sql/postgres/driver.go +++ b/sql/postgres/driver.go @@ -38,7 +38,10 @@ type ( schema string // System variables that are set on `Open`. version int - crdb bool + // Whether the connected database is a CockroachDB. + crdb bool + // Whether the connected database is a RisingWave. + risingwave bool } ) @@ -121,6 +124,20 @@ func Open(db schema.ExecQuerier) (migrate.Driver, error) { }, }, nil } + c.risingwave, err = c.isRisingWaveConn() + if err != nil { + return nil, fmt.Errorf("postgres: failed checking if connected to RisingWave: %w", err) + } + if c.risingwave { + return noLockDriver{ + &Driver{ + conn: c, + Differ: &sqlx.Diff{DiffDriver: &risingwaveDiff{diff{c}}}, + Inspector: &risingwaveInspect{inspect{c}}, + PlanApplier: &planApply{c}, + }, + }, nil + } return &Driver{ conn: c, Differ: &sqlx.Diff{DiffDriver: &diff{c}}, diff --git a/sql/postgres/inspect.go b/sql/postgres/inspect.go index 996549c190c..65d5fd28878 100644 --- a/sql/postgres/inspect.go +++ b/sql/postgres/inspect.go @@ -95,7 +95,7 @@ func (i *inspect) InspectRealm(ctx context.Context, opts *schema.InspectRealmOpt // referenced objects in the public schema (or any other default search_path) are returned // qualified in the inspection. func (i *inspect) noSearchPath(ctx context.Context) (func() error, error) { - if i.crdb { + if i.crdb || i.risingwave { // Skip logic for CockroachDB. return func() error { return nil }, nil } @@ -473,6 +473,9 @@ func (i *inspect) indexes(ctx context.Context, s *schema.Schema) error { if i.crdb { return i.crdbIndexes(ctx, s) } + if i.risingwave { + return i.risingwaveIndexes(ctx, s) + } rows, err := i.querySchema(ctx, i.indexesQuery(), s) if err != nil { return fmt.Errorf("postgres: querying schema %q indexes: %w", s.Name, err) @@ -1500,7 +1503,7 @@ SELECT FROM pg_catalog.pg_namespace WHERE - nspname NOT IN ('information_schema', 'pg_catalog', 'pg_toast', 'crdb_internal', 'pg_extension') + nspname NOT IN ('information_schema', 'pg_catalog', 'pg_toast', 'crdb_internal', 'rw_catalog', 'pg_extension') AND nspname NOT LIKE 'pg_%temp_%' ORDER BY nspname` @@ -1532,7 +1535,7 @@ FROM JOIN pg_catalog.pg_namespace AS t2 ON t2.nspname = t1.table_schema JOIN pg_catalog.pg_class AS t3 ON t3.relnamespace = t2.oid AND t3.relname = t1.table_name LEFT JOIN pg_catalog.pg_partitioned_table AS t4 ON t4.partrelid = t3.oid - LEFT JOIN pg_depend AS t5 ON t5.classid = 'pg_catalog.pg_class'::regclass::oid AND t5.objid = t3.oid AND t5.deptype = 'e' + LEFT JOIN pg_depend AS t5 ON t5.classid::text = 'pg_catalog.pg_class' AND t5.objid = t3.oid AND t5.deptype = 'e' WHERE t1.table_type = 'BASE TABLE' AND NOT COALESCE(t3.relispartition, false) @@ -1555,7 +1558,7 @@ FROM JOIN pg_catalog.pg_namespace AS t2 ON t2.nspname = t1.table_schema JOIN pg_catalog.pg_class AS t3 ON t3.relnamespace = t2.oid AND t3.relname = t1.table_name LEFT JOIN pg_catalog.pg_partitioned_table AS t4 ON t4.partrelid = t3.oid - LEFT JOIN pg_depend AS t5 ON t5.classid = 'pg_catalog.pg_class'::regclass::oid AND t5.objid = t3.oid AND t5.deptype = 'e' + LEFT JOIN pg_depend AS t5 ON t5.classid::text = 'pg_catalog.pg_class' AND t5.objid = t3.oid AND t5.deptype = 'e' WHERE t1.table_type = 'BASE TABLE' AND NOT COALESCE(t3.relispartition, false) diff --git a/sql/postgres/risingwave.go b/sql/postgres/risingwave.go new file mode 100644 index 00000000000..bc80d72bb8c --- /dev/null +++ b/sql/postgres/risingwave.go @@ -0,0 +1,149 @@ +// Copyright 2021-present The Atlas Authors. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +package postgres + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "ariga.io/atlas/sql/internal/sqlx" + "ariga.io/atlas/sql/schema" +) + +type ( + risingwaveDiff struct{ diff } + risingwaveInspect struct{ inspect } +) + +var _ sqlx.DiffDriver = (*risingwaveDiff)(nil) + +func (c *conn) isRisingWaveConn() (bool, error) { + rows, err := c.QueryContext(context.Background(), "select version()") + if err != nil { + return false, err + } + defer rows.Close() + + var version string + err = sqlx.ScanOne(rows, &version) + if err != nil { + return false, fmt.Errorf("postgres: failed scanning rows: %w", err) + } + return strings.Contains(strings.ToLower(version), "risingwave"), nil +} + +func (i *inspect) risingwaveIndexes(ctx context.Context, s *schema.Schema) error { + rows, err := i.querySchema(ctx, risingwaveIndexesQuery, s) + if err != nil { + return fmt.Errorf("postgres: querying schema %q indexes: %w", s.Name, err) + } + defer rows.Close() + if err := i.risingwaveAddIndexes(s, rows); err != nil { + return err + } + return rows.Err() +} + +// RisingWave doesn't support: +// - Unique Indexes +// - Index Constraints +// - Partial Indexes or Partial Index Predicates +// - Indexes on Expressions +func (i *inspect) risingwaveAddIndexes(s *schema.Schema, rows *sql.Rows) error { + names := make(map[string]*schema.Index) + for rows.Next() { + var ( + primary bool + table, name string + desc, nullsfirst, nullslast sql.NullBool + column, comment sql.NullString + ) + if err := rows.Scan(&table, &name, &column, &primary, &desc, &nullsfirst, &nullslast, &comment); err != nil { + return fmt.Errorf("risingwave: scanning indexes for schema %q: %w", s.Name, err) + } + + t, ok := s.Table(table) + if !ok { + return fmt.Errorf("table %q was not found in schema", table) + } + + idx, ok := names[name] + if !ok { + idx = &schema.Index{ + Name: name, + Unique: primary, + Table: t, + } + if sqlx.ValidString(comment) { + idx.Attrs = append(idx.Attrs, &schema.Comment{Text: comment.String}) + } + if primary { + t.PrimaryKey = idx + } else { + t.Indexes = append(t.Indexes, idx) + } + names[name] = idx + } + // TODO: Extract isdesc from RisingWave indexes. + part := &schema.IndexPart{SeqNo: len(idx.Parts) + 1, Desc: desc.Bool} + if nullsfirst.Bool || nullslast.Bool { + part.Attrs = append(part.Attrs, &IndexColumnProperty{ + NullsFirst: nullsfirst.Bool, + NullsLast: nullslast.Bool, + }) + } + switch { + case sqlx.ValidString(column): + part.C, ok = t.Column(column.String) + if !ok { + return fmt.Errorf("risingwave: column %q was not found for index %q", column.String, idx.Name) + } + part.C.Indexes = append(part.C.Indexes, idx) + default: + return fmt.Errorf("risingwave: invalid part for index %q", idx.Name) + } + idx.Parts = append(idx.Parts, part) + } + return nil +} + +const ( + /// table, name, typ, column, primary, comment + risingwaveIndexesQuery = ` +SELECT + t.relname AS table_name, + i.relname AS index_name, + a.attname AS column_name, + idx.indisprimary AS primary, + pg_index_column_has_property(idx.indexrelid, idx.ord, 'desc') AS isdesc, + pg_index_column_has_property(idx.indexrelid, idx.ord, 'nulls_first') AS nulls_first, + pg_index_column_has_property(idx.indexrelid, idx.ord, 'nulls_last') AS nulls_last, + obj_description(i.oid, 'pg_class') AS comment +FROM + ( + select + *, + generate_series(1,array_length(i.indkey,1)) as ord, + unnest(i.indkey) AS key + from pg_index i + ) idx + JOIN pg_class i ON i.oid = idx.indexrelid + JOIN pg_class t ON t.oid = idx.indrelid + JOIN pg_namespace n ON n.oid = t.relnamespace + LEFT JOIN ( + select conindid, jsonb_object_agg(conname, contype) AS nametypes + from pg_constraint + group by conindid + ) con ON con.conindid = idx.indexrelid + LEFT JOIN pg_attribute a ON (a.attrelid, a.attnum) = (idx.indrelid, idx.key) +WHERE + n.nspname = $1 + AND t.relname IN (%s) +ORDER BY + table_name, index_name, idx.ord +` +)