Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enh: support input request id in rest and ws #290

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions common/reqid.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package common

import (
"context"
"fmt"
"math/bits"
"os"
"sync/atomic"
Expand Down Expand Up @@ -78,3 +80,17 @@ func murmurHash32(data []byte, seed uint32) uint32 {

return h1
}

func GetReqIDFromCtx(ctx context.Context) (int64, error) {
var reqIDValue int64
var ok bool
reqID := ctx.Value(ReqIDKey)
if reqID != nil {
reqIDValue, ok = reqID.(int64)
if !ok {
return 0, fmt.Errorf("invalid taos_req_id: %v, should be int64, got %T", reqID, reqID)
}
return reqIDValue, nil
}
return 0, nil
}
15 changes: 14 additions & 1 deletion taosRestful/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type taosConn struct {
cfg *config
client *http.Client
url *url.URL
baseRawQuery string
header map[string][]string
readBufferSize int
}
Expand Down Expand Up @@ -69,7 +70,7 @@ func newTaosConn(cfg *config) (*taosConn, error) {
"Connection": {"keep-alive"},
}
if cfg.token != "" {
tc.url.RawQuery = fmt.Sprintf("token=%s", cfg.token)
tc.baseRawQuery = fmt.Sprintf("token=%s", cfg.token)
} else {
basic := base64.StdEncoding.EncodeToString([]byte(cfg.user + ":" + cfg.passwd))
tc.header["Authorization"] = []string{fmt.Sprintf("Basic %s", basic)}
Expand Down Expand Up @@ -191,6 +192,18 @@ func (tc *taosConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.
}

func (tc *taosConn) taosQuery(ctx context.Context, sql string, bufferSize int) (*common.TDEngineRestfulResp, error) {
reqIDValue, err := common.GetReqIDFromCtx(ctx)
if err != nil {
return nil, err
}
if reqIDValue == 0 {
reqIDValue = common.GetReqID()
}
if tc.baseRawQuery != "" {
tc.url.RawQuery = fmt.Sprintf("%s&req_id=%d", tc.baseRawQuery, reqIDValue)
} else {
tc.url.RawQuery = fmt.Sprintf("req_id=%d", reqIDValue)
}
body := ioutil.NopCloser(strings.NewReader(sql))
req := &http.Request{
Method: http.MethodPost,
Expand Down
15 changes: 6 additions & 9 deletions taosSql/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,10 @@ func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver
}

func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
var reqIDValue int64
reqID := ctx.Value(common.ReqIDKey)
if reqID != nil {
reqIDValue, _ = reqID.(int64)
reqIDValue, err := common.GetReqIDFromCtx(ctx)
if err != nil {
return nil, err
}

if len(args) != 0 {
if !tc.cfg.interpolateParams {
return nil, driver.ErrSkip
Expand Down Expand Up @@ -132,10 +130,9 @@ func (tc *taosConn) QueryContext(ctx context.Context, query string, args []drive
}

func (tc *taosConn) queryCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
var reqIDValue int64
reqID := ctx.Value(common.ReqIDKey)
if reqID != nil {
reqIDValue, _ = reqID.(int64)
reqIDValue, err := common.GetReqIDFromCtx(ctx)
if err != nil {
return nil, err
}
if len(args) != 0 {
if !tc.cfg.interpolateParams {
Expand Down
52 changes: 33 additions & 19 deletions taosWS/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ var (
type taosConn struct {
buf *bytes.Buffer
client *websocket.Conn
requestID uint64
writeLock sync.Mutex
readTimeout time.Duration
writeTimeout time.Duration
Expand All @@ -69,10 +68,6 @@ type message struct {
err error
}

func (tc *taosConn) generateReqID() uint64 {
return atomic.AddUint64(&tc.requestID, 1)
}

func newTaosConn(cfg *config) (*taosConn, error) {
endpointUrl := &url.URL{
Scheme: cfg.net,
Expand All @@ -98,7 +93,6 @@ func newTaosConn(cfg *config) (*taosConn, error) {
tc := &taosConn{
buf: &bytes.Buffer{},
client: ws,
requestID: 0,
readTimeout: cfg.readTimeout,
writeTimeout: cfg.writeTimeout,
cfg: cfg,
Expand Down Expand Up @@ -170,10 +164,28 @@ func (tc *taosConn) isClosed() bool {
}

func (tc *taosConn) Prepare(query string) (driver.Stmt, error) {
return tc.PrepareContext(context.Background(), query)
}

func getReqID(ctx context.Context) (uint64, error) {
reqID, err := common.GetReqIDFromCtx(ctx)
if err != nil {
return 0, err
}
if reqID == 0 {
return uint64(common.GetReqID()), nil
}
return uint64(reqID), nil
}
func (tc *taosConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if tc.isClosed() {
return nil, driver.ErrBadConn
}
stmtID, err := tc.stmtInit()
reqID, err := getReqID(ctx)
if err != nil {
return nil, err
}
stmtID, err := tc.stmtInit(reqID)
if err != nil {
return nil, err
}
Expand All @@ -191,8 +203,7 @@ func (tc *taosConn) Prepare(query string) (driver.Stmt, error) {
return stmt, nil
}

func (tc *taosConn) stmtInit() (uint64, error) {
reqID := tc.generateReqID()
func (tc *taosConn) stmtInit(reqID uint64) (uint64, error) {
req := &StmtInitReq{
ReqID: reqID,
}
Expand Down Expand Up @@ -225,7 +236,7 @@ func (tc *taosConn) stmtInit() (uint64, error) {
}

func (tc *taosConn) stmtPrepare(stmtID uint64, sql string) (bool, error) {
reqID := tc.generateReqID()
reqID := uint64(common.GetReqID())
req := &StmtPrepareRequest{
ReqID: reqID,
StmtID: stmtID,
Expand Down Expand Up @@ -260,7 +271,7 @@ func (tc *taosConn) stmtPrepare(stmtID uint64, sql string) (bool, error) {
}

func (tc *taosConn) stmtClose(stmtID uint64) error {
reqID := tc.generateReqID()
reqID := uint64(common.GetReqID())
req := &StmtCloseRequest{
ReqID: reqID,
StmtID: stmtID,
Expand All @@ -286,7 +297,7 @@ func (tc *taosConn) stmtClose(stmtID uint64) error {
}

func (tc *taosConn) stmtGetColFields(stmtID uint64) ([]*stmtCommon.StmtField, error) {
reqID := tc.generateReqID()
reqID := uint64(common.GetReqID())
req := &StmtGetColFieldsRequest{
ReqID: reqID,
StmtID: stmtID,
Expand Down Expand Up @@ -320,7 +331,7 @@ func (tc *taosConn) stmtGetColFields(stmtID uint64) ([]*stmtCommon.StmtField, er
}

func (tc *taosConn) stmtBindParam(stmtID uint64, block []byte) error {
reqID := tc.generateReqID()
reqID := uint64(common.GetReqID())
tc.buf.Reset()
WriteUint64(tc.buf, reqID)
WriteUint64(tc.buf, stmtID)
Expand Down Expand Up @@ -365,7 +376,7 @@ func WriteUint16(buffer *bytes.Buffer, v uint16) {
}

func (tc *taosConn) stmtAddBatch(stmtID uint64) error {
reqID := tc.generateReqID()
reqID := uint64(common.GetReqID())
req := &StmtAddBatchRequest{
ReqID: reqID,
StmtID: stmtID,
Expand Down Expand Up @@ -399,7 +410,7 @@ func (tc *taosConn) stmtAddBatch(stmtID uint64) error {
}

func (tc *taosConn) stmtExec(stmtID uint64) (int, error) {
reqID := tc.generateReqID()
reqID := uint64(common.GetReqID())
req := &StmtExecRequest{
ReqID: reqID,
StmtID: stmtID,
Expand Down Expand Up @@ -433,7 +444,7 @@ func (tc *taosConn) stmtExec(stmtID uint64) (int, error) {
}

func (tc *taosConn) stmtUseResult(stmtID uint64) (*rows, error) {
reqID := tc.generateReqID()
reqID := uint64(common.GetReqID())
req := &StmtUseResultRequest{
ReqID: reqID,
StmtID: stmtID,
Expand Down Expand Up @@ -527,10 +538,14 @@ func (tc *taosConn) queryCtx(ctx context.Context, query string, args []driver.Na
return rs, err
}

func (tc *taosConn) doQuery(_ context.Context, query string, args []driver.NamedValue) (*WSQueryResp, error) {
func (tc *taosConn) doQuery(ctx context.Context, query string, args []driver.NamedValue) (*WSQueryResp, error) {
if tc.isClosed() {
return nil, driver.ErrBadConn
}
reqID, err := getReqID(ctx)
if err != nil {
return nil, err
}
if len(args) != 0 {
if !tc.cfg.interpolateParams {
return nil, driver.ErrSkip
Expand All @@ -542,7 +557,6 @@ func (tc *taosConn) doQuery(_ context.Context, query string, args []driver.Named
}
query = prepared
}
reqID := tc.generateReqID()
tc.buf.Reset()

WriteUint64(tc.buf, reqID) // req id
Expand All @@ -551,7 +565,7 @@ func (tc *taosConn) doQuery(_ context.Context, query string, args []driver.Named
WriteUint16(tc.buf, 1) // version
WriteUint32(tc.buf, uint32(len(query))) // sql length
tc.buf.WriteString(query)
err := tc.writeBinary(tc.buf.Bytes())
err = tc.writeBinary(tc.buf.Bytes())
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions taosWS/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (rs *rows) Next(dest []driver.Value) error {
}

func (rs *rows) taosFetchBlock() error {
reqID := rs.conn.generateReqID()
reqID := uint64(common.GetReqID())
rs.buf.Reset()
WriteUint64(rs.buf, reqID) // req id
WriteUint64(rs.buf, rs.resultID) // message id
Expand Down Expand Up @@ -139,7 +139,7 @@ func (rs *rows) taosFetchBlock() error {

func (rs *rows) freeResult() error {
tc := rs.conn
reqID := tc.generateReqID()
reqID := uint64(common.GetReqID())
req := &WSFreeResultReq{
ReqID: reqID,
ID: rs.resultID,
Expand Down
Loading