From f0f283bb3d7ea9a2a2760959f0eccdb4f8c7e540 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net> Date: Wed, 19 Oct 2016 17:37:38 -0500 Subject: [PATCH] Add test for Where() before Set() on update. --- lib/sqlbuilder/builder.go | 16 ++++++++++++ lib/sqlbuilder/builder_test.go | 46 ++++++++++++++++++++++++++++++++++ lib/sqlbuilder/delete.go | 4 +++ lib/sqlbuilder/interfaces.go | 6 +++++ lib/sqlbuilder/select.go | 20 ++++++--------- lib/sqlbuilder/update.go | 39 +++++++++++++++++++--------- 6 files changed, 107 insertions(+), 24 deletions(-) diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go index 9130dba0..b5eb8623 100644 --- a/lib/sqlbuilder/builder.go +++ b/lib/sqlbuilder/builder.go @@ -530,3 +530,19 @@ var ( _ = Builder(&sqlBuilder{}) _ = exprDB(&exprProxy{}) ) + +func joinArguments(args ...[]interface{}) []interface{} { + total := 0 + for i := range args { + total += len(args[i]) + } + if total == 0 { + return nil + } + + flatten := make([]interface{}, 0, total) + for i := range args { + flatten = append(flatten, args[i]...) + } + return flatten +} diff --git a/lib/sqlbuilder/builder_test.go b/lib/sqlbuilder/builder_test.go index 315cd4f1..4ad9d8c6 100644 --- a/lib/sqlbuilder/builder_test.go +++ b/lib/sqlbuilder/builder_test.go @@ -654,6 +654,45 @@ func TestUpdate(t *testing.T) { b.Update("artist").Set("name", "Artist").String(), ) + { + idSlice := []int64{8, 7, 6} + q := b.Update("artist").Set(db.Cond{"some_column": 10}).Where(db.Cond{"id": 1}, db.Cond{"another_val": idSlice}) + assert.Equal( + `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IN ($3, $4, $5))`, + q.String(), + ) + assert.Equal( + []interface{}{10, 1, int64(8), int64(7), int64(6)}, + q.Arguments(), + ) + } + + { + idSlice := []int64{} + q := b.Update("artist").Set(db.Cond{"some_column": 10}).Where(db.Cond{"id": 1}, db.Cond{"another_val": idSlice}) + assert.Equal( + `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IS NULL)`, + q.String(), + ) + assert.Equal( + []interface{}{10, 1}, + q.Arguments(), + ) + } + + { + idSlice := []int64{} + q := b.Update("artist").Where(db.Cond{"id": 1}, db.Cond{"another_val": idSlice}).Set(db.Cond{"some_column": 10}) + assert.Equal( + `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IS NULL)`, + q.String(), + ) + assert.Equal( + []interface{}{10, 1}, + q.Arguments(), + ) + } + assert.Equal( `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(), @@ -671,6 +710,13 @@ func TestUpdate(t *testing.T) { }{"Artist"}).Where(db.Cond{"id <": 5}).String(), ) + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Where(db.Cond{"id <": 5}).Set(struct { + Nombre string `db:"name"` + }{"Artist"}).String(), + ) + assert.Equal( `UPDATE "artist" SET "name" = $1, "last_name" = $2 WHERE ("id" < $3)`, b.Update("artist").Set(struct { diff --git a/lib/sqlbuilder/delete.go b/lib/sqlbuilder/delete.go index 5425af0e..d417784d 100644 --- a/lib/sqlbuilder/delete.go +++ b/lib/sqlbuilder/delete.go @@ -27,6 +27,10 @@ func (qd *deleter) Limit(limit int) Deleter { return qd } +func (qd *deleter) Arguments() []interface{} { + return qd.arguments +} + func (qd *deleter) Exec() (sql.Result, error) { return qd.builder.sess.StatementExec(qd.statement(), qd.arguments...) } diff --git a/lib/sqlbuilder/interfaces.go b/lib/sqlbuilder/interfaces.go index 478179d2..b94d9020 100644 --- a/lib/sqlbuilder/interfaces.go +++ b/lib/sqlbuilder/interfaces.go @@ -365,6 +365,9 @@ type Deleter interface { // fmt.Stringer provides `String() string`, you can use `String()` to compile // the `Inserter` into a string. fmt.Stringer + + // Arguments returns the arguments that are prepared for this query. + Arguments() []interface{} } // Updater represents an UPDATE statement. @@ -388,6 +391,9 @@ type Updater interface { // fmt.Stringer provides `String() string`, you can use `String()` to compile // the `Inserter` into a string. fmt.Stringer + + // Arguments returns the arguments that are prepared for this query. + Arguments() []interface{} } // Execer provides methods for executing statements that do not return results. diff --git a/lib/sqlbuilder/select.go b/lib/sqlbuilder/select.go index 326e613e..131911df 100644 --- a/lib/sqlbuilder/select.go +++ b/lib/sqlbuilder/select.go @@ -121,18 +121,14 @@ func (qs *selector) Arguments() []interface{} { qs.mu.Lock() defer qs.mu.Unlock() - total := len(qs.tableArgs) + len(qs.columnsArgs) + len(qs.whereArgs) + len(qs.joinsArgs) + len(qs.groupByArgs) + len(qs.orderByArgs) - if total == 0 { - return nil - } - args := make([]interface{}, 0, total) - args = append(args, qs.tableArgs...) - args = append(args, qs.columnsArgs...) - args = append(args, qs.joinsArgs...) - args = append(args, qs.whereArgs...) - args = append(args, qs.groupByArgs...) - args = append(args, qs.orderByArgs...) - return args + return joinArguments( + qs.tableArgs, + qs.columnsArgs, + qs.joinsArgs, + qs.whereArgs, + qs.groupByArgs, + qs.orderByArgs, + ) } func (qs *selector) GroupBy(columns ...interface{}) Selector { diff --git a/lib/sqlbuilder/update.go b/lib/sqlbuilder/update.go index 6c73724b..cb1b09bd 100644 --- a/lib/sqlbuilder/update.go +++ b/lib/sqlbuilder/update.go @@ -2,18 +2,25 @@ package sqlbuilder import ( "database/sql" + "sync" "upper.io/db.v2/internal/sqladapter/exql" ) type updater struct { *stringer - builder *sqlBuilder - table string - columnValues *exql.ColumnValues - limit int - where *exql.Where - arguments []interface{} + builder *sqlBuilder + table string + + columnValues *exql.ColumnValues + columnValuesArgs []interface{} + + limit int + + where *exql.Where + whereArgs []interface{} + + mu sync.Mutex } func (qu *updater) Set(terms ...interface{}) Updater { @@ -36,28 +43,36 @@ func (qu *updater) Set(terms ...interface{}) Updater { cvs = append(cvs, cv) } - args = append(args, qu.arguments...) - qu.columnValues.Insert(cvs...) - qu.arguments = append(qu.arguments, args...) + qu.columnValuesArgs = append(qu.columnValuesArgs, args...) } else if len(terms) > 1 { cv, arguments := qu.builder.t.ToColumnValues(terms) qu.columnValues.Insert(cv.ColumnValues...) - qu.arguments = append(qu.arguments, arguments...) + qu.columnValuesArgs = append(qu.columnValuesArgs, arguments...) } return qu } +func (qu *updater) Arguments() []interface{} { + qu.mu.Lock() + defer qu.mu.Unlock() + + return joinArguments( + qu.columnValuesArgs, + qu.whereArgs, + ) +} + func (qu *updater) Where(terms ...interface{}) Updater { where, arguments := qu.builder.t.ToWhereWithArguments(terms) qu.where = &where - qu.arguments = append(qu.arguments, arguments...) + qu.whereArgs = append(qu.whereArgs, arguments...) return qu } func (qu *updater) Exec() (sql.Result, error) { - return qu.builder.sess.StatementExec(qu.statement(), qu.arguments...) + return qu.builder.sess.StatementExec(qu.statement(), qu.Arguments()...) } func (qu *updater) Limit(limit int) Updater {