diff --git a/pkg/datasource/sql/conn_xa.go b/pkg/datasource/sql/conn_xa.go index 7dce1e966..f557c1e24 100644 --- a/pkg/datasource/sql/conn_xa.go +++ b/pkg/datasource/sql/conn_xa.go @@ -107,17 +107,19 @@ func (c *XAConn) ExecContext(ctx context.Context, query string, args []driver.Na // BeginTx like common transaction. but it just exec XA START func (c *XAConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if !tm.IsGlobalTx(ctx) { + tx, err := c.Conn.BeginTx(ctx, opts) + return tx, err + } + c.autoCommit = false c.txCtx = types.NewTxCtx() c.txCtx.DBType = c.res.dbType c.txCtx.TxOpt = opts c.txCtx.ResourceID = c.res.resourceID - - if tm.IsGlobalTx(ctx) { - c.txCtx.XID = tm.GetXID(ctx) - c.txCtx.TransactionMode = types.XAMode - } + c.txCtx.XID = tm.GetXID(ctx) + c.txCtx.TransactionMode = types.XAMode tx, err := c.Conn.BeginTx(ctx, opts) if err != nil { @@ -170,10 +172,14 @@ func (c *XAConn) createOnceTxContext(ctx context.Context) bool { } func (c *XAConn) createNewTxOnExecIfNeed(ctx context.Context, f func() (types.ExecResult, error)) (types.ExecResult, error) { - var err error + var ( + tx driver.Tx + err error + ) + currentAutoCommit := c.autoCommit - if c.txCtx.TransactionMode != types.Local && c.autoCommit { - _, err = c.BeginTx(ctx, driver.TxOptions{Isolation: driver.IsolationLevel(gosql.LevelDefault)}) + if c.txCtx.TransactionMode != types.Local && tm.IsGlobalTx(ctx) && c.autoCommit { + tx, err = c.BeginTx(ctx, driver.TxOptions{Isolation: driver.IsolationLevel(gosql.LevelDefault)}) if err != nil { return nil, err } @@ -201,7 +207,7 @@ func (c *XAConn) createNewTxOnExecIfNeed(ctx context.Context, f func() (types.Ex return nil, err } - if currentAutoCommit { + if tx != nil && currentAutoCommit { if err := c.Commit(ctx); err != nil { log.Errorf("xa connection proxy commit failure xid:%s, err:%v", c.txCtx.XID, err) // XA End & Rollback diff --git a/pkg/datasource/sql/xa/mysql_xa_connection_test.go b/pkg/datasource/sql/xa/mysql_xa_connection_test.go index 08a370f1c..b897a907f 100644 --- a/pkg/datasource/sql/xa/mysql_xa_connection_test.go +++ b/pkg/datasource/sql/xa/mysql_xa_connection_test.go @@ -66,7 +66,11 @@ func TestMysqlXAConn_Commit(t *testing.T) { mockConn.EXPECT().ExecContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn( func(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { // check if the xid is nil - if len(strings.Split(strings.Trim(query, " "), " ")) != 3 { + xidSplits := strings.Split(strings.Trim(query, " "), " ") + if len(xidSplits) != 3 { + return nil, errors.New("xid is nil") + } + if xidSplits[2] == "''" { return nil, errors.New("xid is nil") } return nil, nil