From 009c18f687003aa5605638b9ec40f1a2b4ac2462 Mon Sep 17 00:00:00 2001 From: Aleksey Myasnikov Date: Mon, 20 May 2024 19:59:06 +0300 Subject: [PATCH] * Added experimental `ydb.RegisterSqlOpenDsnParser` for register external custom DSN parser for `ydb.Open` and `sql.Open` driver constructor --- CHANGELOG.md | 2 + driver.go | 20 +-- dsn.go | 116 ++++++++++++++++++ internal/xsql/dsn_test.go => dsn_test.go | 74 +++++------ internal/xsql/dsn.go | 98 --------------- options.go | 2 +- sql.go | 16 ++- tests/integration/register_dsn_parser_test.go | 54 ++++++++ 8 files changed, 235 insertions(+), 147 deletions(-) create mode 100644 dsn.go rename internal/xsql/dsn_test.go => dsn_test.go (71%) delete mode 100644 internal/xsql/dsn.go create mode 100644 tests/integration/register_dsn_parser_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 873c3687b..d22e2c8d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Added experimental `ydb.RegisterDsnParser` for register external custom DSN parser for `ydb.Open` and `sql.Open` driver constructor + ## v3.67.2 * Fixed incorrect formatting of decimal. Implementation of decimal has been reverted to latest working version diff --git a/driver.go b/driver.go index 170e2f86d..ee752b4a4 100644 --- a/driver.go +++ b/driver.go @@ -3,6 +3,7 @@ package ydb import ( "context" "errors" + "fmt" "os" "sync" @@ -234,13 +235,18 @@ func (d *Driver) Topic() topic.Client { // See sugar.DSN helper for make dsn from endpoint and database // //nolint:nonamedreturns -func Open(ctx context.Context, dsn string, opts ...Option) (_ *Driver, err error) { - d, err := newConnectionFromOptions(ctx, append( - []Option{ - WithConnectionString(dsn), - }, - opts..., - )...) +func Open(ctx context.Context, dsn string, opts ...Option) (_ *Driver, _ error) { + opts = append(append(make([]Option, 0, len(opts)+1), WithConnectionString(dsn)), opts...) + + for _, parser := range dsnParsers { + optsFromParser, err := parser(dsn) + if err != nil { + return nil, xerrors.WithStackTrace(fmt.Errorf("data source name '%s' wrong: %w", dsn, err)) + } + opts = append(opts, optsFromParser...) + } + + d, err := newConnectionFromOptions(ctx, opts...) if err != nil { return nil, xerrors.WithStackTrace(err) } diff --git a/dsn.go b/dsn.go new file mode 100644 index 000000000..5cdf8bf23 --- /dev/null +++ b/dsn.go @@ -0,0 +1,116 @@ +package ydb + +import ( + "errors" + "fmt" + "regexp" + "strings" + + "github.com/ydb-platform/ydb-go-sdk/v3/balancers" + "github.com/ydb-platform/ydb-go-sdk/v3/credentials" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/bind" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/dsn" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql" +) + +const tablePathPrefixTransformer = "table_path_prefix" + +var dsnParsers = []func(dsn string) (opts []Option, _ error){ + func(dsn string) ([]Option, error) { + opts, err := parseConnectionString(dsn) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + + return opts, nil + }, +} + +// RegisterDsnParser registers DSN parser for ydb.Open and sql.Open driver constructors +// +// Experimental: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#experimental +func RegisterDsnParser(parser func(dsn string) (opts []Option, _ error)) { + dsnParsers = append(dsnParsers, parser) +} + +func parseConnectionString(dataSourceName string) (opts []Option, _ error) { + info, err := dsn.Parse(dataSourceName) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + opts = append(opts, With(info.Options...)) + if token := info.Params.Get("token"); token != "" { + opts = append(opts, WithCredentials(credentials.NewAccessTokenCredentials(token))) + } + if balancer := info.Params.Get("go_balancer"); balancer != "" { + opts = append(opts, WithBalancer(balancers.FromConfig(balancer))) + } else if balancer := info.Params.Get("balancer"); balancer != "" { + opts = append(opts, WithBalancer(balancers.FromConfig(balancer))) + } + if queryMode := info.Params.Get("go_query_mode"); queryMode != "" { + mode := xsql.QueryModeFromString(queryMode) + if mode == xsql.UnknownQueryMode { + return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode)) + } + opts = append(opts, withConnectorOptions(xsql.WithDefaultQueryMode(mode))) + } else if queryMode := info.Params.Get("query_mode"); queryMode != "" { + mode := xsql.QueryModeFromString(queryMode) + if mode == xsql.UnknownQueryMode { + return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode)) + } + opts = append(opts, withConnectorOptions(xsql.WithDefaultQueryMode(mode))) + } + if fakeTx := info.Params.Get("go_fake_tx"); fakeTx != "" { + for _, queryMode := range strings.Split(fakeTx, ",") { + mode := xsql.QueryModeFromString(queryMode) + if mode == xsql.UnknownQueryMode { + return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode)) + } + opts = append(opts, withConnectorOptions(xsql.WithFakeTx(mode))) + } + } + if info.Params.Has("go_query_bind") { + var binders []xsql.ConnectorOption + queryTransformers := strings.Split(info.Params.Get("go_query_bind"), ",") + for _, transformer := range queryTransformers { + switch transformer { + case "declare": + binders = append(binders, xsql.WithQueryBind(bind.AutoDeclare{})) + case "positional": + binders = append(binders, xsql.WithQueryBind(bind.PositionalArgs{})) + case "numeric": + binders = append(binders, xsql.WithQueryBind(bind.NumericArgs{})) + default: + if strings.HasPrefix(transformer, tablePathPrefixTransformer) { + prefix, err := extractTablePathPrefixFromBinderName(transformer) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + binders = append(binders, xsql.WithTablePathPrefix(prefix)) + } else { + return nil, xerrors.WithStackTrace( + fmt.Errorf("unknown query rewriter: %s", transformer), + ) + } + } + } + opts = append(opts, withConnectorOptions(binders...)) + } + + return opts, nil +} + +var ( + tablePathPrefixRe = regexp.MustCompile(tablePathPrefixTransformer + "\\((.*)\\)") + errWrongTablePathPrefix = errors.New("wrong '" + tablePathPrefixTransformer + "' query transformer") +) + +func extractTablePathPrefixFromBinderName(binderName string) (string, error) { + ss := tablePathPrefixRe.FindAllStringSubmatch(binderName, -1) + if len(ss) != 1 || len(ss[0]) != 2 || ss[0][1] == "" { + return "", xerrors.WithStackTrace(fmt.Errorf("%w: %s", errWrongTablePathPrefix, binderName)) + } + + return ss[0][1], nil +} diff --git a/internal/xsql/dsn_test.go b/dsn_test.go similarity index 71% rename from internal/xsql/dsn_test.go rename to dsn_test.go index d28c96ac3..e10da8606 100644 --- a/internal/xsql/dsn_test.go +++ b/dsn_test.go @@ -1,17 +1,19 @@ -package xsql +package ydb import ( + "context" "testing" "github.com/stretchr/testify/require" "github.com/ydb-platform/ydb-go-sdk/v3/config" "github.com/ydb-platform/ydb-go-sdk/v3/internal/bind" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql" ) func TestParse(t *testing.T) { - newConnector := func(opts ...ConnectorOption) *Connector { - c := &Connector{} + newConnector := func(opts ...xsql.ConnectorOption) *xsql.Connector { + c := &xsql.Connector{} for _, opt := range opts { if opt != nil { if err := opt.Apply(c); err != nil { @@ -30,7 +32,7 @@ func TestParse(t *testing.T) { for _, tt := range []struct { dsn string opts []config.Option - connectorOpts []ConnectorOption + connectorOpts []xsql.ConnectorOption err error }{ { @@ -40,9 +42,9 @@ func TestParse(t *testing.T) { config.WithEndpoint("localhost:2135"), config.WithDatabase("/local"), }, - connectorOpts: []ConnectorOption{ - WithFakeTx(ScriptingQueryMode), - WithFakeTx(SchemeQueryMode), + connectorOpts: []xsql.ConnectorOption{ + xsql.WithFakeTx(xsql.ScriptingQueryMode), + xsql.WithFakeTx(xsql.SchemeQueryMode), }, err: nil, }, @@ -73,8 +75,8 @@ func TestParse(t *testing.T) { config.WithEndpoint("localhost:2135"), config.WithDatabase("/local"), }, - connectorOpts: []ConnectorOption{ - WithDefaultQueryMode(ScriptingQueryMode), + connectorOpts: []xsql.ConnectorOption{ + xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode), }, err: nil, }, @@ -85,9 +87,9 @@ func TestParse(t *testing.T) { config.WithEndpoint("localhost:2135"), config.WithDatabase("/local"), }, - connectorOpts: []ConnectorOption{ - WithDefaultQueryMode(ScriptingQueryMode), - WithTablePathPrefix("path/to/tables"), + connectorOpts: []xsql.ConnectorOption{ + xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode), + xsql.WithTablePathPrefix("path/to/tables"), }, err: nil, }, @@ -98,10 +100,10 @@ func TestParse(t *testing.T) { config.WithEndpoint("localhost:2135"), config.WithDatabase("/local"), }, - connectorOpts: []ConnectorOption{ - WithDefaultQueryMode(ScriptingQueryMode), - WithTablePathPrefix("path/to/tables"), - WithQueryBind(bind.NumericArgs{}), + connectorOpts: []xsql.ConnectorOption{ + xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode), + xsql.WithTablePathPrefix("path/to/tables"), + xsql.WithQueryBind(bind.NumericArgs{}), }, err: nil, }, @@ -112,10 +114,10 @@ func TestParse(t *testing.T) { config.WithEndpoint("localhost:2135"), config.WithDatabase("/local"), }, - connectorOpts: []ConnectorOption{ - WithDefaultQueryMode(ScriptingQueryMode), - WithTablePathPrefix("path/to/tables"), - WithQueryBind(bind.PositionalArgs{}), + connectorOpts: []xsql.ConnectorOption{ + xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode), + xsql.WithTablePathPrefix("path/to/tables"), + xsql.WithQueryBind(bind.PositionalArgs{}), }, err: nil, }, @@ -126,10 +128,10 @@ func TestParse(t *testing.T) { config.WithEndpoint("localhost:2135"), config.WithDatabase("/local"), }, - connectorOpts: []ConnectorOption{ - WithDefaultQueryMode(ScriptingQueryMode), - WithTablePathPrefix("path/to/tables"), - WithQueryBind(bind.AutoDeclare{}), + connectorOpts: []xsql.ConnectorOption{ + xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode), + xsql.WithTablePathPrefix("path/to/tables"), + xsql.WithQueryBind(bind.AutoDeclare{}), }, err: nil, }, @@ -140,9 +142,9 @@ func TestParse(t *testing.T) { config.WithEndpoint("localhost:2135"), config.WithDatabase("/local"), }, - connectorOpts: []ConnectorOption{ - WithDefaultQueryMode(ScriptingQueryMode), - WithTablePathPrefix("path/to/tables"), + connectorOpts: []xsql.ConnectorOption{ + xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode), + xsql.WithTablePathPrefix("path/to/tables"), }, err: nil, }, @@ -153,23 +155,25 @@ func TestParse(t *testing.T) { config.WithEndpoint("localhost:2135"), config.WithDatabase("/local"), }, - connectorOpts: []ConnectorOption{ - WithDefaultQueryMode(ScriptingQueryMode), - WithTablePathPrefix("path/to/tables"), - WithQueryBind(bind.PositionalArgs{}), - WithQueryBind(bind.AutoDeclare{}), + connectorOpts: []xsql.ConnectorOption{ + xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode), + xsql.WithTablePathPrefix("path/to/tables"), + xsql.WithQueryBind(bind.PositionalArgs{}), + xsql.WithQueryBind(bind.AutoDeclare{}), }, err: nil, }, } { t.Run("", func(t *testing.T) { - opts, connectorOpts, err := Parse(tt.dsn) + opts, err := parseConnectionString(tt.dsn) if tt.err != nil { require.ErrorIs(t, err, tt.err) } else { require.NoError(t, err) - require.Equal(t, newConnector(tt.connectorOpts...), newConnector(connectorOpts...)) - compareConfigs(t, config.New(tt.opts...), config.New(opts...)) + d, err := newConnectionFromOptions(context.Background(), opts...) + require.NoError(t, err) + require.Equal(t, newConnector(tt.connectorOpts...), newConnector(d.databaseSQLOptions...)) + compareConfigs(t, config.New(tt.opts...), d.config) } }) } diff --git a/internal/xsql/dsn.go b/internal/xsql/dsn.go deleted file mode 100644 index 508308995..000000000 --- a/internal/xsql/dsn.go +++ /dev/null @@ -1,98 +0,0 @@ -package xsql - -import ( - "errors" - "fmt" - "regexp" - "strings" - - "github.com/ydb-platform/ydb-go-sdk/v3/balancers" - "github.com/ydb-platform/ydb-go-sdk/v3/config" - "github.com/ydb-platform/ydb-go-sdk/v3/credentials" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/bind" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/dsn" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" -) - -const tablePathPrefixTransformer = "table_path_prefix" - -func Parse(dataSourceName string) (opts []config.Option, connectorOpts []ConnectorOption, _ error) { - info, err := dsn.Parse(dataSourceName) - if err != nil { - return nil, nil, xerrors.WithStackTrace(err) - } - opts = append(opts, info.Options...) - if token := info.Params.Get("token"); token != "" { - opts = append(opts, config.WithCredentials(credentials.NewAccessTokenCredentials(token))) - } - if balancer := info.Params.Get("go_balancer"); balancer != "" { - opts = append(opts, config.WithBalancer(balancers.FromConfig(balancer))) - } else if balancer := info.Params.Get("balancer"); balancer != "" { - opts = append(opts, config.WithBalancer(balancers.FromConfig(balancer))) - } - if queryMode := info.Params.Get("go_query_mode"); queryMode != "" { - mode := QueryModeFromString(queryMode) - if mode == UnknownQueryMode { - return nil, nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode)) - } - connectorOpts = append(connectorOpts, WithDefaultQueryMode(mode)) - } else if queryMode := info.Params.Get("query_mode"); queryMode != "" { - mode := QueryModeFromString(queryMode) - if mode == UnknownQueryMode { - return nil, nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode)) - } - connectorOpts = append(connectorOpts, WithDefaultQueryMode(mode)) - } - if fakeTx := info.Params.Get("go_fake_tx"); fakeTx != "" { - for _, queryMode := range strings.Split(fakeTx, ",") { - mode := QueryModeFromString(queryMode) - if mode == UnknownQueryMode { - return nil, nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode)) - } - connectorOpts = append(connectorOpts, WithFakeTx(mode)) - } - } - if info.Params.Has("go_query_bind") { - var binders []ConnectorOption - queryTransformers := strings.Split(info.Params.Get("go_query_bind"), ",") - for _, transformer := range queryTransformers { - switch transformer { - case "declare": - binders = append(binders, WithQueryBind(bind.AutoDeclare{})) - case "positional": - binders = append(binders, WithQueryBind(bind.PositionalArgs{})) - case "numeric": - binders = append(binders, WithQueryBind(bind.NumericArgs{})) - default: - if strings.HasPrefix(transformer, tablePathPrefixTransformer) { - prefix, err := extractTablePathPrefixFromBinderName(transformer) - if err != nil { - return nil, nil, xerrors.WithStackTrace(err) - } - binders = append(binders, WithTablePathPrefix(prefix)) - } else { - return nil, nil, xerrors.WithStackTrace( - fmt.Errorf("unknown query rewriter: %s", transformer), - ) - } - } - } - connectorOpts = append(connectorOpts, binders...) - } - - return opts, connectorOpts, nil -} - -var ( - tablePathPrefixRe = regexp.MustCompile(tablePathPrefixTransformer + "\\((.*)\\)") - errWrongTablePathPrefix = errors.New("wrong '" + tablePathPrefixTransformer + "' query transformer") -) - -func extractTablePathPrefixFromBinderName(binderName string) (string, error) { - ss := tablePathPrefixRe.FindAllStringSubmatch(binderName, -1) - if len(ss) != 1 || len(ss[0]) != 2 || ss[0][1] == "" { - return "", xerrors.WithStackTrace(fmt.Errorf("%w: %s", errWrongTablePathPrefix, binderName)) - } - - return ss[0][1], nil -} diff --git a/options.go b/options.go index 0798890ce..2adf098b4 100644 --- a/options.go +++ b/options.go @@ -31,7 +31,7 @@ import ( ) // Option contains configuration values for Driver -type Option func(ctx context.Context, c *Driver) error +type Option func(ctx context.Context, d *Driver) error func WithStaticCredentials(user, password string) Option { return func(ctx context.Context, c *Driver) error { diff --git a/sql.go b/sql.go index efe7cf92e..0ec78b964 100644 --- a/sql.go +++ b/sql.go @@ -22,6 +22,14 @@ func init() { //nolint:gochecknoinits sql.Register("ydb/v3", d) } +func withConnectorOptions(opts ...ConnectorOption) Option { + return func(ctx context.Context, d *Driver) error { + d.databaseSQLOptions = append(d.databaseSQLOptions, opts...) + + return nil + } +} + type sqlDriver struct { connectors map[*xsql.Connector]*Driver connectorsMtx xsync.RWMutex @@ -56,16 +64,12 @@ func (d *sqlDriver) Open(string) (driver.Conn, error) { } func (d *sqlDriver) OpenConnector(dataSourceName string) (driver.Connector, error) { - opts, connectorOpts, err := xsql.Parse(dataSourceName) - if err != nil { - return nil, xerrors.WithStackTrace(fmt.Errorf("data source name '%s' wrong: %w", dataSourceName, err)) - } - db, err := Open(context.Background(), "", With(opts...)) + db, err := Open(context.Background(), dataSourceName) if err != nil { return nil, xerrors.WithStackTrace(fmt.Errorf("failed to connect by data source name '%s': %w", dataSourceName, err)) } - return Connector(db, connectorOpts...) + return Connector(db, db.databaseSQLOptions...) } func (d *sqlDriver) attach(c *xsql.Connector, parent *Driver) { diff --git a/tests/integration/register_dsn_parser_test.go b/tests/integration/register_dsn_parser_test.go new file mode 100644 index 000000000..8c4d88bdc --- /dev/null +++ b/tests/integration/register_dsn_parser_test.go @@ -0,0 +1,54 @@ +//go:build integration +// +build integration + +package integration + +import ( + "context" + "database/sql" + "os" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ydb-platform/ydb-go-sdk/v3" +) + +func TestRegisterDsnParser(t *testing.T) { + t.Run("native", func(t *testing.T) { + var visited bool + ydb.RegisterDsnParser(func(dsn string) (opts []ydb.Option, _ error) { + return []ydb.Option{ + func(ctx context.Context, d *ydb.Driver) error { + visited = true + + return nil + }, + }, nil + }) + db, err := ydb.Open(context.Background(), os.Getenv("YDB_CONNECTION_STRING")) + require.NoError(t, err) + require.True(t, visited) + defer func() { + _ = db.Close(context.Background()) + }() + }) + t.Run("database/sql", func(t *testing.T) { + var visited bool + ydb.RegisterDsnParser(func(dsn string) (opts []ydb.Option, _ error) { + return []ydb.Option{ + func(ctx context.Context, d *ydb.Driver) error { + visited = true + + return nil + }, + }, nil + }) + db, err := sql.Open("ydb", os.Getenv("YDB_CONNECTION_STRING")) + require.NoError(t, err) + require.True(t, visited) + defer func() { + _ = db.Close() + }() + }) +}