diff --git a/common/reqid.go b/common/reqid.go index 02f1c72..15effae 100644 --- a/common/reqid.go +++ b/common/reqid.go @@ -1,6 +1,8 @@ package common import ( + "context" + "fmt" "math/bits" "os" "sync/atomic" @@ -78,3 +80,17 @@ func murmurHash32(data []byte, seed uint32) uint32 { return h1 } + +func GetReqIDFromCtx(ctx context.Context) (int64, error) { + var reqIDValue int64 + var ok bool + reqID := ctx.Value(ReqIDKey) + if reqID != nil { + reqIDValue, ok = reqID.(int64) + if !ok { + return 0, fmt.Errorf("invalid taos_req_id: %v, should be int64, got %T", reqID, reqID) + } + return reqIDValue, nil + } + return 0, nil +} diff --git a/taosRestful/connection.go b/taosRestful/connection.go index 4970184..5bef61b 100644 --- a/taosRestful/connection.go +++ b/taosRestful/connection.go @@ -27,6 +27,7 @@ type taosConn struct { cfg *config client *http.Client url *url.URL + baseRawQuery string header map[string][]string readBufferSize int } @@ -69,7 +70,7 @@ func newTaosConn(cfg *config) (*taosConn, error) { "Connection": {"keep-alive"}, } if cfg.token != "" { - tc.url.RawQuery = fmt.Sprintf("token=%s", cfg.token) + tc.baseRawQuery = fmt.Sprintf("token=%s", cfg.token) } else { basic := base64.StdEncoding.EncodeToString([]byte(cfg.user + ":" + cfg.passwd)) tc.header["Authorization"] = []string{fmt.Sprintf("Basic %s", basic)} @@ -191,6 +192,18 @@ func (tc *taosConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver. } func (tc *taosConn) taosQuery(ctx context.Context, sql string, bufferSize int) (*common.TDEngineRestfulResp, error) { + reqIDValue, err := common.GetReqIDFromCtx(ctx) + if err != nil { + return nil, err + } + if reqIDValue == 0 { + reqIDValue = common.GetReqID() + } + if tc.baseRawQuery != "" { + tc.url.RawQuery = fmt.Sprintf("%s&req_id=%d", tc.baseRawQuery, reqIDValue) + } else { + tc.url.RawQuery = fmt.Sprintf("req_id=%d", reqIDValue) + } body := ioutil.NopCloser(strings.NewReader(sql)) req := &http.Request{ Method: http.MethodPost, diff --git a/taosSql/connection.go b/taosSql/connection.go index 288a74f..7277c05 100644 --- a/taosSql/connection.go +++ b/taosSql/connection.go @@ -79,12 +79,10 @@ func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver } func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - var reqIDValue int64 - reqID := ctx.Value(common.ReqIDKey) - if reqID != nil { - reqIDValue, _ = reqID.(int64) + reqIDValue, err := common.GetReqIDFromCtx(ctx) + if err != nil { + return nil, err } - if len(args) != 0 { if !tc.cfg.interpolateParams { return nil, driver.ErrSkip @@ -132,10 +130,9 @@ func (tc *taosConn) QueryContext(ctx context.Context, query string, args []drive } func (tc *taosConn) queryCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - var reqIDValue int64 - reqID := ctx.Value(common.ReqIDKey) - if reqID != nil { - reqIDValue, _ = reqID.(int64) + reqIDValue, err := common.GetReqIDFromCtx(ctx) + if err != nil { + return nil, err } if len(args) != 0 { if !tc.cfg.interpolateParams { diff --git a/taosWS/connection.go b/taosWS/connection.go index c465815..f73d04d 100644 --- a/taosWS/connection.go +++ b/taosWS/connection.go @@ -51,7 +51,6 @@ var ( type taosConn struct { buf *bytes.Buffer client *websocket.Conn - requestID uint64 writeLock sync.Mutex readTimeout time.Duration writeTimeout time.Duration @@ -69,10 +68,6 @@ type message struct { err error } -func (tc *taosConn) generateReqID() uint64 { - return atomic.AddUint64(&tc.requestID, 1) -} - func newTaosConn(cfg *config) (*taosConn, error) { endpointUrl := &url.URL{ Scheme: cfg.net, @@ -98,7 +93,6 @@ func newTaosConn(cfg *config) (*taosConn, error) { tc := &taosConn{ buf: &bytes.Buffer{}, client: ws, - requestID: 0, readTimeout: cfg.readTimeout, writeTimeout: cfg.writeTimeout, cfg: cfg, @@ -170,10 +164,28 @@ func (tc *taosConn) isClosed() bool { } func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { + return tc.PrepareContext(context.Background(), query) +} + +func getReqID(ctx context.Context) (uint64, error) { + reqID, err := common.GetReqIDFromCtx(ctx) + if err != nil { + return 0, err + } + if reqID == 0 { + return uint64(common.GetReqID()), nil + } + return uint64(reqID), nil +} +func (tc *taosConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { if tc.isClosed() { return nil, driver.ErrBadConn } - stmtID, err := tc.stmtInit() + reqID, err := getReqID(ctx) + if err != nil { + return nil, err + } + stmtID, err := tc.stmtInit(reqID) if err != nil { return nil, err } @@ -191,8 +203,7 @@ func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { return stmt, nil } -func (tc *taosConn) stmtInit() (uint64, error) { - reqID := tc.generateReqID() +func (tc *taosConn) stmtInit(reqID uint64) (uint64, error) { req := &StmtInitReq{ ReqID: reqID, } @@ -225,7 +236,7 @@ func (tc *taosConn) stmtInit() (uint64, error) { } func (tc *taosConn) stmtPrepare(stmtID uint64, sql string) (bool, error) { - reqID := tc.generateReqID() + reqID := uint64(common.GetReqID()) req := &StmtPrepareRequest{ ReqID: reqID, StmtID: stmtID, @@ -260,7 +271,7 @@ func (tc *taosConn) stmtPrepare(stmtID uint64, sql string) (bool, error) { } func (tc *taosConn) stmtClose(stmtID uint64) error { - reqID := tc.generateReqID() + reqID := uint64(common.GetReqID()) req := &StmtCloseRequest{ ReqID: reqID, StmtID: stmtID, @@ -286,7 +297,7 @@ func (tc *taosConn) stmtClose(stmtID uint64) error { } func (tc *taosConn) stmtGetColFields(stmtID uint64) ([]*stmtCommon.StmtField, error) { - reqID := tc.generateReqID() + reqID := uint64(common.GetReqID()) req := &StmtGetColFieldsRequest{ ReqID: reqID, StmtID: stmtID, @@ -320,7 +331,7 @@ func (tc *taosConn) stmtGetColFields(stmtID uint64) ([]*stmtCommon.StmtField, er } func (tc *taosConn) stmtBindParam(stmtID uint64, block []byte) error { - reqID := tc.generateReqID() + reqID := uint64(common.GetReqID()) tc.buf.Reset() WriteUint64(tc.buf, reqID) WriteUint64(tc.buf, stmtID) @@ -365,7 +376,7 @@ func WriteUint16(buffer *bytes.Buffer, v uint16) { } func (tc *taosConn) stmtAddBatch(stmtID uint64) error { - reqID := tc.generateReqID() + reqID := uint64(common.GetReqID()) req := &StmtAddBatchRequest{ ReqID: reqID, StmtID: stmtID, @@ -399,7 +410,7 @@ func (tc *taosConn) stmtAddBatch(stmtID uint64) error { } func (tc *taosConn) stmtExec(stmtID uint64) (int, error) { - reqID := tc.generateReqID() + reqID := uint64(common.GetReqID()) req := &StmtExecRequest{ ReqID: reqID, StmtID: stmtID, @@ -433,7 +444,7 @@ func (tc *taosConn) stmtExec(stmtID uint64) (int, error) { } func (tc *taosConn) stmtUseResult(stmtID uint64) (*rows, error) { - reqID := tc.generateReqID() + reqID := uint64(common.GetReqID()) req := &StmtUseResultRequest{ ReqID: reqID, StmtID: stmtID, @@ -527,10 +538,14 @@ func (tc *taosConn) queryCtx(ctx context.Context, query string, args []driver.Na return rs, err } -func (tc *taosConn) doQuery(_ context.Context, query string, args []driver.NamedValue) (*WSQueryResp, error) { +func (tc *taosConn) doQuery(ctx context.Context, query string, args []driver.NamedValue) (*WSQueryResp, error) { if tc.isClosed() { return nil, driver.ErrBadConn } + reqID, err := getReqID(ctx) + if err != nil { + return nil, err + } if len(args) != 0 { if !tc.cfg.interpolateParams { return nil, driver.ErrSkip @@ -542,7 +557,6 @@ func (tc *taosConn) doQuery(_ context.Context, query string, args []driver.Named } query = prepared } - reqID := tc.generateReqID() tc.buf.Reset() WriteUint64(tc.buf, reqID) // req id @@ -551,7 +565,7 @@ func (tc *taosConn) doQuery(_ context.Context, query string, args []driver.Named WriteUint16(tc.buf, 1) // version WriteUint32(tc.buf, uint32(len(query))) // sql length tc.buf.WriteString(query) - err := tc.writeBinary(tc.buf.Bytes()) + err = tc.writeBinary(tc.buf.Bytes()) if err != nil { return nil, err } diff --git a/taosWS/rows.go b/taosWS/rows.go index 636f54c..5b192fc 100644 --- a/taosWS/rows.go +++ b/taosWS/rows.go @@ -86,7 +86,7 @@ func (rs *rows) Next(dest []driver.Value) error { } func (rs *rows) taosFetchBlock() error { - reqID := rs.conn.generateReqID() + reqID := uint64(common.GetReqID()) rs.buf.Reset() WriteUint64(rs.buf, reqID) // req id WriteUint64(rs.buf, rs.resultID) // message id @@ -139,7 +139,7 @@ func (rs *rows) taosFetchBlock() error { func (rs *rows) freeResult() error { tc := rs.conn - reqID := tc.generateReqID() + reqID := uint64(common.GetReqID()) req := &WSFreeResultReq{ ReqID: reqID, ID: rs.resultID,