diff --git a/go/logic/applier.go b/go/logic/applier.go index caa759cd2..37c306a2a 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -14,10 +14,13 @@ import ( "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/binlog" - "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" - "github.com/openark/golib/log" + "context" + "database/sql/driver" + + "github.com/github/gh-ost/go/mysql" + drivermysql "github.com/go-sql-driver/mysql" "github.com/openark/golib/sqlutils" ) @@ -77,7 +80,8 @@ func NewApplier(migrationContext *base.MigrationContext) *Applier { func (this *Applier) InitDBConnections(maxConns int) (err error) { applierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) - if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, applierUri); err != nil { + uriWithMulti := fmt.Sprintf("%s&multiStatements=true", applierUri) + if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, uriWithMulti); err != nil { return err } this.db.SetMaxOpenConns(maxConns) @@ -1209,45 +1213,80 @@ func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) []*dmlB // ApplyDMLEventQueries applies multiple DML queries onto the _ghost_ table func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) error { var totalDelta int64 + ctx := context.Background() err := func() error { - tx, err := this.db.Begin() + conn, err := this.db.Conn(ctx) if err != nil { return err } + defer conn.Close() + + sessionQuery := "SET /* gh-ost */ SESSION time_zone = '+00:00'" + sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery()) + if _, err := conn.ExecContext(ctx, sessionQuery); err != nil { + return err + } + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + return err + } rollback := func(err error) error { tx.Rollback() return err } - sessionQuery := "SET /* gh-ost */ SESSION time_zone = '+00:00'" - sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery()) - - if _, err := tx.Exec(sessionQuery); err != nil { - return rollback(err) - } + buildResults := make([]*dmlBuildResult, 0, len(dmlEvents)) + nArgs := 0 for _, dmlEvent := range dmlEvents { for _, buildResult := range this.buildDMLEventQuery(dmlEvent) { if buildResult.err != nil { return rollback(buildResult.err) } - result, err := tx.Exec(buildResult.query, buildResult.args...) + nArgs += len(buildResult.args) + buildResults = append(buildResults, buildResult) + } + } - if err != nil { - err = fmt.Errorf("%w; query=%s; args=%+v", err, buildResult.query, buildResult.args) - return rollback(err) + // We batch together the DML queries into multi-statements to minimize network trips. + // We have to use the raw driver connection to access the rows affected + // for each statement in the multi-statement. + execErr := conn.Raw(func(driverConn any) error { + ex := driverConn.(driver.ExecerContext) + nvc := driverConn.(driver.NamedValueChecker) + + multiArgs := make([]driver.NamedValue, 0, nArgs) + multiQueryBuilder := strings.Builder{} + for _, buildResult := range buildResults { + for _, arg := range buildResult.args { + nv := driver.NamedValue{Value: driver.Value(arg)} + nvc.CheckNamedValue(&nv) + multiArgs = append(multiArgs, nv) } - rowsAffected, err := result.RowsAffected() - if err != nil { - log.Warningf("error getting rows affected from DML event query: %s. i'm going to assume that the DML affected a single row, but this may result in inaccurate statistics", err) - rowsAffected = 1 - } - // each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1). - // multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event - totalDelta += buildResult.rowsDelta * rowsAffected + multiQueryBuilder.WriteString(buildResult.query) + multiQueryBuilder.WriteString(";\n") } + + res, err := ex.ExecContext(ctx, multiQueryBuilder.String(), multiArgs) + if err != nil { + err = fmt.Errorf("%w; query=%s; args=%+v", err, multiQueryBuilder.String(), multiArgs) + return err + } + + mysqlRes := res.(drivermysql.Result) + + // each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1). + // multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event + for i, rowsAffected := range mysqlRes.AllRowsAffected() { + totalDelta += buildResults[i].rowsDelta * rowsAffected + } + return nil + }) + + if execErr != nil { + return rollback(execErr) } if err := tx.Commit(); err != nil { return err