Skip to content

Commit

Permalink
Merge pull request #290 from taosdata/enh/xftan/TD-31941-3.1
Browse files Browse the repository at this point in the history
enh: support input request id in rest and ws
  • Loading branch information
sheyanjie-qq authored Sep 6, 2024
2 parents 3cda922 + ee20317 commit 4da4002
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 31 deletions.
16 changes: 16 additions & 0 deletions common/reqid.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package common

import (
"context"
"fmt"
"math/bits"
"os"
"sync/atomic"
Expand Down Expand Up @@ -78,3 +80,17 @@ func murmurHash32(data []byte, seed uint32) uint32 {

return h1
}

func GetReqIDFromCtx(ctx context.Context) (int64, error) {
var reqIDValue int64
var ok bool
reqID := ctx.Value(ReqIDKey)
if reqID != nil {
reqIDValue, ok = reqID.(int64)
if !ok {
return 0, fmt.Errorf("invalid taos_req_id: %v, should be int64, got %T", reqID, reqID)
}
return reqIDValue, nil
}
return 0, nil
}
15 changes: 14 additions & 1 deletion taosRestful/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type taosConn struct {
cfg *config
client *http.Client
url *url.URL
baseRawQuery string
header map[string][]string
readBufferSize int
}
Expand Down Expand Up @@ -69,7 +70,7 @@ func newTaosConn(cfg *config) (*taosConn, error) {
"Connection": {"keep-alive"},
}
if cfg.token != "" {
tc.url.RawQuery = fmt.Sprintf("token=%s", cfg.token)
tc.baseRawQuery = fmt.Sprintf("token=%s", cfg.token)
} else {
basic := base64.StdEncoding.EncodeToString([]byte(cfg.user + ":" + cfg.passwd))
tc.header["Authorization"] = []string{fmt.Sprintf("Basic %s", basic)}
Expand Down Expand Up @@ -191,6 +192,18 @@ func (tc *taosConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.
}

func (tc *taosConn) taosQuery(ctx context.Context, sql string, bufferSize int) (*common.TDEngineRestfulResp, error) {
reqIDValue, err := common.GetReqIDFromCtx(ctx)
if err != nil {
return nil, err
}
if reqIDValue == 0 {
reqIDValue = common.GetReqID()
}
if tc.baseRawQuery != "" {
tc.url.RawQuery = fmt.Sprintf("%s&req_id=%d", tc.baseRawQuery, reqIDValue)
} else {
tc.url.RawQuery = fmt.Sprintf("req_id=%d", reqIDValue)
}
body := ioutil.NopCloser(strings.NewReader(sql))
req := &http.Request{
Method: http.MethodPost,
Expand Down
15 changes: 6 additions & 9 deletions taosSql/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,10 @@ func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver
}

func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
var reqIDValue int64
reqID := ctx.Value(common.ReqIDKey)
if reqID != nil {
reqIDValue, _ = reqID.(int64)
reqIDValue, err := common.GetReqIDFromCtx(ctx)
if err != nil {
return nil, err
}

if len(args) != 0 {
if !tc.cfg.interpolateParams {
return nil, driver.ErrSkip
Expand Down Expand Up @@ -132,10 +130,9 @@ func (tc *taosConn) QueryContext(ctx context.Context, query string, args []drive
}

func (tc *taosConn) queryCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
var reqIDValue int64
reqID := ctx.Value(common.ReqIDKey)
if reqID != nil {
reqIDValue, _ = reqID.(int64)
reqIDValue, err := common.GetReqIDFromCtx(ctx)
if err != nil {
return nil, err
}
if len(args) != 0 {
if !tc.cfg.interpolateParams {
Expand Down
52 changes: 33 additions & 19 deletions taosWS/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ var (
type taosConn struct {
buf *bytes.Buffer
client *websocket.Conn
requestID uint64
writeLock sync.Mutex
readTimeout time.Duration
writeTimeout time.Duration
Expand All @@ -69,10 +68,6 @@ type message struct {
err error
}

func (tc *taosConn) generateReqID() uint64 {
return atomic.AddUint64(&tc.requestID, 1)
}

func newTaosConn(cfg *config) (*taosConn, error) {
endpointUrl := &url.URL{
Scheme: cfg.net,
Expand All @@ -98,7 +93,6 @@ func newTaosConn(cfg *config) (*taosConn, error) {
tc := &taosConn{
buf: &bytes.Buffer{},
client: ws,
requestID: 0,
readTimeout: cfg.readTimeout,
writeTimeout: cfg.writeTimeout,
cfg: cfg,
Expand Down Expand Up @@ -170,10 +164,28 @@ func (tc *taosConn) isClosed() bool {
}

func (tc *taosConn) Prepare(query string) (driver.Stmt, error) {
return tc.PrepareContext(context.Background(), query)
}

func getReqID(ctx context.Context) (uint64, error) {
reqID, err := common.GetReqIDFromCtx(ctx)
if err != nil {
return 0, err
}
if reqID == 0 {
return uint64(common.GetReqID()), nil
}
return uint64(reqID), nil
}
func (tc *taosConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if tc.isClosed() {
return nil, driver.ErrBadConn
}
stmtID, err := tc.stmtInit()
reqID, err := getReqID(ctx)
if err != nil {
return nil, err
}
stmtID, err := tc.stmtInit(reqID)
if err != nil {
return nil, err
}
Expand All @@ -191,8 +203,7 @@ func (tc *taosConn) Prepare(query string) (driver.Stmt, error) {
return stmt, nil
}

func (tc *taosConn) stmtInit() (uint64, error) {
reqID := tc.generateReqID()
func (tc *taosConn) stmtInit(reqID uint64) (uint64, error) {
req := &StmtInitReq{
ReqID: reqID,
}
Expand Down Expand Up @@ -225,7 +236,7 @@ func (tc *taosConn) stmtInit() (uint64, error) {
}

func (tc *taosConn) stmtPrepare(stmtID uint64, sql string) (bool, error) {
reqID := tc.generateReqID()
reqID := uint64(common.GetReqID())
req := &StmtPrepareRequest{
ReqID: reqID,
StmtID: stmtID,
Expand Down Expand Up @@ -260,7 +271,7 @@ func (tc *taosConn) stmtPrepare(stmtID uint64, sql string) (bool, error) {
}

func (tc *taosConn) stmtClose(stmtID uint64) error {
reqID := tc.generateReqID()
reqID := uint64(common.GetReqID())
req := &StmtCloseRequest{
ReqID: reqID,
StmtID: stmtID,
Expand All @@ -286,7 +297,7 @@ func (tc *taosConn) stmtClose(stmtID uint64) error {
}

func (tc *taosConn) stmtGetColFields(stmtID uint64) ([]*stmtCommon.StmtField, error) {
reqID := tc.generateReqID()
reqID := uint64(common.GetReqID())
req := &StmtGetColFieldsRequest{
ReqID: reqID,
StmtID: stmtID,
Expand Down Expand Up @@ -320,7 +331,7 @@ func (tc *taosConn) stmtGetColFields(stmtID uint64) ([]*stmtCommon.StmtField, er
}

func (tc *taosConn) stmtBindParam(stmtID uint64, block []byte) error {
reqID := tc.generateReqID()
reqID := uint64(common.GetReqID())
tc.buf.Reset()
WriteUint64(tc.buf, reqID)
WriteUint64(tc.buf, stmtID)
Expand Down Expand Up @@ -365,7 +376,7 @@ func WriteUint16(buffer *bytes.Buffer, v uint16) {
}

func (tc *taosConn) stmtAddBatch(stmtID uint64) error {
reqID := tc.generateReqID()
reqID := uint64(common.GetReqID())
req := &StmtAddBatchRequest{
ReqID: reqID,
StmtID: stmtID,
Expand Down Expand Up @@ -399,7 +410,7 @@ func (tc *taosConn) stmtAddBatch(stmtID uint64) error {
}

func (tc *taosConn) stmtExec(stmtID uint64) (int, error) {
reqID := tc.generateReqID()
reqID := uint64(common.GetReqID())
req := &StmtExecRequest{
ReqID: reqID,
StmtID: stmtID,
Expand Down Expand Up @@ -433,7 +444,7 @@ func (tc *taosConn) stmtExec(stmtID uint64) (int, error) {
}

func (tc *taosConn) stmtUseResult(stmtID uint64) (*rows, error) {
reqID := tc.generateReqID()
reqID := uint64(common.GetReqID())
req := &StmtUseResultRequest{
ReqID: reqID,
StmtID: stmtID,
Expand Down Expand Up @@ -527,10 +538,14 @@ func (tc *taosConn) queryCtx(ctx context.Context, query string, args []driver.Na
return rs, err
}

func (tc *taosConn) doQuery(_ context.Context, query string, args []driver.NamedValue) (*WSQueryResp, error) {
func (tc *taosConn) doQuery(ctx context.Context, query string, args []driver.NamedValue) (*WSQueryResp, error) {
if tc.isClosed() {
return nil, driver.ErrBadConn
}
reqID, err := getReqID(ctx)
if err != nil {
return nil, err
}
if len(args) != 0 {
if !tc.cfg.interpolateParams {
return nil, driver.ErrSkip
Expand All @@ -542,7 +557,6 @@ func (tc *taosConn) doQuery(_ context.Context, query string, args []driver.Named
}
query = prepared
}
reqID := tc.generateReqID()
tc.buf.Reset()

WriteUint64(tc.buf, reqID) // req id
Expand All @@ -551,7 +565,7 @@ func (tc *taosConn) doQuery(_ context.Context, query string, args []driver.Named
WriteUint16(tc.buf, 1) // version
WriteUint32(tc.buf, uint32(len(query))) // sql length
tc.buf.WriteString(query)
err := tc.writeBinary(tc.buf.Bytes())
err = tc.writeBinary(tc.buf.Bytes())
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions taosWS/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (rs *rows) Next(dest []driver.Value) error {
}

func (rs *rows) taosFetchBlock() error {
reqID := rs.conn.generateReqID()
reqID := uint64(common.GetReqID())
rs.buf.Reset()
WriteUint64(rs.buf, reqID) // req id
WriteUint64(rs.buf, rs.resultID) // message id
Expand Down Expand Up @@ -139,7 +139,7 @@ func (rs *rows) taosFetchBlock() error {

func (rs *rows) freeResult() error {
tc := rs.conn
reqID := tc.generateReqID()
reqID := uint64(common.GetReqID())
req := &WSFreeResultReq{
ReqID: reqID,
ID: rs.resultID,
Expand Down

0 comments on commit 4da4002

Please sign in to comment.