Skip to content

Commit

Permalink
* Added experimental ydb.{Register,Unregister}DsnParser global func…
Browse files Browse the repository at this point in the history
…s for register/unregister external custom DSN parser for `ydb.Open` and `sql.Open` driver constructor
  • Loading branch information
asmyasnikov committed May 21, 2024
1 parent d6d2af1 commit 35d60ba
Show file tree
Hide file tree
Showing 15 changed files with 507 additions and 390 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Added experimental `ydb.{Register,Unregister}DsnParser` global funcs for register/unregister external custom DSN parser for `ydb.Open` and `sql.Open` driver constructor
* Simple implement option WithReaderWithoutConsumer
* Fixed bug: topic didn't send specified partition number to a server

Expand Down
22 changes: 15 additions & 7 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ydb
import (
"context"
"errors"
"fmt"
"os"
"sync"

Expand Down Expand Up @@ -234,13 +235,20 @@ func (d *Driver) Topic() topic.Client {
// See sugar.DSN helper for make dsn from endpoint and database
//
//nolint:nonamedreturns
func Open(ctx context.Context, dsn string, opts ...Option) (_ *Driver, err error) {
d, err := newConnectionFromOptions(ctx, append(
[]Option{
WithConnectionString(dsn),
},
opts...,
)...)
func Open(ctx context.Context, dsn string, opts ...Option) (_ *Driver, _ error) {
opts = append(append(make([]Option, 0, len(opts)+1), WithConnectionString(dsn)), opts...)

for parserIdx := range dsnParsers {
if parser := dsnParsers[parserIdx]; parser != nil {
optsFromParser, err := parser(dsn)
if err != nil {
return nil, xerrors.WithStackTrace(fmt.Errorf("data source name '%s' wrong: %w", dsn, err))
}
opts = append(opts, optsFromParser...)
}
}

d, err := newConnectionFromOptions(ctx, opts...)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}
Expand Down
125 changes: 125 additions & 0 deletions dsn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package ydb

import (
"errors"
"fmt"
"regexp"
"strings"

"github.com/ydb-platform/ydb-go-sdk/v3/balancers"
"github.com/ydb-platform/ydb-go-sdk/v3/credentials"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/bind"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/dsn"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql"
)

const tablePathPrefixTransformer = "table_path_prefix"

var dsnParsers = []func(dsn string) (opts []Option, _ error){
func(dsn string) ([]Option, error) {
opts, err := parseConnectionString(dsn)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}

return opts, nil
},
}

// RegisterDsnParser registers DSN parser for ydb.Open and sql.Open driver constructors
//
// Experimental: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#experimental
func RegisterDsnParser(parser func(dsn string) (opts []Option, _ error)) (registrationID int) {
dsnParsers = append(dsnParsers, parser)

return len(dsnParsers) - 1
}

// UnregisterDsnParser unregisters DSN parser by key
//
// Experimental: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#experimental
func UnregisterDsnParser(registrationID int) {
dsnParsers[registrationID] = nil
}

func parseConnectionString(dataSourceName string) (opts []Option, _ error) {
info, err := dsn.Parse(dataSourceName)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}
opts = append(opts, With(info.Options...))
if token := info.Params.Get("token"); token != "" {
opts = append(opts, WithCredentials(credentials.NewAccessTokenCredentials(token)))
}
if balancer := info.Params.Get("go_balancer"); balancer != "" {
opts = append(opts, WithBalancer(balancers.FromConfig(balancer)))
} else if balancer := info.Params.Get("balancer"); balancer != "" {
opts = append(opts, WithBalancer(balancers.FromConfig(balancer)))
}
if queryMode := info.Params.Get("go_query_mode"); queryMode != "" {
mode := xsql.QueryModeFromString(queryMode)
if mode == xsql.UnknownQueryMode {
return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
}
opts = append(opts, withConnectorOptions(xsql.WithDefaultQueryMode(mode)))
} else if queryMode := info.Params.Get("query_mode"); queryMode != "" {
mode := xsql.QueryModeFromString(queryMode)
if mode == xsql.UnknownQueryMode {
return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
}
opts = append(opts, withConnectorOptions(xsql.WithDefaultQueryMode(mode)))
}
if fakeTx := info.Params.Get("go_fake_tx"); fakeTx != "" {
for _, queryMode := range strings.Split(fakeTx, ",") {
mode := xsql.QueryModeFromString(queryMode)
if mode == xsql.UnknownQueryMode {
return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
}
opts = append(opts, withConnectorOptions(xsql.WithFakeTx(mode)))
}
}
if info.Params.Has("go_query_bind") {
var binders []xsql.ConnectorOption
queryTransformers := strings.Split(info.Params.Get("go_query_bind"), ",")
for _, transformer := range queryTransformers {
switch transformer {
case "declare":
binders = append(binders, xsql.WithQueryBind(bind.AutoDeclare{}))
case "positional":
binders = append(binders, xsql.WithQueryBind(bind.PositionalArgs{}))
case "numeric":
binders = append(binders, xsql.WithQueryBind(bind.NumericArgs{}))
default:
if strings.HasPrefix(transformer, tablePathPrefixTransformer) {
prefix, err := extractTablePathPrefixFromBinderName(transformer)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}
binders = append(binders, xsql.WithTablePathPrefix(prefix))
} else {
return nil, xerrors.WithStackTrace(
fmt.Errorf("unknown query rewriter: %s", transformer),
)
}
}
}
opts = append(opts, withConnectorOptions(binders...))
}

return opts, nil
}

