From 5ad9ee5f2df6f81e9043e29294a6f6822cba2c66 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Sun, 29 Sep 2024 14:40:21 +0800 Subject: [PATCH 01/35] enh: tmq add config `session.timeout.ms` and `max.poll.interval.ms` --- ws/tmq/config.go | 10 ++++++++++ ws/tmq/consumer.go | 34 ++++++++++++++++++++++++---------- ws/tmq/consumer_test.go | 2 ++ ws/tmq/proto.go | 2 ++ 4 files changed, 38 insertions(+), 10 deletions(-) diff --git a/ws/tmq/config.go b/ws/tmq/config.go index e119dcf..1f90ee7 100644 --- a/ws/tmq/config.go +++ b/ws/tmq/config.go @@ -23,6 +23,8 @@ type config struct { AutoReconnect bool ReconnectIntervalMs int ReconnectRetryCount int + SessionTimeoutMS string + MaxPollIntervalMS string } func newConfig(url string, chanLength uint) *config { @@ -99,3 +101,11 @@ func (c *config) setReconnectIntervalMs(reconnectIntervalMs int) { func (c *config) setReconnectRetryCount(reconnectRetryCount int) { c.ReconnectRetryCount = reconnectRetryCount } + +func (c *config) setSessionTimeoutMS(sessionTimeoutMS string) { + c.SessionTimeoutMS = sessionTimeoutMS +} + +func (c *config) setMaxPollIntervalMS(maxPollIntervalMS string) { + c.MaxPollIntervalMS = maxPollIntervalMS +} diff --git a/ws/tmq/consumer.go b/ws/tmq/consumer.go index a6edbac..b7fbe4f 100644 --- a/ws/tmq/consumer.go +++ b/ws/tmq/consumer.go @@ -42,6 +42,8 @@ type Consumer struct { offsetRest string snapshotEnable string withTableName string + sessionTimeoutMS string + maxPollIntervalMS string closeOnce sync.Once closeChan chan struct{} topics []string @@ -243,6 +245,14 @@ func configMapToConfig(m *tmq.ConfigMap) (*config, error) { if err != nil { return nil, err } + sessionTimeoutMS, err := m.Get("session.timeout.ms", "") + if err != nil { + return nil, err + } + maxPollIntervalMS, err := m.Get("max.poll.interval.ms", "") + if err != nil { + return nil, err + } config := newConfig(url.(string), chanLen.(uint)) err = config.setMessageTimeout(messageTimeout.(time.Duration)) if err != nil { @@ -265,6 +275,8 @@ func configMapToConfig(m *tmq.ConfigMap) (*config, error) { config.setAutoReconnect(autoReconnect.(bool)) config.setReconnectIntervalMs(reconnectIntervalMs.(int)) config.setReconnectRetryCount(reconnectRetryCount.(int)) + config.setSessionTimeoutMS(sessionTimeoutMS.(string)) + config.setMaxPollIntervalMS(maxPollIntervalMS.(string)) return config, nil } @@ -417,16 +429,18 @@ func (c *Consumer) doSubscribe(topics []string, reconnect bool) error { } reqID := c.generateReqID() req := &SubscribeReq{ - ReqID: reqID, - User: c.user, - Password: c.password, - GroupID: c.groupID, - ClientID: c.clientID, - OffsetRest: c.offsetRest, - Topics: topics, - AutoCommit: "false", - SnapshotEnable: c.snapshotEnable, - WithTableName: c.withTableName, + ReqID: reqID, + User: c.user, + Password: c.password, + GroupID: c.groupID, + ClientID: c.clientID, + OffsetRest: c.offsetRest, + Topics: topics, + AutoCommit: "false", + SnapshotEnable: c.snapshotEnable, + WithTableName: c.withTableName, + SessionTimeoutMS: c.sessionTimeoutMS, + MaxPollIntervalMS: c.maxPollIntervalMS, } args, err := client.JsonI.Marshal(req) if err != nil { diff --git a/ws/tmq/consumer_test.go b/ws/tmq/consumer_test.go index 37dd34b..4ab7a98 100644 --- a/ws/tmq/consumer_test.go +++ b/ws/tmq/consumer_test.go @@ -141,6 +141,8 @@ func TestConsumer(t *testing.T) { "enable.auto.commit": "true", "auto.commit.interval.ms": "5000", "msg.with.table.name": "true", + "session.timeout.ms": "12000", + "max.poll.interval.ms": "300000", }) if err != nil { t.Error(err) diff --git a/ws/tmq/proto.go b/ws/tmq/proto.go index 3a17c8b..ce9d501 100644 --- a/ws/tmq/proto.go +++ b/ws/tmq/proto.go @@ -19,6 +19,8 @@ type SubscribeReq struct { AutoCommitIntervalMS string `json:"auto_commit_interval_ms"` SnapshotEnable string `json:"snapshot_enable"` WithTableName string `json:"with_table_name"` + SessionTimeoutMS string `json:"session_timeout_ms"` + MaxPollIntervalMS string `json:"max_poll_interval_ms"` } type SubscribeResp struct { From 947079db952524ce85af104aebcf2ec04cd37367 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 10 Oct 2024 10:24:59 +0800 Subject: [PATCH 02/35] enh: support int64 timestamp and avoid memory check exceptions --- common/stmt/stmt2.go | 13 +-- common/stmt/stmt2_test.go | 50 ++++++++++++ wrapper/stmt2.go | 167 +++++++++++++++++++++++--------------- wrapper/stmt2_test.go | 20 +++++ 4 files changed, 180 insertions(+), 70 deletions(-) diff --git a/common/stmt/stmt2.go b/common/stmt/stmt2.go index f169d95..82e92f9 100644 --- a/common/stmt/stmt2.go +++ b/common/stmt/stmt2.go @@ -346,12 +346,15 @@ func generateBindColData(data []driver.Value, colType *StmtField, tmpBuffer *byt isNull[i] = 1 writeUint64(tmpBuffer, 0) } else { - v, ok := data[i].(time.Time) - if !ok { - return nil, fmt.Errorf("data type not match, expect time.Time, but get %T, value:%v", data[i], data[i]) + switch v := data[i].(type) { + case int64: + writeUint64(tmpBuffer, uint64(v)) + case time.Time: + ts := common.TimeToTimestamp(v, precision) + writeUint64(tmpBuffer, uint64(ts)) + default: + return nil, fmt.Errorf("data type not match, expect int64 or time.Time, but get %T, value:%v", data[i], data[i]) } - ts := common.TimeToTimestamp(v, precision) - writeUint64(tmpBuffer, uint64(ts)) } } case common.TSDB_DATA_TYPE_BINARY, common.TSDB_DATA_TYPE_NCHAR, common.TSDB_DATA_TYPE_VARBINARY, common.TSDB_DATA_TYPE_GEOMETRY, common.TSDB_DATA_TYPE_JSON: diff --git a/common/stmt/stmt2_test.go b/common/stmt/stmt2_test.go index 686aa3a..dfd10a0 100644 --- a/common/stmt/stmt2_test.go +++ b/common/stmt/stmt2_test.go @@ -2373,6 +2373,52 @@ func TestMarshalBinary(t *testing.T) { want: nil, wantErr: true, }, + { + name: "int64 timestamp", + args: args{ + t: []*TaosStmt2BindData{{ + Tags: []driver.Value{int64(1726803356466)}, + }}, + isInsert: true, + tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_TIMESTAMP}}, + colType: nil, + }, + want: []byte{ + // total Length + 0x3a, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x01, 0x00, 0x00, 0x00, + // ColCount + 0x00, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x00, 0x00, 0x00, 0x00, + // TagsOffset + 0x1c, 0x00, 0x00, 0x00, + // ColOffset + 0x00, 0x00, 0x00, 0x00, + // tags + // table length + 0x1a, 0x00, 0x00, 0x00, + //table 0 tags + //tag 0 + //total length + 0x1a, 0x00, 0x00, 0x00, + //type + 0x09, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x00, + // buffer length + 0x08, 0x00, 0x00, 0x00, + 0x32, 0x2b, 0x80, 0x0d, 0x92, 0x01, 0x00, 0x00, + }, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -2385,3 +2431,7 @@ func TestMarshalBinary(t *testing.T) { }) } } + +func TestT(t *testing.T) { + +} diff --git a/wrapper/stmt2.go b/wrapper/stmt2.go index 67a7cff..c9cd3f0 100644 --- a/wrapper/stmt2.go +++ b/wrapper/stmt2.go @@ -52,9 +52,9 @@ func TaosStmt2BindParam(stmt unsafe.Pointer, isInsert bool, params []*stmt.TaosS tbNames := unsafe.Pointer(C.malloc(C.size_t(count) * C.size_t(PointerSize))) needFreePointer := []unsafe.Pointer{tbNames} defer func() { - for _, p := range needFreePointer { - if p != nil { - C.free(p) + for i := len(needFreePointer) - 1; i >= 0; i-- { + if needFreePointer[i] != nil { + C.free(needFreePointer[i]) } } }() @@ -89,11 +89,12 @@ func TaosStmt2BindParam(stmt unsafe.Pointer, isInsert bool, params []*stmt.TaosS for j := 0; j < len(param.Tags); j++ { columnFormatTags[j] = []driver.Value{param.Tags[j]} } - tags, err := generateTaosStmt2BindsInsert(columnFormatTags, tagTypes, &needFreePointer) + tags, freePointer, err := generateTaosStmt2BindsInsert(columnFormatTags, tagTypes) + needFreePointer = append(needFreePointer, freePointer...) if err != nil { return taosError.NewError(0xffff, fmt.Sprintf("generate tags Bindv struct error: %s", err.Error())) } - *(**C.TAOS_STMT2_BIND)(currentTagP) = &(tags[0]) + *(**C.TAOS_STMT2_BIND)(currentTagP) = (*C.TAOS_STMT2_BIND)(tags) } else { *(**C.TAOS_STMT2_BIND)(currentTagP) = nil } @@ -101,16 +102,18 @@ func TaosStmt2BindParam(stmt unsafe.Pointer, isInsert bool, params []*stmt.TaosS currentColP = pointer.AddUintptr(colList, uintptr(i)*PointerSize) if len(param.Cols) > 0 { var err error - var cols []C.TAOS_STMT2_BIND + var cols unsafe.Pointer + var freePointer []unsafe.Pointer if isInsert { - cols, err = generateTaosStmt2BindsInsert(param.Cols, colTypes, &needFreePointer) + cols, freePointer, err = generateTaosStmt2BindsInsert(param.Cols, colTypes) } else { - cols, err = generateTaosStmt2BindsQuery(param.Cols, &needFreePointer) + cols, freePointer, err = generateTaosStmt2BindsQuery(param.Cols) } + needFreePointer = append(needFreePointer, freePointer...) if err != nil { return taosError.NewError(0xffff, fmt.Sprintf("generate cols Bindv struct error: %s", err.Error())) } - *(**C.TAOS_STMT2_BIND)(currentColP) = &(cols[0]) + *(**C.TAOS_STMT2_BIND)(currentColP) = (*C.TAOS_STMT2_BIND)(cols) } else { *(**C.TAOS_STMT2_BIND)(currentColP) = nil } @@ -118,7 +121,6 @@ func TaosStmt2BindParam(stmt unsafe.Pointer, isInsert bool, params []*stmt.TaosS cBindv.bind_cols = (**C.TAOS_STMT2_BIND)(unsafe.Pointer(colList)) cBindv.tags = (**C.TAOS_STMT2_BIND)(unsafe.Pointer(tagList)) cBindv.tbnames = (**C.char)(tbNames) - code := int(C.taos_stmt2_bind_param(stmt, &cBindv, C.int32_t(colIdx))) if code != 0 { errStr := TaosStmt2Error(stmt) @@ -127,30 +129,32 @@ func TaosStmt2BindParam(stmt unsafe.Pointer, isInsert bool, params []*stmt.TaosS return nil } -func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt.StmtField, needFreePointer *[]unsafe.Pointer) ([]C.TAOS_STMT2_BIND, error) { +func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt.StmtField) (unsafe.Pointer, []unsafe.Pointer, error) { + var needFreePointer []unsafe.Pointer if len(multiBind) != len(fieldTypes) { - return nil, fmt.Errorf("data and type length not match, data length: %d, type length: %d", len(multiBind), len(fieldTypes)) + return nil, needFreePointer, fmt.Errorf("data and type length not match, data length: %d, type length: %d", len(multiBind), len(fieldTypes)) } - binds := make([]C.TAOS_STMT2_BIND, len(multiBind)) + binds := unsafe.Pointer(C.malloc(C.size_t(C.size_t(len(multiBind)) * C.size_t(unsafe.Sizeof(C.TAOS_STMT2_BIND{}))))) + needFreePointer = append(needFreePointer, binds) rowLen := len(multiBind[0]) for columnIndex, columnData := range multiBind { if len(multiBind[columnIndex]) != rowLen { - return nil, fmt.Errorf("data length not match, column %d data length: %d, expect: %d", columnIndex, len(multiBind[columnIndex]), rowLen) + return nil, needFreePointer, fmt.Errorf("data length not match, column %d data length: %d, expect: %d", columnIndex, len(multiBind[columnIndex]), rowLen) } - bind := C.TAOS_STMT2_BIND{} + bind := (*C.TAOS_STMT2_BIND)(unsafe.Pointer(uintptr(binds) + uintptr(columnIndex)*unsafe.Sizeof(C.TAOS_STMT2_BIND{}))) bind.num = C.int(rowLen) nullList := unsafe.Pointer(C.malloc(C.size_t(C.uint(rowLen)))) - *needFreePointer = append(*needFreePointer, nullList) + needFreePointer = append(needFreePointer, nullList) lengthList := unsafe.Pointer(C.calloc(C.size_t(C.uint(rowLen)), C.size_t(C.uint(4)))) - *needFreePointer = append(*needFreePointer, lengthList) + needFreePointer = append(needFreePointer, lengthList) var p unsafe.Pointer - *needFreePointer = append(*needFreePointer, p) columnType := fieldTypes[columnIndex].FieldType precision := int(fieldTypes[columnIndex].Precision) switch columnType { case common.TSDB_DATA_TYPE_BOOL: //1 p = unsafe.Pointer(C.malloc(C.size_t(C.uint(rowLen)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_BOOL for i, rowData := range columnData { currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) @@ -160,7 +164,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt *(*C.char)(currentNull) = C.char(0) value, ok := rowData.(bool) if !ok { - return nil, fmt.Errorf("data type error, expect bool, but got %T, value: %v", rowData, value) + return nil, needFreePointer, fmt.Errorf("data type error, expect bool, but got %T, value: %v", rowData, value) } current := unsafe.Pointer(uintptr(p) + uintptr(i)) if value { @@ -176,6 +180,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt case common.TSDB_DATA_TYPE_TINYINT: //1 p = unsafe.Pointer(C.malloc(C.size_t(C.uint(rowLen)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_TINYINT for i, rowData := range columnData { currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) @@ -185,7 +190,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt *(*C.char)(currentNull) = C.char(0) value, ok := rowData.(int8) if !ok { - return nil, fmt.Errorf("data type error, expect int8, but got %T, value: %v", rowData, value) + return nil, needFreePointer, fmt.Errorf("data type error, expect int8, but got %T, value: %v", rowData, value) } current := unsafe.Pointer(uintptr(p) + uintptr(i)) *(*C.int8_t)(current) = C.int8_t(value) @@ -197,6 +202,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt case common.TSDB_DATA_TYPE_SMALLINT: //2 p = unsafe.Pointer(C.malloc(C.size_t(C.uint(2 * rowLen)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_SMALLINT for i, rowData := range columnData { currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) @@ -206,7 +212,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt *(*C.char)(currentNull) = C.char(0) value, ok := rowData.(int16) if !ok { - return nil, fmt.Errorf("data type error, expect int16, but got %T, value: %v", rowData, value) + return nil, needFreePointer, fmt.Errorf("data type error, expect int16, but got %T, value: %v", rowData, value) } current := unsafe.Pointer(uintptr(p) + uintptr(2*i)) *(*C.int16_t)(current) = C.int16_t(value) @@ -218,6 +224,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt case common.TSDB_DATA_TYPE_INT: //4 p = unsafe.Pointer(C.malloc(C.size_t(C.uint(4 * rowLen)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_INT for i, rowData := range columnData { currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) @@ -227,7 +234,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt *(*C.char)(currentNull) = C.char(0) value, ok := rowData.(int32) if !ok { - return nil, fmt.Errorf("data type error, expect int32, but got %T, value: %v", rowData, value) + return nil, needFreePointer, fmt.Errorf("data type error, expect int32, but got %T, value: %v", rowData, value) } current := unsafe.Pointer(uintptr(p) + uintptr(4*i)) *(*C.int32_t)(current) = C.int32_t(value) @@ -239,6 +246,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt case common.TSDB_DATA_TYPE_BIGINT: //8 p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8 * rowLen)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_BIGINT for i, rowData := range columnData { currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) @@ -248,7 +256,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt *(*C.char)(currentNull) = C.char(0) value, ok := rowData.(int64) if !ok { - return nil, fmt.Errorf("data type error, expect int64, but got %T, value: %v", rowData, value) + return nil, needFreePointer, fmt.Errorf("data type error, expect int64, but got %T, value: %v", rowData, value) } current := unsafe.Pointer(uintptr(p) + uintptr(8*i)) *(*C.int64_t)(current) = C.int64_t(value) @@ -260,6 +268,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt case common.TSDB_DATA_TYPE_UTINYINT: //1 p = unsafe.Pointer(C.malloc(C.size_t(C.uint(rowLen)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_UTINYINT for i, rowData := range columnData { currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) @@ -269,7 +278,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt *(*C.char)(currentNull) = C.char(0) value, ok := rowData.(uint8) if !ok { - return nil, fmt.Errorf("data type error, expect uint8, but got %T, value: %v", rowData, value) + return nil, needFreePointer, fmt.Errorf("data type error, expect uint8, but got %T, value: %v", rowData, value) } current := unsafe.Pointer(uintptr(p) + uintptr(i)) *(*C.uint8_t)(current) = C.uint8_t(value) @@ -281,6 +290,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt case common.TSDB_DATA_TYPE_USMALLINT: //2 p = unsafe.Pointer(C.malloc(C.size_t(C.uint(2 * rowLen)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_USMALLINT for i, rowData := range columnData { currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) @@ -290,7 +300,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt *(*C.char)(currentNull) = C.char(0) value, ok := rowData.(uint16) if !ok { - return nil, fmt.Errorf("data type error, expect uint16, but got %T, value: %v", rowData, value) + return nil, needFreePointer, fmt.Errorf("data type error, expect uint16, but got %T, value: %v", rowData, value) } current := unsafe.Pointer(uintptr(p) + uintptr(2*i)) *(*C.uint16_t)(current) = C.uint16_t(value) @@ -302,6 +312,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt case common.TSDB_DATA_TYPE_UINT: //4 p = unsafe.Pointer(C.malloc(C.size_t(C.uint(4 * rowLen)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_UINT for i, rowData := range columnData { currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) @@ -311,7 +322,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt *(*C.char)(currentNull) = C.char(0) value, ok := rowData.(uint32) if !ok { - return nil, fmt.Errorf("data type error, expect uint32, but got %T, value: %v", rowData, value) + return nil, needFreePointer, fmt.Errorf("data type error, expect uint32, but got %T, value: %v", rowData, value) } current := unsafe.Pointer(uintptr(p) + uintptr(4*i)) *(*C.uint32_t)(current) = C.uint32_t(value) @@ -323,6 +334,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt case common.TSDB_DATA_TYPE_UBIGINT: //8 p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8 * rowLen)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_UBIGINT for i, rowData := range columnData { currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) @@ -332,7 +344,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt *(*C.char)(currentNull) = C.char(0) value, ok := rowData.(uint64) if !ok { - return nil, fmt.Errorf("data type error, expect uint64, but got %T, value: %v", rowData, value) + return nil, needFreePointer, fmt.Errorf("data type error, expect uint64, but got %T, value: %v", rowData, value) } current := unsafe.Pointer(uintptr(p) + uintptr(8*i)) *(*C.uint64_t)(current) = C.uint64_t(value) @@ -344,6 +356,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt case common.TSDB_DATA_TYPE_FLOAT: //4 p = unsafe.Pointer(C.malloc(C.size_t(C.uint(4 * rowLen)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_FLOAT for i, rowData := range columnData { currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) @@ -353,7 +366,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt *(*C.char)(currentNull) = C.char(0) value, ok := rowData.(float32) if !ok { - return nil, fmt.Errorf("data type error, expect float32, but got %T, value: %v", rowData, value) + return nil, needFreePointer, fmt.Errorf("data type error, expect float32, but got %T, value: %v", rowData, value) } current := unsafe.Pointer(uintptr(p) + uintptr(4*i)) *(*C.float)(current) = C.float(value) @@ -365,6 +378,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt case common.TSDB_DATA_TYPE_DOUBLE: //8 p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8 * rowLen)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_DOUBLE for i, rowData := range columnData { currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) @@ -376,7 +390,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt *(*C.char)(currentNull) = C.char(0) value, ok := rowData.(float64) if !ok { - return nil, fmt.Errorf("data type error, expect float64, but got %T, value: %v", rowData, value) + return nil, needFreePointer, fmt.Errorf("data type error, expect float64, but got %T, value: %v", rowData, value) } current := unsafe.Pointer(uintptr(p) + uintptr(8*i)) *(*C.double)(current) = C.double(value) @@ -408,12 +422,13 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) *(*C.int32_t)(l) = C.int32_t(len(value)) default: - return nil, fmt.Errorf("data type error, expect string or []byte, but got %T, value: %v", rowData, value) + return nil, needFreePointer, fmt.Errorf("data type error, expect string or []byte, but got %T, value: %v", rowData, value) } *(*C.char)(currentNull) = C.char(0) } } p = unsafe.Pointer(C.malloc(C.size_t(C.uint(totalLen)))) + needFreePointer = append(needFreePointer, p) for i, rowData := range columnData { if rowData != nil { switch value := rowData.(type) { @@ -423,13 +438,14 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt case []byte: C.memcpy(unsafe.Pointer(uintptr(p)+uintptr(colOffset[i])), unsafe.Pointer(&value[0]), C.size_t(len(value))) default: - return nil, fmt.Errorf("data type error, expect string or []byte, but got %T, value: %v", rowData, value) + return nil, needFreePointer, fmt.Errorf("data type error, expect string or []byte, but got %T, value: %v", rowData, value) } } } case common.TSDB_DATA_TYPE_TIMESTAMP: //8 p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8 * rowLen)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_TIMESTAMP for i, rowData := range columnData { currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) @@ -439,11 +455,15 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt *(*C.int32_t)(l) = C.int32_t(0) } else { *(*C.char)(currentNull) = C.char(0) - value, ok := rowData.(time.Time) - if !ok { - return nil, fmt.Errorf("data type error, expect time.Time, but got %T, value: %v", rowData, value) + var ts int64 + switch value := rowData.(type) { + case time.Time: + ts = common.TimeToTimestamp(value, precision) + case int64: + ts = value + default: + return nil, needFreePointer, fmt.Errorf("data type error, expect time.Time or int64, but got %T, value: %v", rowData, rowData) } - ts := common.TimeToTimestamp(value, precision) current := unsafe.Pointer(uintptr(p) + uintptr(8*i)) *(*C.int64_t)(current) = C.int64_t(ts) @@ -455,37 +475,36 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt bind.buffer = p bind.length = (*C.int32_t)(lengthList) bind.is_null = (*C.char)(nullList) - binds[columnIndex] = bind } - return binds, nil + return binds, needFreePointer, nil } -func generateTaosStmt2BindsQuery(multiBind [][]driver.Value, needFreePointer *[]unsafe.Pointer) ([]C.TAOS_STMT2_BIND, error) { - binds := make([]C.TAOS_STMT2_BIND, len(multiBind)) +func generateTaosStmt2BindsQuery(multiBind [][]driver.Value) (unsafe.Pointer, []unsafe.Pointer, error) { + var needFreePointer []unsafe.Pointer + binds := unsafe.Pointer(C.malloc(C.size_t(C.size_t(len(multiBind)) * C.size_t(unsafe.Sizeof(C.TAOS_STMT2_BIND{}))))) + needFreePointer = append(needFreePointer, binds) for columnIndex, columnData := range multiBind { if len(columnData) != 1 { - return nil, fmt.Errorf("bind query data length must be 1, but column %d got %d", columnIndex, len(columnData)) + return nil, needFreePointer, fmt.Errorf("bind query data length must be 1, but column %d got %d", columnIndex, len(columnData)) } - bind := C.TAOS_STMT2_BIND{} + bind := (*C.TAOS_STMT2_BIND)(unsafe.Pointer(uintptr(binds) + uintptr(columnIndex)*unsafe.Sizeof(C.TAOS_STMT2_BIND{}))) data := columnData[0] bind.num = C.int(1) nullList := unsafe.Pointer(C.malloc(C.size_t(C.uint(1)))) - *needFreePointer = append(*needFreePointer, nullList) + needFreePointer = append(needFreePointer, nullList) var lengthList unsafe.Pointer - *needFreePointer = append(*needFreePointer, lengthList) var p unsafe.Pointer - *needFreePointer = append(*needFreePointer, p) - if data == nil { - return nil, fmt.Errorf("bind query data can not be nil") + return nil, needFreePointer, fmt.Errorf("bind query data can not be nil") } *(*C.char)(nullList) = C.char(0) switch rowData := data.(type) { case bool: p = unsafe.Pointer(C.malloc(C.size_t(C.uint(1)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_BOOL if rowData { *(*C.int8_t)(p) = C.int8_t(1) @@ -495,99 +514,115 @@ func generateTaosStmt2BindsQuery(multiBind [][]driver.Value, needFreePointer *[] case int8: p = unsafe.Pointer(C.malloc(C.size_t(C.uint(1)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_TINYINT *(*C.int8_t)(p) = C.int8_t(rowData) case int16: p = unsafe.Pointer(C.malloc(C.size_t(C.uint(2)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_SMALLINT *(*C.int16_t)(p) = C.int16_t(rowData) case int32: p = unsafe.Pointer(C.malloc(C.size_t(C.uint(4)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_INT *(*C.int32_t)(p) = C.int32_t(rowData) case int64: p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_BIGINT *(*C.int64_t)(p) = C.int64_t(rowData) case int: p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_BIGINT *(*C.int64_t)(p) = C.int64_t(int64(rowData)) case uint8: p = unsafe.Pointer(C.malloc(C.size_t(C.uint(1)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_UTINYINT *(*C.uint8_t)(p) = C.uint8_t(rowData) case uint16: p = unsafe.Pointer(C.malloc(C.size_t(C.uint(2)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_USMALLINT *(*C.uint16_t)(p) = C.uint16_t(rowData) case uint32: p = unsafe.Pointer(C.malloc(C.size_t(C.uint(4)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_UINT *(*C.uint32_t)(p) = C.uint32_t(rowData) case uint64: p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_UBIGINT *(*C.uint64_t)(p) = C.uint64_t(rowData) case uint: p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_UBIGINT *(*C.uint64_t)(p) = C.uint64_t(uint64(rowData)) case float32: p = unsafe.Pointer(C.malloc(C.size_t(C.uint(4)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_FLOAT *(*C.float)(p) = C.float(rowData) case float64: p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_DOUBLE *(*C.double)(p) = C.double(rowData) case []byte: valueLength := len(rowData) p = unsafe.Pointer(C.malloc(C.size_t(C.uint(valueLength)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_BINARY C.memcpy(p, unsafe.Pointer(&rowData[0]), C.size_t(valueLength)) lengthList = unsafe.Pointer(C.calloc(C.size_t(C.uint(1)), C.size_t(C.uint(4)))) + needFreePointer = append(needFreePointer, lengthList) *(*C.int32_t)(lengthList) = C.int32_t(valueLength) - case string: valueLength := len(rowData) p = unsafe.Pointer(C.malloc(C.size_t(C.uint(valueLength)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_BINARY x := *(*[]byte)(unsafe.Pointer(&rowData)) C.memcpy(p, unsafe.Pointer(&x[0]), C.size_t(valueLength)) lengthList = unsafe.Pointer(C.calloc(C.size_t(C.uint(1)), C.size_t(C.uint(4)))) + needFreePointer = append(needFreePointer, lengthList) *(*C.int32_t)(lengthList) = C.int32_t(valueLength) case time.Time: buffer := make([]byte, 0, 35) value := rowData.AppendFormat(buffer, time.RFC3339Nano) valueLength := len(value) p = unsafe.Pointer(C.malloc(C.size_t(C.uint(valueLength)))) + needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_BINARY x := *(*[]byte)(unsafe.Pointer(&value)) C.memcpy(p, unsafe.Pointer(&x[0]), C.size_t(valueLength)) lengthList = unsafe.Pointer(C.calloc(C.size_t(C.uint(1)), C.size_t(C.uint(4)))) + needFreePointer = append(needFreePointer, lengthList) *(*C.int32_t)(lengthList) = C.int32_t(valueLength) default: - return nil, fmt.Errorf("data type error, expect bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, []byte, string, time.Time, but got %T, value: %v", data, data) + return nil, needFreePointer, fmt.Errorf("data type error, expect bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, []byte, string, time.Time, but got %T, value: %v", data, data) } bind.buffer = p bind.length = (*C.int32_t)(lengthList) bind.is_null = (*C.char)(nullList) - binds[columnIndex] = bind } - return binds, nil + return binds, needFreePointer, nil } // TaosStmt2Exec int taos_stmt2_exec(TAOS_STMT2 *stmt, int *affected_rows); @@ -635,9 +670,9 @@ func TaosStmt2BindBinary(stmt2 unsafe.Pointer, data []byte, colIdx int32) error } var freePointer []unsafe.Pointer defer func() { - for _, p := range freePointer { - if p != nil { - C.free(p) + for i := len(freePointer) - 1; i >= 0; i-- { + if freePointer[i] != nil { + C.free(freePointer[i]) } } }() @@ -726,7 +761,8 @@ func TaosStmt2BindBinary(stmt2 unsafe.Pointer, data []byte, colIdx int32) error cBindv.tbnames = nil } if tagsOffset > 0 { - tags, err := generateStmt2Binds(count, tagCount, dataP, tagsOffset, &freePointer) + tags, needFreePointer, err := generateStmt2Binds(count, tagCount, dataP, tagsOffset) + freePointer = append(freePointer, needFreePointer...) if err != nil { return fmt.Errorf("generate tags error: %s", err.Error()) } @@ -735,7 +771,8 @@ func TaosStmt2BindBinary(stmt2 unsafe.Pointer, data []byte, colIdx int32) error cBindv.tags = nil } if colsOffset > 0 { - cols, err := generateStmt2Binds(count, colCount, dataP, colsOffset, &freePointer) + cols, needFreePointer, err := generateStmt2Binds(count, colCount, dataP, colsOffset) + freePointer = append(freePointer, needFreePointer...) if err != nil { return fmt.Errorf("generate cols error: %s", err.Error()) } @@ -751,9 +788,10 @@ func TaosStmt2BindBinary(stmt2 unsafe.Pointer, data []byte, colIdx int32) error return nil } -func generateStmt2Binds(count uint32, fieldCount uint32, dataP unsafe.Pointer, fieldsOffset uint32, freePointer *[]unsafe.Pointer) (**C.TAOS_STMT2_BIND, error) { - bindsCList := C.malloc(C.size_t(uintptr(fieldCount) * PointerSize)) - *freePointer = append(*freePointer, bindsCList) +func generateStmt2Binds(count uint32, fieldCount uint32, dataP unsafe.Pointer, fieldsOffset uint32) (unsafe.Pointer, []unsafe.Pointer, error) { + var freePointer []unsafe.Pointer + bindsCList := unsafe.Pointer(C.malloc(C.size_t(uintptr(count) * PointerSize))) + freePointer = append(freePointer, bindsCList) // dataLength [count]uint32 // length have checked in TaosStmt2BindBinary baseLengthPointer := pointer.AddUintptr(dataP, uintptr(fieldsOffset)) @@ -762,7 +800,8 @@ func generateStmt2Binds(count uint32, fieldCount uint32, dataP unsafe.Pointer, f var bindsPointer unsafe.Pointer for tableIndex := uint32(0); tableIndex < count; tableIndex++ { bindsPointer = pointer.AddUintptr(bindsCList, uintptr(tableIndex)*PointerSize) - binds := make([]C.TAOS_STMT2_BIND, fieldCount) + binds := unsafe.Pointer(C.malloc(C.size_t(C.size_t(fieldCount) * C.size_t(unsafe.Sizeof(C.TAOS_STMT2_BIND{}))))) + freePointer = append(freePointer, binds) var bindDataP unsafe.Pointer var bindDataTotalLength uint32 var num int32 @@ -774,7 +813,7 @@ func generateStmt2Binds(count uint32, fieldCount uint32, dataP unsafe.Pointer, f // totalLength bindDataTotalLength = *(*uint32)(bindDataP) bindDataP = pointer.AddUintptr(bindDataP, common.UInt32Size) - bind := C.TAOS_STMT2_BIND{} + bind := (*C.TAOS_STMT2_BIND)(unsafe.Pointer(uintptr(binds) + uintptr(fieldIndex)*unsafe.Sizeof(C.TAOS_STMT2_BIND{}))) // buffer_type bind.buffer_type = *(*C.int)(bindDataP) bindDataP = pointer.AddUintptr(bindDataP, common.Int32Size) @@ -808,13 +847,11 @@ func generateStmt2Binds(count uint32, fieldCount uint32, dataP unsafe.Pointer, f // check bind data length bindDataLen := uintptr(bindDataP) - uintptr(dataPointer) if bindDataLen != uintptr(bindDataTotalLength) { - return nil, fmt.Errorf("bind data length not match, expect %d, but get %d, tableIndex:%d", bindDataTotalLength, bindDataLen, tableIndex) + return nil, freePointer, fmt.Errorf("bind data length not match, expect %d, but get %d, tableIndex:%d", bindDataTotalLength, bindDataLen, tableIndex) } - binds[fieldIndex] = bind dataPointer = bindDataP } - *(**C.TAOS_STMT2_BIND)(bindsPointer) = (*C.TAOS_STMT2_BIND)(&binds[0]) - + *(**C.TAOS_STMT2_BIND)(bindsPointer) = (*C.TAOS_STMT2_BIND)(binds) } - return (**C.TAOS_STMT2_BIND)(bindsCList), nil + return bindsCList, freePointer, nil } diff --git a/wrapper/stmt2_test.go b/wrapper/stmt2_test.go index 1a56f65..0348774 100644 --- a/wrapper/stmt2_test.go +++ b/wrapper/stmt2_test.go @@ -1127,6 +1127,26 @@ func TestStmt2BindData(t *testing.T) { {next2S, []byte("中文")}, }, }, + + { + name: "timestamp", + tbType: "ts timestamp, v timestamp", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + now, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, now}, + }, + }, } for i, tc := range tests { t.Run(tc.name, func(t *testing.T) { From 9084f46dff46d3c5c8ed28d8f8ed11b9cd00fe10 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Fri, 1 Nov 2024 09:54:04 +0800 Subject: [PATCH 03/35] fix: remove marshal stmt2 binary unnecessary serialization --- common/stmt/stmt2.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/common/stmt/stmt2.go b/common/stmt/stmt2.go index 82e92f9..09e4d71 100644 --- a/common/stmt/stmt2.go +++ b/common/stmt/stmt2.go @@ -95,9 +95,6 @@ func MarshalStmt2Binary(bindData []*TaosStmt2BindData, isInsert bool, colType, t needCols = true binary.LittleEndian.PutUint32(header[ColCountPosition:], uint32(colCount)) } - if needTableNames { - binary.LittleEndian.PutUint32(header[TableNamesOffsetPosition:], uint32(colCount)) - } if !needTableNames && !needTags && !needCols { return nil, fmt.Errorf("no data") } From 8fb37f82db5140bb0316a9ae9eef038e5d7b5b53 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 31 Oct 2024 10:44:31 +0800 Subject: [PATCH 04/35] ci: fix TDengine build --- .github/workflows/go.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 63558d9..fe19301 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -74,7 +74,7 @@ jobs: cd TDengine mkdir debug cd debug - cmake .. -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 + cmake .. -DBUILD_TEST=off -DBUILD_HTTP=false -DBUILD_DEPENDENCY_TESTS=0 -DVERNUMBER=3.9.9.9 make -j 4 - name: package From 21f5e03fa90ff1d37c5b45257d4c13eb68da8653 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 5 Nov 2024 11:09:21 +0800 Subject: [PATCH 05/35] ci: add ignore files --- .codecov.yml | 4 ++++ .github/workflows/go.yml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 .codecov.yml diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 0000000..e17b50f --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,4 @@ +ignore: + - "bench" + - "benchmark" + - "examples" \ No newline at end of file diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index fe19301..91b3266 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -174,7 +174,7 @@ jobs: - name: Upload coverage to Codecov if: ${{ matrix.go }} == 'stable' - uses: codecov/codecov-action@v4-beta + uses: codecov/codecov-action@v4 with: files: ./coverage.txt env: From ff31594ab2ed309d42fedbad336f3ed3effa1c86 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 5 Nov 2024 15:12:55 +0800 Subject: [PATCH 06/35] ci: add ignore files --- .github/workflows/go.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 91b3266..4efba15 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -173,7 +173,7 @@ jobs: run: sudo go test -v --count=1 -coverprofile=coverage.txt -covermode=atomic ./... - name: Upload coverage to Codecov - if: ${{ matrix.go }} == 'stable' + if: ${{ matrix.go == 'stable'}} uses: codecov/codecov-action@v4 with: files: ./coverage.txt From 575a092caf3df1e3758f8bd635d4dc15c40a2c59 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 5 Nov 2024 19:42:19 +0800 Subject: [PATCH 07/35] test: add unit test --- wrapper/whitelistcb_test.go | 57 +++++++++++++++++ ws/schemaless/schemaless_test.go | 30 +++++++++ ws/stmt/stmt_test.go | 2 + ws/tmq/consumer.go | 2 +- ws/tmq/consumer_test.go | 103 +++++++++++++++++++++++++++++++ 5 files changed, 193 insertions(+), 1 deletion(-) create mode 100644 wrapper/whitelistcb_test.go diff --git a/wrapper/whitelistcb_test.go b/wrapper/whitelistcb_test.go new file mode 100644 index 0000000..86d01cd --- /dev/null +++ b/wrapper/whitelistcb_test.go @@ -0,0 +1,57 @@ +package wrapper + +import ( + "net" + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/driver-go/v3/wrapper/cgo" +) + +func TestWhitelistCallback_ErrorCode(t *testing.T) { + // Create a channel to receive the result + resultChan := make(chan *WhitelistResult, 1) + handle := cgo.NewHandle(resultChan) + // Simulate an error (code != 0) + go WhitelistCallback(handle.Pointer(), 1, nil, 0, nil) + + // Expect the result to have an error code + result := <-resultChan + assert.Equal(t, int32(1), result.ErrCode) + assert.Nil(t, result.IPNets) // No IPs should be returned +} + +func TestWhitelistCallback_Success(t *testing.T) { + // Prepare the test data: a list of byte slices representing IPs and masks + ipList := []byte{ + 192, 168, 1, 1, 24, // 192.168.1.1/24 + 0, 0, 0, + 10, 0, 0, 1, 16, // 10.0.0.1/16 + } + + // Create a channel to receive the result + resultChan := make(chan *WhitelistResult, 1) + + // Cast the byte slice to an unsafe pointer + pWhiteLists := unsafe.Pointer(&ipList[0]) + handle := cgo.NewHandle(resultChan) + // Simulate a successful callback (code == 0) + go WhitelistCallback(handle.Pointer(), 0, nil, 2, pWhiteLists) + + // Expect the result to have two IPNets + result := <-resultChan + assert.Equal(t, int32(0), result.ErrCode) + assert.Len(t, result.IPNets, 2) + + // Validate the first IPNet (192.168.1.1/24) + assert.Equal(t, net.IPv4(192, 168, 1, 1).To4(), result.IPNets[0].IP) + + ones, _ := result.IPNets[0].Mask.Size() + assert.Equal(t, 24, ones) + + // Validate the second IPNet (10.0.0.1/16) + assert.Equal(t, net.IPv4(10, 0, 0, 1).To4(), result.IPNets[1].IP) + ones, _ = result.IPNets[1].Mask.Size() + assert.Equal(t, 16, ones) +} diff --git a/ws/schemaless/schemaless_test.go b/ws/schemaless/schemaless_test.go index dc9caa6..63729ac 100644 --- a/ws/schemaless/schemaless_test.go +++ b/ws/schemaless/schemaless_test.go @@ -239,3 +239,33 @@ func TestSchemalessReconnect(t *testing.T) { err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0) assert.NoError(t, err) } + +func TestWrongNewSchemaless(t *testing.T) { + s, err := NewSchemaless(NewConfig("://localhost:6041", 1, + SetUser("root"), + SetPassword("taosdata"), + )) + assert.Error(t, err) + assert.Nil(t, s) + + s, err = NewSchemaless(NewConfig("wrong://localhost:6041", 1, + SetUser("root"), + SetPassword("taosdata"), + )) + assert.Error(t, err) + assert.Nil(t, s) + + s, err = NewSchemaless(NewConfig("ws://localhost:6041", 1, + SetUser("root"), + SetPassword("wrongpassword"), + )) + assert.Error(t, err) + assert.Nil(t, s) + + s, err = NewSchemaless(NewConfig("ws://localhost:9999", 1, + SetUser("root"), + SetPassword("taosdata"), + )) + assert.Error(t, err) + assert.Nil(t, s) +} diff --git a/ws/stmt/stmt_test.go b/ws/stmt/stmt_test.go index 652766e..a9b04cf 100644 --- a/ws/stmt/stmt_test.go +++ b/ws/stmt/stmt_test.go @@ -764,6 +764,7 @@ func TestSTMTQuery(t *testing.T) { err = rows.Next(values) if err != nil { if err == io.EOF { + rows.Close() break } assert.NoError(t, err) @@ -936,6 +937,7 @@ func TestSTMTQuery(t *testing.T) { err = rows.Next(values) if err != nil { if err == io.EOF { + rows.Close() break } assert.NoError(t, err) diff --git a/ws/tmq/consumer.go b/ws/tmq/consumer.go index b7fbe4f..6ef55f7 100644 --- a/ws/tmq/consumer.go +++ b/ws/tmq/consumer.go @@ -65,7 +65,7 @@ type WSError struct { } func (e *WSError) Error() string { - return fmt.Sprintf("websocket close with error %s", e.err) + return fmt.Sprintf("websocket close with error %v", e.err) } // NewConsumer create a tmq consumer diff --git a/ws/tmq/consumer_test.go b/ws/tmq/consumer_test.go index 4ab7a98..2d1a401 100644 --- a/ws/tmq/consumer_test.go +++ b/ws/tmq/consumer_test.go @@ -731,6 +731,58 @@ func Test_configMapToConfigWrong(t *testing.T) { }, wantErr: "ws.message.writeWait cannot be less than 1 second", }, + { + name: "ws.autoReconnect", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "ws.autoReconnect": 123, + }, + }, + wantErr: "ws.autoReconnect expects type bool, not int", + }, + //ws.reconnectIntervalMs + { + name: "ws.reconnectIntervalMs", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "ws.reconnectIntervalMs": "not int", + }, + }, + wantErr: "ws.reconnectIntervalMs expects type int, not string", + }, + //ws.reconnectRetryCount + { + name: "ws.reconnectRetryCount", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "ws.reconnectRetryCount": "not int", + }, + }, + wantErr: "ws.reconnectRetryCount expects type int, not string", + }, + { + name: "session.timeout.ms", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "session.timeout.ms": 123, + }, + }, + wantErr: "session.timeout.ms expects type string, not int", + }, + { + name: "max.poll.interval.ms", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "max.poll.interval.ms": 123, + }, + }, + wantErr: "max.poll.interval.ms expects type string, not int", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -977,3 +1029,54 @@ func TestSubscribeReconnect(t *testing.T) { } assert.True(t, haveMessage) } + +func TestWSError_Error(t *testing.T) { + // Test scenario where an error is provided + expectedErr := errors.New("connection lost") + wsErr := &WSError{err: expectedErr} + + // Call the Error() method and check if the format is correct + actualError := wsErr.Error() + + // The expected error string format + expectedError := "websocket close with error connection lost" + + // Assert that the error string is formatted correctly + assert.Equal(t, expectedError, actualError, "Error string should match the expected format") + + // Test scenario where no error is provided (nil error) + wsErrNil := &WSError{} + actualErrorNil := wsErrNil.Error() + + // Expected format when error is nil (shouldn't panic) + expectedErrorNil := "websocket close with error " + + // Assert that the error string handles nil properly + assert.Equal(t, expectedErrorNil, actualErrorNil, "Error string should handle nil error correctly") +} + +func TestWrongConsumer(t *testing.T) { + consumer, err := NewConsumer(&tmq.ConfigMap{}) + assert.Error(t, err) + assert.Nil(t, consumer) + + consumer, err = NewConsumer(&tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "auto.commit.interval.ms": "abc", + }) + assert.Error(t, err) + assert.Nil(t, consumer) + + consumer, err = NewConsumer(&tmq.ConfigMap{ + "ws.url": ":xxx", + }) + assert.Error(t, err) + assert.Nil(t, consumer) + + consumer, err = NewConsumer(&tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:9999", + }) + assert.Error(t, err) + assert.Nil(t, consumer) + +} From 5c8397c9fd17971846d0f81bea5707587c394c5a Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 7 Nov 2024 16:08:22 +0800 Subject: [PATCH 08/35] test: add unit test --- af/insertstmt/stmt_test.go | 66 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/af/insertstmt/stmt_test.go b/af/insertstmt/stmt_test.go index 0b50f39..2698b8d 100644 --- a/af/insertstmt/stmt_test.go +++ b/af/insertstmt/stmt_test.go @@ -99,6 +99,72 @@ func TestStmt(t *testing.T) { assert.Equal(t, int(1), affected) } +func TestStmtWithWithReqID(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer wrapper.TaosClose(conn) + s := NewInsertStmt(conn) + defer s.Close() + err = s.Prepare("insert into ? values(?,?,?)") + assert.NoError(t, err) +} + +func TestPrepareError(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer wrapper.TaosClose(conn) + s := NewInsertStmt(conn) + stmtHandle := s.stmt + defer wrapper.TaosStmtClose(stmtHandle) + s.stmt = nil + err = s.Prepare("insert into ? values(?,?,?)") + assert.Error(t, err) + s.stmt = stmtHandle + err = s.Prepare("select * from information_schema.ins_databases where name = ?") + assert.Error(t, err) +} + +func TestSetTableNameError(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer wrapper.TaosClose(conn) + s := NewInsertStmt(conn) + stmtHandle := s.stmt + defer wrapper.TaosStmtClose(stmtHandle) + s.stmt = nil + err = s.SetTableName("test") + assert.Error(t, err) + + err = s.SetSubTableName("test") + assert.Error(t, err) + + err = s.SetTableNameWithTags("test", param.NewParam(1).AddBinary([]byte("test"))) + assert.Error(t, err) +} + +func TestAddBatchError(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer wrapper.TaosClose(conn) + s := NewInsertStmt(conn) + stmtHandle := s.stmt + defer wrapper.TaosStmtClose(stmtHandle) + s.stmt = nil + err = s.AddBatch() + assert.Error(t, err) +} + +func TestExecuteError(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer wrapper.TaosClose(conn) + s := NewInsertStmt(conn) + stmtHandle := s.stmt + defer wrapper.TaosStmtClose(stmtHandle) + s.stmt = nil + err = s.Execute() + assert.Error(t, err) +} func exec(conn unsafe.Pointer, sql string) error { res := wrapper.TaosQuery(conn, sql) From 26aea7fbabf454ce3af77c40317d790f62519888 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 7 Nov 2024 16:38:44 +0800 Subject: [PATCH 09/35] test: add unit test --- common/parser/block_test.go | 2 +- common/sql_test.go | 39 +++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/common/parser/block_test.go b/common/parser/block_test.go index 42b2d26..61fd80e 100644 --- a/common/parser/block_test.go +++ b/common/parser/block_test.go @@ -680,7 +680,7 @@ func TestParseBlock(t *testing.T) { version := RawBlockGetVersion(block) t.Log(version) length := RawBlockGetLength(block) - assert.Equal(t, int32(447), length) + assert.Equal(t, int32(448), length) rows := RawBlockGetNumOfRows(block) assert.Equal(t, int32(2), rows) columns := RawBlockGetNumOfCols(block) diff --git a/common/sql_test.go b/common/sql_test.go index 543294d..05b0834 100644 --- a/common/sql_test.go +++ b/common/sql_test.go @@ -2,6 +2,7 @@ package common import ( "database/sql/driver" + "reflect" "testing" "time" ) @@ -95,3 +96,41 @@ func TestInterpolateParams(t *testing.T) { }) } } + +func TestValueArgsToNamedValueArgs(t *testing.T) { + tests := []struct { + name string + args []driver.Value + want []driver.NamedValue + }{ + { + name: "empty args", + args: []driver.Value{}, + want: []driver.NamedValue{}, + }, + { + name: "single arg", + args: []driver.Value{int64(1)}, + want: []driver.NamedValue{ + {Ordinal: 1, Value: int64(1)}, + }, + }, + { + name: "multiple args", + args: []driver.Value{int64(1), "test", nil}, + want: []driver.NamedValue{ + {Ordinal: 1, Value: int64(1)}, + {Ordinal: 2, Value: "test"}, + {Ordinal: 3, Value: nil}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ValueArgsToNamedValueArgs(tt.args); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ValueArgsToNamedValueArgs() = %v, want %v", got, tt.want) + } + }) + } +} From 4e1e85792d04391abf73475e8fc25f6fa9e7e690 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 7 Nov 2024 16:47:42 +0800 Subject: [PATCH 10/35] test: add unit test --- taosRestful/connector_test.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/taosRestful/connector_test.go b/taosRestful/connector_test.go index eac38db..845179f 100644 --- a/taosRestful/connector_test.go +++ b/taosRestful/connector_test.go @@ -530,3 +530,14 @@ func TestSSL(t *testing.T) { } assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) } + +func TestConnect(t *testing.T) { + conn := connector{ + cfg: &config{}, + } + db, err := conn.Connect(context.Background()) + assert.NoError(t, err) + db.Close() + driver := conn.Driver() + assert.Equal(t, &TDengineDriver{}, driver) +} From a96feb53ff30feaede23a022068293cb3a2eb1ca Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 7 Nov 2024 17:03:47 +0800 Subject: [PATCH 11/35] enh: remove Exec and Query as they are deprecated --- taosRestful/connection.go | 30 ------------------------------ taosSql/connection.go | 8 -------- taosWS/connection.go | 10 ---------- taosWS/connection_test.go | 3 ++- 4 files changed, 2 insertions(+), 49 deletions(-) diff --git a/taosRestful/connection.go b/taosRestful/connection.go index 5bef61b..2dd5240 100644 --- a/taosRestful/connection.go +++ b/taosRestful/connection.go @@ -97,10 +97,6 @@ func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { return nil, &taosErrors.TaosError{Code: 0xffff, ErrStr: "restful does not support stmt"} } -func (tc *taosConn) Exec(query string, args []driver.Value) (driver.Result, error) { - return tc.ExecContext(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) -} - func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { return tc.execCtx(ctx, query, args) } @@ -127,32 +123,6 @@ func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.Nam return driver.RowsAffected(result.Data[0][0].(int32)), nil } -func (tc *taosConn) Query(query string, args []driver.Value) (driver.Rows, error) { - if len(args) != 0 { - if !tc.cfg.interpolateParams { - return nil, driver.ErrSkip - } - // try client-side prepare to reduce round trip - prepared, err := common.InterpolateParams(query, common.ValueArgsToNamedValueArgs(args)) - if err != nil { - return nil, err - } - query = prepared - } - result, err := tc.taosQuery(context.TODO(), query, tc.readBufferSize) - if err != nil { - return nil, err - } - if result == nil { - return nil, errors.New("wrong result") - } - // Read Result - rs := &rows{ - result: result, - } - return rs, err -} - func (tc *taosConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { return tc.queryCtx(ctx, query, args) } diff --git a/taosSql/connection.go b/taosSql/connection.go index 7277c05..8cf54f1 100644 --- a/taosSql/connection.go +++ b/taosSql/connection.go @@ -66,10 +66,6 @@ func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { return stmt, nil } -func (tc *taosConn) Exec(query string, args []driver.Value) (driver.Result, error) { - return tc.ExecContext(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) -} - func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Result, err error) { if tc.taos == nil { return nil, driver.ErrBadConn @@ -118,10 +114,6 @@ func (tc *taosConn) processExecResult(result *handler.AsyncResult) (driver.Resul return driver.RowsAffected(affectRows), nil } -func (tc *taosConn) Query(query string, args []driver.Value) (driver.Rows, error) { - return tc.QueryContext(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) -} - func (tc *taosConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { if tc.taos == nil { return nil, driver.ErrBadConn diff --git a/taosWS/connection.go b/taosWS/connection.go index 236c7c5..cf1791e 100644 --- a/taosWS/connection.go +++ b/taosWS/connection.go @@ -24,9 +24,6 @@ var jsonI = jsoniter.ConfigCompatibleWithStandardLibrary const ( WSConnect = "conn" - WSQuery = "query" - WSFetch = "fetch" - WSFetchBlock = "fetch_block" WSFreeResult = "free_result" STMTInit = "init" @@ -490,9 +487,6 @@ func (tc *taosConn) stmtUseResult(stmtID uint64) (*rows, error) { } return rs, nil } -func (tc *taosConn) Exec(query string, args []driver.Value) (driver.Result, error) { - return tc.execCtx(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) -} func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { return tc.execCtx(ctx, query, args) @@ -509,10 +503,6 @@ func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.Nam return driver.RowsAffected(resp.AffectedRows), nil } -func (tc *taosConn) Query(query string, args []driver.Value) (driver.Rows, error) { - return tc.QueryContext(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) -} - func (tc *taosConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { return tc.queryCtx(ctx, query, args) } diff --git a/taosWS/connection_test.go b/taosWS/connection_test.go index e93d710..cb814e8 100644 --- a/taosWS/connection_test.go +++ b/taosWS/connection_test.go @@ -1,6 +1,7 @@ package taosWS import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -67,7 +68,7 @@ func TestBadConnection(t *testing.T) { // to test bad connection, we manually close the connection conn.Close() - _, err = conn.Query("select 1", nil) + _, err = conn.QueryContext(context.Background(), "select 1", nil) if err == nil { t.Fatalf("query should fail") } From 8a823195e5ca4329807b52dc26178433f8a10756 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 7 Nov 2024 17:07:45 +0800 Subject: [PATCH 12/35] test: add unit test --- taosWS/connector_test.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/taosWS/connector_test.go b/taosWS/connector_test.go index d971a1d..4d19033 100644 --- a/taosWS/connector_test.go +++ b/taosWS/connector_test.go @@ -1,6 +1,7 @@ package taosWS import ( + "context" "database/sql" "fmt" "math/rand" @@ -456,3 +457,14 @@ func TestBatch(t *testing.T) { }) } } + +func TestConnect(t *testing.T) { + conn := connector{ + cfg: &config{}, + } + db, err := conn.Connect(context.Background()) + assert.NoError(t, err) + db.Close() + driver := conn.Driver() + assert.Equal(t, &TDengineDriver{}, driver) +} From 0fdf38a81b4c178da41be5a3dcb11440cafe5b16 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 7 Nov 2024 17:14:10 +0800 Subject: [PATCH 13/35] test: add unit test --- wrapper/stmt_test.go | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/wrapper/stmt_test.go b/wrapper/stmt_test.go index 328be38..be9871c 100644 --- a/wrapper/stmt_test.go +++ b/wrapper/stmt_test.go @@ -204,6 +204,17 @@ func TestStmt(t *testing.T) { t.Error(err) return } + isInsert, code := TaosStmtIsInsert(insertStmt) + if code != 0 { + errStr := TaosStmtErrStr(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + if !isInsert { + t.Errorf("expect insert stmt") + return + } code = TaosStmtBindParamBatch(insertStmt, tc.params, tc.bindType) if code != 0 { errStr := TaosStmtErrStr(insertStmt) @@ -1031,7 +1042,7 @@ func TestTaosStmtSetTags(t *testing.T) { t.Error(taosError.NewError(code, errStr)) return } - code = TaosStmtSetTBName(stmt, "test_wrapper.t1") + code = TaosStmtSetSubTBName(stmt, "test_wrapper.t1") if code != 0 { errStr := TaosStmtErrStr(stmt) t.Error(taosError.NewError(code, errStr)) @@ -1298,6 +1309,14 @@ func TestStmtJson(t *testing.T) { t.Error(err) return } + count, code := TaosStmtNumParams(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.Equal(t, 1, count) code = TaosStmtBindParam(stmt, param.NewParam(1).AddBigint(1).GetValues()) if code != 0 { errStr := TaosStmtErrStr(stmt) From 1c44bb07e29666f81773799cb7b4d79d848f16d6 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 7 Nov 2024 17:15:43 +0800 Subject: [PATCH 14/35] test: add unit test --- wrapper/stmt_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/wrapper/stmt_test.go b/wrapper/stmt_test.go index be9871c..b771ddd 100644 --- a/wrapper/stmt_test.go +++ b/wrapper/stmt_test.go @@ -1219,6 +1219,9 @@ func TestTaosStmtGetParam(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 6, dt) assert.Equal(t, 4, dl) + + _, _, err = TaosStmtGetParam(stmt, 4) // invalid index + assert.Error(t, err) } func TestStmtJson(t *testing.T) { From 55d4e795268b733749bc336f11d99f26dca3bd1a Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 7 Nov 2024 18:01:41 +0800 Subject: [PATCH 15/35] test: add unit test --- taosWS/statement_test.go | 66 ++++++++++++++++++++++++++++++---------- wrapper/stmt2_test.go | 36 +++++++++++++++++++++- 2 files changed, 85 insertions(+), 17 deletions(-) diff --git a/taosWS/statement_test.go b/taosWS/statement_test.go index 1ab008c..c5663a5 100644 --- a/taosWS/statement_test.go +++ b/taosWS/statement_test.go @@ -48,19 +48,21 @@ func TestStmtExec(t *testing.T) { "c10 float," + "c11 double," + "c12 binary(20)," + - "c13 nchar(20)" + + "c13 nchar(20)," + + "c14 varbinary(20)," + + "c15 geometry(100)" + ")") if err != nil { t.Error(err) return } - stmt, err := db.Prepare("insert into test_stmt_driver_ws.ct values (?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + stmt, err := db.Prepare("insert into test_stmt_driver_ws.ct values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)") if err != nil { t.Error(err) return } - result, err := stmt.Exec(time.Now(), 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, "binary", "nchar") + result, err := stmt.Exec(time.Now(), 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, "binary", "nchar", "varbinary", []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}) if err != nil { t.Error(err) return @@ -99,19 +101,21 @@ func TestStmtQuery(t *testing.T) { "c10 float," + "c11 double," + "c12 binary(20)," + - "c13 nchar(20)" + + "c13 nchar(20)," + + "c14 varbinary(20)," + + "c15 geometry(100)" + ")") if err != nil { t.Error(err) return } - stmt, err := db.Prepare("insert into test_stmt_driver_ws_q.ct values (?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + stmt, err := db.Prepare("insert into test_stmt_driver_ws_q.ct values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)") if err != nil { t.Error(err) return } now := time.Now() - result, err := stmt.Exec(now, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, "binary", "nchar") + result, err := stmt.Exec(now, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, "binary", "nchar", "varbinary", []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}) if err != nil { t.Error(err) return @@ -138,7 +142,7 @@ func TestStmtQuery(t *testing.T) { t.Error(err) return } - assert.Equal(t, []string{"ts", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11", "c12", "c13"}, columns) + assert.Equal(t, []string{"ts", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11", "c12", "c13", "c14", "c15"}, columns) count := 0 for rows.Next() { count += 1 @@ -157,6 +161,8 @@ func TestStmtQuery(t *testing.T) { c11 float64 c12 string c13 string + c14 string + c15 []byte ) err = rows.Scan(&ts, &c1, @@ -171,7 +177,9 @@ func TestStmtQuery(t *testing.T) { &c10, &c11, &c12, - &c13) + &c13, + &c14, + &c15) assert.NoError(t, err) assert.Equal(t, now.UnixNano()/1e6, ts.UnixNano()/1e6) assert.Equal(t, true, c1) @@ -187,6 +195,8 @@ func TestStmtQuery(t *testing.T) { assert.Equal(t, float64(11), c11) assert.Equal(t, "binary", c12) assert.Equal(t, "nchar", c13) + assert.Equal(t, "varbinary", c14) + assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, c15) } assert.Equal(t, 1, count) } @@ -1007,6 +1017,27 @@ func TestStmtConvertExec(t *testing.T) { bind: []interface{}{now, "1970-01-01T00:00:00.001Z"}, expectValue: time.Unix(0, 1e6), }, + { + name: "varbinary_string_chinese", + tbType: "ts timestamp,v varbinary(24)", + pos: "?,?", + bind: []interface{}{now, "中文"}, + expectValue: []byte("中文"), + }, + { + name: "varbinary_bytes_chinese", + tbType: "ts timestamp,v varbinary(24)", + pos: "?,?", + bind: []interface{}{now, []byte("中文")}, + expectValue: []byte("中文"), + }, + { + name: "geometry", + tbType: "ts timestamp,v geometry(100)", + pos: "?,?", + bind: []interface{}{now, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}}, + expectValue: []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1076,20 +1107,23 @@ func TestStmtConvertExec(t *testing.T) { t.Error(err) return } - v, err := values[0].(driver.Valuer).Value() - if err != nil { - t.Error(err) + value, ok := values[0].(driver.Valuer) + if ok { + v, err := value.Value() + if err != nil { + t.Error(err) + } + data = append(data, v) + } else { + data = append(data, *values[0].(*[]byte)) } - data = append(data, v) + } if len(data) != 1 { t.Errorf("expect %d got %d", 1, len(data)) return } - if data[0] != tt.expectValue { - t.Errorf("expect %v got %v", tt.expectValue, data[0]) - return - } + assert.Equal(t, tt.expectValue, data[0], "expect %v got %v", tt.expectValue, data[0]) }) } } diff --git a/wrapper/stmt2_test.go b/wrapper/stmt2_test.go index 0348774..5876b25 100644 --- a/wrapper/stmt2_test.go +++ b/wrapper/stmt2_test.go @@ -125,16 +125,50 @@ func TestStmt2BindData(t *testing.T) { {next2S, int32(2)}, }, }, + { + name: "gouint null 3 cols", + tbType: "ts timestamp, v bigint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + uint(1), + nil, + uint(2), + }, + }, + }}, + expectValue: [][]driver.Value{ + {now, uint64(1)}, + {next1S, nil}, + {next2S, uint64(2)}, + }, + }, { name: "bool", tbType: "ts timestamp, v bool", pos: "?, ?", params: []*stmt.TaosStmt2BindData{{ - Cols: [][]driver.Value{{now}, {bool(true)}}, + Cols: [][]driver.Value{{now}, {true}}, }}, expectValue: [][]driver.Value{{now, true}}, }, + { + name: "bool false", + tbType: "ts timestamp, v bool", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{{now}, {false}}, + }}, + + expectValue: [][]driver.Value{{now, false}}, + }, { name: "bool null", tbType: "ts timestamp, v bool", From ce8c60c5644685627f2a0c66465bc8d6391a5ea7 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 7 Nov 2024 18:07:44 +0800 Subject: [PATCH 16/35] test: add unit test --- taosWS/statement_test.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/taosWS/statement_test.go b/taosWS/statement_test.go index c5663a5..4c9e260 100644 --- a/taosWS/statement_test.go +++ b/taosWS/statement_test.go @@ -1031,6 +1031,13 @@ func TestStmtConvertExec(t *testing.T) { bind: []interface{}{now, []byte("中文")}, expectValue: []byte("中文"), }, + { + name: "varbinary_err", + tbType: "ts timestamp,v varbinary(24)", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectError: true, + }, { name: "geometry", tbType: "ts timestamp,v geometry(100)", @@ -1038,6 +1045,13 @@ func TestStmtConvertExec(t *testing.T) { bind: []interface{}{now, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}}, expectValue: []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, }, + { + name: "geometry_err", + tbType: "ts timestamp,v geometry(100)", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectError: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From b34f3fc9796d483e2183f8ad62ae0fa89811371c Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 7 Nov 2024 18:11:38 +0800 Subject: [PATCH 17/35] test: add unit test --- wrapper/stmt2_test.go | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/wrapper/stmt2_test.go b/wrapper/stmt2_test.go index 5876b25..1500977 100644 --- a/wrapper/stmt2_test.go +++ b/wrapper/stmt2_test.go @@ -125,30 +125,6 @@ func TestStmt2BindData(t *testing.T) { {next2S, int32(2)}, }, }, - { - name: "gouint null 3 cols", - tbType: "ts timestamp, v bigint unsigned", - pos: "?, ?", - params: []*stmt.TaosStmt2BindData{{ - Cols: [][]driver.Value{ - { - now, - next1S, - next2S, - }, - { - uint(1), - nil, - uint(2), - }, - }, - }}, - expectValue: [][]driver.Value{ - {now, uint64(1)}, - {next1S, nil}, - {next2S, uint64(2)}, - }, - }, { name: "bool", tbType: "ts timestamp, v bool", From 3d50d26d926cb7cdda2063a8a46034d8cba626ea Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 7 Nov 2024 19:07:48 +0800 Subject: [PATCH 18/35] test: add unit test --- ws/client/conn.go | 11 +++++++++ ws/client/conn_test.go | 24 +++++++++++++++++++ ws/schemaless/schemaless.go | 17 ++----------- ws/stmt/connector.go | 13 ++-------- ws/stmt/rows.go | 5 +--- ws/stmt/stmt.go | 45 ++++++---------------------------- ws/tmq/consumer.go | 48 +++++++------------------------------ 7 files changed, 56 insertions(+), 107 deletions(-) diff --git a/ws/client/conn.go b/ws/client/conn.go index 76506fc..9415c5c 100644 --- a/ws/client/conn.go +++ b/ws/client/conn.go @@ -11,6 +11,7 @@ import ( "github.com/gorilla/websocket" jsoniter "github.com/json-iterator/go" "github.com/taosdata/driver-go/v3/common" + errors2 "github.com/taosdata/driver-go/v3/errors" ) const ( @@ -211,3 +212,13 @@ func (c *Client) Close() { func (c *Client) handleError(err error) { c.errHandlerOnce.Do(func() { c.ErrorHandler(err) }) } + +func HandleResponseError(err error, code int, msg string) error { + if err != nil { + return err + } + if code != 0 { + return errors2.NewError(code, msg) + } + return nil +} diff --git a/ws/client/conn_test.go b/ws/client/conn_test.go index c7f9a7c..8c95ed9 100644 --- a/ws/client/conn_test.go +++ b/ws/client/conn_test.go @@ -2,6 +2,7 @@ package client import ( "bytes" + "errors" "net/http" "net/http/httptest" "strings" @@ -10,6 +11,7 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + taosErrors "github.com/taosdata/driver-go/v3/errors" ) func TestEnvelopePool(t *testing.T) { @@ -96,3 +98,25 @@ func TestClient(t *testing.T) { t.Error("timeout") } } + +func TestHandleResponseError(t *testing.T) { + t.Run("Error not nil", func(t *testing.T) { + err := errors.New("some error") + result := HandleResponseError(err, 0, "ignored message") + assert.Equal(t, err, result, "Expected the original error to be returned") + }) + + t.Run("Error nil and non-zero code", func(t *testing.T) { + code := 123 + msg := "some error message" + expectedErr := taosErrors.NewError(code, msg) + + result := HandleResponseError(nil, code, msg) + assert.EqualError(t, result, expectedErr.Error(), "Expected a new error to be returned based on code and message") + }) + + t.Run("Error nil and zero code", func(t *testing.T) { + result := HandleResponseError(nil, 0, "ignored message") + assert.Nil(t, result, "Expected nil to be returned when there is no error and code is zero") + }) +} diff --git a/ws/schemaless/schemaless.go b/ws/schemaless/schemaless.go index f22c7a6..9d1582b 100644 --- a/ws/schemaless/schemaless.go +++ b/ws/schemaless/schemaless.go @@ -13,7 +13,6 @@ import ( "github.com/gorilla/websocket" jsoniter "github.com/json-iterator/go" "github.com/taosdata/driver-go/v3/common" - taosErrors "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/ws/client" ) @@ -181,13 +180,7 @@ func (s *Schemaless) Insert(lines string, protocol int, precision string, ttl in } var resp schemalessResp err = client.JsonI.Unmarshal(respBytes, &resp) - if err != nil { - return err - } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } - return nil + return client.HandleResponseError(err, resp.Code, resp.Message) } func (s *Schemaless) Close() { @@ -247,13 +240,7 @@ func connect(ws *websocket.Conn, user string, password string, db string, writeT } var resp wsConnectResp err = client.JsonI.Unmarshal(respBytes, &resp) - if err != nil { - return err - } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } - return nil + return client.HandleResponseError(err, resp.Code, resp.Message) } func (s *Schemaless) sendText(reqID uint64, envelope *client.Envelope) ([]byte, error) { diff --git a/ws/stmt/connector.go b/ws/stmt/connector.go index a08cd2e..c96c160 100644 --- a/ws/stmt/connector.go +++ b/ws/stmt/connector.go @@ -15,7 +15,6 @@ import ( "github.com/gorilla/websocket" jsoniter "github.com/json-iterator/go" "github.com/taosdata/driver-go/v3/common" - taosErrors "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/ws/client" ) @@ -161,13 +160,7 @@ func connect(ws *websocket.Conn, user string, password string, db string, writeT } var resp ConnectResp err = client.JsonI.Unmarshal(respBytes, &resp) - if err != nil { - return err - } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } - return nil + return client.HandleResponseError(err, resp.Code, resp.Message) } func (c *Connector) handleTextMessage(message []byte) { @@ -369,12 +362,10 @@ func (c *Connector) Init() (*Stmt, error) { } var resp InitResp err = client.JsonI.Unmarshal(respBytes, &resp) + err = client.HandleResponseError(err, resp.Code, resp.Message) if err != nil { return nil, err } - if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } return &Stmt{ id: resp.StmtID, connector: c, diff --git a/ws/stmt/rows.go b/ws/stmt/rows.go index 5247b55..88514de 100644 --- a/ws/stmt/rows.go +++ b/ws/stmt/rows.go @@ -11,7 +11,6 @@ import ( "github.com/taosdata/driver-go/v3/common" "github.com/taosdata/driver-go/v3/common/parser" "github.com/taosdata/driver-go/v3/common/pointer" - taosErrors "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/ws/client" ) @@ -101,12 +100,10 @@ func (rs *Rows) taosFetchBlock() error { } var resp WSFetchResp err = client.JsonI.Unmarshal(respBytes, &resp) + err = client.HandleResponseError(err, resp.Code, resp.Message) if err != nil { return err } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } if resp.Completed { rs.blockSize = 0 return nil diff --git a/ws/stmt/stmt.go b/ws/stmt/stmt.go index 373b763..59a1f88 100644 --- a/ws/stmt/stmt.go +++ b/ws/stmt/stmt.go @@ -6,7 +6,6 @@ import ( "github.com/taosdata/driver-go/v3/common/param" "github.com/taosdata/driver-go/v3/common/serializer" - taosErrors "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/ws/client" ) @@ -43,13 +42,7 @@ func (s *Stmt) Prepare(sql string) error { } var resp PrepareResp err = client.JsonI.Unmarshal(respBytes, &resp) - if err != nil { - return err - } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } - return nil + return client.HandleResponseError(err, resp.Code, resp.Message) } func (s *Stmt) SetTableName(name string) error { @@ -79,13 +72,7 @@ func (s *Stmt) SetTableName(name string) error { } var resp SetTableNameResp err = client.JsonI.Unmarshal(respBytes, &resp) - if err != nil { - return err - } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } - return nil + return client.HandleResponseError(err, resp.Code, resp.Message) } func (s *Stmt) SetTags(tags *param.Param, bindType *param.ColumnType) error { @@ -114,13 +101,7 @@ func (s *Stmt) SetTags(tags *param.Param, bindType *param.ColumnType) error { } var resp SetTagsResp err = client.JsonI.Unmarshal(respBytes, &resp) - if err != nil { - return err - } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } - return nil + return client.HandleResponseError(err, resp.Code, resp.Message) } func (s *Stmt) BindParam(params []*param.Param, bindType *param.ColumnType) error { @@ -148,13 +129,7 @@ func (s *Stmt) BindParam(params []*param.Param, bindType *param.ColumnType) erro } var resp BindResp err = client.JsonI.Unmarshal(respBytes, &resp) - if err != nil { - return err - } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } - return nil + return client.HandleResponseError(err, resp.Code, resp.Message) } func (s *Stmt) AddBatch() error { @@ -183,12 +158,10 @@ func (s *Stmt) AddBatch() error { } var resp AddBatchResp err = client.JsonI.Unmarshal(respBytes, &resp) + err = client.HandleResponseError(err, resp.Code, resp.Message) if err != nil { return err } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } return nil } @@ -218,12 +191,10 @@ func (s *Stmt) Exec() error { } var resp ExecResp err = client.JsonI.Unmarshal(respBytes, &resp) + err = client.HandleResponseError(err, resp.Code, resp.Message) if err != nil { return err } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } s.lastAffected = resp.Affected return nil } @@ -258,12 +229,10 @@ func (s *Stmt) UseResult() (*Rows, error) { } var resp UseResultResp err = client.JsonI.Unmarshal(respBytes, &resp) + err = client.HandleResponseError(err, resp.Code, resp.Message) if err != nil { return nil, err } - if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } return &Rows{ buf: &bytes.Buffer{}, conn: s.connector, diff --git a/ws/tmq/consumer.go b/ws/tmq/consumer.go index 6ef55f7..f59ee89 100644 --- a/ws/tmq/consumer.go +++ b/ws/tmq/consumer.go @@ -477,12 +477,10 @@ func (c *Consumer) doSubscribe(topics []string, reconnect bool) error { } var resp SubscribeResp err = client.JsonI.Unmarshal(respBytes, &resp) + err = client.HandleResponseError(err, resp.Code, resp.Message) if err != nil { return err } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } c.topics = make([]string, len(topics)) copy(c.topics, topics) return nil @@ -643,12 +641,10 @@ func (c *Consumer) fetchJsonMeta(messageID uint64) (*tmq.Meta, error) { } var resp FetchJsonMetaResp err = client.JsonI.Unmarshal(respBytes, &resp) + err = client.HandleResponseError(err, resp.Code, resp.Message) if err != nil { return nil, err } - if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } var meta tmq.Meta err = client.JsonI.Unmarshal(resp.Data, &meta) if err != nil { @@ -736,13 +732,7 @@ func (c *Consumer) doCommit() error { } var resp CommitResp err = client.JsonI.Unmarshal(respBytes, &resp) - if err != nil { - return err - } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } - return nil + return client.HandleResponseError(err, resp.Code, resp.Message) } func (c *Consumer) Unsubscribe() error { @@ -773,13 +763,7 @@ func (c *Consumer) Unsubscribe() error { } var resp CommitResp err = client.JsonI.Unmarshal(respBytes, &resp) - if err != nil { - return err - } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } - return nil + return client.HandleResponseError(err, resp.Code, resp.Message) } func (c *Consumer) Assignment() (partitions []tmq.TopicPartition, err error) { @@ -812,12 +796,10 @@ func (c *Consumer) Assignment() (partitions []tmq.TopicPartition, err error) { } var resp AssignmentResp err = client.JsonI.Unmarshal(respBytes, &resp) + err = client.HandleResponseError(err, resp.Code, resp.Message) if err != nil { return nil, err } - if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } topicName := topic for i := 0; i < len(resp.Assignment); i++ { offset := tmq.Offset(resp.Assignment[i].Offset) @@ -862,13 +844,7 @@ func (c *Consumer) Seek(partition tmq.TopicPartition, ignoredTimeoutMs int) erro } var resp OffsetSeekResp err = client.JsonI.Unmarshal(respBytes, &resp) - if err != nil { - return err - } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } - return nil + return client.HandleResponseError(err, resp.Code, resp.Message) } func (c *Consumer) Committed(partitions []tmq.TopicPartition, timeoutMs int) (offsets []tmq.TopicPartition, err error) { @@ -904,12 +880,10 @@ func (c *Consumer) Committed(partitions []tmq.TopicPartition, timeoutMs int) (of } var resp CommittedResp err = client.JsonI.Unmarshal(respBytes, &resp) + err = client.HandleResponseError(err, resp.Code, resp.Message) if err != nil { return nil, err } - if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } for i := 0; i < len(resp.Committed); i++ { offsets[i] = tmq.TopicPartition{ Topic: partitions[i].Topic, @@ -953,12 +927,10 @@ func (c *Consumer) CommitOffsets(offsets []tmq.TopicPartition) ([]tmq.TopicParti } var resp CommitOffsetResp err = client.JsonI.Unmarshal(respBytes, &resp) + err = client.HandleResponseError(err, resp.Code, resp.Message) if err != nil { return nil, err } - if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } } return c.Committed(offsets, 0) } @@ -996,12 +968,10 @@ func (c *Consumer) Position(partitions []tmq.TopicPartition) (offsets []tmq.Topi } var resp PositionResp err = client.JsonI.Unmarshal(respBytes, &resp) + err = client.HandleResponseError(err, resp.Code, resp.Message) if err != nil { return nil, err } - if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } for i := 0; i < len(resp.Position); i++ { offsets[i] = tmq.TopicPartition{ Topic: partitions[i].Topic, From b833141adf897e9488610dbdefca41dfff79d0ed Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Fri, 8 Nov 2024 10:58:04 +0800 Subject: [PATCH 19/35] test: add unit test --- taosWS/connection.go | 54 +++++++++++++-------------------------- taosWS/connection_test.go | 24 +++++++++++++++++ ws/stmt/stmt.go | 6 +---- 3 files changed, 43 insertions(+), 41 deletions(-) diff --git a/taosWS/connection.go b/taosWS/connection.go index cf1791e..10be9c4 100644 --- a/taosWS/connection.go +++ b/taosWS/connection.go @@ -226,12 +226,10 @@ func (tc *taosConn) stmtInit(reqID uint64) (uint64, error) { } var resp StmtInitResp err = tc.readTo(&resp) + err = handleResponseError(err, resp.Code, resp.Message) if err != nil { return 0, err } - if resp.Code != 0 { - return 0, taosErrors.NewError(resp.Code, resp.Message) - } return resp.StmtID, nil } @@ -261,12 +259,10 @@ func (tc *taosConn) stmtPrepare(stmtID uint64, sql string) (bool, error) { } var resp StmtPrepareResponse err = tc.readTo(&resp) + err = handleResponseError(err, resp.Code, resp.Message) if err != nil { return false, err } - if resp.Code != 0 { - return false, taosErrors.NewError(resp.Code, resp.Message) - } return resp.IsInsert, nil } @@ -321,12 +317,10 @@ func (tc *taosConn) stmtGetColFields(stmtID uint64) ([]*stmtCommon.StmtField, er } var resp StmtGetColFieldsResponse err = tc.readTo(&resp) + err = handleResponseError(err, resp.Code, resp.Message) if err != nil { return nil, err } - if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } return resp.Fields, nil } @@ -343,13 +337,7 @@ func (tc *taosConn) stmtBindParam(stmtID uint64, block []byte) error { } var resp StmtBindResponse err = tc.readTo(&resp) - if err != nil { - return err - } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } - return nil + return handleResponseError(err, resp.Code, resp.Message) } func WriteUint64(buffer *bytes.Buffer, v uint64) { @@ -400,13 +388,7 @@ func (tc *taosConn) stmtAddBatch(stmtID uint64) error { } var resp StmtAddBatchResponse err = tc.readTo(&resp) - if err != nil { - return err - } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } - return nil + return handleResponseError(err, resp.Code, resp.Message) } func (tc *taosConn) stmtExec(stmtID uint64) (int, error) { @@ -434,12 +416,10 @@ func (tc *taosConn) stmtExec(stmtID uint64) (int, error) { } var resp StmtExecResponse err = tc.readTo(&resp) + err = handleResponseError(err, resp.Code, resp.Message) if err != nil { return 0, err } - if resp.Code != 0 { - return 0, taosErrors.NewError(resp.Code, resp.Message) - } return resp.Affected, nil } @@ -468,12 +448,10 @@ func (tc *taosConn) stmtUseResult(stmtID uint64) (*rows, error) { } var resp StmtUseResultResponse err = tc.readTo(&resp) + err = handleResponseError(err, resp.Code, resp.Message) if err != nil { return nil, err } - if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } rs := &rows{ buf: &bytes.Buffer{}, conn: tc, @@ -603,13 +581,7 @@ func (tc *taosConn) connect() error { } var resp WSConnectResp err = tc.readTo(&resp) - if err != nil { - return err - } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) - } - return nil + return handleResponseError(err, resp.Code, resp.Message) } func (tc *taosConn) writeText(data []byte) error { @@ -704,3 +676,13 @@ func formatBytes(bs []byte) string { buffer.WriteByte(']') return buffer.String() } + +func handleResponseError(err error, code int, msg string) error { + if err != nil { + return err + } + if code != 0 { + return taosErrors.NewError(code, msg) + } + return nil +} diff --git a/taosWS/connection_test.go b/taosWS/connection_test.go index cb814e8..da0b279 100644 --- a/taosWS/connection_test.go +++ b/taosWS/connection_test.go @@ -2,9 +2,11 @@ package taosWS import ( "context" + "errors" "testing" "github.com/stretchr/testify/assert" + taosErrors "github.com/taosdata/driver-go/v3/errors" ) // @author: xftan @@ -73,3 +75,25 @@ func TestBadConnection(t *testing.T) { t.Fatalf("query should fail") } } + +func TestHandleResponseError(t *testing.T) { + t.Run("Error not nil", func(t *testing.T) { + err := errors.New("some error") + result := handleResponseError(err, 0, "ignored message") + assert.Equal(t, err, result, "Expected the original error to be returned") + }) + + t.Run("Error nil and non-zero code", func(t *testing.T) { + code := 123 + msg := "some error message" + expectedErr := taosErrors.NewError(code, msg) + + result := handleResponseError(nil, code, msg) + assert.EqualError(t, result, expectedErr.Error(), "Expected a new error to be returned based on code and message") + }) + + t.Run("Error nil and zero code", func(t *testing.T) { + result := handleResponseError(nil, 0, "ignored message") + assert.Nil(t, result, "Expected nil to be returned when there is no error and code is zero") + }) +} diff --git a/ws/stmt/stmt.go b/ws/stmt/stmt.go index 59a1f88..fe55c8c 100644 --- a/ws/stmt/stmt.go +++ b/ws/stmt/stmt.go @@ -158,11 +158,7 @@ func (s *Stmt) AddBatch() error { } var resp AddBatchResp err = client.JsonI.Unmarshal(respBytes, &resp) - err = client.HandleResponseError(err, resp.Code, resp.Message) - if err != nil { - return err - } - return nil + return client.HandleResponseError(err, resp.Code, resp.Message) } func (s *Stmt) Exec() error { From 526835e7058dc6aa86a27c1ef602b4c0f70c039b Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Fri, 8 Nov 2024 13:55:03 +0800 Subject: [PATCH 20/35] test: add unit test --- taosSql/connection.go | 31 ++++++++++++++----------------- taosSql/connection_test.go | 15 +++++++++++++++ taosSql/statement_test.go | 26 ++++++++++++++++++++++++++ taosWS/connection_test.go | 16 ++++++++++++++++ 4 files changed, 71 insertions(+), 17 deletions(-) diff --git a/taosSql/connection.go b/taosSql/connection.go index 8cf54f1..b625137 100644 --- a/taosSql/connection.go +++ b/taosSql/connection.go @@ -38,23 +38,13 @@ func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { stmtP := wrapper.TaosStmtInit(tc.taos) code := wrapper.TaosStmtPrepare(stmtP, query) locker.Unlock() - if code != 0 { - errStr := wrapper.TaosStmtErrStr(stmtP) - err := errors.NewError(code, errStr) - locker.Lock() - wrapper.TaosStmtClose(stmtP) - locker.Unlock() + if err := checkStmtError(code, stmtP); err != nil { return nil, err } locker.Lock() isInsert, code := wrapper.TaosStmtIsInsert(stmtP) locker.Unlock() - if code != 0 { - errStr := wrapper.TaosStmtErrStr(stmtP) - err := errors.NewError(code, errStr) - locker.Lock() - wrapper.TaosStmtClose(stmtP) - locker.Unlock() + if err := checkStmtError(code, stmtP); err != nil { return nil, err } stmt := &Stmt{ @@ -66,6 +56,18 @@ func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { return stmt, nil } +func checkStmtError(code int, stmtP unsafe.Pointer) error { + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtP) + err := errors.NewError(code, errStr) + locker.Lock() + wrapper.TaosStmtClose(stmtP) + locker.Unlock() + return err + } + return nil +} + func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Result, err error) { if tc.taos == nil { return nil, driver.ErrBadConn @@ -177,11 +179,6 @@ func (tc *taosConn) Ping(ctx context.Context) (err error) { return errors.ErrTscInvalidConnection } -// BeginTx implements driver.ConnBeginTx interface -func (tc *taosConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - return nil, &errors.TaosError{Code: 0xffff, ErrStr: "taosSql does not support transaction"} -} - func (tc *taosConn) taosQuery(sqlStr string, handler *handler.Handler, reqID int64) *handler.AsyncResult { locker.Lock() if reqID == 0 { diff --git a/taosSql/connection_test.go b/taosSql/connection_test.go index f473ec5..c6f5b65 100644 --- a/taosSql/connection_test.go +++ b/taosSql/connection_test.go @@ -5,6 +5,7 @@ import ( "database/sql" "testing" + "github.com/stretchr/testify/assert" "github.com/taosdata/driver-go/v3/common" ) @@ -51,3 +52,17 @@ func TestTaosConn_ExecContext(t *testing.T) { t.Fatal("result miss") } } + +func TestWrongReqID(t *testing.T) { + ctx := context.WithValue(context.Background(), common.ReqIDKey, uint64(1234)) + db, err := sql.Open("taosSql", dataSourceName) + if err != nil { + t.Fatal(err) + } + defer db.Close() + rs, err := db.QueryContext(ctx, "select 1") + assert.Error(t, err) + assert.Nil(t, rs) + _, err = db.ExecContext(ctx, "create database if not exists test_wrong_req_id") + assert.Error(t, err) +} diff --git a/taosSql/statement_test.go b/taosSql/statement_test.go index 7230d78..e1c30ba 100644 --- a/taosSql/statement_test.go +++ b/taosSql/statement_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/taosdata/driver-go/v3/wrapper" ) // @author: xftan @@ -2157,3 +2158,28 @@ func TestStmtConvertQuery(t *testing.T) { }) } } + +func TestWrongStmt(t *testing.T) { + d := &TDengineDriver{} + conn, err := d.Open(dataSourceName) + assert.NoError(t, err) + defer conn.Close() + c := conn.(*taosConn) + cPointer := c.taos + c.taos = nil + defer func() { + c.taos = cPointer + }() + _, err = c.Prepare("") + assert.Error(t, err) + + p, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + code := wrapper.TaosStmtPrepare(p, "") + err = checkStmtError(code, p) + assert.NoError(t, err) + code = wrapper.TaosStmtExecute(p) + assert.NotEqual(t, 0, code) + err = checkStmtError(code, p) + assert.Error(t, err) +} diff --git a/taosWS/connection_test.go b/taosWS/connection_test.go index da0b279..d4b4d2d 100644 --- a/taosWS/connection_test.go +++ b/taosWS/connection_test.go @@ -97,3 +97,19 @@ func TestHandleResponseError(t *testing.T) { assert.Nil(t, result, "Expected nil to be returned when there is no error and code is zero") }) } + +func TestBegin(t *testing.T) { + cfg, err := parseDSN(dataSourceName) + if err != nil { + t.Fatalf("parseDSN error: %v", err) + } + conn, err := newTaosConn(cfg) + if err != nil { + t.Fatalf("newTaosConn error: %v", err) + } + defer conn.Close() + + tx, err := conn.Begin() + assert.Error(t, err) + assert.Nil(t, tx) +} From 3090bebd7c6da098eeb6652b2e9f24719d6308c8 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Fri, 8 Nov 2024 14:12:36 +0800 Subject: [PATCH 21/35] test: add unit test --- af/tmq/consumer.go | 32 ++++++++++++-------------------- af/tmq/consumer_test.go | 6 ++++++ 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/af/tmq/consumer.go b/af/tmq/consumer.go index aab447c..ddc2e0a 100644 --- a/af/tmq/consumer.go +++ b/af/tmq/consumer.go @@ -65,12 +65,12 @@ func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) err for _, topic := range topics { errCode := wrapper.TMQListAppend(topicList, topic) if errCode != 0 { - return c.tmqError(errCode) + return tmqError(errCode) } } errCode := wrapper.TMQSubscribe(c.cConsumer, topicList) if errCode != 0 { - return c.tmqError(errCode) + return tmqError(errCode) } return nil } @@ -79,7 +79,7 @@ func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) err func (c *Consumer) Unsubscribe() error { errCode := wrapper.TMQUnsubscribe(c.cConsumer) if errCode != taosError.SUCCESS { - return c.tmqError(errCode) + return tmqError(errCode) } return nil } @@ -199,7 +199,7 @@ func (c *Consumer) getData(message unsafe.Pointer) ([]*tmq.Data, error) { func (c *Consumer) Commit() ([]tmq.TopicPartition, error) { errCode := wrapper.TMQCommitSync(c.cConsumer, nil) if errCode != taosError.SUCCESS { - return nil, c.tmqError(errCode) + return nil, tmqError(errCode) } partitions, err := c.Assignment() if err != nil { @@ -208,18 +208,10 @@ func (c *Consumer) Commit() ([]tmq.TopicPartition, error) { return c.Committed(partitions, 0) } -func (c *Consumer) doCommit(message unsafe.Pointer) ([]tmq.TopicPartition, error) { - errCode := wrapper.TMQCommitSync(c.cConsumer, message) - if errCode != taosError.SUCCESS { - return nil, c.tmqError(errCode) - } - return nil, nil -} - func (c *Consumer) Assignment() (partitions []tmq.TopicPartition, err error) { errCode, list := wrapper.TMQSubscription(c.cConsumer) if errCode != taosError.SUCCESS { - return nil, c.tmqError(errCode) + return nil, tmqError(errCode) } defer wrapper.TMQListDestroy(list) size := wrapper.TMQListGetSize(list) @@ -227,7 +219,7 @@ func (c *Consumer) Assignment() (partitions []tmq.TopicPartition, err error) { for _, topic := range topics { errCode, assignment := wrapper.TMQGetTopicAssignment(c.cConsumer, topic) if errCode != taosError.SUCCESS { - return nil, c.tmqError(errCode) + return nil, tmqError(errCode) } for i := 0; i < len(assignment); i++ { topicName := topic @@ -244,7 +236,7 @@ func (c *Consumer) Assignment() (partitions []tmq.TopicPartition, err error) { func (c *Consumer) Seek(partition tmq.TopicPartition, ignoredTimeoutMs int) error { errCode := wrapper.TMQOffsetSeek(c.cConsumer, *partition.Topic, partition.Partition, int64(partition.Offset)) if errCode != taosError.SUCCESS { - return c.tmqError(errCode) + return tmqError(errCode) } return nil } @@ -255,7 +247,7 @@ func (c *Consumer) Committed(partitions []tmq.TopicPartition, timeoutMs int) (of cOffset := wrapper.TMQCommitted(c.cConsumer, *partitions[i].Topic, partitions[i].Partition) offset := tmq.Offset(cOffset) if !offset.Valid() { - return nil, c.tmqError(int32(offset)) + return nil, tmqError(int32(offset)) } offsets[i] = tmq.TopicPartition{ Topic: partitions[i].Topic, @@ -270,7 +262,7 @@ func (c *Consumer) CommitOffsets(offsets []tmq.TopicPartition) ([]tmq.TopicParti for i := 0; i < len(offsets); i++ { errCode := wrapper.TMQCommitOffsetSync(c.cConsumer, *offsets[i].Topic, offsets[i].Partition, int64(offsets[i].Offset)) if errCode != taosError.SUCCESS { - return nil, c.tmqError(errCode) + return nil, tmqError(errCode) } } return c.Committed(offsets, 0) @@ -281,7 +273,7 @@ func (c *Consumer) Position(partitions []tmq.TopicPartition) (offsets []tmq.Topi for i := 0; i < len(partitions); i++ { position := wrapper.TMQPosition(c.cConsumer, *partitions[i].Topic, partitions[i].Partition) if position < 0 { - return nil, c.tmqError(int32(position)) + return nil, tmqError(int32(position)) } offsets[i] = tmq.TopicPartition{ Topic: partitions[i].Topic, @@ -296,12 +288,12 @@ func (c *Consumer) Position(partitions []tmq.TopicPartition) (offsets []tmq.Topi func (c *Consumer) Close() error { errCode := wrapper.TMQConsumerClose(c.cConsumer) if errCode != 0 { - return c.tmqError(errCode) + return tmqError(errCode) } return nil } -func (c *Consumer) tmqError(errCode int32) error { +func tmqError(errCode int32) error { errStr := wrapper.TMQErr2Str(errCode) return taosError.NewError(int(errCode), errStr) } diff --git a/af/tmq/consumer_test.go b/af/tmq/consumer_test.go index 47252b7..ebee650 100644 --- a/af/tmq/consumer_test.go +++ b/af/tmq/consumer_test.go @@ -493,3 +493,9 @@ func TestMeta(t *testing.T) { } } } + +func Test_tmqError(t *testing.T) { + err := tmqError(-1) + expectError := &errors.TaosError{Code: 65535, ErrStr: "fail"} + assert.Equal(t, expectError, err) +} From d8de000e39eeb52464963f19fc307b190c27d229 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Fri, 8 Nov 2024 14:17:04 +0800 Subject: [PATCH 22/35] test: add unit test --- taosSql/statement_test.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/taosSql/statement_test.go b/taosSql/statement_test.go index e1c30ba..ea6b140 100644 --- a/taosSql/statement_test.go +++ b/taosSql/statement_test.go @@ -2175,11 +2175,13 @@ func TestWrongStmt(t *testing.T) { p, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) assert.NoError(t, err) - code := wrapper.TaosStmtPrepare(p, "") - err = checkStmtError(code, p) + defer wrapper.TaosClose(p) + stmt := wrapper.TaosStmtInit(p) + code := wrapper.TaosStmtPrepare(stmt, "") + err = checkStmtError(code, stmt) assert.NoError(t, err) - code = wrapper.TaosStmtExecute(p) + code = wrapper.TaosStmtExecute(stmt) assert.NotEqual(t, 0, code) - err = checkStmtError(code, p) + err = checkStmtError(code, stmt) assert.Error(t, err) } From d9e5c616b3b824100e9bdac7f6c131a45ec02f63 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 12 Nov 2024 10:36:15 +0800 Subject: [PATCH 23/35] test: fix stmt bind geometry test --- wrapper/stmt_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wrapper/stmt_test.go b/wrapper/stmt_test.go index b771ddd..557b10d 100644 --- a/wrapper/stmt_test.go +++ b/wrapper/stmt_test.go @@ -154,7 +154,7 @@ func TestStmt(t *testing.T) { params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosGeometry{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}}}, bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, { Type: taosTypes.TaosGeometryType, - MaxLen: 3, + MaxLen: 100, }}, expectValue: []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, }, //3 From 619c863023ed00f2f20e91fa6e0960aa6f789aad Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 12 Nov 2024 11:03:26 +0800 Subject: [PATCH 24/35] test: add cgo test --- wrapper/cgo/handle_test.go | 57 +++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/wrapper/cgo/handle_test.go b/wrapper/cgo/handle_test.go index 89bfd31..46cdc52 100644 --- a/wrapper/cgo/handle_test.go +++ b/wrapper/cgo/handle_test.go @@ -7,6 +7,8 @@ package cgo import ( "reflect" "testing" + + "github.com/stretchr/testify/assert" ) // @author: xftan @@ -56,34 +58,33 @@ func TestHandle(t *testing.T) { } } -//func TestInvalidHandle(t *testing.T) { -// t.Run("zero", func(t *testing.T) { -// h := Handle(0) -// -// defer func() { -// if r := recover(); r != nil { -// return -// } -// t.Fatalf("Delete of zero handle did not trigger a panic") -// }() -// -// h.Delete() -// }) -// -// t.Run("invalid", func(t *testing.T) { -// h := NewHandle(42) -// -// defer func() { -// if r := recover(); r != nil { -// h.Delete() -// return -// } -// t.Fatalf("Invalid handle did not trigger a panic") -// }() -// -// Handle(h + 1).Delete() -// }) -//} +func TestPointer(t *testing.T) { + v := 42 + h := NewHandle(&v) + p := h.Pointer() + assert.Equal(t, *(*Handle)(p), h) + h.Delete() + defer func() { + if r := recover(); r != nil { + return + } + t.Fatalf("Pointer should panic") + }() + h.Pointer() +} + +func TestInvalidValue(t *testing.T) { + v := 42 + h := NewHandle(&v) + h.Delete() + defer func() { + if r := recover(); r != nil { + return + } + t.Fatalf("Value should panic") + }() + h.Value() +} func BenchmarkHandle(b *testing.B) { b.Run("non-concurrent", func(b *testing.B) { From 782689d3b1944f06e8cdb3416f12bab5712f013f Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Fri, 15 Nov 2024 17:29:45 +0800 Subject: [PATCH 25/35] ci: fix TDengine build --- .github/workflows/go.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 4efba15..5e81127 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -74,7 +74,7 @@ jobs: cd TDengine mkdir debug cd debug - cmake .. -DBUILD_TEST=off -DBUILD_HTTP=false -DBUILD_DEPENDENCY_TESTS=0 -DVERNUMBER=3.9.9.9 + cmake .. -DBUILD_TEST=off -DBUILD_HTTP=false -DBUILD_DEPENDENCY_TESTS=0 -DVERNUMBER=3.9.9.9 -DBUILD_DEPENDENCY_TESTS=0 make -j 4 - name: package From a05f8293f92fb87ef03310f942b0c3e49a30c08c Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Mon, 18 Nov 2024 09:57:43 +0800 Subject: [PATCH 26/35] ci: fix TDengine build --- .github/workflows/go.yml | 52 ++++++++++---- .github/workflows/push.yml | 141 ------------------------------------- 2 files changed, 39 insertions(+), 154 deletions(-) delete mode 100644 .github/workflows/push.yml diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 5e81127..978ed77 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -6,6 +6,12 @@ on: - 'main' - '3.0' - '3.1' + push: + branches: + - 'main' + - '3.0' + - '3.1' + workflow_dispatch: inputs: tbBranch: @@ -16,7 +22,7 @@ on: jobs: build: - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest name: Build outputs: commit_id: ${{ steps.get_commit_id.outputs.commit_id }} @@ -29,6 +35,14 @@ jobs: path: 'TDengine' ref: ${{ github.base_ref }} + - name: checkout TDengine by push + if: github.event_name == 'push' + uses: actions/checkout@v4 + with: + repository: 'taosdata/TDengine' + path: 'TDengine' + ref: ${{ github.ref_name }} + - name: checkout TDengine manually if: github.event_name == 'workflow_dispatch' uses: actions/checkout@v4 @@ -52,6 +66,14 @@ jobs: path: server.tar.gz key: ${{ runner.os }}-build-${{ github.base_ref }}-${{ steps.get_commit_id.outputs.commit_id }} + - name: Cache server by push + if: github.event_name == 'push' + id: cache-server-push + uses: actions/cache@v4 + with: + path: server.tar.gz + key: ${{ runner.os }}-build-${{ github.ref_name }}-${{ steps.get_commit_id.outputs.commit_id }} + - name: Cache server manually if: github.event_name == 'workflow_dispatch' id: cache-server-manually @@ -60,27 +82,24 @@ jobs: path: server.tar.gz key: ${{ runner.os }}-build-${{ inputs.tbBranch }}-${{ steps.get_commit_id.outputs.commit_id }} - - name: prepare install - if: > - (github.event_name == 'workflow_dispatch' && steps.cache-server-manually.outputs.cache-hit != 'true') || - (github.event_name == 'pull_request' && steps.cache-server-pr.outputs.cache-hit != 'true') - run: sudo apt install -y libgeos-dev - name: install TDengine if: > (github.event_name == 'workflow_dispatch' && steps.cache-server-manually.outputs.cache-hit != 'true') || - (github.event_name == 'pull_request' && steps.cache-server-pr.outputs.cache-hit != 'true') + (github.event_name == 'pull_request' && steps.cache-server-pr.outputs.cache-hit != 'true') || + (github.event_name == 'push' && steps.cache-server-push.outputs.cache-hit != 'true') run: | cd TDengine mkdir debug cd debug - cmake .. -DBUILD_TEST=off -DBUILD_HTTP=false -DBUILD_DEPENDENCY_TESTS=0 -DVERNUMBER=3.9.9.9 -DBUILD_DEPENDENCY_TESTS=0 + cmake .. -DBUILD_TEST=off -DBUILD_HTTP=false -DBUILD_DEPENDENCY_TESTS=0 -DVERNUMBER=3.9.9.9 make -j 4 - name: package if: > (github.event_name == 'workflow_dispatch' && steps.cache-server-manually.outputs.cache-hit != 'true') || - (github.event_name == 'pull_request' && steps.cache-server-pr.outputs.cache-hit != 'true') + (github.event_name == 'pull_request' && steps.cache-server-pr.outputs.cache-hit != 'true') || + (github.event_name == 'push' && steps.cache-server-push.outputs.cache-hit != 'true') run: | mkdir -p ./release cp ./TDengine/debug/build/bin/taos ./release/ @@ -107,7 +126,7 @@ jobs: tar -zcvf server.tar.gz ./release test: - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest needs: build strategy: matrix: @@ -124,6 +143,16 @@ jobs: restore-keys: | ${{ runner.os }}-build-${{ github.base_ref }}- + - name: get cache server by push + if: github.event_name == 'push' + id: get-cache-server-push + uses: actions/cache@v4 + with: + path: server.tar.gz + key: ${{ runner.os }}-build-${{ github.ref_name }}-${{ needs.build.outputs.commit_id }} + restore-keys: | + ${{ runner.os }}-build-${{ github.ref_name }}- + - name: cache server manually if: github.event_name == 'workflow_dispatch' id: get-cache-server-manually @@ -134,8 +163,6 @@ jobs: restore-keys: | ${{ runner.os }}-build-${{ inputs.tbBranch }}- - - name: prepare install - run: sudo apt install -y libgeos-dev - name: install run: | @@ -173,7 +200,6 @@ jobs: run: sudo go test -v --count=1 -coverprofile=coverage.txt -covermode=atomic ./... - name: Upload coverage to Codecov - if: ${{ matrix.go == 'stable'}} uses: codecov/codecov-action@v4 with: files: ./coverage.txt diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml deleted file mode 100644 index 99e666c..0000000 --- a/.github/workflows/push.yml +++ /dev/null @@ -1,141 +0,0 @@ -name: push - -on: - push: - branches: - - 'main' - - '3.0' - - '3.1' - - -jobs: - build: - runs-on: ubuntu-22.04 - name: Build - outputs: - commit_id: ${{ steps.get_commit_id.outputs.commit_id }} - steps: - - name: checkout TDengine - uses: actions/checkout@v4 - with: - repository: 'taosdata/TDengine' - path: 'TDengine' - ref: ${{ github.ref_name }} - - - name: get_commit_id - id: get_commit_id - run: | - cd TDengine - echo "commit_id=$(git rev-parse HEAD)" >> $GITHUB_OUTPUT - - - - name: Cache server - id: cache-server - uses: actions/cache@v4 - with: - path: server.tar.gz - key: ${{ runner.os }}-build-${{ github.ref_name }}-${{ steps.get_commit_id.outputs.commit_id }} - - - name: prepare install - if: steps.cache-server.outputs.cache-hit != 'true' - run: sudo apt install -y libgeos-dev - - - - name: install TDengine - if: steps.cache-server.outputs.cache-hit != 'true' - run: | - cd TDengine - mkdir debug - cd debug - cmake .. -DBUILD_JDBC=false -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 - make -j 4 - - - name: package - if: steps.cache-server.outputs.cache-hit != 'true' - run: | - mkdir -p ./release - cp ./TDengine/debug/build/bin/taos ./release/ - cp ./TDengine/debug/build/bin/taosd ./release/ - cp ./TDengine/tools/taosadapter/taosadapter ./release/ - cp ./TDengine/debug/build/lib/libtaos.so.3.9.9.9 ./release/ - cp ./TDengine/debug/build/lib/librocksdb.so.8.1.1 ./release/ ||: - cp ./TDengine/include/client/taos.h ./release/ - cat >./release/install.sh<start.sh< Date: Mon, 18 Nov 2024 10:59:47 +0800 Subject: [PATCH 27/35] ci: update actions/setup-go --- .github/workflows/compatibility.yml | 2 +- .github/workflows/go.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/compatibility.yml b/.github/workflows/compatibility.yml index 363a85e..1dcee8c 100644 --- a/.github/workflows/compatibility.yml +++ b/.github/workflows/compatibility.yml @@ -131,7 +131,7 @@ jobs: run: sudo taosadapter & - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} cache-dependency-path: go.sum diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 978ed77..3009faf 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -191,7 +191,7 @@ jobs: run: sudo taosadapter & - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} cache-dependency-path: go.sum From 055bb7c903b63714af3146503f22a3fcf1f51302 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 17 Dec 2024 18:18:46 +0800 Subject: [PATCH 28/35] feat: rename taos_stmt2_get_fields --- .github/workflows/go.yml | 76 ++- af/stmt2.go | 30 +- common/stmt/stmt2.go | 26 +- common/stmt/stmt2_test.go | 291 ++++++---- wrapper/notify_test.go | 20 +- wrapper/stmt2.go | 303 +++------- wrapper/stmt2_test.go | 1079 ++++++++++++++++++++++++++++++----- wrapper/stmt2binary.go | 272 +++++++++ wrapper/stmt_test.go | 2 +- wrapper/whitelistcb_test.go | 1 + 10 files changed, 1606 insertions(+), 494 deletions(-) create mode 100644 wrapper/stmt2binary.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 3009faf..70674ff 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -204,4 +204,78 @@ jobs: with: files: ./coverage.txt env: - CODECOV_TOKEN: ${{ secrets.CODECOV_ORG_TOKEN }} \ No newline at end of file + CODECOV_TOKEN: ${{ secrets.CODECOV_ORG_TOKEN }} + + test_asan: + runs-on: ubuntu-latest + needs: build + strategy: + matrix: + go: [ 'stable' ] + name: Go-ASAN-${{ matrix.go }} + steps: + - name: get cache server by pr + if: github.event_name == 'pull_request' + id: get-cache-server-pr + uses: actions/cache@v4 + with: + path: server.tar.gz + key: ${{ runner.os }}-build-${{ github.base_ref }}-${{ needs.build.outputs.commit_id }} + restore-keys: | + ${{ runner.os }}-build-${{ github.base_ref }}- + + - name: get cache server by push + if: github.event_name == 'push' + id: get-cache-server-push + uses: actions/cache@v4 + with: + path: server.tar.gz + key: ${{ runner.os }}-build-${{ github.ref_name }}-${{ needs.build.outputs.commit_id }} + restore-keys: | + ${{ runner.os }}-build-${{ github.ref_name }}- + + - name: cache server manually + if: github.event_name == 'workflow_dispatch' + id: get-cache-server-manually + uses: actions/cache@v4 + with: + path: server.tar.gz + key: ${{ runner.os }}-build-${{ inputs.tbBranch }}-${{ needs.build.outputs.commit_id }} + restore-keys: | + ${{ runner.os }}-build-${{ inputs.tbBranch }}- + + + - name: install + run: | + tar -zxvf server.tar.gz + cd release && sudo sh install.sh + + - name: checkout + uses: actions/checkout@v4 + + - name: copy taos cfg + run: | + sudo mkdir -p /etc/taos + sudo cp ./.github/workflows/taos.cfg /etc/taos/taos.cfg + sudo cp ./.github/workflows/taosadapter.toml /etc/taos/taosadapter.toml + + - name: shell + run: | + cat >start.sh<= 0; i-- { - if freePointer[i] != nil { - C.free(freePointer[i]) - } - } - }() - dataP := unsafe.Pointer(C.CBytes(data)) - freePointer = append(freePointer, dataP) - count := binary.LittleEndian.Uint32(data[stmt.CountPosition:]) - tagCount := binary.LittleEndian.Uint32(data[stmt.TagCountPosition:]) - colCount := binary.LittleEndian.Uint32(data[stmt.ColCountPosition:]) - tableNamesOffset := binary.LittleEndian.Uint32(data[stmt.TableNamesOffsetPosition:]) - tagsOffset := binary.LittleEndian.Uint32(data[stmt.TagsOffsetPosition:]) - colsOffset := binary.LittleEndian.Uint32(data[stmt.ColsOffsetPosition:]) - // check table names - if tableNamesOffset > 0 { - tableNameEnd := tableNamesOffset + count*2 - // table name lengths out of range - if tableNameEnd > totalLength { - return fmt.Errorf("table name lengths out of range, total length: %d, tableNamesLengthEnd: %d", totalLength, tableNameEnd) - } - for i := uint32(0); i < count; i++ { - tableNameLength := binary.LittleEndian.Uint16(data[tableNamesOffset+i*2:]) - tableNameEnd += uint32(tableNameLength) - } - if tableNameEnd > totalLength { - return fmt.Errorf("table names out of range, total length: %d, tableNameTotalLength: %d", totalLength, tableNameEnd) - } - } - // check tags - if tagsOffset > 0 { - if tagCount == 0 { - return fmt.Errorf("tag count is zero, but tags offset is not zero") - } - tagsEnd := tagsOffset + count*4 - if tagsEnd > totalLength { - return fmt.Errorf("tags lengths out of range, total length: %d, tagsLengthEnd: %d", totalLength, tagsEnd) - } - for i := uint32(0); i < count; i++ { - tagLength := binary.LittleEndian.Uint32(data[tagsOffset+i*4:]) - if tagLength == 0 { - return fmt.Errorf("tag length is zero, data index: %d", i) - } - tagsEnd += tagLength - } - if tagsEnd > totalLength { - return fmt.Errorf("tags out of range, total length: %d, tagsTotalLength: %d", totalLength, tagsEnd) - } +// TaosStmt2GetFields int taos_stmt2_get_fields(TAOS_STMT2 *stmt, int *count, TAOS_FIELD_ALL **fields); +func TaosStmt2GetFields(stmt2 unsafe.Pointer) (code, count int, fields unsafe.Pointer) { + code = int(C.taos_stmt2_get_fields(stmt2, (*C.int)(unsafe.Pointer(&count)), (**C.TAOS_FIELD_ALL)(unsafe.Pointer(&fields)))) + return +} + +//typedef struct TAOS_FIELD_ALL { +//char name[65]; +//int8_t type; +//uint8_t precision; +//uint8_t scale; +//int32_t bytes; +//TAOS_FIELD_T field_type; +//} TAOS_FIELD_ALL; + +func Stmt2ParseAllFields(num int, fields unsafe.Pointer) []*stmt.Stmt2AllField { + if num <= 0 { + return nil } - // check cols - if colsOffset > 0 { - if colCount == 0 { - return fmt.Errorf("col count is zero, but cols offset is not zero") - } - colsEnd := colsOffset + count*4 - if colsEnd > totalLength { - return fmt.Errorf("cols lengths out of range, total length: %d, colsLengthEnd: %d", totalLength, colsEnd) - } - for i := uint32(0); i < count; i++ { - colLength := binary.LittleEndian.Uint32(data[colsOffset+i*4:]) - if colLength == 0 { - return fmt.Errorf("col length is zero, data: %d", i) - } - colsEnd += colLength - } - if colsEnd > totalLength { - return fmt.Errorf("cols out of range, total length: %d, colsTotalLength: %d", totalLength, colsEnd) - } + if fields == nil { + return nil } - cBindv := C.TAOS_STMT2_BINDV{} - cBindv.count = C.int(count) - if tableNamesOffset > 0 { - tableNameLengthP := pointer.AddUintptr(dataP, uintptr(tableNamesOffset)) - cTableNames := C.malloc(C.size_t(uintptr(count) * PointerSize)) - freePointer = append(freePointer, cTableNames) - tableDataP := pointer.AddUintptr(tableNameLengthP, uintptr(count)*2) - var tableNamesArrayP unsafe.Pointer - for i := uint32(0); i < count; i++ { - tableNamesArrayP = pointer.AddUintptr(cTableNames, uintptr(i)*PointerSize) - *(**C.char)(tableNamesArrayP) = (*C.char)(tableDataP) - tableNameLength := *(*uint16)(pointer.AddUintptr(tableNameLengthP, uintptr(i*2))) - if tableNameLength == 0 { - return fmt.Errorf("table name length is zero, data index: %d", i) + result := make([]*stmt.Stmt2AllField, num) + buf := bytes.NewBufferString("") + for i := 0; i < num; i++ { + r := &stmt.Stmt2AllField{} + field := *(*C.TAOS_FIELD_ALL)(unsafe.Pointer(uintptr(fields) + uintptr(C.sizeof_struct_TAOS_FIELD_ALL*C.int(i)))) + for _, c := range field.name { + if c == 0 { + break } - tableDataP = pointer.AddUintptr(tableDataP, uintptr(tableNameLength)) + buf.WriteByte(byte(c)) } - cBindv.tbnames = (**C.char)(cTableNames) - } else { - cBindv.tbnames = nil - } - if tagsOffset > 0 { - tags, needFreePointer, err := generateStmt2Binds(count, tagCount, dataP, tagsOffset) - freePointer = append(freePointer, needFreePointer...) - if err != nil { - return fmt.Errorf("generate tags error: %s", err.Error()) - } - cBindv.tags = (**C.TAOS_STMT2_BIND)(tags) - } else { - cBindv.tags = nil - } - if colsOffset > 0 { - cols, needFreePointer, err := generateStmt2Binds(count, colCount, dataP, colsOffset) - freePointer = append(freePointer, needFreePointer...) - if err != nil { - return fmt.Errorf("generate cols error: %s", err.Error()) - } - cBindv.bind_cols = (**C.TAOS_STMT2_BIND)(cols) - } else { - cBindv.bind_cols = nil - } - code := int(C.taos_stmt2_bind_param(stmt2, &cBindv, C.int32_t(colIdx))) - if code != 0 { - errStr := TaosStmt2Error(stmt2) - return taosError.NewError(code, errStr) + r.Name = buf.String() + buf.Reset() + r.FieldType = int8(field._type) + r.Precision = uint8(field.precision) + r.Scale = uint8(field.scale) + r.Bytes = int32(field.bytes) + r.BindType = int8(field.field_type) + result[i] = r } - return nil + return result } -func generateStmt2Binds(count uint32, fieldCount uint32, dataP unsafe.Pointer, fieldsOffset uint32) (unsafe.Pointer, []unsafe.Pointer, error) { - var freePointer []unsafe.Pointer - bindsCList := unsafe.Pointer(C.malloc(C.size_t(uintptr(count) * PointerSize))) - freePointer = append(freePointer, bindsCList) - // dataLength [count]uint32 - // length have checked in TaosStmt2BindBinary - baseLengthPointer := pointer.AddUintptr(dataP, uintptr(fieldsOffset)) - // dataBuffer - dataPointer := pointer.AddUintptr(baseLengthPointer, uintptr(count)*4) - var bindsPointer unsafe.Pointer - for tableIndex := uint32(0); tableIndex < count; tableIndex++ { - bindsPointer = pointer.AddUintptr(bindsCList, uintptr(tableIndex)*PointerSize) - binds := unsafe.Pointer(C.malloc(C.size_t(C.size_t(fieldCount) * C.size_t(unsafe.Sizeof(C.TAOS_STMT2_BIND{}))))) - freePointer = append(freePointer, binds) - var bindDataP unsafe.Pointer - var bindDataTotalLength uint32 - var num int32 - var haveLength byte - var bufferLength uint32 - for fieldIndex := uint32(0); fieldIndex < fieldCount; fieldIndex++ { - // field data - bindDataP = dataPointer - // totalLength - bindDataTotalLength = *(*uint32)(bindDataP) - bindDataP = pointer.AddUintptr(bindDataP, common.UInt32Size) - bind := (*C.TAOS_STMT2_BIND)(unsafe.Pointer(uintptr(binds) + uintptr(fieldIndex)*unsafe.Sizeof(C.TAOS_STMT2_BIND{}))) - // buffer_type - bind.buffer_type = *(*C.int)(bindDataP) - bindDataP = pointer.AddUintptr(bindDataP, common.Int32Size) - // num - num = *(*int32)(bindDataP) - bind.num = C.int(num) - bindDataP = pointer.AddUintptr(bindDataP, common.Int32Size) - // is_null - bind.is_null = (*C.char)(bindDataP) - bindDataP = pointer.AddUintptr(bindDataP, uintptr(num)) - // haveLength - haveLength = *(*byte)(bindDataP) - bindDataP = pointer.AddUintptr(bindDataP, common.Int8Size) - if haveLength == 0 { - bind.length = nil - } else { - // length [num]int32 - bind.length = (*C.int32_t)(bindDataP) - bindDataP = pointer.AddUintptr(bindDataP, common.Int32Size*uintptr(num)) - } - // bufferLength - bufferLength = *(*uint32)(bindDataP) - bindDataP = pointer.AddUintptr(bindDataP, common.UInt32Size) - // buffer - if bufferLength == 0 { - bind.buffer = nil - } else { - bind.buffer = bindDataP - } - bindDataP = pointer.AddUintptr(bindDataP, uintptr(bufferLength)) - // check bind data length - bindDataLen := uintptr(bindDataP) - uintptr(dataPointer) - if bindDataLen != uintptr(bindDataTotalLength) { - return nil, freePointer, fmt.Errorf("bind data length not match, expect %d, but get %d, tableIndex:%d", bindDataTotalLength, bindDataLen, tableIndex) - } - dataPointer = bindDataP - } - *(**C.TAOS_STMT2_BIND)(bindsPointer) = (*C.TAOS_STMT2_BIND)(binds) +// stringHeader instead of reflect.StringHeader +type stringHeader struct { + data unsafe.Pointer + len int +} + +// sliceHeader instead of reflect.SliceHeader +type sliceHeader struct { + data unsafe.Pointer + len int + cap int +} + +// ToUnsafeBytes converts s to a byte slice without memory allocations. +// +// The returned byte slice is valid only until s is reachable and unmodified. +func ToUnsafeBytes(s string) (b []byte) { + if len(s) == 0 { + return []byte{} } - return bindsCList, freePointer, nil + hdr := (*sliceHeader)(unsafe.Pointer(&b)) + hdr.data = (*stringHeader)(unsafe.Pointer(&s)).data + hdr.cap = len(s) + hdr.len = len(s) + return b } diff --git a/wrapper/stmt2_test.go b/wrapper/stmt2_test.go index 1500977..bf47c2b 100644 --- a/wrapper/stmt2_test.go +++ b/wrapper/stmt2_test.go @@ -1193,17 +1193,18 @@ func TestStmt2BindData(t *testing.T) { return } assert.True(t, isInsert) - code, count, cfields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_COL) + code, count, cfields := TaosStmt2GetFields(insertStmt) if code != 0 { errStr := TaosStmt2Error(insertStmt) err = taosError.NewError(code, errStr) t.Error(err) return } + defer TaosStmt2FreeFields(insertStmt, cfields) assert.Equal(t, 2, count) - fields := StmtParseFields(count, cfields) - err = TaosStmt2BindParam(insertStmt, true, tc.params, fields, nil, -1) + fields := Stmt2ParseAllFields(count, cfields) + err = TaosStmt2BindParam(insertStmt, true, tc.params, fields, -1) if err != nil { t.Error(err) return @@ -2367,7 +2368,7 @@ func TestStmt2BindBinary(t *testing.T) { return } assert.True(t, isInsert) - code, count, cfields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_COL) + code, count, cfields := TaosStmt2GetFields(insertStmt) if code != 0 { errStr := TaosStmt2Error(insertStmt) err = taosError.NewError(code, errStr) @@ -2376,8 +2377,8 @@ func TestStmt2BindBinary(t *testing.T) { } defer TaosStmt2FreeFields(insertStmt, cfields) assert.Equal(t, 2, count) - fields := StmtParseFields(count, cfields) - bs, err := stmt.MarshalStmt2Binary(tc.params, true, fields, nil) + fields := Stmt2ParseAllFields(count, cfields) + bs, err := stmt.MarshalStmt2Binary(tc.params, true, fields) if err != nil { t.Error("marshal binary error:", err) return @@ -2496,22 +2497,12 @@ func TestStmt2AllType(t *testing.T) { params := []*stmt.TaosStmt2BindData{{ TableName: "ctb1", }} - err = TaosStmt2BindParam(insertStmt, true, params, nil, nil, -1) + err = TaosStmt2BindParam(insertStmt, true, params, nil, -1) if err != nil { t.Error(err) return } - code, count, cTablefields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_TBNAME) - if code != 0 { - errStr := TaosStmt2Error(insertStmt) - err = taosError.NewError(code, errStr) - t.Error(err) - return - } - assert.Equal(t, 1, count) - assert.Equal(t, unsafe.Pointer(nil), cTablefields) - isInsert, code := TaosStmt2IsInsert(insertStmt) if code != 0 { errStr := TaosStmt2Error(insertStmt) @@ -2520,28 +2511,17 @@ func TestStmt2AllType(t *testing.T) { return } assert.True(t, isInsert) - code, count, cColFields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_COL) + code, count, cFields := TaosStmt2GetFields(insertStmt) if code != 0 { errStr := TaosStmt2Error(insertStmt) err = taosError.NewError(code, errStr) t.Error(err) return } - defer TaosStmt2FreeFields(insertStmt, cColFields) - assert.Equal(t, 16, count) - colFields := StmtParseFields(count, cColFields) - t.Log(colFields) - code, count, cTagfields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_TAG) - if code != 0 { - errStr := TaosStmt2Error(insertStmt) - err = taosError.NewError(code, errStr) - t.Error(err) - return - } - defer TaosStmt2FreeFields(insertStmt, cTagfields) - assert.Equal(t, 16, count) - tagFields := StmtParseFields(count, cTagfields) - t.Log(tagFields) + defer TaosStmt2FreeFields(insertStmt, cFields) + assert.Equal(t, 32, count) + fields := Stmt2ParseAllFields(count, cFields) + t.Log(fields) now := time.Now() //colTypes := []int8{ // common.TSDB_DATA_TYPE_TIMESTAMP, @@ -2681,7 +2661,7 @@ func TestStmt2AllType(t *testing.T) { }, }} - err = TaosStmt2BindParam(insertStmt, true, params2, colFields, tagFields, -1) + err = TaosStmt2BindParam(insertStmt, true, params2, fields, -1) if err != nil { t.Error(err) return @@ -2787,7 +2767,7 @@ func TestStmt2AllTypeBytes(t *testing.T) { params := []*stmt.TaosStmt2BindData{{ TableName: "ctb1", }} - bs, err := stmt.MarshalStmt2Binary(params, true, nil, nil) + bs, err := stmt.MarshalStmt2Binary(params, true, nil) if err != nil { t.Error(err) return @@ -2798,16 +2778,6 @@ func TestStmt2AllTypeBytes(t *testing.T) { return } - code, count, cTablefields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_TBNAME) - if code != 0 { - errStr := TaosStmt2Error(insertStmt) - err = taosError.NewError(code, errStr) - t.Error(err) - return - } - assert.Equal(t, 1, count) - assert.Equal(t, unsafe.Pointer(nil), cTablefields) - isInsert, code := TaosStmt2IsInsert(insertStmt) if code != 0 { errStr := TaosStmt2Error(insertStmt) @@ -2816,28 +2786,18 @@ func TestStmt2AllTypeBytes(t *testing.T) { return } assert.True(t, isInsert) - code, count, cColFields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_COL) - if code != 0 { - errStr := TaosStmt2Error(insertStmt) - err = taosError.NewError(code, errStr) - t.Error(err) - return - } - defer TaosStmt2FreeFields(insertStmt, cColFields) - assert.Equal(t, 16, count) - colFields := StmtParseFields(count, cColFields) - t.Log(colFields) - code, count, cTagfields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_TAG) + + code, count, cFields := TaosStmt2GetFields(insertStmt) if code != 0 { errStr := TaosStmt2Error(insertStmt) err = taosError.NewError(code, errStr) t.Error(err) return } - defer TaosStmt2FreeFields(insertStmt, cTagfields) - assert.Equal(t, 16, count) - tagFields := StmtParseFields(count, cTagfields) - t.Log(tagFields) + defer TaosStmt2FreeFields(insertStmt, cFields) + assert.Equal(t, 32, count) + fields := Stmt2ParseAllFields(count, cFields) + t.Log(fields) now := time.Now() //colTypes := []int8{ // common.TSDB_DATA_TYPE_TIMESTAMP, @@ -2976,7 +2936,7 @@ func TestStmt2AllTypeBytes(t *testing.T) { }, }, }} - bs, err = stmt.MarshalStmt2Binary(params2, true, colFields, tagFields) + bs, err = stmt.MarshalStmt2Binary(params2, true, fields) if err != nil { t.Error(err) return @@ -3060,13 +3020,15 @@ func TestStmt2Query(t *testing.T) { } assert.True(t, isInsert) now := time.Now().Round(time.Millisecond) - colTypes := []*stmt.StmtField{ + colTypes := []*stmt.Stmt2AllField{ { FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond, + BindType: stmt.TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_INT, + BindType: stmt.TAOS_FIELD_COL, }, } params := []*stmt.TaosStmt2BindData{ @@ -3097,7 +3059,7 @@ func TestStmt2Query(t *testing.T) { }, }, } - err = TaosStmt2BindParam(stmt2, true, params, colTypes, nil, -1) + err = TaosStmt2BindParam(stmt2, true, params, colTypes, -1) if err != nil { t.Error(err) return @@ -3145,7 +3107,7 @@ func TestStmt2Query(t *testing.T) { }, } - err = TaosStmt2BindParam(stmt2, false, params, nil, nil, -1) + err = TaosStmt2BindParam(stmt2, false, params, nil, -1) if err != nil { t.Error(err) return @@ -3254,13 +3216,15 @@ func TestStmt2QueryBytes(t *testing.T) { } assert.True(t, isInsert) now := time.Now().Round(time.Millisecond) - colTypes := []*stmt.StmtField{ + colTypes := []*stmt.Stmt2AllField{ { FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond, + BindType: stmt.TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_INT, + BindType: stmt.TAOS_FIELD_COL, }, } params := []*stmt.TaosStmt2BindData{ @@ -3291,7 +3255,7 @@ func TestStmt2QueryBytes(t *testing.T) { }, }, } - bs, err := stmt.MarshalStmt2Binary(params, true, colTypes, nil) + bs, err := stmt.MarshalStmt2Binary(params, true, colTypes) if err != nil { t.Error(err) return @@ -3343,7 +3307,7 @@ func TestStmt2QueryBytes(t *testing.T) { }, }, } - bs, err = stmt.MarshalStmt2Binary(params, false, nil, nil) + bs, err = stmt.MarshalStmt2Binary(params, false, nil) if err != nil { t.Error(err) return @@ -3458,23 +3422,23 @@ func TestStmt2QueryAllType(t *testing.T) { handler := cgo.NewHandle(caller) stmt2 := TaosStmt2Init(conn, 0xcc123, false, false, handler) prepareInsertSql := "insert into t values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" - colTypes := []*stmt.StmtField{ - {FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond}, - {FieldType: common.TSDB_DATA_TYPE_BOOL}, - {FieldType: common.TSDB_DATA_TYPE_TINYINT}, - {FieldType: common.TSDB_DATA_TYPE_SMALLINT}, - {FieldType: common.TSDB_DATA_TYPE_INT}, - {FieldType: common.TSDB_DATA_TYPE_BIGINT}, - {FieldType: common.TSDB_DATA_TYPE_UTINYINT}, - {FieldType: common.TSDB_DATA_TYPE_USMALLINT}, - {FieldType: common.TSDB_DATA_TYPE_UINT}, - {FieldType: common.TSDB_DATA_TYPE_UBIGINT}, - {FieldType: common.TSDB_DATA_TYPE_FLOAT}, - {FieldType: common.TSDB_DATA_TYPE_DOUBLE}, - {FieldType: common.TSDB_DATA_TYPE_BINARY}, - {FieldType: common.TSDB_DATA_TYPE_VARBINARY}, - {FieldType: common.TSDB_DATA_TYPE_GEOMETRY}, - {FieldType: common.TSDB_DATA_TYPE_NCHAR}, + colTypes := []*stmt.Stmt2AllField{ + {FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_BOOL, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_TINYINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_SMALLINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_INT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_BIGINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_UTINYINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_USMALLINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_UINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_UBIGINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_FLOAT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_DOUBLE, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_BINARY, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_VARBINARY, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_GEOMETRY, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_NCHAR, BindType: stmt.TAOS_FIELD_COL}, } now := time.Now() @@ -3578,7 +3542,7 @@ func TestStmt2QueryAllType(t *testing.T) { return } assert.True(t, isInsert) - err = TaosStmt2BindParam(stmt2, true, params2, colTypes, nil, -1) + err = TaosStmt2BindParam(stmt2, true, params2, colTypes, -1) if err != nil { t.Error(err) return @@ -3636,7 +3600,7 @@ func TestStmt2QueryAllType(t *testing.T) { }, }, } - err = TaosStmt2BindParam(stmt2, false, params, nil, nil, -1) + err = TaosStmt2BindParam(stmt2, false, params, nil, -1) if err != nil { t.Error(err) return @@ -3732,23 +3696,23 @@ func TestStmt2QueryAllTypeBytes(t *testing.T) { handler := cgo.NewHandle(caller) stmt2 := TaosStmt2Init(conn, 0xcc123, false, false, handler) prepareInsertSql := "insert into t values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" - colTypes := []*stmt.StmtField{ - {FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond}, - {FieldType: common.TSDB_DATA_TYPE_BOOL}, - {FieldType: common.TSDB_DATA_TYPE_TINYINT}, - {FieldType: common.TSDB_DATA_TYPE_SMALLINT}, - {FieldType: common.TSDB_DATA_TYPE_INT}, - {FieldType: common.TSDB_DATA_TYPE_BIGINT}, - {FieldType: common.TSDB_DATA_TYPE_UTINYINT}, - {FieldType: common.TSDB_DATA_TYPE_USMALLINT}, - {FieldType: common.TSDB_DATA_TYPE_UINT}, - {FieldType: common.TSDB_DATA_TYPE_UBIGINT}, - {FieldType: common.TSDB_DATA_TYPE_FLOAT}, - {FieldType: common.TSDB_DATA_TYPE_DOUBLE}, - {FieldType: common.TSDB_DATA_TYPE_BINARY}, - {FieldType: common.TSDB_DATA_TYPE_VARBINARY}, - {FieldType: common.TSDB_DATA_TYPE_GEOMETRY}, - {FieldType: common.TSDB_DATA_TYPE_NCHAR}, + colTypes := []*stmt.Stmt2AllField{ + {FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_BOOL, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_TINYINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_SMALLINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_INT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_BIGINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_UTINYINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_USMALLINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_UINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_UBIGINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_FLOAT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_DOUBLE, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_BINARY, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_VARBINARY, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_GEOMETRY, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_NCHAR, BindType: stmt.TAOS_FIELD_COL}, } now := time.Now() @@ -3852,7 +3816,7 @@ func TestStmt2QueryAllTypeBytes(t *testing.T) { return } assert.True(t, isInsert) - bs, err := stmt.MarshalStmt2Binary(params2, true, colTypes, nil) + bs, err := stmt.MarshalStmt2Binary(params2, true, colTypes) if err != nil { t.Error(err) return @@ -3915,7 +3879,7 @@ func TestStmt2QueryAllTypeBytes(t *testing.T) { }, }, } - bs, err = stmt.MarshalStmt2Binary(params, false, nil, nil) + bs, err = stmt.MarshalStmt2Binary(params, false, nil) if err != nil { t.Error(err) return @@ -4024,14 +3988,12 @@ func TestStmt2Json(t *testing.T) { {int32(1)}, }, }} - colTypes := []*stmt.StmtField{ - {FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond}, - {FieldType: common.TSDB_DATA_TYPE_INT}, - } - tagTypes := []*stmt.StmtField{ - {FieldType: common.TSDB_DATA_TYPE_JSON}, + types := []*stmt.Stmt2AllField{ + {FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_INT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_JSON, BindType: stmt.TAOS_FIELD_TAG}, } - err = TaosStmt2BindParam(stmt2, true, params, colTypes, tagTypes, -1) + err = TaosStmt2BindParam(stmt2, true, params, types, -1) if err != nil { t.Error(err) return @@ -4058,7 +4020,7 @@ func TestStmt2Json(t *testing.T) { {int32(1)}, }, }} - err = TaosStmt2BindParam(stmt2, false, params, nil, nil, -1) + err = TaosStmt2BindParam(stmt2, false, params, nil, -1) if err != nil { t.Error(err) return @@ -4185,21 +4147,21 @@ func TestStmt2BindMultiTables(t *testing.T) { Tags: []driver.Value{int32(3)}, }, } - colType := []*stmt.StmtField{ + fields := []*stmt.Stmt2AllField{ { FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond, + BindType: stmt.TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_BIGINT, + BindType: stmt.TAOS_FIELD_COL, }, - } - tagType := []*stmt.StmtField{ { FieldType: common.TSDB_DATA_TYPE_INT, + BindType: stmt.TAOS_FIELD_TAG, }, } - isInsert, code := TaosStmt2IsInsert(insertStmt) if code != 0 { errStr := TaosStmt2Error(insertStmt) @@ -4209,7 +4171,7 @@ func TestStmt2BindMultiTables(t *testing.T) { } assert.True(t, isInsert) - err = TaosStmt2BindParam(insertStmt, true, binds, colType, tagType, -1) + err = TaosStmt2BindParam(insertStmt, true, binds, fields, -1) if err != nil { t.Error(err) return @@ -4288,6 +4250,18 @@ func TestTaosStmt2BindBinaryParse(t *testing.T) { args args wantErr assert.ErrorAssertionFunc }{ + { + name: "wrong data length", + args: args{ + sql: "insert into ? values (?,?)", + data: []byte{ + // total Length + 0x00, 0x00, 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, { name: "normal table name", args: args{ @@ -4901,40 +4875,6 @@ func TestTaosStmt2BindBinaryParse(t *testing.T) { }, wantErr: assert.Error, }, - { - name: "wrong param count", - args: args{ - sql: "insert into test1 values (?,?)", - data: []byte{ - // total Length - 0x3A, 0x00, 0x00, 0x00, - // tableCount - 0x01, 0x00, 0x00, 0x00, - // TagCount - 0x00, 0x00, 0x00, 0x00, - // ColCount - 0x01, 0x00, 0x00, 0x00, - // TableNamesOffset - 0x00, 0x00, 0x00, 0x00, - // TagsOffset - 0x00, 0x00, 0x00, 0x00, - // ColOffset - 0x1c, 0x00, 0x00, 0x00, - // cols - 0x1a, 0x00, 0x00, 0x00, - - 0x1a, 0x00, 0x00, 0x00, - 0x09, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, - 0x00, - 0x00, - 0x08, 0x00, 0x00, 0x00, - 0xba, 0x08, 0x32, 0x27, 0x92, 0x01, 0x00, 0x00, - }, - colIdx: -1, - }, - wantErr: assert.Error, - }, { name: "bind binary", args: args{ @@ -5074,3 +5014,836 @@ func TestTaosStmt2BindBinaryParse(t *testing.T) { }) } } + +func TestTaosStmt2GetStbFields(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt2_stb_fields") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database test_stmt2_stb_fields precision 'ns'") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt2_stb_fields") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table if not exists all_stb("+ + "ts timestamp, "+ + "v1 bool, "+ + "v2 tinyint, "+ + "v3 smallint, "+ + "v4 int, "+ + "v5 bigint, "+ + "v6 tinyint unsigned, "+ + "v7 smallint unsigned, "+ + "v8 int unsigned, "+ + "v9 bigint unsigned, "+ + "v10 float, "+ + "v11 double, "+ + "v12 binary(20), "+ + "v13 varbinary(20), "+ + "v14 geometry(100), "+ + "v15 nchar(20))"+ + "tags("+ + "tts timestamp, "+ + "tv1 bool, "+ + "tv2 tinyint, "+ + "tv3 smallint, "+ + "tv4 int, "+ + "tv5 bigint, "+ + "tv6 tinyint unsigned, "+ + "tv7 smallint unsigned, "+ + "tv8 int unsigned, "+ + "tv9 bigint unsigned, "+ + "tv10 float, "+ + "tv11 double, "+ + "tv12 binary(20), "+ + "tv13 varbinary(20), "+ + "tv14 geometry(100), "+ + "tv15 nchar(20))") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table if not exists commontb("+ + "ts timestamp, "+ + "v1 bool, "+ + "v2 tinyint, "+ + "v3 smallint, "+ + "v4 int, "+ + "v5 bigint, "+ + "v6 tinyint unsigned, "+ + "v7 smallint unsigned, "+ + "v8 int unsigned, "+ + "v9 bigint unsigned, "+ + "v10 float, "+ + "v11 double, "+ + "v12 binary(20), "+ + "v13 varbinary(20), "+ + "v14 geometry(100), "+ + "v15 nchar(20))") + if err != nil { + t.Error(err) + return + } + expectMap := map[string]*stmt.Stmt2AllField{ + "tts": { + Name: "tts", + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + Precision: common.PrecisionNanoSecond, + Bytes: 8, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv1": { + Name: "tv1", + FieldType: common.TSDB_DATA_TYPE_BOOL, + Bytes: 1, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv2": { + Name: "tv2", + FieldType: common.TSDB_DATA_TYPE_TINYINT, + Bytes: 1, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv3": { + Name: "tv3", + FieldType: common.TSDB_DATA_TYPE_SMALLINT, + Bytes: 2, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv4": { + Name: "tv4", + FieldType: common.TSDB_DATA_TYPE_INT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv5": { + Name: "tv5", + FieldType: common.TSDB_DATA_TYPE_BIGINT, + Bytes: 8, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv6": { + Name: "tv6", + FieldType: common.TSDB_DATA_TYPE_UTINYINT, + Bytes: 1, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv7": { + Name: "tv7", + FieldType: common.TSDB_DATA_TYPE_USMALLINT, + Bytes: 2, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv8": { + Name: "tv8", + FieldType: common.TSDB_DATA_TYPE_UINT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv9": { + Name: "tv9", + FieldType: common.TSDB_DATA_TYPE_UBIGINT, + Bytes: 8, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv10": { + Name: "tv10", + FieldType: common.TSDB_DATA_TYPE_FLOAT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv11": { + Name: "tv11", + FieldType: common.TSDB_DATA_TYPE_DOUBLE, + Bytes: 8, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv12": { + Name: "tv12", + FieldType: common.TSDB_DATA_TYPE_BINARY, + Bytes: 22, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv13": { + Name: "tv13", + FieldType: common.TSDB_DATA_TYPE_VARBINARY, + Bytes: 22, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv14": { + Name: "tv14", + FieldType: common.TSDB_DATA_TYPE_GEOMETRY, + Bytes: 102, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv15": { + Name: "tv15", + FieldType: common.TSDB_DATA_TYPE_NCHAR, + Bytes: 82, + BindType: stmt.TAOS_FIELD_TAG, + }, + "ts": { + Name: "ts", + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + Precision: common.PrecisionNanoSecond, + Bytes: 8, + BindType: stmt.TAOS_FIELD_COL, + }, + "v1": { + Name: "v1", + FieldType: common.TSDB_DATA_TYPE_BOOL, + Bytes: 1, + BindType: stmt.TAOS_FIELD_COL, + }, + "v2": { + Name: "v2", + FieldType: common.TSDB_DATA_TYPE_TINYINT, + Bytes: 1, + BindType: stmt.TAOS_FIELD_COL, + }, + "v3": { + Name: "v3", + FieldType: common.TSDB_DATA_TYPE_SMALLINT, + Bytes: 2, + BindType: stmt.TAOS_FIELD_COL, + }, + "v4": { + Name: "v4", + FieldType: common.TSDB_DATA_TYPE_INT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_COL, + }, + "v5": { + Name: "v5", + FieldType: common.TSDB_DATA_TYPE_BIGINT, + Bytes: 8, + BindType: stmt.TAOS_FIELD_COL, + }, + "v6": { + Name: "v6", + FieldType: common.TSDB_DATA_TYPE_UTINYINT, + Bytes: 1, + BindType: stmt.TAOS_FIELD_COL, + }, + "v7": { + Name: "v7", + FieldType: common.TSDB_DATA_TYPE_USMALLINT, + Bytes: 2, + BindType: stmt.TAOS_FIELD_COL, + }, + "v8": { + Name: "v8", + FieldType: common.TSDB_DATA_TYPE_UINT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_COL, + }, + "v9": { + Name: "v9", + FieldType: common.TSDB_DATA_TYPE_UBIGINT, + Bytes: 8, + BindType: stmt.TAOS_FIELD_COL, + }, + "v10": { + Name: "v10", + FieldType: common.TSDB_DATA_TYPE_FLOAT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_COL, + }, + "v11": { + Name: "v11", + FieldType: common.TSDB_DATA_TYPE_DOUBLE, + Bytes: 8, + BindType: stmt.TAOS_FIELD_COL, + }, + "v12": { + Name: "v12", + FieldType: common.TSDB_DATA_TYPE_BINARY, + Bytes: 22, + BindType: stmt.TAOS_FIELD_COL, + }, + "v13": { + Name: "v13", + FieldType: common.TSDB_DATA_TYPE_VARBINARY, + Bytes: 22, + BindType: stmt.TAOS_FIELD_COL, + }, + "v14": { + Name: "v14", + FieldType: common.TSDB_DATA_TYPE_GEOMETRY, + Bytes: 102, + BindType: stmt.TAOS_FIELD_COL, + }, + "v15": { + Name: "v15", + FieldType: common.TSDB_DATA_TYPE_NCHAR, + Bytes: 82, + BindType: stmt.TAOS_FIELD_COL, + }, + "tbname": { + Name: "tbname", + FieldType: common.TSDB_DATA_TYPE_BINARY, + Bytes: 271, + BindType: stmt.TAOS_FIELD_TBNAME, + }, + } + tests := []struct { + name string + sql string + expect []string + }{ + { + name: "with subTableName", + sql: "insert into tb1 using all_stb tags(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) values(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + expect: []string{"tts", "tv1", "tv2", "tv3", "tv4", "tv5", "tv6", "tv7", "tv8", "tv9", "tv10", "tv11", "tv12", "tv13", "tv14", "tv15", "ts", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"}, + }, + { + name: "using stb", + sql: "insert into ? using all_stb tags(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) values(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + expect: []string{"tbname", "tts", "tv1", "tv2", "tv3", "tv4", "tv5", "tv6", "tv7", "tv8", "tv9", "tv10", "tv11", "tv12", "tv13", "tv14", "tv15", "ts", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"}, + }, + { + name: "tbname as value", + sql: "insert into all_stb (tbname,tts,tv1,tv2,tv3,tv4,tv5,tv6,tv7,tv8,tv9,tv10,tv11,tv12,tv13,tv14,tv15,ts,v1,v2,v3,v4,v5,v6,v7,v8,v9,v10,v11,v12,v13,v14,v15) values(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + expect: []string{"tbname", "tts", "tv1", "tv2", "tv3", "tv4", "tv5", "tv6", "tv7", "tv8", "tv9", "tv10", "tv11", "tv12", "tv13", "tv14", "tv15", "ts", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"}, + }, + { + name: "tbname as value random", + sql: "insert into all_stb (ts,v1,v2,v3,v4,v5,v6,tts,tv1,tv2,tv3,tv4,tv5,tv6,tv7,tv8,tv9,tv10,tv11,tv12,tv13,tv14,tbname,tv15,v7,v8,v9,v10,v11,v12,v13,v14,v15) values(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + expect: []string{"ts", "v1", "v2", "v3", "v4", "v5", "v6", "tts", "tv1", "tv2", "tv3", "tv4", "tv5", "tv6", "tv7", "tv8", "tv9", "tv10", "tv11", "tv12", "tv13", "tv14", "tbname", "tv15", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"}, + }, + { + name: "common table", + sql: "insert into commontb values(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + expect: []string{"ts", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"}, + }, + } + for _, tt := range tests { + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + stmt2 := TaosStmt2Init(conn, 0xed123, false, false, handler) + defer TaosStmt2Close(stmt2) + code := TaosStmt2Prepare(stmt2, tt.sql) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err := taosError.NewError(code, errStr) + t.Error(err) + return + } + code, count, fields := TaosStmt2GetFields(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err := taosError.NewError(code, errStr) + t.Error(err) + return + } + fs := Stmt2ParseAllFields(count, fields) + TaosStmt2FreeFields(stmt2, fields) + expect := make([]*stmt.Stmt2AllField, len(tt.expect)) + for i := 0; i < len(tt.expect); i++ { + assert.Equal(t, expectMap[tt.expect[i]].Name, fs[i].Name) + assert.Equal(t, expectMap[tt.expect[i]].FieldType, fs[i].FieldType) + assert.Equal(t, expectMap[tt.expect[i]].Bytes, fs[i].Bytes) + assert.Equal(t, expectMap[tt.expect[i]].BindType, fs[i].BindType) + if expectMap[tt.expect[i]].FieldType == common.TSDB_DATA_TYPE_TIMESTAMP { + assert.Equal(t, expectMap[tt.expect[i]].Precision, fs[i].Precision) + } + expect[i] = expectMap[tt.expect[i]] + } + } + + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + stmt2 := TaosStmt2Init(conn, 0xfd123, false, false, handler) + defer TaosStmt2Close(stmt2) + code := TaosStmt2Prepare(stmt2, "select * from commontb where ts = ? and v1 = ?") + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err := taosError.NewError(code, errStr) + t.Error(err) + return + } + code, count, fields := TaosStmt2GetFields(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err := taosError.NewError(code, errStr) + t.Error(err) + return + } + TaosStmt2FreeFields(stmt2, fields) + assert.Equal(t, 2, count) +} + +func TestWrongParseStmt2StbFields(t *testing.T) { + fs := Stmt2ParseAllFields(0, nil) + assert.Nil(t, fs) + fs = Stmt2ParseAllFields(2, nil) + assert.Nil(t, fs) +} + +func TestStmt2BindTbnameAsValue(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt2_bind_tbname_as_value") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt2_bind_tbname_as_value precision 'ns' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt2_bind_tbname_as_value") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table if not exists all_stb("+ + "ts timestamp, "+ + "v1 bool, "+ + "v2 tinyint, "+ + "v3 smallint, "+ + "v4 int, "+ + "v5 bigint, "+ + "v6 tinyint unsigned, "+ + "v7 smallint unsigned, "+ + "v8 int unsigned, "+ + "v9 bigint unsigned, "+ + "v10 float, "+ + "v11 double, "+ + "v12 binary(20), "+ + "v13 varbinary(20), "+ + "v14 geometry(100), "+ + "v15 nchar(20))"+ + "tags("+ + "tts timestamp, "+ + "tv1 bool, "+ + "tv2 tinyint, "+ + "tv3 smallint, "+ + "tv4 int, "+ + "tv5 bigint, "+ + "tv6 tinyint unsigned, "+ + "tv7 smallint unsigned, "+ + "tv8 int unsigned, "+ + "tv9 bigint unsigned, "+ + "tv10 float, "+ + "tv11 double, "+ + "tv12 binary(20), "+ + "tv13 varbinary(20), "+ + "tv14 geometry(100), "+ + "tv15 nchar(20))") + if err != nil { + t.Error(err) + return + } + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + insertStmt := TaosStmt2Init(conn, 0xff1234, false, false, handler) + prepareInsertSql := "insert into all_stb (ts ,v1 ,v2 ,v3 ,v4 ,v5 ,v6 ,v7 ,v8 ,v9 ,v10,v11,v12,v13,v14,v15,tbname,tts,tv1 ,tv2 ,tv3 ,tv4 ,tv5 ,tv6 ,tv7 ,tv8 ,tv9 ,tv10,tv11,tv12,tv13,tv14,tv15) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" + code := TaosStmt2Prepare(insertStmt, prepareInsertSql) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + + isInsert, code := TaosStmt2IsInsert(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.True(t, isInsert) + + code, count, cFields := TaosStmt2GetFields(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + defer TaosStmt2FreeFields(insertStmt, cFields) + assert.Equal(t, 33, count) + fields := Stmt2ParseAllFields(count, cFields) + assert.Equal(t, 33, len(fields)) + expectMap := map[string]*stmt.Stmt2AllField{ + "tts": { + Name: "tts", + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + Precision: common.PrecisionNanoSecond, + Bytes: 8, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv1": { + Name: "tv1", + FieldType: common.TSDB_DATA_TYPE_BOOL, + Bytes: 1, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv2": { + Name: "tv2", + FieldType: common.TSDB_DATA_TYPE_TINYINT, + Bytes: 1, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv3": { + Name: "tv3", + FieldType: common.TSDB_DATA_TYPE_SMALLINT, + Bytes: 2, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv4": { + Name: "tv4", + FieldType: common.TSDB_DATA_TYPE_INT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv5": { + Name: "tv5", + FieldType: common.TSDB_DATA_TYPE_BIGINT, + Bytes: 8, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv6": { + Name: "tv6", + FieldType: common.TSDB_DATA_TYPE_UTINYINT, + Bytes: 1, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv7": { + Name: "tv7", + FieldType: common.TSDB_DATA_TYPE_USMALLINT, + Bytes: 2, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv8": { + Name: "tv8", + FieldType: common.TSDB_DATA_TYPE_UINT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv9": { + Name: "tv9", + FieldType: common.TSDB_DATA_TYPE_UBIGINT, + Bytes: 8, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv10": { + Name: "tv10", + FieldType: common.TSDB_DATA_TYPE_FLOAT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv11": { + Name: "tv11", + FieldType: common.TSDB_DATA_TYPE_DOUBLE, + Bytes: 8, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv12": { + Name: "tv12", + FieldType: common.TSDB_DATA_TYPE_BINARY, + Bytes: 22, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv13": { + Name: "tv13", + FieldType: common.TSDB_DATA_TYPE_VARBINARY, + Bytes: 22, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv14": { + Name: "tv14", + FieldType: common.TSDB_DATA_TYPE_GEOMETRY, + Bytes: 102, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv15": { + Name: "tv15", + FieldType: common.TSDB_DATA_TYPE_NCHAR, + Bytes: 82, + BindType: stmt.TAOS_FIELD_TAG, + }, + "ts": { + Name: "ts", + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + Precision: common.PrecisionNanoSecond, + Bytes: 8, + BindType: stmt.TAOS_FIELD_COL, + }, + "v1": { + Name: "v1", + FieldType: common.TSDB_DATA_TYPE_BOOL, + Bytes: 1, + BindType: stmt.TAOS_FIELD_COL, + }, + "v2": { + Name: "v2", + FieldType: common.TSDB_DATA_TYPE_TINYINT, + Bytes: 1, + BindType: stmt.TAOS_FIELD_COL, + }, + "v3": { + Name: "v3", + FieldType: common.TSDB_DATA_TYPE_SMALLINT, + Bytes: 2, + BindType: stmt.TAOS_FIELD_COL, + }, + "v4": { + Name: "v4", + FieldType: common.TSDB_DATA_TYPE_INT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_COL, + }, + "v5": { + Name: "v5", + FieldType: common.TSDB_DATA_TYPE_BIGINT, + Bytes: 8, + BindType: stmt.TAOS_FIELD_COL, + }, + "v6": { + Name: "v6", + FieldType: common.TSDB_DATA_TYPE_UTINYINT, + Bytes: 1, + BindType: stmt.TAOS_FIELD_COL, + }, + "v7": { + Name: "v7", + FieldType: common.TSDB_DATA_TYPE_USMALLINT, + Bytes: 2, + BindType: stmt.TAOS_FIELD_COL, + }, + "v8": { + Name: "v8", + FieldType: common.TSDB_DATA_TYPE_UINT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_COL, + }, + "v9": { + Name: "v9", + FieldType: common.TSDB_DATA_TYPE_UBIGINT, + Bytes: 8, + BindType: stmt.TAOS_FIELD_COL, + }, + "v10": { + Name: "v10", + FieldType: common.TSDB_DATA_TYPE_FLOAT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_COL, + }, + "v11": { + Name: "v11", + FieldType: common.TSDB_DATA_TYPE_DOUBLE, + Bytes: 8, + BindType: stmt.TAOS_FIELD_COL, + }, + "v12": { + Name: "v12", + FieldType: common.TSDB_DATA_TYPE_BINARY, + Bytes: 22, + BindType: stmt.TAOS_FIELD_COL, + }, + "v13": { + Name: "v13", + FieldType: common.TSDB_DATA_TYPE_VARBINARY, + Bytes: 22, + BindType: stmt.TAOS_FIELD_COL, + }, + "v14": { + Name: "v14", + FieldType: common.TSDB_DATA_TYPE_GEOMETRY, + Bytes: 102, + BindType: stmt.TAOS_FIELD_COL, + }, + "v15": { + Name: "v15", + FieldType: common.TSDB_DATA_TYPE_NCHAR, + Bytes: 82, + BindType: stmt.TAOS_FIELD_COL, + }, + "tbname": { + Name: "tbname", + FieldType: common.TSDB_DATA_TYPE_BINARY, + Bytes: 271, + BindType: stmt.TAOS_FIELD_TBNAME, + }, + } + + for i := 0; i < 33; i++ { + expect := expectMap[fields[i].Name] + assert.Equal(t, expect, fields[i]) + } + + now := time.Now() + params2 := []*stmt.TaosStmt2BindData{{ + TableName: "ctb1", + Tags: []driver.Value{ + // TIMESTAMP + now, + // BOOL + true, + // TINYINT + int8(1), + // SMALLINT + int16(1), + // INT + int32(1), + // BIGINT + int64(1), + // UTINYINT + uint8(1), + // USMALLINT + uint16(1), + // UINT + uint32(1), + // UBIGINT + uint64(1), + // FLOAT + float32(1.2), + // DOUBLE + float64(1.2), + // BINARY + []byte("binary"), + // VARBINARY + []byte("varbinary"), + // GEOMETRY + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + // NCHAR + "nchar", + }, + Cols: [][]driver.Value{ + { + now, + now.Add(time.Second), + now.Add(time.Second * 2), + }, + { + true, + nil, + false, + }, + { + int8(11), + nil, + int8(12), + }, + { + int16(11), + nil, + int16(12), + }, + { + int32(11), + nil, + int32(12), + }, + { + int64(11), + nil, + int64(12), + }, + { + uint8(11), + nil, + uint8(12), + }, + { + uint16(11), + nil, + uint16(12), + }, + { + uint32(11), + nil, + uint32(12), + }, + { + uint64(11), + nil, + uint64(12), + }, + { + float32(11.2), + nil, + float32(12.2), + }, + { + float64(11.2), + nil, + float64(12.2), + }, + { + []byte("binary1"), + nil, + []byte("binary2"), + }, + { + []byte("varbinary1"), + nil, + []byte("varbinary2"), + }, + { + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + nil, + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + }, + { + "nchar1", + nil, + "nchar2", + }, + }, + }} + bs, err := stmt.MarshalStmt2Binary(params2, true, fields) + assert.NoError(t, err) + err = TaosStmt2BindBinary(insertStmt, bs, -1) + if err != nil { + t.Error(err) + return + } + code = TaosStmt2Exec(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r := <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + assert.Equal(t, 3, r.affected) + + code = TaosStmt2Close(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } +} diff --git a/wrapper/stmt2binary.go b/wrapper/stmt2binary.go new file mode 100644 index 0000000..8e7115f --- /dev/null +++ b/wrapper/stmt2binary.go @@ -0,0 +1,272 @@ +package wrapper + +/* +#include +#include +#include +#include + +int +go_generate_stmt2_binds(char *data, uint32_t count, uint32_t field_count, uint32_t field_offset, + TAOS_STMT2_BIND *bind_struct, + TAOS_STMT2_BIND **bind_ptr, char *err_msg) { + uint32_t *base_length = (uint32_t *) (data + field_offset); + char *data_ptr = (char *) (base_length + count); + for (int table_index = 0; table_index < count; table_index++) { + bind_ptr[table_index] = bind_struct + table_index * field_count; + char *bind_data_ptr; + for (uint32_t field_index = 0; field_index < field_count; field_index++) { + bind_data_ptr = data_ptr; + TAOS_STMT2_BIND *bind = bind_ptr[table_index] + field_index; + // total length + uint32_t bind_data_totalLength = *(uint32_t *) bind_data_ptr; + bind_data_ptr += 4; + // buffer_type + bind->buffer_type = *(int *) bind_data_ptr; + bind_data_ptr += 4; + // num + bind->num = *(int *) bind_data_ptr; + bind_data_ptr += 4; + // is_null + bind->is_null = (char *) bind_data_ptr; + bind_data_ptr += bind->num; + // have_length + char have_length = *(char *) bind_data_ptr; + bind_data_ptr += 1; + if (have_length == 0) { + bind->length = NULL; + } else { + bind->length = (int32_t *) bind_data_ptr; + bind_data_ptr += bind->num * 4; + } + // buffer_length + int32_t buffer_length = *(int32_t *) bind_data_ptr; + bind_data_ptr += 4; + // buffer + if (buffer_length > 0) { + bind->buffer = (void *) bind_data_ptr; + bind_data_ptr += buffer_length; + } else { + bind->buffer = NULL; + } + // check bind data length + if (bind_data_ptr - data_ptr != bind_data_totalLength) { + snprintf(err_msg, 128, "bind data length error, tableIndex: %d, fieldIndex: %d", table_index, field_index); + return -1; + } + data_ptr = bind_data_ptr; + } + } + return 0; +} + + +int go_stmt2_bind_binary(TAOS_STMT2 *stmt, char *data, int32_t col_idx, char *err_msg) { + uint32_t *header = (uint32_t *) data; + uint32_t total_length = header[0]; + uint32_t count = header[1]; + uint32_t tag_count = header[2]; + uint32_t col_count = header[3]; + uint32_t table_names_offset = header[4]; + uint32_t tags_offset = header[5]; + uint32_t cols_offset = header[6]; + // check table names + if (table_names_offset > 0) { + uint32_t table_name_end = table_names_offset + count * 2; + if (table_name_end > total_length) { + snprintf(err_msg, 128, "table name lengths out of range, total length: %d, tableNamesLengthEnd: %d", total_length, + table_name_end); + return -1; + } + uint16_t *table_name_length_ptr = (uint16_t *) (data + table_names_offset); + for (int32_t i = 0; i < count; ++i) { + if (table_name_length_ptr[i] == 0) { + snprintf(err_msg, 128, "table name length is 0, tableIndex: %d", i); + return -1; + } + table_name_end += (uint32_t) table_name_length_ptr[i]; + } + if (table_name_end > total_length) { + snprintf(err_msg, 128, "table names out of range, total length: %d, tableNameTotalLength: %d", total_length, + table_name_end); + return -1; + } + } + // check tags + if (tags_offset > 0) { + if (tag_count == 0) { + snprintf(err_msg, 128, "tag count is 0, but tags offset is not 0"); + return -1; + } + uint32_t tag_end = tags_offset + count * 4; + if (tag_end > total_length) { + snprintf(err_msg, 128, "tags out of range, total length: %d, tagEnd: %d", total_length, tag_end); + return -1; + } + uint32_t *tab_length_ptr = (uint32_t *) (data + tags_offset); + for (int32_t i = 0; i < count; ++i) { + if (tab_length_ptr[i] == 0) { + snprintf(err_msg, 128, "tag length is 0, tableIndex: %d", i); + return -1; + } + tag_end += tab_length_ptr[i]; + } + if (tag_end > total_length) { + snprintf(err_msg, 128, "tags out of range, total length: %d, tagsTotalLength: %d", total_length, tag_end); + return -1; + } + } + // check cols + if (cols_offset > 0) { + if (col_count == 0) { + snprintf(err_msg, 128, "col count is 0, but cols offset is not 0"); + return -1; + } + uint32_t colEnd = cols_offset + count * 4; + if (colEnd > total_length) { + snprintf(err_msg, 128, "cols out of range, total length: %d, colEnd: %d", total_length, colEnd); + return -1; + } + uint32_t *col_length_ptr = (uint32_t *) (data + cols_offset); + for (int32_t i = 0; i < count; ++i) { + if (col_length_ptr[i] == 0) { + snprintf(err_msg, 128, "col length is 0, tableIndex: %d", i); + return -1; + } + colEnd += col_length_ptr[i]; + } + if (colEnd > total_length) { + snprintf(err_msg, 128, "cols out of range, total length: %d, colsTotalLength: %d", total_length, colEnd); + return -1; + } + } + // generate bindv struct + TAOS_STMT2_BINDV bind_v; + bind_v.count = (int) count; + if (table_names_offset > 0) { + uint16_t *table_name_length_ptr = (uint16_t *) (data + table_names_offset); + char *table_name_data_ptr = (char *) (table_name_length_ptr) + 2 * count; + char **table_name = (char **) malloc(sizeof(char *) * count); + if (table_name == NULL) { + snprintf(err_msg, 128, "malloc tableName error"); + return -1; + } + for (int i = 0; i < count; i++) { + table_name[i] = table_name_data_ptr; + table_name_data_ptr += table_name_length_ptr[i]; + } + bind_v.tbnames = table_name; + } else { + bind_v.tbnames = NULL; + } + uint32_t bind_struct_count = 0; + uint32_t bind_ptr_count = 0; + if (tags_offset == 0) { + bind_v.tags = NULL; + } else { + bind_struct_count += count * tag_count; + bind_ptr_count += count; + } + if (cols_offset == 0) { + bind_v.bind_cols = NULL; + } else { + bind_struct_count += count * col_count; + bind_ptr_count += count; + } + TAOS_STMT2_BIND *bind_struct = NULL; + TAOS_STMT2_BIND **bind_ptr = NULL; + if (bind_struct_count == 0) { + bind_v.tags = NULL; + bind_v.bind_cols = NULL; + } else { + // []TAOS_STMT2_BIND bindStruct + bind_struct = (TAOS_STMT2_BIND *) malloc(sizeof(TAOS_STMT2_BIND) * bind_struct_count); + if (bind_struct == NULL) { + snprintf(err_msg, 128, "malloc bind struct error"); + free(bind_v.tbnames); + return -1; + } + // []TAOS_STMT2_BIND *bindPtr + bind_ptr = (TAOS_STMT2_BIND **) malloc(sizeof(TAOS_STMT2_BIND *) * bind_ptr_count); + if (bind_ptr == NULL) { + snprintf(err_msg, 128, "malloc bind pointer error"); + free(bind_struct); + free(bind_v.tbnames); + return -1; + } + uint32_t struct_index = 0; + uint32_t ptr_index = 0; + if (tags_offset > 0) { + int code = go_generate_stmt2_binds(data, count, tag_count, tags_offset, bind_struct, bind_ptr, err_msg); + if (code != 0) { + free(bind_struct); + free(bind_ptr); + free(bind_v.tbnames); + return code; + } + bind_v.tags = bind_ptr; + struct_index += count * tag_count; + ptr_index += count; + } + if (cols_offset > 0) { + TAOS_STMT2_BIND *col_bind_struct = bind_struct + struct_index; + TAOS_STMT2_BIND **col_bind_ptr = bind_ptr + ptr_index; + int code = go_generate_stmt2_binds(data, count, col_count, cols_offset, col_bind_struct, col_bind_ptr, + err_msg); + if (code != 0) { + free(bind_struct); + free(bind_ptr); + free(bind_v.tbnames); + return code; + } + bind_v.bind_cols = col_bind_ptr; + } + } + int code = taos_stmt2_bind_param(stmt, &bind_v, col_idx); + if (code != 0) { + char *msg = taos_stmt2_error(stmt); + snprintf(err_msg, 128, "%s", msg); + } + if (bind_v.tbnames != NULL) { + free(bind_v.tbnames); + } + if (bind_struct != NULL) { + free(bind_struct); + } + if (bind_ptr != NULL) { + free(bind_ptr); + } + return code; +} +*/ +import "C" +import ( + "encoding/binary" + "fmt" + "unsafe" + + "github.com/taosdata/driver-go/v3/common/stmt" + taosError "github.com/taosdata/driver-go/v3/errors" +) + +// TaosStmt2BindBinary bind binary data to stmt2 +func TaosStmt2BindBinary(stmt2 unsafe.Pointer, data []byte, colIdx int32) error { + if len(data) < stmt.DataPosition { + return fmt.Errorf("data length is less than 28") + } + totalLength := binary.LittleEndian.Uint32(data[stmt.TotalLengthPosition:]) + if totalLength != uint32(len(data)) { + return fmt.Errorf("total length not match, expect %d, but get %d", len(data), totalLength) + } + dataP := C.CBytes(data) + defer C.free(dataP) + errMsg := (*C.char)(C.malloc(128)) + defer C.free(unsafe.Pointer(errMsg)) + + code := C.go_stmt2_bind_binary(stmt2, (*C.char)(dataP), C.int32_t(colIdx), errMsg) + if code != 0 { + msg := C.GoString(errMsg) + return taosError.NewError(int(code), msg) + } + return nil +} diff --git a/wrapper/stmt_test.go b/wrapper/stmt_test.go index 557b10d..8ce160a 100644 --- a/wrapper/stmt_test.go +++ b/wrapper/stmt_test.go @@ -898,7 +898,7 @@ func TestGetFieldsCommonTable(t *testing.T) { return } code, num, _ := TaosStmtGetTagFields(stmt) - assert.Equal(t, 0, code) + assert.NotEqual(t, 0, code) assert.Equal(t, 0, num) code, columnCount, columnsP := TaosStmtGetColFields(stmt) if code != 0 { diff --git a/wrapper/whitelistcb_test.go b/wrapper/whitelistcb_test.go index 86d01cd..9afdb0b 100644 --- a/wrapper/whitelistcb_test.go +++ b/wrapper/whitelistcb_test.go @@ -28,6 +28,7 @@ func TestWhitelistCallback_Success(t *testing.T) { 192, 168, 1, 1, 24, // 192.168.1.1/24 0, 0, 0, 10, 0, 0, 1, 16, // 10.0.0.1/16 + 0, 0, 0, } // Create a channel to receive the result From ab39d6aef8aefdc71fbe625cd117f23b7a2809f2 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 17 Dec 2024 19:01:29 +0800 Subject: [PATCH 29/35] fix: stmt bind test error --- af/conn_test.go | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/af/conn_test.go b/af/conn_test.go index e3a9ccf..c8202f9 100644 --- a/af/conn_test.go +++ b/af/conn_test.go @@ -581,8 +581,22 @@ func TestFastInsertWithSetSubTableName(t *testing.T) { params []*param2.Param bindType *param2.ColumnType }{ - {"set_table_name_sub_int", "1,'int'", "ts timestamp, `value` int", "?, ?", []*param2.Param{param2.NewParam(1).AddTimestamp(now, common.PrecisionMicroSecond), param2.NewParam(1).AddInt(1)}, param2.NewColumnType(2).AddTimestamp().AddInt()}, - {"set_table_name_sub_nchar", "2,'nchar'", "ts timestamp, `value` nchar(8)", "?, ?", []*param2.Param{param2.NewParam(1).AddTimestamp(time.Now(), common.PrecisionMicroSecond), param2.NewParam(1).AddNchar("ttt")}, param2.NewColumnType(2).AddTimestamp().AddNchar(1)}, + { + sTableName: "set_table_name_sub_int", + tags: "1,'int'", + tbType: "ts timestamp, `value` int", + pos: "?, ?", + params: []*param2.Param{param2.NewParam(1).AddTimestamp(now, common.PrecisionMicroSecond), param2.NewParam(1).AddInt(1)}, + bindType: param2.NewColumnType(2).AddTimestamp().AddInt(), + }, + { + sTableName: "set_table_name_sub_nchar", + tags: "2,'nchar'", + tbType: "ts timestamp, `value` nchar(8)", + pos: "?, ?", + params: []*param2.Param{param2.NewParam(1).AddTimestamp(time.Now(), common.PrecisionMicroSecond), param2.NewParam(1).AddNchar("ttt")}, + bindType: param2.NewColumnType(2).AddTimestamp().AddNchar(5), + }, } { tbName := fmt.Sprintf("test_fast_insert_with_sub_table_name_%02d", i) tbType := tc.tbType From 87b58ebe3fc36ec89d81cec16d63d3fb242969b7 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 17 Dec 2024 19:47:27 +0800 Subject: [PATCH 30/35] test: add unit test for stmt2 --- af/stmt2_test.go | 46 +++++++++++++++++++++++++++++++++++++++++++ wrapper/stmt2_test.go | 12 +++++++++++ 2 files changed, 58 insertions(+) diff --git a/af/stmt2_test.go b/af/stmt2_test.go index f5d6ad8..1c327a4 100644 --- a/af/stmt2_test.go +++ b/af/stmt2_test.go @@ -283,3 +283,49 @@ func TestStmt2(t *testing.T) { assert.ErrorIs(t, err, io.EOF) } + +func TestStmt2_Prepare(t *testing.T) { + conn, err := Open("", "root", "taosdata", "", 0) + if !assert.NoError(t, err) { + return + } + stmt2 := conn.Stmt2(0x123456789, false) + if stmt2 == nil { + t.Errorf("Expected stmt to be not nil") + return + } + defer func() { + err = stmt2.Close() + assert.NoError(t, err) + }() + _, err = conn.Exec("create database if not exists stmt2_prepare_wrong_test") + if !assert.NoError(t, err) { + return + } + defer func() { + _, err = conn.Exec("drop database if exists stmt2_prepare_wrong_test") + assert.NoError(t, err) + }() + _, err = conn.Exec("use stmt2_prepare_wrong_test") + if !assert.NoError(t, err) { + return + } + err = stmt2.Prepare("insert into not_exist_table values(?,?,?)") + assert.Error(t, err) + _, err = conn.Exec("create table t (ts timestamp, b int, c int)") + assert.NoError(t, err) + err = stmt2.Prepare("") + assert.NoError(t, err) + err = stmt2.Bind([]*stmt.TaosStmt2BindData{ + { + Cols: [][]driver.Value{ + {time.Now()}, + {int32(1)}, + {int32(2)}, + }, + }, + }) + assert.Error(t, err) + err = stmt2.Prepare("insert into t values(?,?,?)") + assert.Error(t, err) +} diff --git a/wrapper/stmt2_test.go b/wrapper/stmt2_test.go index bf47c2b..6e8b774 100644 --- a/wrapper/stmt2_test.go +++ b/wrapper/stmt2_test.go @@ -1,6 +1,7 @@ package wrapper import ( + "bytes" "database/sql/driver" "fmt" "testing" @@ -5847,3 +5848,14 @@ func TestStmt2BindTbnameAsValue(t *testing.T) { return } } + +func TestToUnsafeBytes(t *testing.T) { + s := "str" + if !bytes.Equal([]byte("str"), ToUnsafeBytes(s)) { + t.Fatalf(`[]bytes(%s) doesnt equal to %s `, s, s) + } + s = "" + if !bytes.Equal([]byte(""), ToUnsafeBytes(s)) { + t.Fatalf(`[]bytes(%s) doesnt equal to %s `, s, s) + } +} From 2cd6f2c4a9376dfb365f07446b4d7e7a0afc8549 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Wed, 18 Dec 2024 10:15:48 +0800 Subject: [PATCH 31/35] test: add unit test for stmt2 --- wrapper/stmt2_test.go | 60 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/wrapper/stmt2_test.go b/wrapper/stmt2_test.go index 6e8b774..fefb69f 100644 --- a/wrapper/stmt2_test.go +++ b/wrapper/stmt2_test.go @@ -5849,6 +5849,66 @@ func TestStmt2BindTbnameAsValue(t *testing.T) { } } +func TestStmt2BindError(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt2_bind_error") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt2_bind_error precision 'ns' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt2_bind_error") + if err != nil { + t.Error(err) + return + } + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + stmt2 := TaosStmt2Init(conn, 0xff1234, false, false, handler) + defer func() { + code := TaosStmt2Close(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + }() + fields := []*stmt.Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + BindType: stmt.TAOS_FIELD_COL, + Precision: TSDB_SML_TIMESTAMP_NANO_SECONDS, + }, + { + FieldType: common.TSDB_DATA_TYPE_INT, + BindType: stmt.TAOS_FIELD_COL, + }, + } + params := []*stmt.TaosStmt2BindData{ + { + Cols: [][]driver.Value{ + {time.Now()}, + {int32(1)}, + }, + }, + } + // without prepare + err = TaosStmt2BindParam(stmt2, false, params, fields, -1) + assert.Error(t, err) +} + func TestToUnsafeBytes(t *testing.T) { s := "str" if !bytes.Equal([]byte("str"), ToUnsafeBytes(s)) { From cac65df071d26f11a90f31f330dc576f778a890e Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Mon, 23 Dec 2024 09:30:20 +0800 Subject: [PATCH 32/35] enh: support special characters --- taosRestful/connection.go | 32 ++++++------ taosRestful/connector.go | 22 ++++----- taosRestful/connector_test.go | 2 +- taosRestful/driver.go | 18 ++++++- taosRestful/driver_test.go | 15 ++++++ taosRestful/dsn.go | 93 +++++++++++++++++++---------------- taosRestful/dsn_test.go | 82 +++++++++++++++++++++--------- taosSql/connection.go | 6 +-- taosSql/connector.go | 35 +++++++++---- taosSql/driver.go | 34 +++++++------ taosSql/driver_test.go | 15 ++++++ taosSql/dsn.go | 91 ++++++++++++++++++---------------- taosSql/dsn_test.go | 83 ++++++++++++++++++++----------- taosWS/connection.go | 28 +++++------ taosWS/connection_test.go | 8 +-- taosWS/connector.go | 30 +++++------ taosWS/connector_test.go | 2 +- taosWS/driver.go | 18 ++++++- taosWS/driver_test.go | 15 ++++++ taosWS/dsn.go | 91 ++++++++++++++++++---------------- taosWS/dsn_test.go | 82 ++++++++++++++++++++++-------- 21 files changed, 513 insertions(+), 289 deletions(-) diff --git a/taosRestful/connection.go b/taosRestful/connection.go index 2dd5240..d8a12a8 100644 --- a/taosRestful/connection.go +++ b/taosRestful/connection.go @@ -24,7 +24,7 @@ import ( var jsonI = jsoniter.ConfigCompatibleWithStandardLibrary type taosConn struct { - cfg *config + cfg *Config client *http.Client url *url.URL baseRawQuery string @@ -32,8 +32,8 @@ type taosConn struct { readBufferSize int } -func newTaosConn(cfg *config) (*taosConn, error) { - readBufferSize := cfg.readBufferSize +func newTaosConn(cfg *Config) (*taosConn, error) { + readBufferSize := cfg.ReadBufferSize if readBufferSize <= 0 { readBufferSize = 4 << 10 } @@ -47,9 +47,9 @@ func newTaosConn(cfg *config) (*taosConn, error) { IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, - DisableCompression: cfg.disableCompression, + DisableCompression: cfg.DisableCompression, } - if cfg.skipVerify { + if cfg.SkipVerify { transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, } @@ -58,24 +58,24 @@ func newTaosConn(cfg *config) (*taosConn, error) { Transport: transport, } path := "/rest/sql" - if len(cfg.dbName) != 0 { - path = fmt.Sprintf("%s/%s", path, cfg.dbName) + if len(cfg.DbName) != 0 { + path = fmt.Sprintf("%s/%s", path, cfg.DbName) } tc.url = &url.URL{ - Scheme: cfg.net, - Host: fmt.Sprintf("%s:%d", cfg.addr, cfg.port), + Scheme: cfg.Net, + Host: fmt.Sprintf("%s:%d", cfg.Addr, cfg.Port), Path: path, } tc.header = map[string][]string{ "Connection": {"keep-alive"}, } - if cfg.token != "" { - tc.baseRawQuery = fmt.Sprintf("token=%s", cfg.token) + if cfg.Token != "" { + tc.baseRawQuery = fmt.Sprintf("token=%s", cfg.Token) } else { - basic := base64.StdEncoding.EncodeToString([]byte(cfg.user + ":" + cfg.passwd)) + basic := base64.StdEncoding.EncodeToString([]byte(cfg.User + ":" + cfg.Passwd)) tc.header["Authorization"] = []string{fmt.Sprintf("Basic %s", basic)} } - if !cfg.disableCompression { + if !cfg.DisableCompression { tc.header["Accept-Encoding"] = []string{"gzip"} } return tc, nil @@ -103,7 +103,7 @@ func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { if len(args) != 0 { - if !tc.cfg.interpolateParams { + if !tc.cfg.InterpolateParams { return nil, driver.ErrSkip } // try to interpolate the parameters to save extra round trips for preparing and closing a statement @@ -129,7 +129,7 @@ func (tc *taosConn) QueryContext(ctx context.Context, query string, args []drive func (tc *taosConn) queryCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { if len(args) != 0 { - if !tc.cfg.interpolateParams { + if !tc.cfg.InterpolateParams { return nil, driver.ErrSkip } // try client-side prepare to reduce round trip @@ -202,7 +202,7 @@ func (tc *taosConn) taosQuery(ctx context.Context, sql string, bufferSize int) ( } respBody := resp.Body defer ioutil.ReadAll(respBody) - if !tc.cfg.disableCompression && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { + if !tc.cfg.DisableCompression && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { respBody, err = gzip.NewReader(resp.Body) if err != nil { return nil, err diff --git a/taosRestful/connector.go b/taosRestful/connector.go index 3ae7396..0919ee4 100644 --- a/taosRestful/connector.go +++ b/taosRestful/connector.go @@ -8,27 +8,27 @@ import ( ) type connector struct { - cfg *config + cfg *Config } // Connect implements driver.Connector interface. // Connect returns a connection to the database. func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { // Connect to Server - if len(c.cfg.user) == 0 { - c.cfg.user = common.DefaultUser + if len(c.cfg.User) == 0 { + c.cfg.User = common.DefaultUser } - if len(c.cfg.passwd) == 0 { - c.cfg.passwd = common.DefaultPassword + if len(c.cfg.Passwd) == 0 { + c.cfg.Passwd = common.DefaultPassword } - if c.cfg.port == 0 { - c.cfg.port = common.DefaultHttpPort + if c.cfg.Port == 0 { + c.cfg.Port = common.DefaultHttpPort } - if len(c.cfg.net) == 0 { - c.cfg.net = "http" + if len(c.cfg.Net) == 0 { + c.cfg.Net = "http" } - if len(c.cfg.addr) == 0 { - c.cfg.addr = "127.0.0.1" + if len(c.cfg.Addr) == 0 { + c.cfg.Addr = "127.0.0.1" } tc, err := newTaosConn(c.cfg) return tc, err diff --git a/taosRestful/connector_test.go b/taosRestful/connector_test.go index 845179f..aca016c 100644 --- a/taosRestful/connector_test.go +++ b/taosRestful/connector_test.go @@ -533,7 +533,7 @@ func TestSSL(t *testing.T) { func TestConnect(t *testing.T) { conn := connector{ - cfg: &config{}, + cfg: &Config{}, } db, err := conn.Connect(context.Background()) assert.NoError(t, err) diff --git a/taosRestful/driver.go b/taosRestful/driver.go index 54313f3..3ca3f8f 100644 --- a/taosRestful/driver.go +++ b/taosRestful/driver.go @@ -13,7 +13,7 @@ type TDengineDriver struct{} // Open new Connection. // the DSN string is formatted func (d TDengineDriver) Open(dsn string) (driver.Conn, error) { - cfg, err := parseDSN(dsn) + cfg, err := ParseDSN(dsn) if err != nil { return nil, err } @@ -26,3 +26,19 @@ func (d TDengineDriver) Open(dsn string) (driver.Conn, error) { func init() { sql.Register("taosRestful", &TDengineDriver{}) } + +// NewConnector returns new driver.Connector. +func NewConnector(cfg *Config) (driver.Connector, error) { + return &connector{cfg: cfg}, nil +} + +// OpenConnector implements driver.DriverContext. +func (d TDengineDriver) OpenConnector(dsn string) (driver.Connector, error) { + cfg, err := ParseDSN(dsn) + if err != nil { + return nil, err + } + return &connector{ + cfg: cfg, + }, nil +} diff --git a/taosRestful/driver_test.go b/taosRestful/driver_test.go index c214d70..043d17c 100644 --- a/taosRestful/driver_test.go +++ b/taosRestful/driver_test.go @@ -289,3 +289,18 @@ func TestChinese(t *testing.T) { } assert.Equal(t, 1, counter) } + +func TestNewConnector(t *testing.T) { + cfg, err := ParseDSN(dataSourceName) + assert.NoError(t, err) + conn, err := NewConnector(cfg) + assert.NoError(t, err) + db := sql.OpenDB(conn) + defer func() { + err := db.Close() + assert.NoError(t, err) + }() + if err := db.Ping(); err != nil { + t.Fatal(err) + } +} diff --git a/taosRestful/dsn.go b/taosRestful/dsn.go index 612962a..4e2f259 100644 --- a/taosRestful/dsn.go +++ b/taosRestful/dsn.go @@ -9,43 +9,43 @@ import ( ) var ( - errInvalidDSNUnescaped = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: did you forget to escape a param value?"} - errInvalidDSNAddr = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network address not terminated (missing closing brace)"} - errInvalidDSNPort = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network port is not a valid number"} - errInvalidDSNNoSlash = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: missing the slash separating the database name"} + ErrInvalidDSNUnescaped = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: did you forget to escape a param value?"} + ErrInvalidDSNAddr = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network address not terminated (missing closing brace)"} + ErrInvalidDSNPort = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network port is not a valid number"} + ErrInvalidDSNNoSlash = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: missing the slash separating the database name"} ) // Config is a configuration parsed from a DSN string. // If a new Config is created instead of being parsed from a DSN string, // the NewConfig function should be used, which sets default values. -type config struct { - user string // Username - passwd string // Password (requires User) - net string // Network type - addr string // Network address (requires Net) - port int - dbName string // Database name - params map[string]string // Connection parameters - interpolateParams bool // Interpolate placeholders into query string - disableCompression bool - readBufferSize int - token string // cloud platform token - skipVerify bool +type Config struct { + User string // Username + Passwd string // Password (requires User) + Net string // Network type + Addr string // Network address (requires Net) + Port int + DbName string // Database name + Params map[string]string // Connection parameters + InterpolateParams bool // Interpolate placeholders into query string + DisableCompression bool + ReadBufferSize int + Token string // cloud platform Token + SkipVerify bool } // NewConfig creates a new Config and sets default values. -func newConfig() *config { - return &config{ - interpolateParams: true, - disableCompression: true, - readBufferSize: 4 << 10, +func NewConfig() *Config { + return &Config{ + InterpolateParams: true, + DisableCompression: true, + ReadBufferSize: 4 << 10, } } // ParseDSN parses the DSN string to a Config -func parseDSN(dsn string) (cfg *config, err error) { +func ParseDSN(dsn string) (cfg *Config, err error) { // New config with some default values - cfg = newConfig() + cfg = NewConfig() // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] // Find the last '/' (since the password or the net addr might contain a '/') @@ -65,11 +65,11 @@ func parseDSN(dsn string) (cfg *config, err error) { // Find the first ':' in dsn[:j] for k = 0; k < j; k++ { if dsn[k] == ':' { - cfg.passwd = dsn[k+1 : j] + cfg.Passwd = tryUnescape(dsn[k+1 : j]) break } } - cfg.user = dsn[:k] + cfg.User = tryUnescape(dsn[:k]) break } @@ -82,25 +82,25 @@ func parseDSN(dsn string) (cfg *config, err error) { // dsn[i-1] must be == ')' if an address is specified if dsn[i-1] != ')' { if strings.ContainsRune(dsn[k+1:i], ')') { - return nil, errInvalidDSNUnescaped + return nil, ErrInvalidDSNUnescaped } //return nil, errInvalidDSNAddr } strList := strings.Split(dsn[k+1:i-1], ":") if len(strList) == 1 { - return nil, errInvalidDSNAddr + return nil, ErrInvalidDSNAddr } if len(strList[0]) != 0 { - cfg.addr = strList[0] - cfg.port, err = strconv.Atoi(strList[1]) + cfg.Addr = strList[0] + cfg.Port, err = strconv.Atoi(strList[1]) if err != nil { - return nil, errInvalidDSNPort + return nil, ErrInvalidDSNPort } } break } } - cfg.net = dsn[j+1 : k] + cfg.Net = dsn[j+1 : k] } // dbname[?param1=value1&...¶mN=valueN] @@ -113,14 +113,14 @@ func parseDSN(dsn string) (cfg *config, err error) { break } } - cfg.dbName = dsn[i+1 : j] + cfg.DbName = dsn[i+1 : j] break } } if !foundSlash && len(dsn) > 0 { - return nil, errInvalidDSNNoSlash + return nil, ErrInvalidDSNNoSlash } return @@ -128,7 +128,7 @@ func parseDSN(dsn string) (cfg *config, err error) { // parseDSNParams parses the DSN "query string" // Values must be url.QueryEscape'ed -func parseDSNParams(cfg *config, params string) (err error) { +func parseDSNParams(cfg *Config, params string) (err error) { for _, v := range strings.Split(params, "&") { param := strings.SplitN(v, "=", 2) if len(param) != 2 { @@ -139,34 +139,34 @@ func parseDSNParams(cfg *config, params string) (err error) { switch value := param[1]; param[0] { // Enable client side placeholder substitution case "interpolateParams": - cfg.interpolateParams, err = strconv.ParseBool(value) + cfg.InterpolateParams, err = strconv.ParseBool(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value} } case "disableCompression": - cfg.disableCompression, err = strconv.ParseBool(value) + cfg.DisableCompression, err = strconv.ParseBool(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value} } case "readBufferSize": - cfg.readBufferSize, err = strconv.Atoi(value) + cfg.ReadBufferSize, err = strconv.Atoi(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid int value: " + value} } case "token": - cfg.token = value + cfg.Token = value case "skipVerify": - cfg.skipVerify, err = strconv.ParseBool(value) + cfg.SkipVerify, err = strconv.ParseBool(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value} } default: // lazy init - if cfg.params == nil { - cfg.params = make(map[string]string) + if cfg.Params == nil { + cfg.Params = make(map[string]string) } - if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { + if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { return } } @@ -174,3 +174,10 @@ func parseDSNParams(cfg *config, params string) (err error) { return } + +func tryUnescape(s string) string { + if res, err := url.QueryUnescape(s); err == nil { + return res + } + return s +} diff --git a/taosRestful/dsn_test.go b/taosRestful/dsn_test.go index 2ec3fbd..b6351e0 100644 --- a/taosRestful/dsn_test.go +++ b/taosRestful/dsn_test.go @@ -1,8 +1,9 @@ package taosRestful import ( - "fmt" "testing" + + "github.com/stretchr/testify/assert" ) // @author: xftan @@ -10,6 +11,7 @@ import ( // @description: test parse dsn func TestParseDsn(t *testing.T) { tcs := []struct { + name string dsn string errs string user string @@ -21,20 +23,45 @@ func TestParseDsn(t *testing.T) { token string skipVerify bool }{{}, - {dsn: "abcd", errs: "invalid DSN: missing the slash separating the database name"}, - {dsn: "user:passwd@http(fqdn:6041)/dbname", user: "user", passwd: "passwd", net: "http", addr: "fqdn", port: 6041, dbName: "dbname"}, - {dsn: "user:passwd@http()/dbname", errs: "invalid DSN: network address not terminated (missing closing brace)"}, - {dsn: "user:passwd@http(:)/dbname", user: "user", passwd: "passwd", net: "http", dbName: "dbname"}, - {dsn: "user:passwd@http(:0)/dbname", user: "user", passwd: "passwd", net: "http", dbName: "dbname"}, - {dsn: "user:passwd@https(:0)/", user: "user", passwd: "passwd", net: "https"}, - {dsn: "user:passwd@https(:0)/?interpolateParams=false&test=1", user: "user", passwd: "passwd", net: "https"}, - {dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token", user: "user", passwd: "passwd", net: "https", token: "token"}, - {dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token&skipVerify=true", user: "user", passwd: "passwd", net: "https", token: "token", skipVerify: true}, + {name: "invalid", dsn: "abcd", errs: "invalid DSN: missing the slash separating the database name"}, + {name: "normal", dsn: "user:passwd@http(fqdn:6041)/dbname", user: "user", passwd: "passwd", net: "http", addr: "fqdn", port: 6041, dbName: "dbname"}, + {name: "invalid addr", dsn: "user:passwd@http()/dbname", errs: "invalid DSN: network address not terminated (missing closing brace)"}, + {name: "default addr", dsn: "user:passwd@http(:)/dbname", user: "user", passwd: "passwd", net: "http", dbName: "dbname"}, + {name: "0port", dsn: "user:passwd@http(:0)/dbname", user: "user", passwd: "passwd", net: "http", dbName: "dbname"}, + {name: "no db", dsn: "user:passwd@https(:0)/", user: "user", passwd: "passwd", net: "https"}, + {name: "params", dsn: "user:passwd@https(:0)/?interpolateParams=false&test=1", user: "user", passwd: "passwd", net: "https"}, + {name: "token", dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token", user: "user", passwd: "passwd", net: "https", token: "token"}, + {name: "skipVerify", dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token&skipVerify=true", user: "user", passwd: "passwd", net: "https", token: "token", skipVerify: true}, + { + name: "skipVerify", + dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token&skipVerify=true", + user: "user", + passwd: "passwd", + net: "https", + token: "token", + skipVerify: true, + }, + { + name: "special char", + dsn: "!%40%23%24%25%5E%26*()-_%2B%3D%5B%5D%7B%7D%3A%3B%3E%3C%3F%7C~%2C.:!%40%23%24%25%5E%26*()-_%2B%3D%5B%5D%7B%7D%3A%3B%3E%3C%3F%7C~%2C.@https(:)/dbname", + user: "!@#$%^&*()-_+=[]{}:;> 0 { + if c.cfg.Net == "cfg" && len(c.cfg.ConfigPath) > 0 { once.Do(func() { locker.Lock() - code := wrapper.TaosOptions(common.TSDB_OPTION_CONFIGDIR, c.cfg.configPath) + code := wrapper.TaosOptions(common.TSDB_OPTION_CONFIGDIR, c.cfg.ConfigPath) locker.Unlock() if code != 0 { err = errors.NewError(code, wrapper.TaosErrorStr(nil)) @@ -37,20 +54,20 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } // Connect to Server - if len(tc.cfg.user) == 0 { - tc.cfg.user = common.DefaultUser + if len(tc.cfg.User) == 0 { + tc.cfg.User = common.DefaultUser } - if len(tc.cfg.passwd) == 0 { - tc.cfg.passwd = common.DefaultPassword + if len(tc.cfg.Passwd) == 0 { + tc.cfg.Passwd = common.DefaultPassword } locker.Lock() - err = wrapper.TaosSetConfig(tc.cfg.params) + err = wrapper.TaosSetConfig(tc.cfg.Params) locker.Unlock() if err != nil { return nil, err } locker.Lock() - tc.taos, err = wrapper.TaosConnect(tc.cfg.addr, tc.cfg.user, tc.cfg.passwd, tc.cfg.dbName, tc.cfg.port) + tc.taos, err = wrapper.TaosConnect(tc.cfg.Addr, tc.cfg.User, tc.cfg.Passwd, tc.cfg.DbName, tc.cfg.Port) locker.Unlock() if err != nil { return nil, err diff --git a/taosSql/driver.go b/taosSql/driver.go index f50a907..85ccfc3 100644 --- a/taosSql/driver.go +++ b/taosSql/driver.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "database/sql/driver" - "runtime" "sync" "github.com/taosdata/driver-go/v3/wrapper/handler" @@ -23,30 +22,33 @@ type TDengineDriver struct{} // Open new Connection. // the DSN string is formatted func (d TDengineDriver) Open(dsn string) (driver.Conn, error) { - cfg, err := parseDSN(dsn) + cfg, err := ParseDSN(dsn) if err != nil { return nil, err } c := &connector{ cfg: cfg, } - onceInitLock.Do(func() { - threads := cfg.cgoThread - if threads <= 0 { - threads = runtime.NumCPU() - } - locker = thread.NewLocker(threads) - }) - onceInitHandlerPool.Do(func() { - poolSize := cfg.cgoAsyncHandlerPoolSize - if poolSize <= 0 { - poolSize = 10000 - } - asyncHandlerPool = handler.NewHandlerPool(poolSize) - }) + return c.Connect(context.Background()) } func init() { sql.Register("taosSql", &TDengineDriver{}) } + +// NewConnector returns new driver.Connector. +func NewConnector(cfg *Config) (driver.Connector, error) { + return &connector{cfg: cfg}, nil +} + +// OpenConnector implements driver.DriverContext. +func (d TDengineDriver) OpenConnector(dsn string) (driver.Connector, error) { + cfg, err := ParseDSN(dsn) + if err != nil { + return nil, err + } + return &connector{ + cfg: cfg, + }, nil +} diff --git a/taosSql/driver_test.go b/taosSql/driver_test.go index 51a76be..587f078 100644 --- a/taosSql/driver_test.go +++ b/taosSql/driver_test.go @@ -560,3 +560,18 @@ func TestChinese(t *testing.T) { } assert.Equal(t, 1, counter) } + +func TestNewConnector(t *testing.T) { + cfg, err := ParseDSN(dataSourceName) + assert.NoError(t, err) + conn, err := NewConnector(cfg) + assert.NoError(t, err) + db := sql.OpenDB(conn) + defer func() { + err := db.Close() + assert.NoError(t, err) + }() + if err := db.Ping(); err != nil { + t.Fatal(err) + } +} diff --git a/taosSql/dsn.go b/taosSql/dsn.go index 90cdb6d..78337b0 100644 --- a/taosSql/dsn.go +++ b/taosSql/dsn.go @@ -26,42 +26,42 @@ import ( ) var ( - errInvalidDSNUnescaped = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: did you forget to escape a param value?"} - errInvalidDSNAddr = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network address not terminated (missing closing brace)"} - errInvalidDSNPort = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network port is not a valid number"} - errInvalidDSNNoSlash = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: missing the slash separating the database name"} + ErrInvalidDSNUnescaped = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: did you forget to escape a param value?"} + ErrInvalidDSNAddr = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network address not terminated (missing closing brace)"} + ErrInvalidDSNPort = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network port is not a valid number"} + ErrInvalidDSNNoSlash = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: missing the slash separating the database name"} ) // Config is a configuration parsed from a DSN string. // If a new Config is created instead of being parsed from a DSN string, // the NewConfig function should be used, which sets default values. -type config struct { - user string // Username - passwd string // Password (requires User) - net string // Network type - addr string // Network address (requires Net) - port int - dbName string // Database name - params map[string]string // Connection parameters - loc *time.Location // Location for time.Time values - interpolateParams bool // Interpolate placeholders into query string - configPath string - cgoThread int - cgoAsyncHandlerPoolSize int +type Config struct { + User string // Username + Passwd string // Password (requires User) + Net string // Network type + Addr string // Network address (requires Net) + Port int + DbName string // Database name + Params map[string]string // Connection parameters + Loc *time.Location // Location for time.Time values + InterpolateParams bool // Interpolate placeholders into query string + ConfigPath string + CgoThread int + CgoAsyncHandlerPoolSize int } // NewConfig creates a new Config and sets default values. -func newConfig() *config { - return &config{ - loc: time.UTC, - interpolateParams: true, +func NewConfig() *Config { + return &Config{ + Loc: time.UTC, + InterpolateParams: true, } } // ParseDSN parses the DSN string to a Config -func parseDSN(dsn string) (cfg *config, err error) { +func ParseDSN(dsn string) (cfg *Config, err error) { // New config with some default values - cfg = newConfig() + cfg = NewConfig() // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] // Find the last '/' (since the password or the net addr might contain a '/') @@ -81,11 +81,11 @@ func parseDSN(dsn string) (cfg *config, err error) { // Find the first ':' in dsn[:j] for k = 0; k < j; k++ { if dsn[k] == ':' { - cfg.passwd = dsn[k+1 : j] + cfg.Passwd = tryUnescape(dsn[k+1 : j]) break } } - cfg.user = dsn[:k] + cfg.User = tryUnescape(dsn[:k]) break } @@ -98,13 +98,13 @@ func parseDSN(dsn string) (cfg *config, err error) { // dsn[i-1] must be == ')' if an address is specified if dsn[i-1] != ')' { if strings.ContainsRune(dsn[k+1:i], ')') { - return nil, errInvalidDSNUnescaped + return nil, ErrInvalidDSNUnescaped } //return nil, errInvalidDSNAddr } if dsn[j+1:k] == "cfg" { cfgPath := dsn[k+1 : i-1] - cfg.configPath, err = filepath.Abs(cfgPath) + cfg.ConfigPath, err = filepath.Abs(cfgPath) if err != nil { return nil, err } @@ -112,20 +112,20 @@ func parseDSN(dsn string) (cfg *config, err error) { } else { strList := strings.Split(dsn[k+1:i-1], ":") if len(strList) == 1 { - return nil, errInvalidDSNAddr + return nil, ErrInvalidDSNAddr } if len(strList[0]) != 0 { - cfg.addr = strList[0] - cfg.port, err = strconv.Atoi(strList[1]) + cfg.Addr = strList[0] + cfg.Port, err = strconv.Atoi(strList[1]) if err != nil { - return nil, errInvalidDSNPort + return nil, ErrInvalidDSNPort } } break } } } - cfg.net = dsn[j+1 : k] + cfg.Net = dsn[j+1 : k] } // dbname[?param1=value1&...¶mN=valueN] @@ -138,14 +138,14 @@ func parseDSN(dsn string) (cfg *config, err error) { break } } - cfg.dbName = dsn[i+1 : j] + cfg.DbName = dsn[i+1 : j] break } } if !foundSlash && len(dsn) > 0 { - return nil, errInvalidDSNNoSlash + return nil, ErrInvalidDSNNoSlash } return @@ -153,7 +153,7 @@ func parseDSN(dsn string) (cfg *config, err error) { // parseDSNParams parses the DSN "query string" // Values must be url.QueryEscape'ed -func parseDSNParams(cfg *config, params string) (err error) { +func parseDSNParams(cfg *Config, params string) (err error) { for _, v := range strings.Split(params, "&") { param := strings.SplitN(v, "=", 2) if len(param) != 2 { @@ -164,7 +164,7 @@ func parseDSNParams(cfg *config, params string) (err error) { switch value := param[1]; param[0] { // Enable client side placeholder substitution case "interpolateParams": - cfg.interpolateParams, err = strconv.ParseBool(value) + cfg.InterpolateParams, err = strconv.ParseBool(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value} } @@ -174,30 +174,30 @@ func parseDSNParams(cfg *config, params string) (err error) { if value, err = url.QueryUnescape(value); err != nil { return } - cfg.loc, err = time.LoadLocation(value) + cfg.Loc, err = time.LoadLocation(value) if err != nil { return } case "cgoThread": - cfg.cgoThread, err = strconv.Atoi(value) + cfg.CgoThread, err = strconv.Atoi(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid cgoThread value: " + value} } case "cgoAsyncHandlerPoolSize": - cfg.cgoAsyncHandlerPoolSize, err = strconv.Atoi(value) + cfg.CgoAsyncHandlerPoolSize, err = strconv.Atoi(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid cgoAsyncHandlerPoolSize value: " + value} } default: // lazy init - if cfg.params == nil { - cfg.params = make(map[string]string) + if cfg.Params == nil { + cfg.Params = make(map[string]string) } - if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { + if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { return } } @@ -205,3 +205,10 @@ func parseDSNParams(cfg *config, params string) (err error) { return } + +func tryUnescape(s string) string { + if res, err := url.QueryUnescape(s); err == nil { + return res + } + return s +} diff --git a/taosSql/dsn_test.go b/taosSql/dsn_test.go index 1ad259e..f22325d 100644 --- a/taosSql/dsn_test.go +++ b/taosSql/dsn_test.go @@ -1,7 +1,6 @@ package taosSql import ( - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -11,7 +10,8 @@ import ( // @date: 2022/1/27 16:18 // @description: test dsn parse func TestParseDsn(t *testing.T) { - tcs := []struct { + tests := []struct { + name string dsn string errs string user string @@ -23,39 +23,66 @@ func TestParseDsn(t *testing.T) { configPath string cgoThread int cgoAsyncHandlerPoolSize int - }{{}, - {dsn: "abcd", errs: "invalid DSN: missing the slash separating the database name"}, - {dsn: "user:passwd@net(fqdn:6030)/dbname", user: "user", passwd: "passwd", net: "net", addr: "fqdn", port: 6030, dbName: "dbname"}, - {dsn: "user:passwd@net()/dbname", errs: "invalid DSN: network address not terminated (missing closing brace)"}, - {dsn: "user:passwd@net(:)/dbname", user: "user", passwd: "passwd", net: "net", dbName: "dbname"}, - {dsn: "user:passwd@net(:0)/dbname", user: "user", passwd: "passwd", net: "net", dbName: "dbname"}, - {dsn: "user:passwd@net(:0)/", user: "user", passwd: "passwd", net: "net"}, - {dsn: "net(:0)/wo", net: "net", dbName: "wo"}, - {dsn: "user:passwd@cfg(/home/taos)/db", user: "user", passwd: "passwd", net: "cfg", configPath: "/home/taos", dbName: "db"}, - {dsn: "user:passwd@cfg/db", user: "user", passwd: "passwd", net: "cfg", configPath: "", dbName: "db"}, - {dsn: "net(:0)/wo?firstEp=LAPTOP-NNKFTLTG.localdomain%3A6030&secondEp=LAPTOP-NNKFTLTG.localdomain%3A6030&fqdn=LAPTOP-NNKFTLTG.localdomain&serverPort=6030&configDir=%2Fetc%2Ftaos&logDir=%2Fvar%2Flog%2Ftaos&scriptDir=%2Fetc%2Ftaos&arbitrator=&numOfThreadsPerCore=1.000000&maxNumOfDistinctRes=10000000&rpcTimer=300&rpcForceTcp=0&rpcMaxTime=600&shellActivityTimer=3&compressMsgSize=-1&maxSQLLength=1048576&maxWildCardsLength=100&maxNumOfOrderedRes=100000&keepColumnName=0&timezone=Asia%2FShanghai+%28CST%2C+%2B0800%29&locale=C.UTF-8&charset=UTF-8&numOfLogLines=10000000&logKeepDays=0&asyncLog=1&debugFlag=0&rpcDebugFlag=131&tmrDebugFlag=131&cDebugFlag=131&jniDebugFlag=131&odbcDebugFlag=131&uDebugFlag=131&qDebugFlag=131&tsdbDebugFlag=131&gitinfo=TAOS_CFG_VTYPE_STRING&gitinfoOfInternal=TAOS_CFG_VTYPE_STRING&buildinfo=TAOS_CFG_VTYPE_STRING&version=TAOS_CFG_VTYPE_STRING&maxBinaryDisplayWidth=30&tempDir=%2Ftmp%2F", net: "net", dbName: "wo"}, - {dsn: "net(:0)/wo?cgoThread=8", net: "net", dbName: "wo", cgoThread: 8}, - {dsn: "net(:0)/wo?cgoThread=8&cgoAsyncHandlerPoolSize=10000", net: "net", dbName: "wo", cgoThread: 8, cgoAsyncHandlerPoolSize: 10000}, + }{ + {name: "invalid", dsn: "abcd", errs: "invalid DSN: missing the slash separating the database name"}, + {name: "normal", dsn: "user:passwd@net(fqdn:6030)/dbname", user: "user", passwd: "passwd", net: "net", addr: "fqdn", port: 6030, dbName: "dbname"}, + {name: "missing closing brace", dsn: "user:passwd@net()/dbname", errs: "invalid DSN: network address not terminated (missing closing brace)"}, + {name: "default addr", dsn: "user:passwd@net(:)/dbname", user: "user", passwd: "passwd", net: "net", dbName: "dbname"}, + {name: "0port", dsn: "user:passwd@net(:0)/dbname", user: "user", passwd: "passwd", net: "net", dbName: "dbname"}, + {name: "no dbname", dsn: "user:passwd@net(:0)/", user: "user", passwd: "passwd", net: "net"}, + {name: "no auth", dsn: "net(:0)/wo", net: "net", dbName: "wo"}, + {name: "cfg", dsn: "user:passwd@cfg(/home/taos)/db", user: "user", passwd: "passwd", net: "cfg", configPath: "/home/taos", dbName: "db"}, + {name: "no addr", dsn: "user:passwd@cfg/db", user: "user", passwd: "passwd", net: "cfg", configPath: "", dbName: "db"}, + {name: "options", dsn: "net(:0)/wo?firstEp=LAPTOP-NNKFTLTG.localdomain%3A6030&secondEp=LAPTOP-NNKFTLTG.localdomain%3A6030&fqdn=LAPTOP-NNKFTLTG.localdomain&serverPort=6030&configDir=%2Fetc%2Ftaos&logDir=%2Fvar%2Flog%2Ftaos&scriptDir=%2Fetc%2Ftaos&arbitrator=&numOfThreadsPerCore=1.000000&maxNumOfDistinctRes=10000000&rpcTimer=300&rpcForceTcp=0&rpcMaxTime=600&shellActivityTimer=3&compressMsgSize=-1&maxSQLLength=1048576&maxWildCardsLength=100&maxNumOfOrderedRes=100000&keepColumnName=0&timezone=Asia%2FShanghai+%28CST%2C+%2B0800%29&locale=C.UTF-8&charset=UTF-8&numOfLogLines=10000000&logKeepDays=0&asyncLog=1&debugFlag=0&rpcDebugFlag=131&tmrDebugFlag=131&cDebugFlag=131&jniDebugFlag=131&odbcDebugFlag=131&uDebugFlag=131&qDebugFlag=131&tsdbDebugFlag=131&gitinfo=TAOS_CFG_VTYPE_STRING&gitinfoOfInternal=TAOS_CFG_VTYPE_STRING&buildinfo=TAOS_CFG_VTYPE_STRING&version=TAOS_CFG_VTYPE_STRING&maxBinaryDisplayWidth=30&tempDir=%2Ftmp%2F", net: "net", dbName: "wo"}, + {name: "cgoThread", dsn: "net(:0)/wo?cgoThread=8", net: "net", dbName: "wo", cgoThread: 8}, + {name: "cgoAsyncHandlerPoolSize", dsn: "net(:0)/wo?cgoThread=8&cgoAsyncHandlerPoolSize=10000", net: "net", dbName: "wo", cgoThread: 8, cgoAsyncHandlerPoolSize: 10000}, + { + name: "special char", + dsn: "!%40%23%24%25%5E%26*()-_%2B%3D%5B%5D%7B%7D%3A%3B%3E%3C%3F%7C~%2C.:!%40%23%24%25%5E%26*()-_%2B%3D%5B%5D%7B%7D%3A%3B%3E%3C%3F%7C~%2C.@net(:)/dbname", + user: "!@#$%^&*()-_+=[]{}:;> 0 { - return nil, errInvalidDSNNoSlash + return nil, ErrInvalidDSNNoSlash } return @@ -127,7 +127,7 @@ func parseDSN(dsn string) (cfg *config, err error) { // parseDSNParams parses the DSN "query string" // Values must be url.QueryEscape'ed -func parseDSNParams(cfg *config, params string) (err error) { +func parseDSNParams(cfg *Config, params string) (err error) { for _, v := range strings.Split(params, "&") { param := strings.SplitN(v, "=", 2) if len(param) != 2 { @@ -138,34 +138,34 @@ func parseDSNParams(cfg *config, params string) (err error) { switch value := param[1]; param[0] { // Enable client side placeholder substitution case "interpolateParams": - cfg.interpolateParams, err = strconv.ParseBool(value) + cfg.InterpolateParams, err = strconv.ParseBool(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value} } case "token": - cfg.token = value + cfg.Token = value case "enableCompression": - cfg.enableCompression, err = strconv.ParseBool(value) + cfg.EnableCompression, err = strconv.ParseBool(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid enableCompression value: " + value} } case "readTimeout": - cfg.readTimeout, err = time.ParseDuration(value) + cfg.ReadTimeout, err = time.ParseDuration(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid duration value: " + value} } case "writeTimeout": - cfg.writeTimeout, err = time.ParseDuration(value) + cfg.WriteTimeout, err = time.ParseDuration(value) if err != nil { return &errors.TaosError{Code: 0xffff, ErrStr: "invalid duration value: " + value} } default: // lazy init - if cfg.params == nil { - cfg.params = make(map[string]string) + if cfg.Params == nil { + cfg.Params = make(map[string]string) } - if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { + if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { return } } @@ -173,3 +173,10 @@ func parseDSNParams(cfg *config, params string) (err error) { return } + +func tryUnescape(s string) string { + if res, err := url.QueryUnescape(s); err == nil { + return res + } + return s +} diff --git a/taosWS/dsn_test.go b/taosWS/dsn_test.go index 7d7c7d7..627efe9 100644 --- a/taosWS/dsn_test.go +++ b/taosWS/dsn_test.go @@ -12,32 +12,63 @@ import ( // @description: test parse dsn func TestParseDsn(t *testing.T) { tests := []struct { + name string dsn string errs string - want *config + want *Config }{ - {dsn: "abcd", errs: "invalid DSN: missing the slash separating the database name"}, - {dsn: "user:passwd@ws(fqdn:6041)/dbname", want: &config{user: "user", passwd: "passwd", net: "ws", addr: "fqdn", port: 6041, dbName: "dbname", interpolateParams: true}}, - {dsn: "user:passwd@ws()/dbname", errs: "invalid DSN: network address not terminated (missing closing brace)"}, - {dsn: "user:passwd@ws(:)/dbname", want: &config{user: "user", passwd: "passwd", net: "ws", dbName: "dbname", interpolateParams: true}}, - {dsn: "user:passwd@ws(:0)/dbname", want: &config{user: "user", passwd: "passwd", net: "ws", dbName: "dbname", interpolateParams: true}}, - {dsn: "user:passwd@wss(:0)/", want: &config{user: "user", passwd: "passwd", net: "wss", interpolateParams: true}}, - {dsn: "user:passwd@wss(:0)/?interpolateParams=false&test=1", want: &config{user: "user", passwd: "passwd", net: "wss", params: map[string]string{"test": "1"}}}, - {dsn: "user:passwd@wss(:0)/?interpolateParams=false&token=token", want: &config{user: "user", passwd: "passwd", net: "wss", token: "token"}}, - {dsn: "user:passwd@wss(:0)/?writeTimeout=8s&readTimeout=10m", want: &config{user: "user", passwd: "passwd", net: "wss", readTimeout: 10 * time.Minute, writeTimeout: 8 * time.Second, interpolateParams: true}}, - {dsn: "user:passwd@wss(:0)/?writeTimeout=8s&readTimeout=10m&enableCompression=true", want: &config{ - user: "user", - passwd: "passwd", - net: "wss", - readTimeout: 10 * time.Minute, - writeTimeout: 8 * time.Second, - interpolateParams: true, - enableCompression: true, + {name: "invalid DSN", dsn: "abcd", errs: "invalid DSN: missing the slash separating the database name"}, + {name: "common DSN", dsn: "user:passwd@ws(fqdn:6041)/dbname", want: &Config{User: "user", Passwd: "passwd", Net: "ws", Addr: "fqdn", Port: 6041, DbName: "dbname", InterpolateParams: true}}, + {name: "missing closing brace", dsn: "user:passwd@ws()/dbname", errs: "invalid DSN: network address not terminated (missing closing brace)"}, + {name: "default address", dsn: "user:passwd@ws(:)/dbname", want: &Config{User: "user", Passwd: "passwd", Net: "ws", DbName: "dbname", InterpolateParams: true}}, + {name: "0 port", dsn: "user:passwd@ws(:0)/dbname", want: &Config{User: "user", Passwd: "passwd", Net: "ws", DbName: "dbname", InterpolateParams: true}}, + {name: "wss protocol", dsn: "user:passwd@wss(:0)/", want: &Config{User: "user", Passwd: "passwd", Net: "wss", InterpolateParams: true}}, + {name: "params", dsn: "user:passwd@wss(:0)/?interpolateParams=false&test=1", want: &Config{User: "user", Passwd: "passwd", Net: "wss", Params: map[string]string{"test": "1"}}}, + {name: "token", dsn: "user:passwd@wss(:0)/?interpolateParams=false&token=token", want: &Config{User: "user", Passwd: "passwd", Net: "wss", Token: "token"}}, + {name: "readTimeout", dsn: "user:passwd@wss(:0)/?writeTimeout=8s&readTimeout=10m", want: &Config{User: "user", Passwd: "passwd", Net: "wss", ReadTimeout: 10 * time.Minute, WriteTimeout: 8 * time.Second, InterpolateParams: true}}, + {name: "compression", dsn: "user:passwd@wss(:0)/?writeTimeout=8s&readTimeout=10m&enableCompression=true", want: &Config{ + User: "user", + Passwd: "passwd", + Net: "wss", + ReadTimeout: 10 * time.Minute, + WriteTimeout: 8 * time.Second, + InterpolateParams: true, + EnableCompression: true, }}, + // encodeURIComponent('!@#$%^&*()-_+=[]{}:;> Date: Mon, 23 Dec 2024 14:47:12 +0800 Subject: [PATCH 33/35] test: add test for open --- taosRestful/driver_test.go | 14 ++++++++++++++ taosSql/driver_test.go | 14 ++++++++++++++ taosWS/driver_test.go | 14 ++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/taosRestful/driver_test.go b/taosRestful/driver_test.go index 043d17c..c071186 100644 --- a/taosRestful/driver_test.go +++ b/taosRestful/driver_test.go @@ -1,6 +1,7 @@ package taosRestful import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -304,3 +305,16 @@ func TestNewConnector(t *testing.T) { t.Fatal(err) } } + +func TestOpen(t *testing.T) { + tdDriver := &TDengineDriver{} + conn, err := tdDriver.Open(dataSourceName) + assert.NoError(t, err) + defer func() { + err := conn.Close() + assert.NoError(t, err) + }() + pinger := conn.(driver.Pinger) + err = pinger.Ping(context.Background()) + assert.NoError(t, err) +} diff --git a/taosSql/driver_test.go b/taosSql/driver_test.go index 587f078..add2f65 100644 --- a/taosSql/driver_test.go +++ b/taosSql/driver_test.go @@ -1,6 +1,7 @@ package taosSql import ( + "context" "database/sql" "database/sql/driver" "encoding/json" @@ -575,3 +576,16 @@ func TestNewConnector(t *testing.T) { t.Fatal(err) } } + +func TestOpen(t *testing.T) { + tdDriver := &TDengineDriver{} + conn, err := tdDriver.Open(dataSourceName) + assert.NoError(t, err) + defer func() { + err := conn.Close() + assert.NoError(t, err) + }() + pinger := conn.(driver.Pinger) + err = pinger.Ping(context.Background()) + assert.NoError(t, err) +} diff --git a/taosWS/driver_test.go b/taosWS/driver_test.go index 9989285..5d45836 100644 --- a/taosWS/driver_test.go +++ b/taosWS/driver_test.go @@ -1,6 +1,7 @@ package taosWS import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -295,3 +296,16 @@ func TestNewConnector(t *testing.T) { t.Fatal(err) } } + +func TestOpen(t *testing.T) { + tdDriver := &TDengineDriver{} + conn, err := tdDriver.Open(dataSourceName) + assert.NoError(t, err) + defer func() { + err := conn.Close() + assert.NoError(t, err) + }() + pinger := conn.(driver.Pinger) + err = pinger.Ping(context.Background()) + assert.NoError(t, err) +} From d292fd188cfec1a22bae8932f1f346ac30baa793 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Mon, 23 Dec 2024 15:50:13 +0800 Subject: [PATCH 34/35] test: add test for dsn --- taosRestful/dsn_test.go | 278 ++++++++++++++++++++++++++------ taosSql/dsn_test.go | 339 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 533 insertions(+), 84 deletions(-) diff --git a/taosRestful/dsn_test.go b/taosRestful/dsn_test.go index b6351e0..0595cb9 100644 --- a/taosRestful/dsn_test.go +++ b/taosRestful/dsn_test.go @@ -11,52 +11,239 @@ import ( // @description: test parse dsn func TestParseDsn(t *testing.T) { tcs := []struct { - name string - dsn string - errs string - user string - passwd string - net string - addr string - port int - dbName string - token string - skipVerify bool - }{{}, - {name: "invalid", dsn: "abcd", errs: "invalid DSN: missing the slash separating the database name"}, - {name: "normal", dsn: "user:passwd@http(fqdn:6041)/dbname", user: "user", passwd: "passwd", net: "http", addr: "fqdn", port: 6041, dbName: "dbname"}, - {name: "invalid addr", dsn: "user:passwd@http()/dbname", errs: "invalid DSN: network address not terminated (missing closing brace)"}, - {name: "default addr", dsn: "user:passwd@http(:)/dbname", user: "user", passwd: "passwd", net: "http", dbName: "dbname"}, - {name: "0port", dsn: "user:passwd@http(:0)/dbname", user: "user", passwd: "passwd", net: "http", dbName: "dbname"}, - {name: "no db", dsn: "user:passwd@https(:0)/", user: "user", passwd: "passwd", net: "https"}, - {name: "params", dsn: "user:passwd@https(:0)/?interpolateParams=false&test=1", user: "user", passwd: "passwd", net: "https"}, - {name: "token", dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token", user: "user", passwd: "passwd", net: "https", token: "token"}, - {name: "skipVerify", dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token&skipVerify=true", user: "user", passwd: "passwd", net: "https", token: "token", skipVerify: true}, + name string + dsn string + errs string + want *Config + }{ { - name: "skipVerify", - dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token&skipVerify=true", - user: "user", - passwd: "passwd", - net: "https", - token: "token", - skipVerify: true, + name: "invalid", + dsn: "abcd", + errs: "invalid DSN: missing the slash separating the database name", }, { - name: "special char", - dsn: "!%40%23%24%25%5E%26*()-_%2B%3D%5B%5D%7B%7D%3A%3B%3E%3C%3F%7C~%2C.:!%40%23%24%25%5E%26*()-_%2B%3D%5B%5D%7B%7D%3A%3B%3E%3C%3F%7C~%2C.@https(:)/dbname", - user: "!@#$%^&*()-_+=[]{}:;> Date: Wed, 25 Dec 2024 15:53:31 +0800 Subject: [PATCH 35/35] test: add test for special password --- taosRestful/driver_test.go | 61 ++++++++++++++++++++++++++++++++++++++ taosSql/driver_test.go | 61 ++++++++++++++++++++++++++++++++++++++ taosWS/driver_test.go | 61 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 183 insertions(+) diff --git a/taosRestful/driver_test.go b/taosRestful/driver_test.go index c071186..9c156bb 100644 --- a/taosRestful/driver_test.go +++ b/taosRestful/driver_test.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "log" + "net/url" "testing" "time" @@ -318,3 +319,63 @@ func TestOpen(t *testing.T) { err = pinger.Ping(context.Background()) assert.NoError(t, err) } + +func TestSpecialPassword(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Fatalf("error on: sql.open %s", err.Error()) + return + } + defer db.Close() + tests := []struct { + name string + user string + pass string + }{ + { + name: "test_special1_rs", + user: "test_special1_rs", + pass: "!q@w#a$1%3^&*()-", + }, + { + name: "test_special2_rs", + user: "test_special2_rs", + pass: "_q+3=[]{}:;><3?|~,w.", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + defer func() { + dropSql := fmt.Sprintf("drop user %s", test.user) + _, _ = db.Exec(dropSql) + }() + createSql := fmt.Sprintf("create user %s pass '%s'", test.user, test.pass) + _, err := db.Exec(createSql) + assert.NoError(t, err) + escapedPass := url.QueryEscape(test.pass) + newDsn := fmt.Sprintf("%s:%s@http(%s:%d)/%s", test.user, escapedPass, host, port, "") + db2, err := sql.Open(driverName, newDsn) + if err != nil { + t.Errorf("error on: sql.open %s", err.Error()) + return + } + defer db2.Close() + rows, err := db2.Query("select 1") + assert.NoError(t, err) + var i int + for rows.Next() { + err := rows.Scan(&i) + assert.NoError(t, err) + assert.Equal(t, 1, i) + } + if i != 1 { + t.Errorf("query failed") + } + }) + } +} diff --git a/taosSql/driver_test.go b/taosSql/driver_test.go index add2f65..152b796 100644 --- a/taosSql/driver_test.go +++ b/taosSql/driver_test.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "log" + "net/url" "testing" "time" @@ -589,3 +590,63 @@ func TestOpen(t *testing.T) { err = pinger.Ping(context.Background()) assert.NoError(t, err) } + +func TestSpecialPassword(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Fatalf("error on: sql.open %s", err.Error()) + return + } + defer db.Close() + tests := []struct { + name string + user string + pass string + }{ + { + name: "test_special1", + user: "test_special1", + pass: "!q@w#a$1%3^&*()-", + }, + { + name: "test_special2", + user: "test_special2", + pass: "_q+3=[]{}:;><3?|~,w.", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + defer func() { + dropSql := fmt.Sprintf("drop user %s", test.user) + _, _ = db.Exec(dropSql) + }() + createSql := fmt.Sprintf("create user %s pass '%s'", test.user, test.pass) + _, err := db.Exec(createSql) + assert.NoError(t, err) + escapedPass := url.QueryEscape(test.pass) + newDsn := fmt.Sprintf("%s:%s@/tcp(%s:%d)/%s", test.user, escapedPass, host, port, "") + db2, err := sql.Open(driverName, newDsn) + if err != nil { + t.Errorf("error on: sql.open %s", err.Error()) + return + } + defer db2.Close() + rows, err := db2.Query("select 1") + assert.NoError(t, err) + var i int + for rows.Next() { + err := rows.Scan(&i) + assert.NoError(t, err) + assert.Equal(t, 1, i) + } + if i != 1 { + t.Errorf("query failed") + } + }) + } +} diff --git a/taosWS/driver_test.go b/taosWS/driver_test.go index 5d45836..a9fc22d 100644 --- a/taosWS/driver_test.go +++ b/taosWS/driver_test.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "log" + "net/url" "testing" "time" @@ -309,3 +310,63 @@ func TestOpen(t *testing.T) { err = pinger.Ping(context.Background()) assert.NoError(t, err) } + +func TestSpecialPassword(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Fatalf("error on: sql.open %s", err.Error()) + return + } + defer db.Close() + tests := []struct { + name string + user string + pass string + }{ + { + name: "test_special1_ws", + user: "test_special1_ws", + pass: "!q@w#a$1%3^&*()-", + }, + { + name: "test_special2_ws", + user: "test_special2_ws", + pass: "_q+3=[]{}:;><3?|~,w.", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + defer func() { + dropSql := fmt.Sprintf("drop user %s", test.user) + _, _ = db.Exec(dropSql) + }() + createSql := fmt.Sprintf("create user %s pass '%s'", test.user, test.pass) + _, err := db.Exec(createSql) + assert.NoError(t, err) + escapedPass := url.QueryEscape(test.pass) + newDsn := fmt.Sprintf("%s:%s@ws(%s:%d)/%s", test.user, escapedPass, host, port, "") + db2, err := sql.Open(driverName, newDsn) + if err != nil { + t.Errorf("error on: sql.open %s", err.Error()) + return + } + defer db2.Close() + rows, err := db2.Query("select 1") + assert.NoError(t, err) + var i int + for rows.Next() { + err := rows.Scan(&i) + assert.NoError(t, err) + assert.Equal(t, 1, i) + } + if i != 1 { + t.Errorf("query failed") + } + }) + } +}