Skip to content

Commit

Permalink
Merge pull request #311 from taosdata/enh/xftan/TD-33257
Browse files Browse the repository at this point in the history
enh: support special characters
  • Loading branch information
sheyanjie-qq authored Dec 29, 2024
2 parents c5c3dac + a524b1c commit 5c95f6d
Show file tree
Hide file tree
Showing 21 changed files with 1,200 additions and 302 deletions.
32 changes: 16 additions & 16 deletions taosRestful/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ import (
var jsonI = jsoniter.ConfigCompatibleWithStandardLibrary

type taosConn struct {
cfg *config
cfg *Config
client *http.Client
url *url.URL
baseRawQuery string
header map[string][]string
readBufferSize int
}

func newTaosConn(cfg *config) (*taosConn, error) {
readBufferSize := cfg.readBufferSize
func newTaosConn(cfg *Config) (*taosConn, error) {
readBufferSize := cfg.ReadBufferSize
if readBufferSize <= 0 {
readBufferSize = 4 << 10
}
Expand All @@ -47,9 +47,9 @@ func newTaosConn(cfg *config) (*taosConn, error) {
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DisableCompression: cfg.disableCompression,
DisableCompression: cfg.DisableCompression,
}
if cfg.skipVerify {
if cfg.SkipVerify {
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
Expand All @@ -58,24 +58,24 @@ func newTaosConn(cfg *config) (*taosConn, error) {
Transport: transport,
}
path := "/rest/sql"
if len(cfg.dbName) != 0 {
path = fmt.Sprintf("%s/%s", path, cfg.dbName)
if len(cfg.DbName) != 0 {
path = fmt.Sprintf("%s/%s", path, cfg.DbName)
}
tc.url = &url.URL{
Scheme: cfg.net,
Host: fmt.Sprintf("%s:%d", cfg.addr, cfg.port),
Scheme: cfg.Net,
Host: fmt.Sprintf("%s:%d", cfg.Addr, cfg.Port),
Path: path,
}
tc.header = map[string][]string{
"Connection": {"keep-alive"},
}
if cfg.token != "" {
tc.baseRawQuery = fmt.Sprintf("token=%s", cfg.token)
if cfg.Token != "" {
tc.baseRawQuery = fmt.Sprintf("token=%s", cfg.Token)
} else {
basic := base64.StdEncoding.EncodeToString([]byte(cfg.user + ":" + cfg.passwd))
basic := base64.StdEncoding.EncodeToString([]byte(cfg.User + ":" + cfg.Passwd))
tc.header["Authorization"] = []string{fmt.Sprintf("Basic %s", basic)}
}
if !cfg.disableCompression {
if !cfg.DisableCompression {
tc.header["Accept-Encoding"] = []string{"gzip"}
}
return tc, nil
Expand Down Expand Up @@ -103,7 +103,7 @@ func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver

func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if len(args) != 0 {
if !tc.cfg.interpolateParams {
if !tc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try to interpolate the parameters to save extra round trips for preparing and closing a statement
Expand All @@ -129,7 +129,7 @@ func (tc *taosConn) QueryContext(ctx context.Context, query string, args []drive

func (tc *taosConn) queryCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
if len(args) != 0 {
if !tc.cfg.interpolateParams {
if !tc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try client-side prepare to reduce round trip
Expand Down Expand Up @@ -202,7 +202,7 @@ func (tc *taosConn) taosQuery(ctx context.Context, sql string, bufferSize int) (
}
respBody := resp.Body
defer ioutil.ReadAll(respBody)
if !tc.cfg.disableCompression && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
if !tc.cfg.DisableCompression && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
respBody, err = gzip.NewReader(resp.Body)
if err != nil {
return nil, err
Expand Down
22 changes: 11 additions & 11 deletions taosRestful/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,27 @@ import (
)

type connector struct {
cfg *config
cfg *Config
}

// Connect implements driver.Connector interface.
// Connect returns a connection to the database.
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
// Connect to Server
if len(c.cfg.user) == 0 {
c.cfg.user = common.DefaultUser
if len(c.cfg.User) == 0 {
c.cfg.User = common.DefaultUser
}
if len(c.cfg.passwd) == 0 {
c.cfg.passwd = common.DefaultPassword
if len(c.cfg.Passwd) == 0 {
c.cfg.Passwd = common.DefaultPassword
}
if c.cfg.port == 0 {
c.cfg.port = common.DefaultHttpPort
if c.cfg.Port == 0 {
c.cfg.Port = common.DefaultHttpPort
}
if len(c.cfg.net) == 0 {
c.cfg.net = "http"
if len(c.cfg.Net) == 0 {
c.cfg.Net = "http"
}
if len(c.cfg.addr) == 0 {
c.cfg.addr = "127.0.0.1"
if len(c.cfg.Addr) == 0 {
c.cfg.Addr = "127.0.0.1"
}
tc, err := newTaosConn(c.cfg)
return tc, err
Expand Down
2 changes: 1 addition & 1 deletion taosRestful/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ func TestSSL(t *testing.T) {

func TestConnect(t *testing.T) {
conn := connector{
cfg: &config{},
cfg: &Config{},
}
db, err := conn.Connect(context.Background())
assert.NoError(t, err)
Expand Down
18 changes: 17 additions & 1 deletion taosRestful/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type TDengineDriver struct{}
// Open new Connection.
// the DSN string is formatted
func (d TDengineDriver) Open(dsn string) (driver.Conn, error) {
cfg, err := parseDSN(dsn)
cfg, err := ParseDSN(dsn)
if err != nil {
return nil, err
}
Expand All @@ -26,3 +26,19 @@ func (d TDengineDriver) Open(dsn string) (driver.Conn, error) {
func init() {
sql.Register("taosRestful", &TDengineDriver{})
}

// NewConnector returns new driver.Connector.
func NewConnector(cfg *Config) (driver.Connector, error) {
return &connector{cfg: cfg}, nil
}

// OpenConnector implements driver.DriverContext.
func (d TDengineDriver) OpenConnector(dsn string) (driver.Connector, error) {
cfg, err := ParseDSN(dsn)
if err != nil {
return nil, err
}
return &connector{
cfg: cfg,
}, nil
}
90 changes: 90 additions & 0 deletions taosRestful/driver_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package taosRestful

import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"log"
"net/url"
"testing"
"time"

Expand Down Expand Up @@ -289,3 +291,91 @@ func TestChinese(t *testing.T) {
}
assert.Equal(t, 1, counter)
}

func TestNewConnector(t *testing.T) {
cfg, err := ParseDSN(dataSourceName)
assert.NoError(t, err)
conn, err := NewConnector(cfg)
assert.NoError(t, err)
db := sql.OpenDB(conn)
defer func() {
err := db.Close()
assert.NoError(t, err)
}()
if err := db.Ping(); err != nil {
t.Fatal(err)
}
}

func TestOpen(t *testing.T) {
tdDriver := &TDengineDriver{}
conn, err := tdDriver.Open(dataSourceName)
assert.NoError(t, err)
defer func() {
err := conn.Close()
assert.NoError(t, err)
}()
pinger := conn.(driver.Pinger)
err = pinger.Ping(context.Background())
assert.NoError(t, err)
}

func TestSpecialPassword(t *testing.T) {
db, err := sql.Open(driverName, dataSourceName)
if err != nil {
t.Fatalf("error on: sql.open %s", err.Error())
return
}
defer db.Close()
tests := []struct {
name string
user string
pass string
}{
{
name: "test_special1_rs",
user: "test_special1_rs",
pass: "!q@w#a$1%3^&*()-",
},
{
name: "test_special2_rs",
user: "test_special2_rs",
pass: "_q+3=[]{}:;><?|~",
},
{
name: "test_special3_rs",
user: "test_special3_rs",
pass: "1><3?|~,w.",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
defer func() {
dropSql := fmt.Sprintf("drop user %s", test.user)
_, _ = db.Exec(dropSql)
}()
createSql := fmt.Sprintf("create user %s pass '%s'", test.user, test.pass)
_, err := db.Exec(createSql)
assert.NoError(t, err)
escapedPass := url.QueryEscape(test.pass)
newDsn := fmt.Sprintf("%s:%s@http(%s:%d)/%s", test.user, escapedPass, host, port, "")
db2, err := sql.Open(driverName, newDsn)
if err != nil {
t.Errorf("error on: sql.open %s", err.Error())
return
}
defer db2.Close()
rows, err := db2.Query("select 1")
assert.NoError(t, err)
var i int
for rows.Next() {
err := rows.Scan(&i)
assert.NoError(t, err)
assert.Equal(t, 1, i)
}
if i != 1 {
t.Errorf("query failed")
}
})
}
}
Loading

0 comments on commit 5c95f6d

Please sign in to comment.