From 35d60ba236439a50134c73a8fc7d99c56e1d1c77 Mon Sep 17 00:00:00 2001 From: Aleksey Myasnikov Date: Mon, 20 May 2024 19:59:06 +0300 Subject: [PATCH] * Added experimental `ydb.{Register,Unregister}DsnParser` global funcs for register/unregister external custom DSN parser for `ydb.Open` and `sql.Open` driver constructor --- CHANGELOG.md | 1 + driver.go | 22 +- dsn.go | 125 +++++ internal/xsql/dsn_test.go => dsn_test.go | 74 +-- internal/xsql/dsn.go | 98 ---- options.go | 2 +- sql.go | 16 +- tests/integration/connection_secure_test.go | 4 +- tests/integration/connection_test.go | 480 +++++++++--------- .../connection_with_compression_test.go | 3 +- tests/integration/monitoring_test.go | 5 +- tests/integration/register_dsn_parser_test.go | 56 ++ tests/integration/static_credentials_test.go | 3 +- .../table_multiple_result_sets_test.go | 3 +- tests/integration/table_tx_lazy_test.go | 5 +- 15 files changed, 507 insertions(+), 390 deletions(-) create mode 100644 dsn.go rename internal/xsql/dsn_test.go => dsn_test.go (71%) delete mode 100644 internal/xsql/dsn.go create mode 100644 tests/integration/register_dsn_parser_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 98b06594c..04be31823 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +* Added experimental `ydb.{Register,Unregister}DsnParser` global funcs for register/unregister external custom DSN parser for `ydb.Open` and `sql.Open` driver constructor * Simple implement option WithReaderWithoutConsumer * Fixed bug: topic didn't send specified partition number to a server diff --git a/driver.go b/driver.go index 170e2f86d..7ef1e765d 100644 --- a/driver.go +++ b/driver.go @@ -3,6 +3,7 @@ package ydb import ( "context" "errors" + "fmt" "os" "sync" @@ -234,13 +235,20 @@ func (d *Driver) Topic() topic.Client { // See sugar.DSN helper for make dsn from endpoint and database // //nolint:nonamedreturns -func Open(ctx context.Context, dsn string, opts ...Option) (_ *Driver, err error) { - d, err := newConnectionFromOptions(ctx, append( - []Option{ - WithConnectionString(dsn), - }, - opts..., - )...) +func Open(ctx context.Context, dsn string, opts ...Option) (_ *Driver, _ error) { + opts = append(append(make([]Option, 0, len(opts)+1), WithConnectionString(dsn)), opts...) + + for parserIdx := range dsnParsers { + if parser := dsnParsers[parserIdx]; parser != nil { + optsFromParser, err := parser(dsn) + if err != nil { + return nil, xerrors.WithStackTrace(fmt.Errorf("data source name '%s' wrong: %w", dsn, err)) + } + opts = append(opts, optsFromParser...) + } + } + + d, err := newConnectionFromOptions(ctx, opts...) if err != nil { return nil, xerrors.WithStackTrace(err) } diff --git a/dsn.go b/dsn.go new file mode 100644 index 000000000..58739d80a --- /dev/null +++ b/dsn.go @@ -0,0 +1,125 @@ +package ydb + +import ( + "errors" + "fmt" + "regexp" + "strings" + + "github.com/ydb-platform/ydb-go-sdk/v3/balancers" + "github.com/ydb-platform/ydb-go-sdk/v3/credentials" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/bind" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/dsn" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql" +) + +const tablePathPrefixTransformer = "table_path_prefix" + +var dsnParsers = []func(dsn string) (opts []Option, _ error){ + func(dsn string) ([]Option, error) { + opts, err := parseConnectionString(dsn) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + + return opts, nil + }, +} + +// RegisterDsnParser registers DSN parser for ydb.Open and sql.Open driver constructors +// +// Experimental: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#experimental +func RegisterDsnParser(parser func(dsn string) (opts []Option, _ error)) (registrationID int) { + dsnParsers = append(dsnParsers, parser) + + return len(dsnParsers) - 1 +} + +// UnregisterDsnParser unregisters DSN parser by key +// +// Experimental: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#experimental +func UnregisterDsnParser(registrationID int) { + dsnParsers[registrationID] = nil +} + +func parseConnectionString(dataSourceName string) (opts []Option, _ error) { + info, err := dsn.Parse(dataSourceName) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + opts = append(opts, With(info.Options...)) + if token := info.Params.Get("token"); token != "" { + opts = append(opts, WithCredentials(credentials.NewAccessTokenCredentials(token))) + } + if balancer := info.Params.Get("go_balancer"); balancer != "" { + opts = append(opts, WithBalancer(balancers.FromConfig(balancer))) + } else if balancer := info.Params.Get("balancer"); balancer != "" { + opts = append(opts, WithBalancer(balancers.FromConfig(balancer))) + } + if queryMode := info.Params.Get("go_query_mode"); queryMode != "" { + mode := xsql.QueryModeFromString(queryMode) + if mode == xsql.UnknownQueryMode { + return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode)) + } + opts = append(opts, withConnectorOptions(xsql.WithDefaultQueryMode(mode))) + } else if queryMode := info.Params.Get("query_mode"); queryMode != "" { + mode := xsql.QueryModeFromString(queryMode) + if mode == xsql.UnknownQueryMode { + return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode)) + } + opts = append(opts, withConnectorOptions(xsql.WithDefaultQueryMode(mode))) + } + if fakeTx := info.Params.Get("go_fake_tx"); fakeTx != "" { + for _, queryMode := range strings.Split(fakeTx, ",") { + mode := xsql.QueryModeFromString(queryMode) + if mode == xsql.UnknownQueryMode { + return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode)) + } + opts = append(opts, withConnectorOptions(xsql.WithFakeTx(mode))) + } + } + if info.Params.Has("go_query_bind") { + var binders []xsql.ConnectorOption + queryTransformers := strings.Split(info.Params.Get("go_query_bind"), ",") + for _, transformer := range queryTransformers { + switch transformer { + case "declare": + binders = append(binders, xsql.WithQueryBind(bind.AutoDeclare{})) + case "positional": + binders = append(binders, xsql.WithQueryBind(bind.PositionalArgs{})) + case "numeric": + binders = append(binders, xsql.WithQueryBind(bind.NumericArgs{})) + default: + if strings.HasPrefix(transformer, tablePathPrefixTransformer) { + prefix, err := extractTablePathPrefixFromBinderName(transformer) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + binders = append(binders, xsql.WithTablePathPrefix(prefix)) + } else { + return nil, xerrors.WithStackTrace( + fmt.Errorf("unknown query rewriter: %s", transformer), + ) + } + } + } + opts = append(opts, withConnectorOptions(binders...)) + } + + return opts, nil +} + +var ( + tablePathPrefixRe = regexp.MustCompile(tablePathPrefixTransformer + "\\((.*)\\)") + errWrongTablePathPrefix = errors.New("wrong '" + tablePathPrefixTransformer + "' query transformer") +) + +func extractTablePathPrefixFromBinderName(binderName string) (string, error) { + ss := tablePathPrefixRe.FindAllStringSubmatch(binderName, -1) + if len(ss) != 1 || len(ss[0]) != 2 || ss[0][1] == "" { + return "", xerrors.WithStackTrace(fmt.Errorf("%w: %s", errWrongTablePathPrefix, binderName)) + } + + return ss[0][1], nil +} diff --git a/internal/xsql/dsn_test.go b/dsn_test.go similarity index 71% rename from internal/xsql/dsn_test.go rename to dsn_test.go index d28c96ac3..e10da8606 100644 --- a/internal/xsql/dsn_test.go +++ b/dsn_test.go @@ -1,17 +1,19 @@ -package xsql +package ydb import ( + "context" "testing" "github.com/stretchr/testify/require" "github.com/ydb-platform/ydb-go-sdk/v3/config" "github.com/ydb-platform/ydb-go-sdk/v3/internal/bind" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql" ) func TestParse(t *testing.T) { - newConnector := func(opts ...ConnectorOption) *Connector { - c := &Connector{} + newConnector := func(opts ...xsql.ConnectorOption) *xsql.Connector { + c := &xsql.Connector{} for _, opt := range opts { if opt != nil { if err := opt.Apply(c); err != nil { @@ -30,7 +32,7 @@ func TestParse(t *testing.T) { for _, tt := range []struct { dsn string opts []config.Option - connectorOpts []ConnectorOption + connectorOpts []xsql.ConnectorOption err error }{ { @@ -40,9 +42,9 @@ func TestParse(t *testing.T) { config.WithEndpoint("localhost:2135"), config.WithDatabase("/local"), }, - connectorOpts: []ConnectorOption{ - WithFakeTx(ScriptingQueryMode), - WithFakeTx(SchemeQueryMode), + connectorOpts: []xsql.ConnectorOption{ + xsql.WithFakeTx(xsql.ScriptingQueryMode), + xsql.WithFakeTx(xsql.SchemeQueryMode), }, err: nil, }, @@ -73,8 +75,8 @@ func TestParse(t *testing.T) { config.WithEndpoint("localhost:2135"), config.WithDatabase("/local"), }, - connectorOpts: []ConnectorOption{ - WithDefaultQueryMode(ScriptingQueryMode), + connectorOpts: []xsql.ConnectorOption{ + xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode), }, err: nil, }, @@ -85,9 +87,9 @@ func TestParse(t *testing.T) { config.WithEndpoint("localhost:2135"), config.WithDatabase("/local"), }, - connectorOpts: []ConnectorOption{ - WithDefaultQueryMode(ScriptingQueryMode), - WithTablePathPrefix("path/to/tables"), + connectorOpts: []xsql.ConnectorOption{ + xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode), + xsql.WithTablePathPrefix("path/to/tables"), }, err: nil, }, @@ -98,10 +100,10 @@ func TestParse(t *testing.T) { config.WithEndpoint("localhost:2135"), config.WithDatabase("/local"), }, - connectorOpts: []ConnectorOption{ - WithDefaultQueryMode(ScriptingQueryMode), - WithTablePathPrefix("path/to/tables"), - WithQueryBind(bind.NumericArgs{}), + connectorOpts: []xsql.ConnectorOption{ + xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode), + xsql.WithTablePathPrefix("path/to/tables"), + xsql.WithQueryBind(bind.NumericArgs{}), }, err: nil, }, @@ -112,10 +114,10 @@ func TestParse(t *testing.T) { config.WithEndpoint("localhost:2135"), config.WithDatabase("/local"), }, - connectorOpts: []ConnectorOption{ - WithDefaultQueryMode(ScriptingQueryMode), - WithTablePathPrefix("path/to/tables"), - WithQueryBind(bind.PositionalArgs{}), + connectorOpts: []xsql.ConnectorOption{ + xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode), + xsql.WithTablePathPrefix("path/to/tables"), + xsql.WithQueryBind(bind.PositionalArgs{}), }, err: nil, }, @@ -126,10 +128,10 @@ func TestParse(t *testing.T) { config.WithEndpoint("localhost:2135"), config.WithDatabase("/local"), }, - connectorOpts: []ConnectorOption{ - WithDefaultQueryMode(ScriptingQueryMode), - WithTablePathPrefix("path/to/tables"), - WithQueryBind(bind.AutoDeclare{}), + connectorOpts: []xsql.ConnectorOption{ + xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode), + xsql.WithTablePathPrefix("path/to/tables"), + xsql.WithQueryBind(bind.AutoDeclare{}), }, err: nil, }, @@ -140,9 +142,9 @@ func TestParse(t *testing.T) { config.WithEndpoint("localhost:2135"), config.WithDatabase("/local"), }, - connectorOpts: []ConnectorOption{ - WithDefaultQueryMode(ScriptingQueryMode), - WithTablePathPrefix("path/to/tables"), + connectorOpts: []xsql.ConnectorOption{ + xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode), + xsql.WithTablePathPrefix("path/to/tables"), }, err: nil, }, @@ -153,23 +155,25 @@ func TestParse(t *testing.T) { config.WithEndpoint("localhost:2135"), config.WithDatabase("/local"), }, - connectorOpts: []ConnectorOption{ - WithDefaultQueryMode(ScriptingQueryMode), - WithTablePathPrefix("path/to/tables"), - WithQueryBind(bind.PositionalArgs{}), - WithQueryBind(bind.AutoDeclare{}), + connectorOpts: []xsql.ConnectorOption{ + xsql.WithDefaultQueryMode(xsql.ScriptingQueryMode), + xsql.WithTablePathPrefix("path/to/tables"), + xsql.WithQueryBind(bind.PositionalArgs{}), + xsql.WithQueryBind(bind.AutoDeclare{}), }, err: nil, }, } { t.Run("", func(t *testing.T) { - opts, connectorOpts, err := Parse(tt.dsn) + opts, err := parseConnectionString(tt.dsn) if tt.err != nil { require.ErrorIs(t, err, tt.err) } else { require.NoError(t, err) - require.Equal(t, newConnector(tt.connectorOpts...), newConnector(connectorOpts...)) - compareConfigs(t, config.New(tt.opts...), config.New(opts...)) + d, err := newConnectionFromOptions(context.Background(), opts...) + require.NoError(t, err) + require.Equal(t, newConnector(tt.connectorOpts...), newConnector(d.databaseSQLOptions...)) + compareConfigs(t, config.New(tt.opts...), d.config) } }) } diff --git a/internal/xsql/dsn.go b/internal/xsql/dsn.go deleted file mode 100644 index 508308995..000000000 --- a/internal/xsql/dsn.go +++ /dev/null @@ -1,98 +0,0 @@ -package xsql - -import ( - "errors" - "fmt" - "regexp" - "strings" - - "github.com/ydb-platform/ydb-go-sdk/v3/balancers" - "github.com/ydb-platform/ydb-go-sdk/v3/config" - "github.com/ydb-platform/ydb-go-sdk/v3/credentials" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/bind" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/dsn" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" -) - -const tablePathPrefixTransformer = "table_path_prefix" - -func Parse(dataSourceName string) (opts []config.Option, connectorOpts []ConnectorOption, _ error) { - info, err := dsn.Parse(dataSourceName) - if err != nil { - return nil, nil, xerrors.WithStackTrace(err) - } - opts = append(opts, info.Options...) - if token := info.Params.Get("token"); token != "" { - opts = append(opts, config.WithCredentials(credentials.NewAccessTokenCredentials(token))) - } - if balancer := info.Params.Get("go_balancer"); balancer != "" { - opts = append(opts, config.WithBalancer(balancers.FromConfig(balancer))) - } else if balancer := info.Params.Get("balancer"); balancer != "" { - opts = append(opts, config.WithBalancer(balancers.FromConfig(balancer))) - } - if queryMode := info.Params.Get("go_query_mode"); queryMode != "" { - mode := QueryModeFromString(queryMode) - if mode == UnknownQueryMode { - return nil, nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode)) - } - connectorOpts = append(connectorOpts, WithDefaultQueryMode(mode)) - } else if queryMode := info.Params.Get("query_mode"); queryMode != "" { - mode := QueryModeFromString(queryMode) - if mode == UnknownQueryMode { - return nil, nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode)) - } - connectorOpts = append(connectorOpts, WithDefaultQueryMode(mode)) - } - if fakeTx := info.Params.Get("go_fake_tx"); fakeTx != "" { - for _, queryMode := range strings.Split(fakeTx, ",") { - mode := QueryModeFromString(queryMode) - if mode == UnknownQueryMode { - return nil, nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode)) - } - connectorOpts = append(connectorOpts, WithFakeTx(mode)) - } - } - if info.Params.Has("go_query_bind") { - var binders []ConnectorOption - queryTransformers := strings.Split(info.Params.Get("go_query_bind"), ",") - for _, transformer := range queryTransformers { - switch transformer { - case "declare": - binders = append(binders, WithQueryBind(bind.AutoDeclare{})) - case "positional": - binders = append(binders, WithQueryBind(bind.PositionalArgs{})) - case "numeric": - binders = append(binders, WithQueryBind(bind.NumericArgs{})) - default: - if strings.HasPrefix(transformer, tablePathPrefixTransformer) { - prefix, err := extractTablePathPrefixFromBinderName(transformer) - if err != nil { - return nil, nil, xerrors.WithStackTrace(err) - } - binders = append(binders, WithTablePathPrefix(prefix)) - } else { - return nil, nil, xerrors.WithStackTrace( - fmt.Errorf("unknown query rewriter: %s", transformer), - ) - } - } - } - connectorOpts = append(connectorOpts, binders...) - } - - return opts, connectorOpts, nil -} - -var ( - tablePathPrefixRe = regexp.MustCompile(tablePathPrefixTransformer + "\\((.*)\\)") - errWrongTablePathPrefix = errors.New("wrong '" + tablePathPrefixTransformer + "' query transformer") -) - -func extractTablePathPrefixFromBinderName(binderName string) (string, error) { - ss := tablePathPrefixRe.FindAllStringSubmatch(binderName, -1) - if len(ss) != 1 || len(ss[0]) != 2 || ss[0][1] == "" { - return "", xerrors.WithStackTrace(fmt.Errorf("%w: %s", errWrongTablePathPrefix, binderName)) - } - - return ss[0][1], nil -} diff --git a/options.go b/options.go index 0798890ce..2adf098b4 100644 --- a/options.go +++ b/options.go @@ -31,7 +31,7 @@ import ( ) // Option contains configuration values for Driver -type Option func(ctx context.Context, c *Driver) error +type Option func(ctx context.Context, d *Driver) error func WithStaticCredentials(user, password string) Option { return func(ctx context.Context, c *Driver) error { diff --git a/sql.go b/sql.go index efe7cf92e..0ec78b964 100644 --- a/sql.go +++ b/sql.go @@ -22,6 +22,14 @@ func init() { //nolint:gochecknoinits sql.Register("ydb/v3", d) } +func withConnectorOptions(opts ...ConnectorOption) Option { + return func(ctx context.Context, d *Driver) error { + d.databaseSQLOptions = append(d.databaseSQLOptions, opts...) + + return nil + } +} + type sqlDriver struct { connectors map[*xsql.Connector]*Driver connectorsMtx xsync.RWMutex @@ -56,16 +64,12 @@ func (d *sqlDriver) Open(string) (driver.Conn, error) { } func (d *sqlDriver) OpenConnector(dataSourceName string) (driver.Connector, error) { - opts, connectorOpts, err := xsql.Parse(dataSourceName) - if err != nil { - return nil, xerrors.WithStackTrace(fmt.Errorf("data source name '%s' wrong: %w", dataSourceName, err)) - } - db, err := Open(context.Background(), "", With(opts...)) + db, err := Open(context.Background(), dataSourceName) if err != nil { return nil, xerrors.WithStackTrace(fmt.Errorf("failed to connect by data source name '%s': %w", dataSourceName, err)) } - return Connector(db, connectorOpts...) + return Connector(db, db.databaseSQLOptions...) } func (d *sqlDriver) attach(c *xsql.Connector, parent *Driver) { diff --git a/tests/integration/connection_secure_test.go b/tests/integration/connection_secure_test.go index c78c4fb13..7b70f927e 100644 --- a/tests/integration/connection_secure_test.go +++ b/tests/integration/connection_secure_test.go @@ -39,9 +39,7 @@ func TestConnectionSecure(t *testing.T) { const sumColumn = "sum" ctx := xtest.Context(t) - db, err := ydb.Open(ctx, - "", // corner case for check replacement of endpoint+database+secure - ydb.WithConnectionString(dsn), + db, err := ydb.Open(ctx, dsn, ydb.WithAccessTokenCredentials( os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), ), diff --git a/tests/integration/connection_test.go b/tests/integration/connection_test.go index fcbcfb837..6db336e34 100644 --- a/tests/integration/connection_test.go +++ b/tests/integration/connection_test.go @@ -71,253 +71,281 @@ func TestConnection(t *testing.T) { ctx = xtest.Context(t) ) - db, err := ydb.Open(ctx, - "", // corner case for check replacement of endpoint+database+secure - ydb.WithConnectionString(os.Getenv("YDB_CONNECTION_STRING")), - ydb.WithAccessTokenCredentials( - os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), - ), - ydb.With( - config.WithOperationTimeout(time.Second*2), - config.WithOperationCancelAfter(time.Second*2), - ), - ydb.WithConnectionTTL(time.Millisecond*10000), - ydb.WithMinTLSVersion(tls.VersionTLS10), - ydb.WithLogger( - newLoggerWithMinLevel(t, log.WARN), - trace.MatchDetails(`ydb\.(driver|discovery|retry|scheme).*`), - ), - ydb.WithApplicationName(userAgent), - ydb.WithRequestsType(requestType), - ydb.With( - config.WithGrpcOptions( - grpc.WithUnaryInterceptor(func( - ctx context.Context, - method string, - req, reply interface{}, - cc *grpc.ClientConn, - invoker grpc.UnaryInvoker, - opts ...grpc.CallOption, - ) error { - checkMetadata(ctx) - return invoker(ctx, method, req, reply, cc, opts...) - }), - grpc.WithStreamInterceptor(func( - ctx context.Context, - desc *grpc.StreamDesc, - cc *grpc.ClientConn, - method string, - streamer grpc.Streamer, - opts ...grpc.CallOption, - ) (grpc.ClientStream, error) { - checkMetadata(ctx) - return streamer(ctx, desc, cc, method, opts...) - }), + t.Run("ydb.New", func(t *testing.T) { + db, err := ydb.New(ctx, //nolint:gocritic + ydb.WithConnectionString(os.Getenv("YDB_CONNECTION_STRING")), + ydb.WithAccessTokenCredentials( + os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), ), - ), - ) - if err != nil { - t.Fatal(err) - } - defer func() { - // cleanup connection - if e := db.Close(ctx); e != nil { - t.Fatalf("close failed: %+v", e) - } - }() - t.Run("discovery.WhoAmI", func(t *testing.T) { - if err = retry.Retry(ctx, func(ctx context.Context) (err error) { - discoveryClient := Ydb_Discovery_V1.NewDiscoveryServiceClient(ydb.GRPCConn(db)) - response, err := discoveryClient.WhoAmI( - ctx, - &Ydb_Discovery.WhoAmIRequest{IncludeGroups: true}, - ) - if err != nil { - return err - } - var result Ydb_Discovery.WhoAmIResult - err = proto.Unmarshal(response.GetOperation().GetResult().GetValue(), &result) - if err != nil { - return - } - return nil - }, retry.WithIdempotent(true)); err != nil { - t.Fatalf("Execute failed: %v", err) - } - }) - t.Run("scripting.ExecuteYql", func(t *testing.T) { - if err = retry.Retry(ctx, func(ctx context.Context) (err error) { - scriptingClient := Ydb_Scripting_V1.NewScriptingServiceClient(ydb.GRPCConn(db)) - response, err := scriptingClient.ExecuteYql( - ctx, - &Ydb_Scripting.ExecuteYqlRequest{Script: "SELECT 1+100 AS sum"}, - ) - if err != nil { - return err - } - var result Ydb_Scripting.ExecuteYqlResult - err = proto.Unmarshal(response.GetOperation().GetResult().GetValue(), &result) - if err != nil { - return - } - if len(result.GetResultSets()) != 1 { - return fmt.Errorf( - "unexpected result sets count: %d", - len(result.GetResultSets()), - ) - } - if len(result.GetResultSets()[0].GetColumns()) != 1 { - return fmt.Errorf( - "unexpected colums count: %d", - len(result.GetResultSets()[0].GetColumns()), - ) - } - if result.GetResultSets()[0].GetColumns()[0].GetName() != sumColumn { - return fmt.Errorf( - "unexpected colum name: %s", - result.GetResultSets()[0].GetColumns()[0].GetName(), - ) - } - if len(result.GetResultSets()[0].GetRows()) != 1 { - return fmt.Errorf( - "unexpected rows count: %d", - len(result.GetResultSets()[0].GetRows()), - ) - } - if result.GetResultSets()[0].GetRows()[0].GetItems()[0].GetInt32Value() != 101 { - return fmt.Errorf( - "unexpected result of select: %d", - result.GetResultSets()[0].GetRows()[0].GetInt64Value(), - ) - } - return nil - }, retry.WithIdempotent(true)); err != nil { - t.Fatalf("Execute failed: %v", err) + ydb.With( + config.WithOperationTimeout(time.Second*2), + config.WithOperationCancelAfter(time.Second*2), + ), + ydb.WithConnectionTTL(time.Millisecond*10000), + ydb.WithMinTLSVersion(tls.VersionTLS10), + ydb.WithLogger( + newLogger(t), + trace.MatchDetails(`ydb\.(driver|discovery|retry|scheme).*`), + ), + ) + if err != nil { + t.Fatal(err) } - }) - t.Run("scripting.StreamExecuteYql", func(t *testing.T) { - if err = retry.Retry(ctx, func(ctx context.Context) (err error) { - scriptingClient := Ydb_Scripting_V1.NewScriptingServiceClient(ydb.GRPCConn(db)) - client, err := scriptingClient.StreamExecuteYql( - ctx, - &Ydb_Scripting.ExecuteYqlRequest{Script: "SELECT 1+100 AS sum"}, - ) - if err != nil { - return err - } - response, err := client.Recv() - if err != nil { - return err - } - if len(response.GetResult().GetResultSet().GetColumns()) != 1 { - return fmt.Errorf( - "unexpected colums count: %d", - len(response.GetResult().GetResultSet().GetColumns()), - ) - } - if response.GetResult().GetResultSet().GetColumns()[0].GetName() != sumColumn { - return fmt.Errorf( - "unexpected colum name: %s", - response.GetResult().GetResultSet().GetColumns()[0].GetName(), - ) - } - if len(response.GetResult().GetResultSet().GetRows()) != 1 { - return fmt.Errorf( - "unexpected rows count: %d", - len(response.GetResult().GetResultSet().GetRows()), - ) - } - if response.GetResult().GetResultSet().GetRows()[0].GetItems()[0].GetInt32Value() != 101 { - return fmt.Errorf( - "unexpected result of select: %d", - response.GetResult().GetResultSet().GetRows()[0].GetInt64Value(), - ) + defer func() { + // cleanup connection + if e := db.Close(ctx); e != nil { + t.Fatalf("close failed: %+v", e) } - return nil - }, retry.WithIdempotent(true)); err != nil { - t.Fatalf("Stream execute failed: %v", err) - } + }() }) - t.Run("with.scripting.StreamExecuteYql", func(t *testing.T) { - var childDB *ydb.Driver - childDB, err = db.With( - ctx, - ydb.WithDialTimeout(time.Second*5), + t.Run("ydb.Open", func(t *testing.T) { + db, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithAccessTokenCredentials( + os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), + ), + ydb.With( + config.WithOperationTimeout(time.Second*2), + config.WithOperationCancelAfter(time.Second*2), + ), + ydb.WithConnectionTTL(time.Millisecond*10000), + ydb.WithMinTLSVersion(tls.VersionTLS10), + ydb.WithLogger( + newLoggerWithMinLevel(t, log.WARN), + trace.MatchDetails(`ydb\.(driver|discovery|retry|scheme).*`), + ), + ydb.WithApplicationName(userAgent), + ydb.WithRequestsType(requestType), + ydb.With( + config.WithGrpcOptions( + grpc.WithUnaryInterceptor(func( + ctx context.Context, + method string, + req, reply interface{}, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + opts ...grpc.CallOption, + ) error { + checkMetadata(ctx) + return invoker(ctx, method, req, reply, cc, opts...) + }), + grpc.WithStreamInterceptor(func( + ctx context.Context, + desc *grpc.StreamDesc, + cc *grpc.ClientConn, + method string, + streamer grpc.Streamer, + opts ...grpc.CallOption, + ) (grpc.ClientStream, error) { + checkMetadata(ctx) + return streamer(ctx, desc, cc, method, opts...) + }), + ), + ), ) if err != nil { - t.Fatalf("failed to open sub-connection: %v", err) + t.Fatal(err) } defer func() { - _ = childDB.Close(ctx) - }() - if err = retry.Retry(ctx, func(ctx context.Context) (err error) { - scriptingClient := Ydb_Scripting_V1.NewScriptingServiceClient(ydb.GRPCConn(childDB)) - client, err := scriptingClient.StreamExecuteYql( - ctx, - &Ydb_Scripting.ExecuteYqlRequest{Script: "SELECT 1+100 AS sum"}, - ) - if err != nil { - return err - } - response, err := client.Recv() - if err != nil { - return err + // cleanup connection + if e := db.Close(ctx); e != nil { + t.Fatalf("close failed: %+v", e) } - if len(response.GetResult().GetResultSet().GetColumns()) != 1 { - return fmt.Errorf( - "unexpected colums count: %d", - len(response.GetResult().GetResultSet().GetColumns()), - ) - } - if response.GetResult().GetResultSet().GetColumns()[0].GetName() != sumColumn { - return fmt.Errorf( - "unexpected colum name: %s", - response.GetResult().GetResultSet().GetColumns()[0].GetName(), + }() + t.Run("discovery.WhoAmI", func(t *testing.T) { + if err = retry.Retry(ctx, func(ctx context.Context) (err error) { + discoveryClient := Ydb_Discovery_V1.NewDiscoveryServiceClient(ydb.GRPCConn(db)) + response, err := discoveryClient.WhoAmI( + ctx, + &Ydb_Discovery.WhoAmIRequest{IncludeGroups: true}, ) + if err != nil { + return err + } + var result Ydb_Discovery.WhoAmIResult + err = proto.Unmarshal(response.GetOperation().GetResult().GetValue(), &result) + if err != nil { + return + } + return nil + }, retry.WithIdempotent(true)); err != nil { + t.Fatalf("Execute failed: %v", err) } - if len(response.GetResult().GetResultSet().GetRows()) != 1 { - return fmt.Errorf( - "unexpected rows count: %d", - len(response.GetResult().GetResultSet().GetRows()), + }) + t.Run("scripting.ExecuteYql", func(t *testing.T) { + if err = retry.Retry(ctx, func(ctx context.Context) (err error) { + scriptingClient := Ydb_Scripting_V1.NewScriptingServiceClient(ydb.GRPCConn(db)) + response, err := scriptingClient.ExecuteYql( + ctx, + &Ydb_Scripting.ExecuteYqlRequest{Script: "SELECT 1+100 AS sum"}, ) + if err != nil { + return err + } + var result Ydb_Scripting.ExecuteYqlResult + err = proto.Unmarshal(response.GetOperation().GetResult().GetValue(), &result) + if err != nil { + return + } + if len(result.GetResultSets()) != 1 { + return fmt.Errorf( + "unexpected result sets count: %d", + len(result.GetResultSets()), + ) + } + if len(result.GetResultSets()[0].GetColumns()) != 1 { + return fmt.Errorf( + "unexpected colums count: %d", + len(result.GetResultSets()[0].GetColumns()), + ) + } + if result.GetResultSets()[0].GetColumns()[0].GetName() != sumColumn { + return fmt.Errorf( + "unexpected colum name: %s", + result.GetResultSets()[0].GetColumns()[0].GetName(), + ) + } + if len(result.GetResultSets()[0].GetRows()) != 1 { + return fmt.Errorf( + "unexpected rows count: %d", + len(result.GetResultSets()[0].GetRows()), + ) + } + if result.GetResultSets()[0].GetRows()[0].GetItems()[0].GetInt32Value() != 101 { + return fmt.Errorf( + "unexpected result of select: %d", + result.GetResultSets()[0].GetRows()[0].GetInt64Value(), + ) + } + return nil + }, retry.WithIdempotent(true)); err != nil { + t.Fatalf("Execute failed: %v", err) } - if response.GetResult().GetResultSet().GetRows()[0].GetItems()[0].GetInt32Value() != 101 { - return fmt.Errorf( - "unexpected result of select: %d", - response.GetResult().GetResultSet().GetRows()[0].GetInt64Value(), + }) + t.Run("scripting.StreamExecuteYql", func(t *testing.T) { + if err = retry.Retry(ctx, func(ctx context.Context) (err error) { + scriptingClient := Ydb_Scripting_V1.NewScriptingServiceClient(ydb.GRPCConn(db)) + client, err := scriptingClient.StreamExecuteYql( + ctx, + &Ydb_Scripting.ExecuteYqlRequest{Script: "SELECT 1+100 AS sum"}, ) + if err != nil { + return err + } + response, err := client.Recv() + if err != nil { + return err + } + if len(response.GetResult().GetResultSet().GetColumns()) != 1 { + return fmt.Errorf( + "unexpected colums count: %d", + len(response.GetResult().GetResultSet().GetColumns()), + ) + } + if response.GetResult().GetResultSet().GetColumns()[0].GetName() != sumColumn { + return fmt.Errorf( + "unexpected colum name: %s", + response.GetResult().GetResultSet().GetColumns()[0].GetName(), + ) + } + if len(response.GetResult().GetResultSet().GetRows()) != 1 { + return fmt.Errorf( + "unexpected rows count: %d", + len(response.GetResult().GetResultSet().GetRows()), + ) + } + if response.GetResult().GetResultSet().GetRows()[0].GetItems()[0].GetInt32Value() != 101 { + return fmt.Errorf( + "unexpected result of select: %d", + response.GetResult().GetResultSet().GetRows()[0].GetInt64Value(), + ) + } + return nil + }, retry.WithIdempotent(true)); err != nil { + t.Fatalf("Stream execute failed: %v", err) } - return nil - }, retry.WithIdempotent(true)); err != nil { - t.Fatalf("Stream execute failed: %v", err) - } - }) - t.Run("export.ExportToS3", func(t *testing.T) { - if err = retry.Retry(ctx, func(ctx context.Context) (err error) { - exportClient := Ydb_Export_V1.NewExportServiceClient(ydb.GRPCConn(db)) - response, err := exportClient.ExportToS3( + }) + t.Run("with.scripting.StreamExecuteYql", func(t *testing.T) { + var childDB *ydb.Driver + childDB, err = db.With( ctx, - &Ydb_Export.ExportToS3Request{ - OperationParams: &Ydb_Operations.OperationParams{ - OperationTimeout: durationpb.New(time.Second), - CancelAfter: durationpb.New(time.Second), - }, - Settings: &Ydb_Export.ExportToS3Settings{}, - }, + ydb.WithDialTimeout(time.Second*5), ) if err != nil { - return err + t.Fatalf("failed to open sub-connection: %v", err) } - if response.GetOperation().GetStatus() != Ydb.StatusIds_BAD_REQUEST { - return fmt.Errorf( - "operation must be BAD_REQUEST: %s", - response.GetOperation().GetStatus().String(), + defer func() { + _ = childDB.Close(ctx) + }() + if err = retry.Retry(ctx, func(ctx context.Context) (err error) { + scriptingClient := Ydb_Scripting_V1.NewScriptingServiceClient(ydb.GRPCConn(childDB)) + client, err := scriptingClient.StreamExecuteYql( + ctx, + &Ydb_Scripting.ExecuteYqlRequest{Script: "SELECT 1+100 AS sum"}, ) + if err != nil { + return err + } + response, err := client.Recv() + if err != nil { + return err + } + if len(response.GetResult().GetResultSet().GetColumns()) != 1 { + return fmt.Errorf( + "unexpected colums count: %d", + len(response.GetResult().GetResultSet().GetColumns()), + ) + } + if response.GetResult().GetResultSet().GetColumns()[0].GetName() != sumColumn { + return fmt.Errorf( + "unexpected colum name: %s", + response.GetResult().GetResultSet().GetColumns()[0].GetName(), + ) + } + if len(response.GetResult().GetResultSet().GetRows()) != 1 { + return fmt.Errorf( + "unexpected rows count: %d", + len(response.GetResult().GetResultSet().GetRows()), + ) + } + if response.GetResult().GetResultSet().GetRows()[0].GetItems()[0].GetInt32Value() != 101 { + return fmt.Errorf( + "unexpected result of select: %d", + response.GetResult().GetResultSet().GetRows()[0].GetInt64Value(), + ) + } + return nil + }, retry.WithIdempotent(true)); err != nil { + t.Fatalf("Stream execute failed: %v", err) } - return nil - }, retry.WithIdempotent(true)); err != nil { - t.Fatalf("check export failed: %v", err) - } + }) + t.Run("export.ExportToS3", func(t *testing.T) { + if err = retry.Retry(ctx, func(ctx context.Context) (err error) { + exportClient := Ydb_Export_V1.NewExportServiceClient(ydb.GRPCConn(db)) + response, err := exportClient.ExportToS3( + ctx, + &Ydb_Export.ExportToS3Request{ + OperationParams: &Ydb_Operations.OperationParams{ + OperationTimeout: durationpb.New(time.Second), + CancelAfter: durationpb.New(time.Second), + }, + Settings: &Ydb_Export.ExportToS3Settings{}, + }, + ) + if err != nil { + return err + } + if response.GetOperation().GetStatus() != Ydb.StatusIds_BAD_REQUEST { + return fmt.Errorf( + "operation must be BAD_REQUEST: %s", + response.GetOperation().GetStatus().String(), + ) + } + return nil + }, retry.WithIdempotent(true)); err != nil { + t.Fatalf("check export failed: %v", err) + } + }) }) } diff --git a/tests/integration/connection_with_compression_test.go b/tests/integration/connection_with_compression_test.go index c0be662b4..4b5decb39 100644 --- a/tests/integration/connection_with_compression_test.go +++ b/tests/integration/connection_with_compression_test.go @@ -64,8 +64,7 @@ func TestConnectionWithCompression(t *testing.T) { ) db, err := ydb.Open(ctx, - "", // corner case for check replacement of endpoint+database+secure - ydb.WithConnectionString(os.Getenv("YDB_CONNECTION_STRING")), + os.Getenv("YDB_CONNECTION_STRING"), // corner case for check replacement of endpoint+database+secure ydb.WithAccessTokenCredentials( os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), ), diff --git a/tests/integration/monitoring_test.go b/tests/integration/monitoring_test.go index 922cf80fa..5220c1b25 100644 --- a/tests/integration/monitoring_test.go +++ b/tests/integration/monitoring_test.go @@ -20,10 +20,7 @@ import ( func TestMonitoring(t *testing.T) { ctx := xtest.Context(t) - db, err := ydb.Open(ctx, - "", // corner case for check replacement of endpoint+database+secure - ydb.WithConnectionString(os.Getenv("YDB_CONNECTION_STRING")), - ) + db, err := ydb.Open(ctx, os.Getenv("YDB_CONNECTION_STRING")) if err != nil { t.Fatal(err) } diff --git a/tests/integration/register_dsn_parser_test.go b/tests/integration/register_dsn_parser_test.go new file mode 100644 index 000000000..080a1afa4 --- /dev/null +++ b/tests/integration/register_dsn_parser_test.go @@ -0,0 +1,56 @@ +//go:build integration +// +build integration + +package integration + +import ( + "context" + "database/sql" + "os" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ydb-platform/ydb-go-sdk/v3" +) + +func TestRegisterDsnParser(t *testing.T) { + t.Run("native", func(t *testing.T) { + var visited bool + registrationID := ydb.RegisterDsnParser(func(dsn string) (opts []ydb.Option, _ error) { + return []ydb.Option{ + func(ctx context.Context, d *ydb.Driver) error { + visited = true + + return nil + }, + }, nil + }) + defer ydb.UnregisterDsnParser(registrationID) + db, err := ydb.Open(context.Background(), os.Getenv("YDB_CONNECTION_STRING")) + require.NoError(t, err) + require.True(t, visited) + defer func() { + _ = db.Close(context.Background()) + }() + }) + t.Run("database/sql", func(t *testing.T) { + var visited bool + registrationID := ydb.RegisterDsnParser(func(dsn string) (opts []ydb.Option, _ error) { + return []ydb.Option{ + func(ctx context.Context, d *ydb.Driver) error { + visited = true + + return nil + }, + }, nil + }) + defer ydb.UnregisterDsnParser(registrationID) + db, err := sql.Open("ydb", os.Getenv("YDB_CONNECTION_STRING")) + require.NoError(t, err) + require.True(t, visited) + defer func() { + _ = db.Close() + }() + }) +} diff --git a/tests/integration/static_credentials_test.go b/tests/integration/static_credentials_test.go index e4eff5cfa..392e8a601 100644 --- a/tests/integration/static_credentials_test.go +++ b/tests/integration/static_credentials_test.go @@ -64,8 +64,7 @@ func TestStaticCredentials(t *testing.T) { t.Logf("token: %s\n", token) db, err := ydb.Open(ctx, - "", // corner case for check replacement of endpoint+database+secure - ydb.WithConnectionString(os.Getenv("YDB_CONNECTION_STRING")), + os.Getenv("YDB_CONNECTION_STRING"), ydb.WithCredentials(staticCredentials), ) if err != nil { diff --git a/tests/integration/table_multiple_result_sets_test.go b/tests/integration/table_multiple_result_sets_test.go index 6833ca6b3..2799479e9 100644 --- a/tests/integration/table_multiple_result_sets_test.go +++ b/tests/integration/table_multiple_result_sets_test.go @@ -40,8 +40,7 @@ func TestTableMultipleResultSets(t *testing.T) { ) db, err := ydb.Open(ctx, - "", // corner case for check replacement of endpoint+database+secure - ydb.WithConnectionString(os.Getenv("YDB_CONNECTION_STRING")), + os.Getenv("YDB_CONNECTION_STRING"), ydb.WithLogger( newLogger(t), trace.MatchDetails(`ydb\.(driver|discovery|retry|scheme).*`), diff --git a/tests/integration/table_tx_lazy_test.go b/tests/integration/table_tx_lazy_test.go index 95182f798..70ece1867 100644 --- a/tests/integration/table_tx_lazy_test.go +++ b/tests/integration/table_tx_lazy_test.go @@ -24,10 +24,7 @@ func TestTableTxLazy(t *testing.T) { t.Run("connect", func(t *testing.T) { var err error - db, err = ydb.Open(ctx, - "", // corner case for check replacement of endpoint+database+secure - ydb.WithConnectionString(os.Getenv("YDB_CONNECTION_STRING")), - ) + db, err = ydb.Open(ctx, os.Getenv("YDB_CONNECTION_STRING")) require.NoError(t, err) })