Skip to content

Commit

Permalink
enh: support special characters
Browse files Browse the repository at this point in the history
  • Loading branch information
huskar-t committed Dec 23, 2024
1 parent c5c3dac commit cac65df
Show file tree
Hide file tree
Showing 21 changed files with 513 additions and 289 deletions.
32 changes: 16 additions & 16 deletions taosRestful/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ import (
var jsonI = jsoniter.ConfigCompatibleWithStandardLibrary

type taosConn struct {
cfg *config
cfg *Config
client *http.Client
url *url.URL
baseRawQuery string
header map[string][]string
readBufferSize int
}

func newTaosConn(cfg *config) (*taosConn, error) {
readBufferSize := cfg.readBufferSize
func newTaosConn(cfg *Config) (*taosConn, error) {
readBufferSize := cfg.ReadBufferSize
if readBufferSize <= 0 {
readBufferSize = 4 << 10
}
Expand All @@ -47,9 +47,9 @@ func newTaosConn(cfg *config) (*taosConn, error) {
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DisableCompression: cfg.disableCompression,
DisableCompression: cfg.DisableCompression,
}
if cfg.skipVerify {
if cfg.SkipVerify {
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
Expand All @@ -58,24 +58,24 @@ func newTaosConn(cfg *config) (*taosConn, error) {
Transport: transport,
}
path := "/rest/sql"
if len(cfg.dbName) != 0 {
path = fmt.Sprintf("%s/%s", path, cfg.dbName)
if len(cfg.DbName) != 0 {
path = fmt.Sprintf("%s/%s", path, cfg.DbName)
}
tc.url = &url.URL{
Scheme: cfg.net,
Host: fmt.Sprintf("%s:%d", cfg.addr, cfg.port),
Scheme: cfg.Net,
Host: fmt.Sprintf("%s:%d", cfg.Addr, cfg.Port),
Path: path,
}
tc.header = map[string][]string{
"Connection": {"keep-alive"},
}
if cfg.token != "" {
tc.baseRawQuery = fmt.Sprintf("token=%s", cfg.token)
if cfg.Token != "" {
tc.baseRawQuery = fmt.Sprintf("token=%s", cfg.Token)
} else {
basic := base64.StdEncoding.EncodeToString([]byte(cfg.user + ":" + cfg.passwd))
basic := base64.StdEncoding.EncodeToString([]byte(cfg.User + ":" + cfg.Passwd))
tc.header["Authorization"] = []string{fmt.Sprintf("Basic %s", basic)}
}
if !cfg.disableCompression {
if !cfg.DisableCompression {
tc.header["Accept-Encoding"] = []string{"gzip"}
}
return tc, nil
Expand Down Expand Up @@ -103,7 +103,7 @@ func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver

func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if len(args) != 0 {
if !tc.cfg.interpolateParams {
if !tc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try to interpolate the parameters to save extra round trips for preparing and closing a statement
Expand All @@ -129,7 +129,7 @@ func (tc *taosConn) QueryContext(ctx context.Context, query string, args []drive

func (tc *taosConn) queryCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
if len(args) != 0 {
if !tc.cfg.interpolateParams {
if !tc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try client-side prepare to reduce round trip
Expand Down Expand Up @@ -202,7 +202,7 @@ func (tc *taosConn) taosQuery(ctx context.Context, sql string, bufferSize int) (
}
respBody := resp.Body
defer ioutil.ReadAll(respBody)
if !tc.cfg.disableCompression && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
if !tc.cfg.DisableCompression && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
respBody, err = gzip.NewReader(resp.Body)
if err != nil {
return nil, err
Expand Down
22 changes: 11 additions & 11 deletions taosRestful/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,27 @@ import (
)

type connector struct {
cfg *config
cfg *Config
}

// Connect implements driver.Connector interface.
// Connect returns a connection to the database.
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
// Connect to Server
if len(c.cfg.user) == 0 {
c.cfg.user = common.DefaultUser
if len(c.cfg.User) == 0 {
c.cfg.User = common.DefaultUser
}
if len(c.cfg.passwd) == 0 {
c.cfg.passwd = common.DefaultPassword
if len(c.cfg.Passwd) == 0 {
c.cfg.Passwd = common.DefaultPassword
}
if c.cfg.port == 0 {
c.cfg.port = common.DefaultHttpPort
if c.cfg.Port == 0 {
c.cfg.Port = common.DefaultHttpPort
}
if len(c.cfg.net) == 0 {
c.cfg.net = "http"
if len(c.cfg.Net) == 0 {
c.cfg.Net = "http"
}
if len(c.cfg.addr) == 0 {
c.cfg.addr = "127.0.0.1"
if len(c.cfg.Addr) == 0 {
c.cfg.Addr = "127.0.0.1"
}
tc, err := newTaosConn(c.cfg)
return tc, err
Expand Down
2 changes: 1 addition & 1 deletion taosRestful/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ func TestSSL(t *testing.T) {

func TestConnect(t *testing.T) {
conn := connector{
cfg: &config{},
cfg: &Config{},
}
db, err := conn.Connect(context.Background())
assert.NoError(t, err)
Expand Down
18 changes: 17 additions & 1 deletion taosRestful/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type TDengineDriver struct{}
// Open new Connection.
// the DSN string is formatted
func (d TDengineDriver) Open(dsn string) (driver.Conn, error) {
cfg, err := parseDSN(dsn)
cfg, err := ParseDSN(dsn)
if err != nil {
return nil, err
}
Expand All @@ -26,3 +26,19 @@ func (d TDengineDriver) Open(dsn string) (driver.Conn, error) {
func init() {
sql.Register("taosRestful", &TDengineDriver{})
}

// NewConnector returns new driver.Connector.
func NewConnector(cfg *Config) (driver.Connector, error) {
return &connector{cfg: cfg}, nil
}

// OpenConnector implements driver.DriverContext.
func (d TDengineDriver) OpenConnector(dsn string) (driver.Connector, error) {
cfg, err := ParseDSN(dsn)
if err != nil {
return nil, err
}
return &connector{
cfg: cfg,
}, nil
}
15 changes: 15 additions & 0 deletions taosRestful/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,18 @@ func TestChinese(t *testing.T) {
}
assert.Equal(t, 1, counter)
}

func TestNewConnector(t *testing.T) {
cfg, err := ParseDSN(dataSourceName)
assert.NoError(t, err)
conn, err := NewConnector(cfg)
assert.NoError(t, err)
db := sql.OpenDB(conn)
defer func() {
err := db.Close()
assert.NoError(t, err)
}()
if err := db.Ping(); err != nil {
t.Fatal(err)
}
}
93 changes: 50 additions & 43 deletions taosRestful/dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,43 +9,43 @@ import (
)

var (
errInvalidDSNUnescaped = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: did you forget to escape a param value?"}
errInvalidDSNAddr = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network address not terminated (missing closing brace)"}
errInvalidDSNPort = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network port is not a valid number"}
errInvalidDSNNoSlash = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: missing the slash separating the database name"}
ErrInvalidDSNUnescaped = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: did you forget to escape a param value?"}
ErrInvalidDSNAddr = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network address not terminated (missing closing brace)"}
ErrInvalidDSNPort = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network port is not a valid number"}
ErrInvalidDSNNoSlash = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: missing the slash separating the database name"}
)

// Config is a configuration parsed from a DSN string.
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type config struct {
user string // Username
passwd string // Password (requires User)
net string // Network type
addr string // Network address (requires Net)
port int
dbName string // Database name
params map[string]string // Connection parameters
interpolateParams bool // Interpolate placeholders into query string
disableCompression bool
readBufferSize int
token string // cloud platform token
skipVerify bool
type Config struct {
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
Port int
DbName string // Database name
Params map[string]string // Connection parameters
InterpolateParams bool // Interpolate placeholders into query string
DisableCompression bool
ReadBufferSize int
Token string // cloud platform Token
SkipVerify bool
}

// NewConfig creates a new Config and sets default values.
func newConfig() *config {
return &config{
interpolateParams: true,
disableCompression: true,
readBufferSize: 4 << 10,
func NewConfig() *Config {
return &Config{
InterpolateParams: true,
DisableCompression: true,
ReadBufferSize: 4 << 10,
}
}

// ParseDSN parses the DSN string to a Config
func parseDSN(dsn string) (cfg *config, err error) {
func ParseDSN(dsn string) (cfg *Config, err error) {
// New config with some default values
cfg = newConfig()
cfg = NewConfig()

// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
// Find the last '/' (since the password or the net addr might contain a '/')
Expand All @@ -65,11 +65,11 @@ func parseDSN(dsn string) (cfg *config, err error) {
// Find the first ':' in dsn[:j]
for k = 0; k < j; k++ {
if dsn[k] == ':' {
cfg.passwd = dsn[k+1 : j]
cfg.Passwd = tryUnescape(dsn[k+1 : j])
break
}
}
cfg.user = dsn[:k]
cfg.User = tryUnescape(dsn[:k])

break
}
Expand All @@ -82,25 +82,25 @@ func parseDSN(dsn string) (cfg *config, err error) {
// dsn[i-1] must be == ')' if an address is specified
if dsn[i-1] != ')' {
if strings.ContainsRune(dsn[k+1:i], ')') {
return nil, errInvalidDSNUnescaped
return nil, ErrInvalidDSNUnescaped
}
//return nil, errInvalidDSNAddr
}
strList := strings.Split(dsn[k+1:i-1], ":")
if len(strList) == 1 {
return nil, errInvalidDSNAddr
return nil, ErrInvalidDSNAddr
}
if len(strList[0]) != 0 {
cfg.addr = strList[0]
cfg.port, err = strconv.Atoi(strList[1])
cfg.Addr = strList[0]
cfg.Port, err = strconv.Atoi(strList[1])
if err != nil {
return nil, errInvalidDSNPort
return nil, ErrInvalidDSNPort
}
}
break
}
}
cfg.net = dsn[j+1 : k]
cfg.Net = dsn[j+1 : k]
}

// dbname[?param1=value1&...&paramN=valueN]
Expand All @@ -113,22 +113,22 @@ func parseDSN(dsn string) (cfg *config, err error) {
break
}
}
cfg.dbName = dsn[i+1 : j]
cfg.DbName = dsn[i+1 : j]

break
}
}

if !foundSlash && len(dsn) > 0 {
return nil, errInvalidDSNNoSlash
return nil, ErrInvalidDSNNoSlash
}

return
}

// parseDSNParams parses the DSN "query string"
// Values must be url.QueryEscape'ed
func parseDSNParams(cfg *config, params string) (err error) {
func parseDSNParams(cfg *Config, params string) (err error) {
for _, v := range strings.Split(params, "&") {
param := strings.SplitN(v, "=", 2)
if len(param) != 2 {
Expand All @@ -139,38 +139,45 @@ func parseDSNParams(cfg *config, params string) (err error) {
switch value := param[1]; param[0] {
// Enable client side placeholder substitution
case "interpolateParams":
cfg.interpolateParams, err = strconv.ParseBool(value)
cfg.InterpolateParams, err = strconv.ParseBool(value)
if err != nil {
return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value}
}
case "disableCompression":
cfg.disableCompression, err = strconv.ParseBool(value)
cfg.DisableCompression, err = strconv.ParseBool(value)
if err != nil {
return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value}
}
case "readBufferSize":
cfg.readBufferSize, err = strconv.Atoi(value)
cfg.ReadBufferSize, err = strconv.Atoi(value)
if err != nil {
return &errors.TaosError{Code: 0xffff, ErrStr: "invalid int value: " + value}
}
case "token":
cfg.token = value
cfg.Token = value
case "skipVerify":
cfg.skipVerify, err = strconv.ParseBool(value)
cfg.SkipVerify, err = strconv.ParseBool(value)
if err != nil {
return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value}
}
default:
// lazy init
if cfg.params == nil {
cfg.params = make(map[string]string)
if cfg.Params == nil {
cfg.Params = make(map[string]string)
}

if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil {
if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil {
return
}
}
}

return
}

func tryUnescape(s string) string {
if res, err := url.QueryUnescape(s); err == nil {
return res
}
return s
}
Loading

0 comments on commit cac65df

Please sign in to comment.