var (
tablePathPrefixRe = regexp.MustCompile(tablePathPrefixTransformer + "\\((.*)\\)")
errWrongTablePathPrefix = errors.New("wrong '" + tablePathPrefixTransformer + "' query transformer")
)

func extractTablePathPrefixFromBinderName(binderName string) (string, error) {
ss := tablePathPrefixRe.FindAllStringSubmatch(binderName, -1)
if len(ss) != 1 || len(ss[0]) != 2 || ss[0][1] == "" {
return "", xerrors.WithStackTrace(fmt.Errorf("%w: %s", errWrongTablePathPrefix, binderName))
}

return ss[0][1], nil
}
74 changes: 39 additions & 35 deletions internal/xsql/dsn_test.go → dsn_test.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
package xsql
package ydb

import (
"context"
"testing"

"github.com/stretchr/testify/require"

"github.com/ydb-platform/ydb-go-sdk/v3/config"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/bind"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql"
)

func TestParse(t *testing.T) {
newConnector := func(opts ...ConnectorOption) *Connector {
c := &Connector{}
newConnector := func(opts ...xsql.ConnectorOption) *xsql.Connector {
c := &xsql.Connector{}
for _, opt := range opts {
if opt != nil {
if err := opt.Apply(c); err != nil {
Expand All @@ -30,7 +32,7 @@ func TestParse(t *testing.T) {
for _, tt := range []struct {
dsn string
opts []config.Option
connectorOpts []ConnectorOption
connectorOpts []xsql.ConnectorOption
err error
}{
{
Expand All @@ -40,9 +42,9 @@ func TestParse(t *testing.T) {
config.WithEndpoint("localhost:2135"),
config.WithDatabase("/local"),
},
connectorOpts: []ConnectorOption{
WithFakeTx(ScriptingQueryMode),
WithFakeTx(SchemeQueryMode),
connectorOpts: []xsql.ConnectorOption{
xsql.WithFakeTx(xsql.ScriptingQueryMode),
xsql.WithFakeTx(xsql.SchemeQueryMode),
},
err: nil,
},
Expand Down Expand Up @@ -73,8 +75,8 @@ func TestParse(t *testing.T) {
config.WithEndpoint("localhost:2135"),
config.WithDatabase("/local"),
},
connectorOpts: []ConnectorOption{
WithDefaultQueryMode(ScriptingQueryMode),
connectorOpts: []xsql.ConnectorOption{
xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode),
},
err: nil,
},
Expand All @@ -85,9 +87,9 @@ func TestParse(t *testing.T) {
config.WithEndpoint("localhost:2135"),
config.WithDatabase("/local"),
},
connectorOpts: []ConnectorOption{
WithDefaultQueryMode(ScriptingQueryMode),
WithTablePathPrefix("path/to/tables"),
connectorOpts: []xsql.ConnectorOption{
xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode),
xsql.WithTablePathPrefix("path/to/tables"),
},
err: nil,
},
Expand All @@ -98,10 +100,10 @@ func TestParse(t *testing.T) {
config.WithEndpoint("localhost:2135"),
config.WithDatabase("/local"),
},
connectorOpts: []ConnectorOption{
WithDefaultQueryMode(ScriptingQueryMode),
WithTablePathPrefix("path/to/tables"),
WithQueryBind(bind.NumericArgs{}),
connectorOpts: []xsql.ConnectorOption{
xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode),
xsql.WithTablePathPrefix("path/to/tables"),
xsql.WithQueryBind(bind.NumericArgs{}),
},
err: nil,
},
Expand All @@ -112,10 +114,10 @@ func TestParse(t *testing.T) {
config.WithEndpoint("localhost:2135"),
config.WithDatabase("/local"),
},
connectorOpts: []ConnectorOption{
WithDefaultQueryMode(ScriptingQueryMode),
WithTablePathPrefix("path/to/tables"),
WithQueryBind(bind.PositionalArgs{}),
connectorOpts: []xsql.ConnectorOption{
xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode),
xsql.WithTablePathPrefix("path/to/tables"),
xsql.WithQueryBind(bind.PositionalArgs{}),
},
err: nil,
},
Expand All @@ -126,10 +128,10 @@ func TestParse(t *testing.T) {
config.WithEndpoint("localhost:2135"),
config.WithDatabase("/local"),
},
connectorOpts: []ConnectorOption{
WithDefaultQueryMode(ScriptingQueryMode),
WithTablePathPrefix("path/to/tables"),
WithQueryBind(bind.AutoDeclare{}),
connectorOpts: []xsql.ConnectorOption{
xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode),
xsql.WithTablePathPrefix("path/to/tables"),
xsql.WithQueryBind(bind.AutoDeclare{}),
},
err: nil,
},
Expand All @@ -140,9 +142,9 @@ func TestParse(t *testing.T) {
config.WithEndpoint("localhost:2135"),
config.WithDatabase("/local"),
},
connectorOpts: []ConnectorOption{
WithDefaultQueryMode(ScriptingQueryMode),
WithTablePathPrefix("path/to/tables"),
connectorOpts: []xsql.ConnectorOption{
xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode),
xsql.WithTablePathPrefix("path/to/tables"),
},
err: nil,
},
Expand All @@ -153,23 +155,25 @@ func TestParse(t *testing.T) {
config.WithEndpoint("localhost:2135"),
config.WithDatabase("/local"),
},
connectorOpts: []ConnectorOption{
WithDefaultQueryMode(ScriptingQueryMode),
WithTablePathPrefix("path/to/tables"),
WithQueryBind(bind.PositionalArgs{}),
WithQueryBind(bind.AutoDeclare{}),
connectorOpts: []xsql.ConnectorOption{
xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode),
xsql.WithTablePathPrefix("path/to/tables"),
xsql.WithQueryBind(bind.PositionalArgs{}),
xsql.WithQueryBind(bind.AutoDeclare{}),
},
err: nil,
},
} {
t.Run("", func(t *testing.T) {
opts, connectorOpts, err := Parse(tt.dsn)
opts, err := parseConnectionString(tt.dsn)
if tt.err != nil {
require.ErrorIs(t, err, tt.err)
} else {
require.NoError(t, err)
require.Equal(t, newConnector(tt.connectorOpts...), newConnector(connectorOpts...))
compareConfigs(t, config.New(tt.opts...), config.New(opts...))
d, err := newConnectionFromOptions(context.Background(), opts...)
require.NoError(t, err)
require.Equal(t, newConnector(tt.connectorOpts...), newConnector(d.databaseSQLOptions...))
compareConfigs(t, config.New(tt.opts...), d.config)
}
})
}
Expand Down
Loading

0 comments on commit 35d60ba

Please sign in to comment.