diff --git a/taosRestful/connection.go b/taosRestful/connection.go index 2dd5240..d8a12a8 100644 --- a/taosRestful/connection.go +++ b/taosRestful/connection.go @@ -24,7 +24,7 @@ import ( var jsonI = jsoniter.ConfigCompatibleWithStandardLibrary type taosConn struct { - cfg *config + cfg *Config client *http.Client url *url.URL baseRawQuery string @@ -32,8 +32,8 @@ type taosConn struct { readBufferSize int } -func newTaosConn(cfg *config) (*taosConn, error) { - readBufferSize := cfg.readBufferSize +func newTaosConn(cfg *Config) (*taosConn, error) { + readBufferSize := cfg.ReadBufferSize if readBufferSize <= 0 { readBufferSize = 4 << 10 } @@ -47,9 +47,9 @@ func newTaosConn(cfg *config) (*taosConn, error) { IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, - DisableCompression: cfg.disableCompression, + DisableCompression: cfg.DisableCompression, } - if cfg.skipVerify { + if cfg.SkipVerify { transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, } @@ -58,24 +58,24 @@ func newTaosConn(cfg *config) (*taosConn, error) { Transport: transport, } path := "/rest/sql" - if len(cfg.dbName) != 0 { - path = fmt.Sprintf("%s/%s", path, cfg.dbName) + if len(cfg.DbName) != 0 { + path = fmt.Sprintf("%s/%s", path, cfg.DbName) } tc.url = &url.URL{ - Scheme: cfg.net, - Host: fmt.Sprintf("%s:%d", cfg.addr, cfg.port), + Scheme: cfg.Net, + Host: fmt.Sprintf("%s:%d", cfg.Addr, cfg.Port), Path: path, } tc.header = map[string][]string{ "Connection": {"keep-alive"}, } - if cfg.token != "" { - tc.baseRawQuery = fmt.Sprintf("token=%s", cfg.token) + if cfg.Token != "" { + tc.baseRawQuery = fmt.Sprintf("token=%s", cfg.Token) } else { - basic := base64.StdEncoding.EncodeToString([]byte(cfg.user + ":" + cfg.passwd)) + basic := base64.StdEncoding.EncodeToString([]byte(cfg.User + ":" + cfg.Passwd)) tc.header["Authorization"] = []string{fmt.Sprintf("Basic %s", basic)} } - if !cfg.disableCompression { + if !cfg.DisableCompression { tc.header["Accept-Encoding"] = []string{"gzip"} } return tc, nil @@ -103,7 +103,7 @@ func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { if len(args) != 0 { - if !tc.cfg.interpolateParams { + if !tc.cfg.InterpolateParams { return nil, driver.ErrSkip } // try to interpolate the parameters to save extra round trips for preparing and closing a statement @@ -129,7 +129,7 @@ func (tc *taosConn) QueryContext(ctx context.Context, query string, args []drive func (tc *taosConn) queryCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { if len(args) != 0 { - if !tc.cfg.interpolateParams { + if !tc.cfg.InterpolateParams { return nil, driver.ErrSkip } // try client-side prepare to reduce round trip @@ -202,7 +202,7 @@ func (tc *taosConn) taosQuery(ctx context.Context, sql string, bufferSize int) ( } respBody := resp.Body defer ioutil.ReadAll(respBody) - if !tc.cfg.disableCompression && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { + if !tc.cfg.DisableCompression && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { respBody, err = gzip.NewReader(resp.Body) if err != nil { return nil, err diff --git a/taosRestful/connector.go b/taosRestful/connector.go index 3ae7396..0919ee4 100644 --- a/taosRestful/connector.go +++ b/taosRestful/connector.go @@ -8,27 +8,27 @@ import ( ) type connector struct { - cfg *config + cfg *Config } // Connect implements driver.Connector interface. // Connect returns a connection to the database. func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { // Connect to Server - if len(c.cfg.user) == 0 { - c.cfg.user = common.DefaultUser + if len(c.cfg.User) == 0 { + c.cfg.User = common.DefaultUser } - if len(c.cfg.passwd) == 0 { - c.cfg.passwd = common.DefaultPassword + if len(c.cfg.Passwd) == 0 { + c.cfg.Passwd = common.DefaultPassword } - if c.cfg.port == 0 { - c.cfg.port = common.DefaultHttpPort + if c.cfg.Port == 0 { + c.cfg.Port = common.DefaultHttpPort } - if len(c.cfg.net) == 0 { - c.cfg.net = "http" + if len(c.cfg.Net) == 0 { + c.cfg.Net = "http" } - if len(c.cfg.addr) == 0 { - c.cfg.addr = "127.0.0.1" + if len(c.cfg.Addr) == 0 { + c.cfg.Addr = "127.0.0.1" } tc, err := newTaosConn(c.cfg) return tc, err diff --git a/taosRestful/connector_test.go b/taosRestful/connector_test.go index 845179f..aca016c 100644 --- a/taosRestful/connector_test.go +++ b/taosRestful/connector_test.go @@ -533,7 +533,7 @@ func TestSSL(t *testing.T) { func TestConnect(t *testing.T) { conn := connector{ - cfg: &config{}, + cfg: &Config{}, } db, err := conn.Connect(context.Background()) assert.NoError(t, err) diff --git a/taosRestful/driver.go b/taosRestful/driver.go index 54313f3..3ca3f8f 100644 --- a/taosRestful/driver.go +++ b/taosRestful/driver.go @@ -13,7 +13,7 @@ type TDengineDriver struct{} // Open new Connection. // the DSN string is formatted func (d TDengineDriver) Open(dsn string) (driver.Conn, error) { - cfg, err := parseDSN(dsn) + cfg, err := ParseDSN(dsn) if err != nil { return nil, err } @@ -26,3 +26,19 @@ func (d TDengineDriver) Open(dsn string) (driver.Conn, error) { func init() { sql.Register("taosRestful", &TDengineDriver{}) } + +// NewConnector returns new driver.Connector. +func NewConnector(cfg *Config) (driver.Connector, error) { + return &connector{cfg: cfg}, nil +} + +// OpenConnector implements driver.DriverContext. +func (d TDengineDriver) OpenConnector(dsn string) (driver.Connector, error) { + cfg, err := ParseDSN(dsn) + if err != nil { + return nil, err + } + return &connector{ + cfg: cfg, + }, nil +} diff --git a/taosRestful/driver_test.go b/taosRestful/driver_test.go index c214d70..9c156bb 100644 --- a/taosRestful/driver_test.go +++ b/taosRestful/driver_test.go @@ -1,11 +1,13 @@ package taosRestful import ( + "context" "database/sql" "database/sql/driver" "errors" "fmt" "log" + "net/url" "testing" "time" @@ -289,3 +291,91 @@ func TestChinese(t *testing.T) { } assert.Equal(t, 1, counter) } + +func TestNewConnector(t *testing.T) { + cfg, err := ParseDSN(dataSourceName) + assert.NoError(t, err) + conn, err := NewConnector(cfg) + assert.NoError(t, err) + db := sql.OpenDB(conn) + defer func() { + err := db.Close() + assert.NoError(t, err) + }() + if err := db.Ping(); err != nil { + t.Fatal(err) + } +} + +func TestOpen(t *testing.T) { + tdDriver := &TDengineDriver{} + conn, err := tdDriver.Open(dataSourceName) + assert.NoError(t, err) + defer func() { + err := conn.Close() + assert.NoError(t, err) + }() + pinger := conn.(driver.Pinger) + err = pinger.Ping(context.Background()) + assert.NoError(t, err) +} + +func TestSpecialPassword(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Fatalf("error on: sql.open %s", err.Error()) + return + } + defer db.Close() + tests := []struct { + name string + user string + pass string + }{ + { + name: "test_special1_rs", + user: "test_special1_rs", + pass: "!q@w#a$1%3^&*()-", + }, + { + name: "test_special2_rs", + user: "test_special2_rs", + pass: "_q+3=[]{}:;><3?|~,w.", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + defer func() { + dropSql := fmt.Sprintf("drop user %s", test.user) + _, _ = db.Exec(dropSql) + }() + createSql := fmt.Sprintf("create user %s pass '%s'", test.user, test.pass) + _, err := db.Exec(createSql) + assert.NoError(t, err) + escapedPass := url.QueryEscape(test.pass) + newDsn := fmt.Sprintf("%s:%s@http(%s:%d)/%s", test.user, escapedPass, host, port, "") + db2, err := sql.Open(driverName, newDsn) + if err != nil { + t.Errorf("error on: sql.open %s", err.Error()) + return + } + defer db2.Close() + rows, err := db2.Query("select 1") + assert.NoError(t, err) + var i int + for rows.Next() { + err := rows.Scan(&i) + assert.NoError(t, err) + assert.Equal(t, 1, i) + } + if i != 1 { + t.Errorf("query failed") + } + }) + } +} diff --git a/taosRestful/dsn.go b/taosRestful/dsn.go index 612962a..4e2f259 100644 --- a/taosRestful/dsn.go +++ b/taosRestful/dsn.go @@ -9,43 +9,43 @@ import ( ) var ( - errInvalidDSNUnescaped = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: did you forget to escape a param value?"} - errInvalidDSNAddr = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network address not terminated (missing closing brace)"} - errInvalidDSNPort = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network port is not a valid number"} - errInvalidDSNNoSlash = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: missing the slash separating the database name"} + ErrInvalidDSNUnescaped = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: did you forget to escape a param value?"} + ErrInvalidDSNAddr = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network address not terminated (missing closing brace)"} + ErrInvalidDSNPort = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network port is not a valid number"} + ErrInvalidDSNNoSlash = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: missing the slash separating the database name"} ) // Config is a configuration parsed from a DSN string. // If a new Config is created instead of being parsed from a DSN string, // the NewConfig function should be used, which sets default values. -type config struct { - user string // Username - passwd string // Password (requires User) - net string // Network type - addr string // Network address (requires Net) - port int - dbName string // Database name - params map[string]string // Connection parameters - interpolateParams bool // Interpolate placeholders into query string - disableCompression bool - readBufferSize int - token string // cloud platform token - skipVerify bool +type Config struct { + User string // Username + Passwd string // Password (requires User) + Net string // Network type + Addr string // Network address (requires Net) + Port int + DbName string // Database name + Params map[string]string // Connection parameters + InterpolateParams bool // Interpolate placeholders into query string + DisableCompression bool + ReadBufferSize int + Token string // cloud platform Token + SkipVerify bool } // NewConfig creates a new Config and sets default values. -func newConfig() *config { - return &config{ - interpolateParams: true, - disableCompression: true, - readBufferSize: 4 << 10, +func NewConfig() *Config { + return &Config{ + InterpolateParams: true, + DisableCompression: true, + ReadBufferSize: 4 << 10, } } // ParseDSN parses the DSN string to a Config -func parseDSN(dsn string) (cfg *config, err error) { +func ParseDSN(dsn string) (cfg *Config, err error) { // New config with some default values - cfg = newConfig() + cfg = NewConfig() // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] // Find the last '/' (since the password or the net addr might contain a '/') @@ -65,11 +65,11 @@ func parseDSN(dsn string) (cfg *config, err error) { // Find the first ':' in dsn[:j] for k = 0; k < j; k++ { if dsn[k] == ':' { - cfg.passwd = dsn[k+1 : j] + cfg.Passwd = tryUnescape(dsn[k+1 : j]) break } } - cfg.user = dsn[:k] + cfg.User = tryUnescape(dsn[:k]) break } @@ -82,25 +82,25 @@ func parseDSN(dsn string) (cfg *config, err error) { // dsn[i-1] must be == ')' if an address is specified if dsn[i-1] != ')' { if strings.ContainsRune(dsn[k+1:i], ')') { - return nil, errInvalidDSNUnescaped + return nil, ErrInvalidDSNUnescaped } //return nil, errInvalidDSNAddr } strList := strings.Split(dsn[k+1:i-1], ":") if len(strList) == 1 { - return nil, errInvalidDSNAddr + return nil, ErrInvalidDSNAddr } if len(strList[0]) != 0 { - cfg.addr = strList[0] - cfg.port, err = strconv.Atoi(strList[1]) + cfg.Addr = strList[0] + cfg.Port, err = strconv.Atoi(strList[1]) if err != nil { - return nil, errInvalidDSNPort + return nil, ErrInvalidDSNPort } } break } } - cfg.net = dsn[j+1 : k] + cfg.Net = dsn[j+1 : k] } // dbname[?param1=value1&...¶mN=valueN] @@ -113,14 +113,14 @@ func parseDSN(dsn string) (cfg *config, err error) { break } } - cfg.dbName = dsn[i+1 : j] + cfg.DbName = dsn[i+1 : j] break } } if !foundSlash && len(dsn) > 0 { - return nil, errInvalidDSNNoSlash + return nil, ErrInvalidDSNNoSlash } return @@ -128,7 +128,7 @@ func parseDSN(dsn string) (cfg *config, err error) { // parseDSNParams parses the DSN "query string" // Values must be url.QueryEscape'ed -func parseDSNParams(cfg *config, params string) (err error) { +func parseDSNParams(cfg *Config, params string) (err error) { for _, v := range strings.Split(params, "&") { param := strings.SplitN(v, "=", 2) if len(param) != 2 { @@ -139,34 +139,34 @@ func parseDSNParams(cfg *config, params string) (err error) { switch value := param[1]; param[0] { // Enable client side placeholder substitution case "interpolateParams": - cfg.interpolateParams, err = strconv.ParseBool(value) + cfg.InterpolateParams, err = strconv.ParseBool(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value} } case "disableCompression": - cfg.disableCompression, err = strconv.ParseBool(value) + cfg.DisableCompression, err = strconv.ParseBool(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value} } case "readBufferSize": - cfg.readBufferSize, err = strconv.Atoi(value) + cfg.ReadBufferSize, err = strconv.Atoi(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid int value: " + value} } case "token": - cfg.token = value + cfg.Token = value case "skipVerify": - cfg.skipVerify, err = strconv.ParseBool(value) + cfg.SkipVerify, err = strconv.ParseBool(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value} } default: // lazy init - if cfg.params == nil { - cfg.params = make(map[string]string) + if cfg.Params == nil { + cfg.Params = make(map[string]string) } - if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { + if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { return } } @@ -174,3 +174,10 @@ func parseDSNParams(cfg *config, params string) (err error) { return } + +func tryUnescape(s string) string { + if res, err := url.QueryUnescape(s); err == nil { + return res + } + return s +} diff --git a/taosRestful/dsn_test.go b/taosRestful/dsn_test.go index 2ec3fbd..0595cb9 100644 --- a/taosRestful/dsn_test.go +++ b/taosRestful/dsn_test.go @@ -1,8 +1,9 @@ package taosRestful import ( - "fmt" "testing" + + "github.com/stretchr/testify/assert" ) // @author: xftan @@ -10,31 +11,244 @@ import ( // @description: test parse dsn func TestParseDsn(t *testing.T) { tcs := []struct { - dsn string - errs string - user string - passwd string - net string - addr string - port int - dbName string - token string - skipVerify bool - }{{}, - {dsn: "abcd", errs: "invalid DSN: missing the slash separating the database name"}, - {dsn: "user:passwd@http(fqdn:6041)/dbname", user: "user", passwd: "passwd", net: "http", addr: "fqdn", port: 6041, dbName: "dbname"}, - {dsn: "user:passwd@http()/dbname", errs: "invalid DSN: network address not terminated (missing closing brace)"}, - {dsn: "user:passwd@http(:)/dbname", user: "user", passwd: "passwd", net: "http", dbName: "dbname"}, - {dsn: "user:passwd@http(:0)/dbname", user: "user", passwd: "passwd", net: "http", dbName: "dbname"}, - {dsn: "user:passwd@https(:0)/", user: "user", passwd: "passwd", net: "https"}, - {dsn: "user:passwd@https(:0)/?interpolateParams=false&test=1", user: "user", passwd: "passwd", net: "https"}, - {dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token", user: "user", passwd: "passwd", net: "https", token: "token"}, - {dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token&skipVerify=true", user: "user", passwd: "passwd", net: "https", token: "token", skipVerify: true}, + name string + dsn string + errs string + want *Config + }{ + { + name: "invalid", + dsn: "abcd", + errs: "invalid DSN: missing the slash separating the database name", + }, + { + name: "normal", + dsn: "user:passwd@http(fqdn:6041)/dbname", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "http", + Addr: "fqdn", + Port: 6041, + DbName: "dbname", + Params: nil, + InterpolateParams: true, + DisableCompression: true, + ReadBufferSize: 4096, + Token: "", + SkipVerify: false, + }, + }, + { + name: "invalid addr", + dsn: "user:passwd@http()/dbname", + errs: "invalid DSN: network address not terminated (missing closing brace)", + }, + { + name: "default addr", + dsn: "user:passwd@http(:)/dbname", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "http", + Addr: "", + Port: 0, + DbName: "dbname", + Params: nil, + InterpolateParams: true, + DisableCompression: true, + ReadBufferSize: 4096, + Token: "", + SkipVerify: false, + }, + }, + { + name: "0port", + dsn: "user:passwd@http(:0)/dbname", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "http", + Addr: "", + Port: 0, + DbName: "dbname", + Params: nil, + InterpolateParams: true, + DisableCompression: true, + ReadBufferSize: 4096, + Token: "", + SkipVerify: false, + }, + }, + { + name: "no db", + dsn: "user:passwd@https(:0)/", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "https", + Addr: "", + Port: 0, + DbName: "", + Params: nil, + InterpolateParams: true, + DisableCompression: true, + ReadBufferSize: 4096, + Token: "", + SkipVerify: false, + }, + }, + { + name: "params", + dsn: "user:passwd@https(:0)/?interpolateParams=false&test=1", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "https", + Addr: "", + Port: 0, + DbName: "", + Params: map[string]string{ + "test": "1", + }, + InterpolateParams: false, + DisableCompression: true, + ReadBufferSize: 4096, + Token: "", + SkipVerify: false, + }, + }, + { + name: "token", + dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "https", + Addr: "", + Port: 0, + DbName: "", + Params: nil, + InterpolateParams: false, + DisableCompression: true, + ReadBufferSize: 4096, + Token: "token", + SkipVerify: false, + }, + }, + { + name: "skipVerify", + dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token&skipVerify=true", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "https", + Addr: "", + Port: 0, + DbName: "", + Params: nil, + InterpolateParams: false, + DisableCompression: true, + ReadBufferSize: 4096, + Token: "token", + SkipVerify: true, + }, + }, + { + name: "skipVerify", + dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token&skipVerify=true", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "https", + Addr: "", + Port: 0, + DbName: "", + Params: nil, + InterpolateParams: false, + DisableCompression: true, + ReadBufferSize: 4096, + Token: "token", + SkipVerify: true, + }, + }, + { + name: "readBufferSize", + dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token&skipVerify=true&readBufferSize=8192", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "https", + Addr: "", + Port: 0, + DbName: "", + Params: nil, + InterpolateParams: false, + DisableCompression: true, + ReadBufferSize: 8192, + Token: "token", + SkipVerify: true, + }, + }, + { + name: "disableCompression", + dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token&skipVerify=true&readBufferSize=8192&disableCompression=false", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "https", + Addr: "", + Port: 0, + DbName: "", + Params: nil, + InterpolateParams: false, + DisableCompression: false, + ReadBufferSize: 8192, + Token: "token", + SkipVerify: true, + }, + }, + { + name: "special char", + dsn: "!%40%23%24%25%5E%26*()-_%2B%3D%5B%5D%7B%7D%3A%3B%3E%3C%3F%7C~%2C.:!%40%23%24%25%5E%26*()-_%2B%3D%5B%5D%7B%7D%3A%3B%3E%3C%3F%7C~%2C.@https(:)/dbname", + want: &Config{ + User: "!@#$%^&*()-_+=[]{}:;> 0 { + if c.cfg.Net == "cfg" && len(c.cfg.ConfigPath) > 0 { once.Do(func() { locker.Lock() - code := wrapper.TaosOptions(common.TSDB_OPTION_CONFIGDIR, c.cfg.configPath) + code := wrapper.TaosOptions(common.TSDB_OPTION_CONFIGDIR, c.cfg.ConfigPath) locker.Unlock() if code != 0 { err = errors.NewError(code, wrapper.TaosErrorStr(nil)) @@ -37,20 +54,20 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } // Connect to Server - if len(tc.cfg.user) == 0 { - tc.cfg.user = common.DefaultUser + if len(tc.cfg.User) == 0 { + tc.cfg.User = common.DefaultUser } - if len(tc.cfg.passwd) == 0 { - tc.cfg.passwd = common.DefaultPassword + if len(tc.cfg.Passwd) == 0 { + tc.cfg.Passwd = common.DefaultPassword } locker.Lock() - err = wrapper.TaosSetConfig(tc.cfg.params) + err = wrapper.TaosSetConfig(tc.cfg.Params) locker.Unlock() if err != nil { return nil, err } locker.Lock() - tc.taos, err = wrapper.TaosConnect(tc.cfg.addr, tc.cfg.user, tc.cfg.passwd, tc.cfg.dbName, tc.cfg.port) + tc.taos, err = wrapper.TaosConnect(tc.cfg.Addr, tc.cfg.User, tc.cfg.Passwd, tc.cfg.DbName, tc.cfg.Port) locker.Unlock() if err != nil { return nil, err diff --git a/taosSql/driver.go b/taosSql/driver.go index f50a907..85ccfc3 100644 --- a/taosSql/driver.go +++ b/taosSql/driver.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "database/sql/driver" - "runtime" "sync" "github.com/taosdata/driver-go/v3/wrapper/handler" @@ -23,30 +22,33 @@ type TDengineDriver struct{} // Open new Connection. // the DSN string is formatted func (d TDengineDriver) Open(dsn string) (driver.Conn, error) { - cfg, err := parseDSN(dsn) + cfg, err := ParseDSN(dsn) if err != nil { return nil, err } c := &connector{ cfg: cfg, } - onceInitLock.Do(func() { - threads := cfg.cgoThread - if threads <= 0 { - threads = runtime.NumCPU() - } - locker = thread.NewLocker(threads) - }) - onceInitHandlerPool.Do(func() { - poolSize := cfg.cgoAsyncHandlerPoolSize - if poolSize <= 0 { - poolSize = 10000 - } - asyncHandlerPool = handler.NewHandlerPool(poolSize) - }) + return c.Connect(context.Background()) } func init() { sql.Register("taosSql", &TDengineDriver{}) } + +// NewConnector returns new driver.Connector. +func NewConnector(cfg *Config) (driver.Connector, error) { + return &connector{cfg: cfg}, nil +} + +// OpenConnector implements driver.DriverContext. +func (d TDengineDriver) OpenConnector(dsn string) (driver.Connector, error) { + cfg, err := ParseDSN(dsn) + if err != nil { + return nil, err + } + return &connector{ + cfg: cfg, + }, nil +} diff --git a/taosSql/driver_test.go b/taosSql/driver_test.go index 51a76be..152b796 100644 --- a/taosSql/driver_test.go +++ b/taosSql/driver_test.go @@ -1,12 +1,14 @@ package taosSql import ( + "context" "database/sql" "database/sql/driver" "encoding/json" "errors" "fmt" "log" + "net/url" "testing" "time" @@ -560,3 +562,91 @@ func TestChinese(t *testing.T) { } assert.Equal(t, 1, counter) } + +func TestNewConnector(t *testing.T) { + cfg, err := ParseDSN(dataSourceName) + assert.NoError(t, err) + conn, err := NewConnector(cfg) + assert.NoError(t, err) + db := sql.OpenDB(conn) + defer func() { + err := db.Close() + assert.NoError(t, err) + }() + if err := db.Ping(); err != nil { + t.Fatal(err) + } +} + +func TestOpen(t *testing.T) { + tdDriver := &TDengineDriver{} + conn, err := tdDriver.Open(dataSourceName) + assert.NoError(t, err) + defer func() { + err := conn.Close() + assert.NoError(t, err) + }() + pinger := conn.(driver.Pinger) + err = pinger.Ping(context.Background()) + assert.NoError(t, err) +} + +func TestSpecialPassword(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Fatalf("error on: sql.open %s", err.Error()) + return + } + defer db.Close() + tests := []struct { + name string + user string + pass string + }{ + { + name: "test_special1", + user: "test_special1", + pass: "!q@w#a$1%3^&*()-", + }, + { + name: "test_special2", + user: "test_special2", + pass: "_q+3=[]{}:;><3?|~,w.", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + defer func() { + dropSql := fmt.Sprintf("drop user %s", test.user) + _, _ = db.Exec(dropSql) + }() + createSql := fmt.Sprintf("create user %s pass '%s'", test.user, test.pass) + _, err := db.Exec(createSql) + assert.NoError(t, err) + escapedPass := url.QueryEscape(test.pass) + newDsn := fmt.Sprintf("%s:%s@/tcp(%s:%d)/%s", test.user, escapedPass, host, port, "") + db2, err := sql.Open(driverName, newDsn) + if err != nil { + t.Errorf("error on: sql.open %s", err.Error()) + return + } + defer db2.Close() + rows, err := db2.Query("select 1") + assert.NoError(t, err) + var i int + for rows.Next() { + err := rows.Scan(&i) + assert.NoError(t, err) + assert.Equal(t, 1, i) + } + if i != 1 { + t.Errorf("query failed") + } + }) + } +} diff --git a/taosSql/dsn.go b/taosSql/dsn.go index 90cdb6d..78337b0 100644 --- a/taosSql/dsn.go +++ b/taosSql/dsn.go @@ -26,42 +26,42 @@ import ( ) var ( - errInvalidDSNUnescaped = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: did you forget to escape a param value?"} - errInvalidDSNAddr = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network address not terminated (missing closing brace)"} - errInvalidDSNPort = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network port is not a valid number"} - errInvalidDSNNoSlash = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: missing the slash separating the database name"} + ErrInvalidDSNUnescaped = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: did you forget to escape a param value?"} + ErrInvalidDSNAddr = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network address not terminated (missing closing brace)"} + ErrInvalidDSNPort = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network port is not a valid number"} + ErrInvalidDSNNoSlash = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: missing the slash separating the database name"} ) // Config is a configuration parsed from a DSN string. // If a new Config is created instead of being parsed from a DSN string, // the NewConfig function should be used, which sets default values. -type config struct { - user string // Username - passwd string // Password (requires User) - net string // Network type - addr string // Network address (requires Net) - port int - dbName string // Database name - params map[string]string // Connection parameters - loc *time.Location // Location for time.Time values - interpolateParams bool // Interpolate placeholders into query string - configPath string - cgoThread int - cgoAsyncHandlerPoolSize int +type Config struct { + User string // Username + Passwd string // Password (requires User) + Net string // Network type + Addr string // Network address (requires Net) + Port int + DbName string // Database name + Params map[string]string // Connection parameters + Loc *time.Location // Location for time.Time values + InterpolateParams bool // Interpolate placeholders into query string + ConfigPath string + CgoThread int + CgoAsyncHandlerPoolSize int } // NewConfig creates a new Config and sets default values. -func newConfig() *config { - return &config{ - loc: time.UTC, - interpolateParams: true, +func NewConfig() *Config { + return &Config{ + Loc: time.UTC, + InterpolateParams: true, } } // ParseDSN parses the DSN string to a Config -func parseDSN(dsn string) (cfg *config, err error) { +func ParseDSN(dsn string) (cfg *Config, err error) { // New config with some default values - cfg = newConfig() + cfg = NewConfig() // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] // Find the last '/' (since the password or the net addr might contain a '/') @@ -81,11 +81,11 @@ func parseDSN(dsn string) (cfg *config, err error) { // Find the first ':' in dsn[:j] for k = 0; k < j; k++ { if dsn[k] == ':' { - cfg.passwd = dsn[k+1 : j] + cfg.Passwd = tryUnescape(dsn[k+1 : j]) break } } - cfg.user = dsn[:k] + cfg.User = tryUnescape(dsn[:k]) break } @@ -98,13 +98,13 @@ func parseDSN(dsn string) (cfg *config, err error) { // dsn[i-1] must be == ')' if an address is specified if dsn[i-1] != ')' { if strings.ContainsRune(dsn[k+1:i], ')') { - return nil, errInvalidDSNUnescaped + return nil, ErrInvalidDSNUnescaped } //return nil, errInvalidDSNAddr } if dsn[j+1:k] == "cfg" { cfgPath := dsn[k+1 : i-1] - cfg.configPath, err = filepath.Abs(cfgPath) + cfg.ConfigPath, err = filepath.Abs(cfgPath) if err != nil { return nil, err } @@ -112,20 +112,20 @@ func parseDSN(dsn string) (cfg *config, err error) { } else { strList := strings.Split(dsn[k+1:i-1], ":") if len(strList) == 1 { - return nil, errInvalidDSNAddr + return nil, ErrInvalidDSNAddr } if len(strList[0]) != 0 { - cfg.addr = strList[0] - cfg.port, err = strconv.Atoi(strList[1]) + cfg.Addr = strList[0] + cfg.Port, err = strconv.Atoi(strList[1]) if err != nil { - return nil, errInvalidDSNPort + return nil, ErrInvalidDSNPort } } break } } } - cfg.net = dsn[j+1 : k] + cfg.Net = dsn[j+1 : k] } // dbname[?param1=value1&...¶mN=valueN] @@ -138,14 +138,14 @@ func parseDSN(dsn string) (cfg *config, err error) { break } } - cfg.dbName = dsn[i+1 : j] + cfg.DbName = dsn[i+1 : j] break } } if !foundSlash && len(dsn) > 0 { - return nil, errInvalidDSNNoSlash + return nil, ErrInvalidDSNNoSlash } return @@ -153,7 +153,7 @@ func parseDSN(dsn string) (cfg *config, err error) { // parseDSNParams parses the DSN "query string" // Values must be url.QueryEscape'ed -func parseDSNParams(cfg *config, params string) (err error) { +func parseDSNParams(cfg *Config, params string) (err error) { for _, v := range strings.Split(params, "&") { param := strings.SplitN(v, "=", 2) if len(param) != 2 { @@ -164,7 +164,7 @@ func parseDSNParams(cfg *config, params string) (err error) { switch value := param[1]; param[0] { // Enable client side placeholder substitution case "interpolateParams": - cfg.interpolateParams, err = strconv.ParseBool(value) + cfg.InterpolateParams, err = strconv.ParseBool(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value} } @@ -174,30 +174,30 @@ func parseDSNParams(cfg *config, params string) (err error) { if value, err = url.QueryUnescape(value); err != nil { return } - cfg.loc, err = time.LoadLocation(value) + cfg.Loc, err = time.LoadLocation(value) if err != nil { return } case "cgoThread": - cfg.cgoThread, err = strconv.Atoi(value) + cfg.CgoThread, err = strconv.Atoi(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid cgoThread value: " + value} } case "cgoAsyncHandlerPoolSize": - cfg.cgoAsyncHandlerPoolSize, err = strconv.Atoi(value) + cfg.CgoAsyncHandlerPoolSize, err = strconv.Atoi(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid cgoAsyncHandlerPoolSize value: " + value} } default: // lazy init - if cfg.params == nil { - cfg.params = make(map[string]string) + if cfg.Params == nil { + cfg.Params = make(map[string]string) } - if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { + if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { return } } @@ -205,3 +205,10 @@ func parseDSNParams(cfg *config, params string) (err error) { return } + +func tryUnescape(s string) string { + if res, err := url.QueryUnescape(s); err == nil { + return res + } + return s +} diff --git a/taosSql/dsn_test.go b/taosSql/dsn_test.go index 1ad259e..c68a049 100644 --- a/taosSql/dsn_test.go +++ b/taosSql/dsn_test.go @@ -1,8 +1,8 @@ package taosSql import ( - "fmt" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -11,9 +11,13 @@ import ( // @date: 2022/1/27 16:18 // @description: test dsn parse func TestParseDsn(t *testing.T) { - tcs := []struct { + ShangHaiTimezone, err := time.LoadLocation("Asia/Shanghai") + assert.NoError(t, err) + tests := []struct { + name string dsn string errs string + want *Config user string passwd string net string @@ -23,39 +27,333 @@ func TestParseDsn(t *testing.T) { configPath string cgoThread int cgoAsyncHandlerPoolSize int - }{{}, - {dsn: "abcd", errs: "invalid DSN: missing the slash separating the database name"}, - {dsn: "user:passwd@net(fqdn:6030)/dbname", user: "user", passwd: "passwd", net: "net", addr: "fqdn", port: 6030, dbName: "dbname"}, - {dsn: "user:passwd@net()/dbname", errs: "invalid DSN: network address not terminated (missing closing brace)"}, - {dsn: "user:passwd@net(:)/dbname", user: "user", passwd: "passwd", net: "net", dbName: "dbname"}, - {dsn: "user:passwd@net(:0)/dbname", user: "user", passwd: "passwd", net: "net", dbName: "dbname"}, - {dsn: "user:passwd@net(:0)/", user: "user", passwd: "passwd", net: "net"}, - {dsn: "net(:0)/wo", net: "net", dbName: "wo"}, - {dsn: "user:passwd@cfg(/home/taos)/db", user: "user", passwd: "passwd", net: "cfg", configPath: "/home/taos", dbName: "db"}, - {dsn: "user:passwd@cfg/db", user: "user", passwd: "passwd", net: "cfg", configPath: "", dbName: "db"}, - {dsn: "net(:0)/wo?firstEp=LAPTOP-NNKFTLTG.localdomain%3A6030&secondEp=LAPTOP-NNKFTLTG.localdomain%3A6030&fqdn=LAPTOP-NNKFTLTG.localdomain&serverPort=6030&configDir=%2Fetc%2Ftaos&logDir=%2Fvar%2Flog%2Ftaos&scriptDir=%2Fetc%2Ftaos&arbitrator=&numOfThreadsPerCore=1.000000&maxNumOfDistinctRes=10000000&rpcTimer=300&rpcForceTcp=0&rpcMaxTime=600&shellActivityTimer=3&compressMsgSize=-1&maxSQLLength=1048576&maxWildCardsLength=100&maxNumOfOrderedRes=100000&keepColumnName=0&timezone=Asia%2FShanghai+%28CST%2C+%2B0800%29&locale=C.UTF-8&charset=UTF-8&numOfLogLines=10000000&logKeepDays=0&asyncLog=1&debugFlag=0&rpcDebugFlag=131&tmrDebugFlag=131&cDebugFlag=131&jniDebugFlag=131&odbcDebugFlag=131&uDebugFlag=131&qDebugFlag=131&tsdbDebugFlag=131&gitinfo=TAOS_CFG_VTYPE_STRING&gitinfoOfInternal=TAOS_CFG_VTYPE_STRING&buildinfo=TAOS_CFG_VTYPE_STRING&version=TAOS_CFG_VTYPE_STRING&maxBinaryDisplayWidth=30&tempDir=%2Ftmp%2F", net: "net", dbName: "wo"}, - {dsn: "net(:0)/wo?cgoThread=8", net: "net", dbName: "wo", cgoThread: 8}, - {dsn: "net(:0)/wo?cgoThread=8&cgoAsyncHandlerPoolSize=10000", net: "net", dbName: "wo", cgoThread: 8, cgoAsyncHandlerPoolSize: 10000}, + }{ + { + name: "invalid", + dsn: "abcd", + errs: "invalid DSN: missing the slash separating the database name", + }, + { + name: "normal", + dsn: "user:passwd@net(fqdn:6030)/dbname", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "net", + Addr: "fqdn", + Port: 6030, + DbName: "dbname", + Params: nil, + Loc: time.UTC, + InterpolateParams: true, + ConfigPath: "", + CgoThread: 0, + CgoAsyncHandlerPoolSize: 0, + }, + }, + { + name: "missing closing brace", + dsn: "user:passwd@net()/dbname", + errs: "invalid DSN: network address not terminated (missing closing brace)", + }, + { + name: "default addr", + dsn: "user:passwd@net(:)/dbname", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "net", + Addr: "", + Port: 0, + DbName: "dbname", + Params: nil, + Loc: time.UTC, + InterpolateParams: true, + ConfigPath: "", + CgoThread: 0, + CgoAsyncHandlerPoolSize: 0, + }, + }, + { + name: "0port", + dsn: "user:passwd@net(:0)/dbname", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "net", + Addr: "", + Port: 0, + DbName: "dbname", + Params: nil, + Loc: time.UTC, + InterpolateParams: true, + ConfigPath: "", + CgoThread: 0, + CgoAsyncHandlerPoolSize: 0, + }, + }, + { + name: "no dbname", + dsn: "user:passwd@net(:0)/", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "net", + Addr: "", + Port: 0, + DbName: "", + Params: nil, + Loc: time.UTC, + InterpolateParams: true, + ConfigPath: "", + CgoThread: 0, + CgoAsyncHandlerPoolSize: 0, + }, + }, + { + name: "no auth", + dsn: "net(:0)/wo", + want: &Config{ + User: "", + Passwd: "", + Net: "net", + Addr: "", + Port: 0, + DbName: "wo", + Params: nil, + Loc: time.UTC, + InterpolateParams: true, + ConfigPath: "", + CgoThread: 0, + CgoAsyncHandlerPoolSize: 0, + }, + }, + { + name: "cfg", + dsn: "user:passwd@cfg(/home/taos)/db", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "cfg", + Addr: "", + Port: 0, + DbName: "db", + Params: nil, + Loc: time.UTC, + InterpolateParams: true, + ConfigPath: "/home/taos", + CgoThread: 0, + CgoAsyncHandlerPoolSize: 0, + }, + }, + { + name: "no addr", + dsn: "user:passwd@cfg/db", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "cfg", + Addr: "", + Port: 0, + DbName: "db", + Params: nil, + Loc: time.UTC, + InterpolateParams: true, + ConfigPath: "", + CgoThread: 0, + CgoAsyncHandlerPoolSize: 0, + }, + }, + { + name: "options", + dsn: "net(:0)/wo?firstEp=LAPTOP-NNKFTLTG.localdomain%3A6030&secondEp=LAPTOP-NNKFTLTG.localdomain%3A6030&fqdn=LAPTOP-NNKFTLTG.localdomain&serverPort=6030&configDir=%2Fetc%2Ftaos&logDir=%2Fvar%2Flog%2Ftaos&scriptDir=%2Fetc%2Ftaos&arbitrator=&numOfThreadsPerCore=1.000000&maxNumOfDistinctRes=10000000&rpcTimer=300&rpcForceTcp=0&rpcMaxTime=600&shellActivityTimer=3&compressMsgSize=-1&maxSQLLength=1048576&maxWildCardsLength=100&maxNumOfOrderedRes=100000&keepColumnName=0&timezone=Asia%2FShanghai&locale=C.UTF-8&charset=UTF-8&numOfLogLines=10000000&logKeepDays=0&asyncLog=1&debugFlag=0&rpcDebugFlag=131&tmrDebugFlag=131&cDebugFlag=131&jniDebugFlag=131&odbcDebugFlag=131&uDebugFlag=131&qDebugFlag=131&tsdbDebugFlag=131&gitinfo=TAOS_CFG_VTYPE_STRING&gitinfoOfInternal=TAOS_CFG_VTYPE_STRING&buildinfo=TAOS_CFG_VTYPE_STRING&version=TAOS_CFG_VTYPE_STRING&maxBinaryDisplayWidth=30&tempDir=%2Ftmp%2F", + want: &Config{ + User: "", + Passwd: "", + Net: "net", + Addr: "", + Port: 0, + DbName: "wo", + Params: map[string]string{ + "firstEp": "LAPTOP-NNKFTLTG.localdomain:6030", + "secondEp": "LAPTOP-NNKFTLTG.localdomain:6030", + "fqdn": "LAPTOP-NNKFTLTG.localdomain", + "serverPort": "6030", + "configDir": "/etc/taos", + "logDir": "/var/log/taos", + "scriptDir": "/etc/taos", + "arbitrator": "", + "numOfThreadsPerCore": "1.000000", + "maxNumOfDistinctRes": "10000000", + "rpcTimer": "300", + "rpcForceTcp": "0", + "rpcMaxTime": "600", + "shellActivityTimer": "3", + "compressMsgSize": "-1", + "maxSQLLength": "1048576", + "maxWildCardsLength": "100", + "maxNumOfOrderedRes": "100000", + "keepColumnName": "0", + "timezone": "Asia/Shanghai", + "locale": "C.UTF-8", + "charset": "UTF-8", + "numOfLogLines": "10000000", + "logKeepDays": "0", + "asyncLog": "1", + "debugFlag": "0", + "rpcDebugFlag": "131", + "tmrDebugFlag": "131", + "cDebugFlag": "131", + "jniDebugFlag": "131", + "odbcDebugFlag": "131", + "uDebugFlag": "131", + "qDebugFlag": "131", + "tsdbDebugFlag": "131", + "gitinfo": "TAOS_CFG_VTYPE_STRING", + "gitinfoOfInternal": "TAOS_CFG_VTYPE_STRING", + "buildinfo": "TAOS_CFG_VTYPE_STRING", + "version": "TAOS_CFG_VTYPE_STRING", + "maxBinaryDisplayWidth": "30", + "tempDir": "/tmp/", + }, + Loc: time.UTC, + InterpolateParams: true, + ConfigPath: "", + CgoThread: 0, + CgoAsyncHandlerPoolSize: 0, + }, + }, + { + name: "cgoThread", + dsn: "net(:0)/wo?cgoThread=8", + want: &Config{ + User: "", + Passwd: "", + Net: "net", + Addr: "", + Port: 0, + DbName: "wo", + Params: nil, + Loc: time.UTC, + InterpolateParams: true, + ConfigPath: "", + CgoThread: 8, + CgoAsyncHandlerPoolSize: 0, + }, + }, + { + name: "cgoAsyncHandlerPoolSize", + dsn: "net(:0)/wo?cgoThread=8&cgoAsyncHandlerPoolSize=10000", + want: &Config{ + User: "", + Passwd: "", + Net: "net", + Addr: "", + Port: 0, + DbName: "wo", + Params: nil, + Loc: time.UTC, + InterpolateParams: true, + ConfigPath: "", + CgoThread: 8, + CgoAsyncHandlerPoolSize: 10000, + }, + }, + { + name: "loc", + dsn: "net(:0)/wo?cgoThread=8&loc=Asia%2FShanghai", + want: &Config{ + User: "", + Passwd: "", + Net: "net", + Addr: "", + Port: 0, + DbName: "wo", + Params: nil, + Loc: ShangHaiTimezone, + InterpolateParams: true, + ConfigPath: "", + CgoThread: 8, + CgoAsyncHandlerPoolSize: 0, + }, + }, + { + name: "interpolateParams", + dsn: "user:passwd@net(:)/dbname?interpolateParams=false", + want: &Config{ + User: "user", + Passwd: "passwd", + Net: "net", + Addr: "", + Port: 0, + DbName: "dbname", + Params: nil, + Loc: time.UTC, + InterpolateParams: false, + ConfigPath: "", + CgoThread: 0, + CgoAsyncHandlerPoolSize: 0, + }, + }, + { + name: "special char", + dsn: "!%40%23%24%25%5E%26*()-_%2B%3D%5B%5D%7B%7D%3A%3B%3E%3C%3F%7C~%2C.:!%40%23%24%25%5E%26*()-_%2B%3D%5B%5D%7B%7D%3A%3B%3E%3C%3F%7C~%2C.@net(:)/dbname", + want: &Config{ + User: "!@#$%^&*()-_+=[]{}:;><3?|~,w.", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + defer func() { + dropSql := fmt.Sprintf("drop user %s", test.user) + _, _ = db.Exec(dropSql) + }() + createSql := fmt.Sprintf("create user %s pass '%s'", test.user, test.pass) + _, err := db.Exec(createSql) + assert.NoError(t, err) + escapedPass := url.QueryEscape(test.pass) + newDsn := fmt.Sprintf("%s:%s@ws(%s:%d)/%s", test.user, escapedPass, host, port, "") + db2, err := sql.Open(driverName, newDsn) + if err != nil { + t.Errorf("error on: sql.open %s", err.Error()) + return + } + defer db2.Close() + rows, err := db2.Query("select 1") + assert.NoError(t, err) + var i int + for rows.Next() { + err := rows.Scan(&i) + assert.NoError(t, err) + assert.Equal(t, 1, i) + } + if i != 1 { + t.Errorf("query failed") + } + }) + } +} diff --git a/taosWS/dsn.go b/taosWS/dsn.go index aa8e4dd..dad3e6b 100644 --- a/taosWS/dsn.go +++ b/taosWS/dsn.go @@ -10,41 +10,41 @@ import ( ) var ( - errInvalidDSNUnescaped = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: did you forget to escape a param value?"} - errInvalidDSNAddr = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network address not terminated (missing closing brace)"} - errInvalidDSNPort = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network port is not a valid number"} - errInvalidDSNNoSlash = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: missing the slash separating the database name"} + ErrInvalidDSNUnescaped = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: did you forget to escape a param value?"} + ErrInvalidDSNAddr = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network address not terminated (missing closing brace)"} + ErrInvalidDSNPort = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network port is not a valid number"} + ErrInvalidDSNNoSlash = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: missing the slash separating the database name"} ) // Config is a configuration parsed from a DSN string. // If a new Config is created instead of being parsed from a DSN string, // the NewConfig function should be used, which sets default values. -type config struct { - user string // Username - passwd string // Password (requires User) - net string // Network type - addr string // Network address (requires Net) - port int - dbName string // Database name - params map[string]string // Connection parameters - interpolateParams bool // Interpolate placeholders into query string - token string // cloud platform token - enableCompression bool // Enable write compression - readTimeout time.Duration // read message timeout - writeTimeout time.Duration // write message timeout +type Config struct { + User string // Username + Passwd string // Password (requires User) + Net string // Network type + Addr string // Network address (requires Net) + Port int + DbName string // Database name + Params map[string]string // Connection parameters + InterpolateParams bool // Interpolate placeholders into query string + Token string // cloud platform Token + EnableCompression bool // Enable write compression + ReadTimeout time.Duration // read message timeout + WriteTimeout time.Duration // write message timeout } // NewConfig creates a new Config and sets default values. -func newConfig() *config { - return &config{ - interpolateParams: true, +func NewConfig() *Config { + return &Config{ + InterpolateParams: true, } } // ParseDSN parses the DSN string to a Config -func parseDSN(dsn string) (cfg *config, err error) { - // New config with some default values - cfg = newConfig() +func ParseDSN(dsn string) (cfg *Config, err error) { + // New Config with some default values + cfg = NewConfig() // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] // Find the last '/' (since the password or the net addr might contain a '/') @@ -64,11 +64,11 @@ func parseDSN(dsn string) (cfg *config, err error) { // Find the first ':' in dsn[:j] for k = 0; k < j; k++ { if dsn[k] == ':' { - cfg.passwd = dsn[k+1 : j] + cfg.Passwd = tryUnescape(dsn[k+1 : j]) break } } - cfg.user = dsn[:k] + cfg.User = tryUnescape(dsn[:k]) break } @@ -81,25 +81,25 @@ func parseDSN(dsn string) (cfg *config, err error) { // dsn[i-1] must be == ')' if an address is specified if dsn[i-1] != ')' { if strings.ContainsRune(dsn[k+1:i], ')') { - return nil, errInvalidDSNUnescaped + return nil, ErrInvalidDSNUnescaped } //return nil, errInvalidDSNAddr } strList := strings.Split(dsn[k+1:i-1], ":") if len(strList) == 1 { - return nil, errInvalidDSNAddr + return nil, ErrInvalidDSNAddr } if len(strList[0]) != 0 { - cfg.addr = strList[0] - cfg.port, err = strconv.Atoi(strList[1]) + cfg.Addr = strList[0] + cfg.Port, err = strconv.Atoi(strList[1]) if err != nil { - return nil, errInvalidDSNPort + return nil, ErrInvalidDSNPort } } break } } - cfg.net = dsn[j+1 : k] + cfg.Net = dsn[j+1 : k] } // dbname[?param1=value1&...¶mN=valueN] @@ -112,14 +112,14 @@ func parseDSN(dsn string) (cfg *config, err error) { break } } - cfg.dbName = dsn[i+1 : j] + cfg.DbName = dsn[i+1 : j] break } } if !foundSlash && len(dsn) > 0 { - return nil, errInvalidDSNNoSlash + return nil, ErrInvalidDSNNoSlash } return @@ -127,7 +127,7 @@ func parseDSN(dsn string) (cfg *config, err error) { // parseDSNParams parses the DSN "query string" // Values must be url.QueryEscape'ed -func parseDSNParams(cfg *config, params string) (err error) { +func parseDSNParams(cfg *Config, params string) (err error) { for _, v := range strings.Split(params, "&") { param := strings.SplitN(v, "=", 2) if len(param) != 2 { @@ -138,34 +138,34 @@ func parseDSNParams(cfg *config, params string) (err error) { switch value := param[1]; param[0] { // Enable client side placeholder substitution case "interpolateParams": - cfg.interpolateParams, err = strconv.ParseBool(value) + cfg.InterpolateParams, err = strconv.ParseBool(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value} } case "token": - cfg.token = value + cfg.Token = value case "enableCompression": - cfg.enableCompression, err = strconv.ParseBool(value) + cfg.EnableCompression, err = strconv.ParseBool(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid enableCompression value: " + value} } case "readTimeout": - cfg.readTimeout, err = time.ParseDuration(value) + cfg.ReadTimeout, err = time.ParseDuration(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid duration value: " + value} } case "writeTimeout": - cfg.writeTimeout, err = time.ParseDuration(value) + cfg.WriteTimeout, err = time.ParseDuration(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid duration value: " + value} } default: // lazy init - if cfg.params == nil { - cfg.params = make(map[string]string) + if cfg.Params == nil { + cfg.Params = make(map[string]string) } - if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { + if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { return } } @@ -173,3 +173,10 @@ func parseDSNParams(cfg *config, params string) (err error) { return } + +func tryUnescape(s string) string { + if res, err := url.QueryUnescape(s); err == nil { + return res + } + return s +} diff --git a/taosWS/dsn_test.go b/taosWS/dsn_test.go index 7d7c7d7..627efe9 100644 --- a/taosWS/dsn_test.go +++ b/taosWS/dsn_test.go @@ -12,32 +12,63 @@ import ( // @description: test parse dsn func TestParseDsn(t *testing.T) { tests := []struct { + name string dsn string errs string - want *config + want *Config }{ - {dsn: "abcd", errs: "invalid DSN: missing the slash separating the database name"}, - {dsn: "user:passwd@ws(fqdn:6041)/dbname", want: &config{user: "user", passwd: "passwd", net: "ws", addr: "fqdn", port: 6041, dbName: "dbname", interpolateParams: true}}, - {dsn: "user:passwd@ws()/dbname", errs: "invalid DSN: network address not terminated (missing closing brace)"}, - {dsn: "user:passwd@ws(:)/dbname", want: &config{user: "user", passwd: "passwd", net: "ws", dbName: "dbname", interpolateParams: true}}, - {dsn: "user:passwd@ws(:0)/dbname", want: &config{user: "user", passwd: "passwd", net: "ws", dbName: "dbname", interpolateParams: true}}, - {dsn: "user:passwd@wss(:0)/", want: &config{user: "user", passwd: "passwd", net: "wss", interpolateParams: true}}, - {dsn: "user:passwd@wss(:0)/?interpolateParams=false&test=1", want: &config{user: "user", passwd: "passwd", net: "wss", params: map[string]string{"test": "1"}}}, - {dsn: "user:passwd@wss(:0)/?interpolateParams=false&token=token", want: &config{user: "user", passwd: "passwd", net: "wss", token: "token"}}, - {dsn: "user:passwd@wss(:0)/?writeTimeout=8s&readTimeout=10m", want: &config{user: "user", passwd: "passwd", net: "wss", readTimeout: 10 * time.Minute, writeTimeout: 8 * time.Second, interpolateParams: true}}, - {dsn: "user:passwd@wss(:0)/?writeTimeout=8s&readTimeout=10m&enableCompression=true", want: &config{ - user: "user", - passwd: "passwd", - net: "wss", - readTimeout: 10 * time.Minute, - writeTimeout: 8 * time.Second, - interpolateParams: true, - enableCompression: true, + {name: "invalid DSN", dsn: "abcd", errs: "invalid DSN: missing the slash separating the database name"}, + {name: "common DSN", dsn: "user:passwd@ws(fqdn:6041)/dbname", want: &Config{User: "user", Passwd: "passwd", Net: "ws", Addr: "fqdn", Port: 6041, DbName: "dbname", InterpolateParams: true}}, + {name: "missing closing brace", dsn: "user:passwd@ws()/dbname", errs: "invalid DSN: network address not terminated (missing closing brace)"}, + {name: "default address", dsn: "user:passwd@ws(:)/dbname", want: &Config{User: "user", Passwd: "passwd", Net: "ws", DbName: "dbname", InterpolateParams: true}}, + {name: "0 port", dsn: "user:passwd@ws(:0)/dbname", want: &Config{User: "user", Passwd: "passwd", Net: "ws", DbName: "dbname", InterpolateParams: true}}, + {name: "wss protocol", dsn: "user:passwd@wss(:0)/", want: &Config{User: "user", Passwd: "passwd", Net: "wss", InterpolateParams: true}}, + {name: "params", dsn: "user:passwd@wss(:0)/?interpolateParams=false&test=1", want: &Config{User: "user", Passwd: "passwd", Net: "wss", Params: map[string]string{"test": "1"}}}, + {name: "token", dsn: "user:passwd@wss(:0)/?interpolateParams=false&token=token", want: &Config{User: "user", Passwd: "passwd", Net: "wss", Token: "token"}}, + {name: "readTimeout", dsn: "user:passwd@wss(:0)/?writeTimeout=8s&readTimeout=10m", want: &Config{User: "user", Passwd: "passwd", Net: "wss", ReadTimeout: 10 * time.Minute, WriteTimeout: 8 * time.Second, InterpolateParams: true}}, + {name: "compression", dsn: "user:passwd@wss(:0)/?writeTimeout=8s&readTimeout=10m&enableCompression=true", want: &Config{ + User: "user", + Passwd: "passwd", + Net: "wss", + ReadTimeout: 10 * time.Minute, + WriteTimeout: 8 * time.Second, + InterpolateParams: true, + EnableCompression: true, }}, + // encodeURIComponent('!@#$%^&*()-_+=[]{}:;>