Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enh: support special characters #311

Merged
merged 4 commits into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions taosRestful/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ import (
var jsonI = jsoniter.ConfigCompatibleWithStandardLibrary

type taosConn struct {
cfg *config
cfg *Config
client *http.Client
url *url.URL
baseRawQuery string
header map[string][]string
readBufferSize int
}

func newTaosConn(cfg *config) (*taosConn, error) {
readBufferSize := cfg.readBufferSize
func newTaosConn(cfg *Config) (*taosConn, error) {
readBufferSize := cfg.ReadBufferSize
if readBufferSize <= 0 {
readBufferSize = 4 << 10
}
Expand All @@ -47,9 +47,9 @@ func newTaosConn(cfg *config) (*taosConn, error) {
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DisableCompression: cfg.disableCompression,
DisableCompression: cfg.DisableCompression,
}
if cfg.skipVerify {
if cfg.SkipVerify {
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
Expand All @@ -58,24 +58,24 @@ func newTaosConn(cfg *config) (*taosConn, error) {
Transport: transport,
}
path := "/rest/sql"
if len(cfg.dbName) != 0 {
path = fmt.Sprintf("%s/%s", path, cfg.dbName)
if len(cfg.DbName) != 0 {
path = fmt.Sprintf("%s/%s", path, cfg.DbName)
}
tc.url = &url.URL{
Scheme: cfg.net,
Host: fmt.Sprintf("%s:%d", cfg.addr, cfg.port),
Scheme: cfg.Net,
Host: fmt.Sprintf("%s:%d", cfg.Addr, cfg.Port),
Path: path,
}
tc.header = map[string][]string{
"Connection": {"keep-alive"},
}
if cfg.token != "" {
tc.baseRawQuery = fmt.Sprintf("token=%s", cfg.token)
if cfg.Token != "" {
tc.baseRawQuery = fmt.Sprintf("token=%s", cfg.Token)
} else {
basic := base64.StdEncoding.EncodeToString([]byte(cfg.user + ":" + cfg.passwd))
basic := base64.StdEncoding.EncodeToString([]byte(cfg.User + ":" + cfg.Passwd))
tc.header["Authorization"] = []string{fmt.Sprintf("Basic %s", basic)}
}
if !cfg.disableCompression {
if !cfg.DisableCompression {
tc.header["Accept-Encoding"] = []string{"gzip"}
}
return tc, nil
Expand Down Expand Up @@ -103,7 +103,7 @@ func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver

func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if len(args) != 0 {
if !tc.cfg.interpolateParams {
if !tc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try to interpolate the parameters to save extra round trips for preparing and closing a statement
Expand All @@ -129,7 +129,7 @@ func (tc *taosConn) QueryContext(ctx context.Context, query string, args []drive

func (tc *taosConn) queryCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
if len(args) != 0 {
if !tc.cfg.interpolateParams {
if !tc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try client-side prepare to reduce round trip
Expand Down Expand Up @@ -202,7 +202,7 @@ func (tc *taosConn) taosQuery(ctx context.Context, sql string, bufferSize int) (
}
respBody := resp.Body
defer ioutil.ReadAll(respBody)
if !tc.cfg.disableCompression && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
if !tc.cfg.DisableCompression && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
respBody, err = gzip.NewReader(resp.Body)
if err != nil {
return nil, err
Expand Down
22 changes: 11 additions & 11 deletions taosRestful/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,27 @@ import (
)

type connector struct {
cfg *config
cfg *Config
}

// Connect implements driver.Connector interface.
// Connect returns a connection to the database.
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
// Connect to Server
if len(c.cfg.user) == 0 {
c.cfg.user = common.DefaultUser
if len(c.cfg.User) == 0 {
c.cfg.User = common.DefaultUser
}
if len(c.cfg.passwd) == 0 {
c.cfg.passwd = common.DefaultPassword
if len(c.cfg.Passwd) == 0 {
c.cfg.Passwd = common.DefaultPassword
}
if c.cfg.port == 0 {
c.cfg.port = common.DefaultHttpPort
if c.cfg.Port == 0 {
c.cfg.Port = common.DefaultHttpPort
}
if len(c.cfg.net) == 0 {
c.cfg.net = "http"
if len(c.cfg.Net) == 0 {
c.cfg.Net = "http"
}
if len(c.cfg.addr) == 0 {
c.cfg.addr = "127.0.0.1"
if len(c.cfg.Addr) == 0 {
c.cfg.Addr = "127.0.0.1"
}
tc, err := newTaosConn(c.cfg)
return tc, err
Expand Down
2 changes: 1 addition & 1 deletion taosRestful/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ func TestSSL(t *testing.T) {

func TestConnect(t *testing.T) {
conn := connector{
cfg: &config{},
cfg: &Config{},
}
db, err := conn.Connect(context.Background())
assert.NoError(t, err)
Expand Down
18 changes: 17 additions & 1 deletion taosRestful/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type TDengineDriver struct{}
// Open new Connection.
// the DSN string is formatted
func (d TDengineDriver) Open(dsn string) (driver.Conn, error) {
cfg, err := parseDSN(dsn)
cfg, err := ParseDSN(dsn)
if err != nil {
return nil, err
}
Expand All @@ -26,3 +26,19 @@ func (d TDengineDriver) Open(dsn string) (driver.Conn, error) {
func init() {
sql.Register("taosRestful", &TDengineDriver{})
}

// NewConnector returns new driver.Connector.
func NewConnector(cfg *Config) (driver.Connector, error) {
return &connector{cfg: cfg}, nil
}

// OpenConnector implements driver.DriverContext.
func (d TDengineDriver) OpenConnector(dsn string) (driver.Connector, error) {
cfg, err := ParseDSN(dsn)
if err != nil {
return nil, err
}
return &connector{
cfg: cfg,
}, nil
}
90 changes: 90 additions & 0 deletions taosRestful/driver_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package taosRestful

import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"log"
"net/url"
"testing"
"time"

Expand Down Expand Up @@ -289,3 +291,91 @@ func TestChinese(t *testing.T) {
}
assert.Equal(t, 1, counter)
}

func TestNewConnector(t *testing.T) {
cfg, err := ParseDSN(dataSourceName)
assert.NoError(t, err)
conn, err := NewConnector(cfg)
assert.NoError(t, err)
db := sql.OpenDB(conn)
defer func() {
err := db.Close()
assert.NoError(t, err)
}()
if err := db.Ping(); err != nil {
t.Fatal(err)
}
}

func TestOpen(t *testing.T) {
tdDriver := &TDengineDriver{}
conn, err := tdDriver.Open(dataSourceName)
assert.NoError(t, err)
defer func() {
err := conn.Close()
assert.NoError(t, err)
}()
pinger := conn.(driver.Pinger)
err = pinger.Ping(context.Background())
assert.NoError(t, err)
}

func TestSpecialPassword(t *testing.T) {
db, err := sql.Open(driverName, dataSourceName)
if err != nil {
t.Fatalf("error on: sql.open %s", err.Error())
return
}
defer db.Close()
tests := []struct {
name string
user string
pass string
}{
{
name: "test_special1_rs",
user: "test_special1_rs",
pass: "!q@w#a$1%3^&*()-",
},
{
name: "test_special2_rs",
user: "test_special2_rs",
pass: "_q+3=[]{}:;><?|~",
},
{
name: "test_special3_rs",
user: "test_special3_rs",
pass: "1><3?|~,w.",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
defer func() {
dropSql := fmt.Sprintf("drop user %s", test.user)
_, _ = db.Exec(dropSql)
}()
createSql := fmt.Sprintf("create user %s pass '%s'", test.user, test.pass)
_, err := db.Exec(createSql)
assert.NoError(t, err)
escapedPass := url.QueryEscape(test.pass)
newDsn := fmt.Sprintf("%s:%s@http(%s:%d)/%s", test.user, escapedPass, host, port, "")
db2, err := sql.Open(driverName, newDsn)
if err != nil {
t.Errorf("error on: sql.open %s", err.Error())
return
}
defer db2.Close()
rows, err := db2.Query("select 1")
assert.NoError(t, err)
var i int
for rows.Next() {
err := rows.Scan(&i)
assert.NoError(t, err)
assert.Equal(t, 1, i)
}
if i != 1 {
t.Errorf("query failed")
}
})
}
}
Loading
Loading