Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support risingwave schema inspect #2898

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion sql/postgres/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ type (
schema string
// System variables that are set on `Open`.
version int
crdb bool
// Whether the connected database is a CockroachDB.
crdb bool
// Whether the connected database is a RisingWave.
risingwave bool
}
)

Expand Down Expand Up @@ -121,6 +124,20 @@ func Open(db schema.ExecQuerier) (migrate.Driver, error) {
},
}, nil
}
c.risingwave, err = c.isRisingWaveConn()
if err != nil {
return nil, fmt.Errorf("postgres: failed checking if connected to RisingWave: %w", err)
}
if c.risingwave {
return noLockDriver{
&Driver{
conn: c,
Differ: &sqlx.Diff{DiffDriver: &risingwaveDiff{diff{c}}},
Inspector: &risingwaveInspect{inspect{c}},
PlanApplier: &planApply{c},
},
}, nil
}
return &Driver{
conn: c,
Differ: &sqlx.Diff{DiffDriver: &diff{c}},
Expand Down
11 changes: 7 additions & 4 deletions sql/postgres/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (i *inspect) InspectRealm(ctx context.Context, opts *schema.InspectRealmOpt
// referenced objects in the public schema (or any other default search_path) are returned
// qualified in the inspection.
func (i *inspect) noSearchPath(ctx context.Context) (func() error, error) {
if i.crdb {
if i.crdb || i.risingwave {
// Skip logic for CockroachDB.
return func() error { return nil }, nil
}
Expand Down Expand Up @@ -473,6 +473,9 @@ func (i *inspect) indexes(ctx context.Context, s *schema.Schema) error {
if i.crdb {
return i.crdbIndexes(ctx, s)
}
if i.risingwave {
return i.risingwaveIndexes(ctx, s)
}
rows, err := i.querySchema(ctx, i.indexesQuery(), s)
if err != nil {
return fmt.Errorf("postgres: querying schema %q indexes: %w", s.Name, err)
Expand Down Expand Up @@ -1500,7 +1503,7 @@ SELECT
FROM
pg_catalog.pg_namespace
WHERE
nspname NOT IN ('information_schema', 'pg_catalog', 'pg_toast', 'crdb_internal', 'pg_extension')
nspname NOT IN ('information_schema', 'pg_catalog', 'pg_toast', 'crdb_internal', 'rw_catalog', 'pg_extension')
AND nspname NOT LIKE 'pg_%temp_%'
ORDER BY
nspname`
Expand Down Expand Up @@ -1532,7 +1535,7 @@ FROM
JOIN pg_catalog.pg_namespace AS t2 ON t2.nspname = t1.table_schema
JOIN pg_catalog.pg_class AS t3 ON t3.relnamespace = t2.oid AND t3.relname = t1.table_name
LEFT JOIN pg_catalog.pg_partitioned_table AS t4 ON t4.partrelid = t3.oid
LEFT JOIN pg_depend AS t5 ON t5.classid = 'pg_catalog.pg_class'::regclass::oid AND t5.objid = t3.oid AND t5.deptype = 'e'
LEFT JOIN pg_depend AS t5 ON t5.classid::text = 'pg_catalog.pg_class' AND t5.objid = t3.oid AND t5.deptype = 'e'
WHERE
t1.table_type = 'BASE TABLE'
AND NOT COALESCE(t3.relispartition, false)
Expand All @@ -1555,7 +1558,7 @@ FROM
JOIN pg_catalog.pg_namespace AS t2 ON t2.nspname = t1.table_schema
JOIN pg_catalog.pg_class AS t3 ON t3.relnamespace = t2.oid AND t3.relname = t1.table_name
LEFT JOIN pg_catalog.pg_partitioned_table AS t4 ON t4.partrelid = t3.oid
LEFT JOIN pg_depend AS t5 ON t5.classid = 'pg_catalog.pg_class'::regclass::oid AND t5.objid = t3.oid AND t5.deptype = 'e'
LEFT JOIN pg_depend AS t5 ON t5.classid::text = 'pg_catalog.pg_class' AND t5.objid = t3.oid AND t5.deptype = 'e'
WHERE
t1.table_type = 'BASE TABLE'
AND NOT COALESCE(t3.relispartition, false)
Expand Down
149 changes: 149 additions & 0 deletions sql/postgres/risingwave.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright 2021-present The Atlas Authors. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.

package postgres

import (
"context"
"database/sql"
"fmt"
"strings"

"ariga.io/atlas/sql/internal/sqlx"
"ariga.io/atlas/sql/schema"
)

type (
risingwaveDiff struct{ diff }
risingwaveInspect struct{ inspect }
)

var _ sqlx.DiffDriver = (*risingwaveDiff)(nil)

func (c *conn) isRisingWaveConn() (bool, error) {
rows, err := c.QueryContext(context.Background(), "select version()")
if err != nil {
return false, err
}
defer rows.Close()

var version string
err = sqlx.ScanOne(rows, &version)
if err != nil {
return false, fmt.Errorf("postgres: failed scanning rows: %w", err)
}
return strings.Contains(strings.ToLower(version), "risingwave"), nil
}

func (i *inspect) risingwaveIndexes(ctx context.Context, s *schema.Schema) error {
rows, err := i.querySchema(ctx, risingwaveIndexesQuery, s)
if err != nil {
return fmt.Errorf("postgres: querying schema %q indexes: %w", s.Name, err)
}
defer rows.Close()
if err := i.risingwaveAddIndexes(s, rows); err != nil {
return err
}
return rows.Err()
}

// RisingWave doesn't support:
// - Unique Indexes
// - Index Constraints
// - Partial Indexes or Partial Index Predicates
// - Indexes on Expressions
func (i *inspect) risingwaveAddIndexes(s *schema.Schema, rows *sql.Rows) error {
names := make(map[string]*schema.Index)
for rows.Next() {
var (
primary bool
table, name string
desc, nullsfirst, nullslast sql.NullBool
column, comment sql.NullString
)
if err := rows.Scan(&table, &name, &column, &primary, &desc, &nullsfirst, &nullslast, &comment); err != nil {
return fmt.Errorf("risingwave: scanning indexes for schema %q: %w", s.Name, err)
}

t, ok := s.Table(table)
if !ok {
return fmt.Errorf("table %q was not found in schema", table)
}

idx, ok := names[name]
if !ok {
idx = &schema.Index{
Name: name,
Unique: primary,
Table: t,
}
if sqlx.ValidString(comment) {
idx.Attrs = append(idx.Attrs, &schema.Comment{Text: comment.String})
}
if primary {
t.PrimaryKey = idx
} else {
t.Indexes = append(t.Indexes, idx)
}
names[name] = idx
}
// TODO: Extract isdesc from RisingWave indexes.
part := &schema.IndexPart{SeqNo: len(idx.Parts) + 1, Desc: desc.Bool}
if nullsfirst.Bool || nullslast.Bool {
part.Attrs = append(part.Attrs, &IndexColumnProperty{
NullsFirst: nullsfirst.Bool,
NullsLast: nullslast.Bool,
})
}
switch {
case sqlx.ValidString(column):
part.C, ok = t.Column(column.String)
if !ok {
return fmt.Errorf("risingwave: column %q was not found for index %q", column.String, idx.Name)
}
part.C.Indexes = append(part.C.Indexes, idx)
default:
return fmt.Errorf("risingwave: invalid part for index %q", idx.Name)
}
idx.Parts = append(idx.Parts, part)
}
return nil
}

const (
/// table, name, typ, column, primary, comment
risingwaveIndexesQuery = `
SELECT
t.relname AS table_name,
i.relname AS index_name,
a.attname AS column_name,
idx.indisprimary AS primary,
pg_index_column_has_property(idx.indexrelid, idx.ord, 'desc') AS isdesc,
pg_index_column_has_property(idx.indexrelid, idx.ord, 'nulls_first') AS nulls_first,
pg_index_column_has_property(idx.indexrelid, idx.ord, 'nulls_last') AS nulls_last,
obj_description(i.oid, 'pg_class') AS comment
FROM
(
select
*,
generate_series(1,array_length(i.indkey,1)) as ord,
unnest(i.indkey) AS key
from pg_index i
) idx
JOIN pg_class i ON i.oid = idx.indexrelid
JOIN pg_class t ON t.oid = idx.indrelid
JOIN pg_namespace n ON n.oid = t.relnamespace
LEFT JOIN (
select conindid, jsonb_object_agg(conname, contype) AS nametypes
from pg_constraint
group by conindid
) con ON con.conindid = idx.indexrelid
LEFT JOIN pg_attribute a ON (a.attrelid, a.attnum) = (idx.indrelid, idx.key)
WHERE
n.nspname = $1
AND t.relname IN (%s)
ORDER BY
table_name, index_name, idx.ord
`
)