From f0f283bb3d7ea9a2a2760959f0eccdb4f8c7e540 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Wed, 19 Oct 2016 17:37:38 -0500
Subject: [PATCH] Add test for Where() before Set() on update.

---
 lib/sqlbuilder/builder.go      | 16 ++++++++++++
 lib/sqlbuilder/builder_test.go | 46 ++++++++++++++++++++++++++++++++++
 lib/sqlbuilder/delete.go       |  4 +++
 lib/sqlbuilder/interfaces.go   |  6 +++++
 lib/sqlbuilder/select.go       | 20 ++++++---------
 lib/sqlbuilder/update.go       | 39 +++++++++++++++++++---------
 6 files changed, 107 insertions(+), 24 deletions(-)

diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go
index 9130dba0..b5eb8623 100644
--- a/lib/sqlbuilder/builder.go
+++ b/lib/sqlbuilder/builder.go
@@ -530,3 +530,19 @@ var (
 	_ = Builder(&sqlBuilder{})
 	_ = exprDB(&exprProxy{})
 )
+
+func joinArguments(args ...[]interface{}) []interface{} {
+	total := 0
+	for i := range args {
+		total += len(args[i])
+	}
+	if total == 0 {
+		return nil
+	}
+
+	flatten := make([]interface{}, 0, total)
+	for i := range args {
+		flatten = append(flatten, args[i]...)
+	}
+	return flatten
+}
diff --git a/lib/sqlbuilder/builder_test.go b/lib/sqlbuilder/builder_test.go
index 315cd4f1..4ad9d8c6 100644
--- a/lib/sqlbuilder/builder_test.go
+++ b/lib/sqlbuilder/builder_test.go
@@ -654,6 +654,45 @@ func TestUpdate(t *testing.T) {
 		b.Update("artist").Set("name", "Artist").String(),
 	)
 
+	{
+		idSlice := []int64{8, 7, 6}
+		q := b.Update("artist").Set(db.Cond{"some_column": 10}).Where(db.Cond{"id": 1}, db.Cond{"another_val": idSlice})
+		assert.Equal(
+			`UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IN ($3, $4, $5))`,
+			q.String(),
+		)
+		assert.Equal(
+			[]interface{}{10, 1, int64(8), int64(7), int64(6)},
+			q.Arguments(),
+		)
+	}
+
+	{
+		idSlice := []int64{}
+		q := b.Update("artist").Set(db.Cond{"some_column": 10}).Where(db.Cond{"id": 1}, db.Cond{"another_val": idSlice})
+		assert.Equal(
+			`UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IS NULL)`,
+			q.String(),
+		)
+		assert.Equal(
+			[]interface{}{10, 1},
+			q.Arguments(),
+		)
+	}
+
+	{
+		idSlice := []int64{}
+		q := b.Update("artist").Where(db.Cond{"id": 1}, db.Cond{"another_val": idSlice}).Set(db.Cond{"some_column": 10})
+		assert.Equal(
+			`UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IS NULL)`,
+			q.String(),
+		)
+		assert.Equal(
+			[]interface{}{10, 1},
+			q.Arguments(),
+		)
+	}
+
 	assert.Equal(
 		`UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`,
 		b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(),
@@ -671,6 +710,13 @@ func TestUpdate(t *testing.T) {
 		}{"Artist"}).Where(db.Cond{"id <": 5}).String(),
 	)
 
+	assert.Equal(
+		`UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`,
+		b.Update("artist").Where(db.Cond{"id <": 5}).Set(struct {
+			Nombre string `db:"name"`
+		}{"Artist"}).String(),
+	)
+
 	assert.Equal(
 		`UPDATE "artist" SET "name" = $1, "last_name" = $2 WHERE ("id" < $3)`,
 		b.Update("artist").Set(struct {
diff --git a/lib/sqlbuilder/delete.go b/lib/sqlbuilder/delete.go
index 5425af0e..d417784d 100644
--- a/lib/sqlbuilder/delete.go
+++ b/lib/sqlbuilder/delete.go
@@ -27,6 +27,10 @@ func (qd *deleter) Limit(limit int) Deleter {
 	return qd
 }
 
+func (qd *deleter) Arguments() []interface{} {
+	return qd.arguments
+}
+
 func (qd *deleter) Exec() (sql.Result, error) {
 	return qd.builder.sess.StatementExec(qd.statement(), qd.arguments...)
 }
diff --git a/lib/sqlbuilder/interfaces.go b/lib/sqlbuilder/interfaces.go
index 478179d2..b94d9020 100644
--- a/lib/sqlbuilder/interfaces.go
+++ b/lib/sqlbuilder/interfaces.go
@@ -365,6 +365,9 @@ type Deleter interface {
 	// fmt.Stringer provides `String() string`, you can use `String()` to compile
 	// the `Inserter` into a string.
 	fmt.Stringer
+
+	// Arguments returns the arguments that are prepared for this query.
+	Arguments() []interface{}
 }
 
 // Updater represents an UPDATE statement.
@@ -388,6 +391,9 @@ type Updater interface {
 	// fmt.Stringer provides `String() string`, you can use `String()` to compile
 	// the `Inserter` into a string.
 	fmt.Stringer
+
+	// Arguments returns the arguments that are prepared for this query.
+	Arguments() []interface{}
 }
 
 // Execer provides methods for executing statements that do not return results.
diff --git a/lib/sqlbuilder/select.go b/lib/sqlbuilder/select.go
index 326e613e..131911df 100644
--- a/lib/sqlbuilder/select.go
+++ b/lib/sqlbuilder/select.go
@@ -121,18 +121,14 @@ func (qs *selector) Arguments() []interface{} {
 	qs.mu.Lock()
 	defer qs.mu.Unlock()
 
-	total := len(qs.tableArgs) + len(qs.columnsArgs) + len(qs.whereArgs) + len(qs.joinsArgs) + len(qs.groupByArgs) + len(qs.orderByArgs)
-	if total == 0 {
-		return nil
-	}
-	args := make([]interface{}, 0, total)
-	args = append(args, qs.tableArgs...)
-	args = append(args, qs.columnsArgs...)
-	args = append(args, qs.joinsArgs...)
-	args = append(args, qs.whereArgs...)
-	args = append(args, qs.groupByArgs...)
-	args = append(args, qs.orderByArgs...)
-	return args
+	return joinArguments(
+		qs.tableArgs,
+		qs.columnsArgs,
+		qs.joinsArgs,
+		qs.whereArgs,
+		qs.groupByArgs,
+		qs.orderByArgs,
+	)
 }
 
 func (qs *selector) GroupBy(columns ...interface{}) Selector {
diff --git a/lib/sqlbuilder/update.go b/lib/sqlbuilder/update.go
index 6c73724b..cb1b09bd 100644
--- a/lib/sqlbuilder/update.go
+++ b/lib/sqlbuilder/update.go
@@ -2,18 +2,25 @@ package sqlbuilder
 
 import (
 	"database/sql"
+	"sync"
 
 	"upper.io/db.v2/internal/sqladapter/exql"
 )
 
 type updater struct {
 	*stringer
-	builder      *sqlBuilder
-	table        string
-	columnValues *exql.ColumnValues
-	limit        int
-	where        *exql.Where
-	arguments    []interface{}
+	builder *sqlBuilder
+	table   string
+
+	columnValues     *exql.ColumnValues
+	columnValuesArgs []interface{}
+
+	limit int
+
+	where     *exql.Where
+	whereArgs []interface{}
+
+	mu sync.Mutex
 }
 
 func (qu *updater) Set(terms ...interface{}) Updater {
@@ -36,28 +43,36 @@ func (qu *updater) Set(terms ...interface{}) Updater {
 			cvs = append(cvs, cv)
 		}
 
-		args = append(args, qu.arguments...)
-
 		qu.columnValues.Insert(cvs...)
-		qu.arguments = append(qu.arguments, args...)
+		qu.columnValuesArgs = append(qu.columnValuesArgs, args...)
 	} else if len(terms) > 1 {
 		cv, arguments := qu.builder.t.ToColumnValues(terms)
 		qu.columnValues.Insert(cv.ColumnValues...)
-		qu.arguments = append(qu.arguments, arguments...)
+		qu.columnValuesArgs = append(qu.columnValuesArgs, arguments...)
 	}
 
 	return qu
 }
 
+func (qu *updater) Arguments() []interface{} {
+	qu.mu.Lock()
+	defer qu.mu.Unlock()
+
+	return joinArguments(
+		qu.columnValuesArgs,
+		qu.whereArgs,
+	)
+}
+
 func (qu *updater) Where(terms ...interface{}) Updater {
 	where, arguments := qu.builder.t.ToWhereWithArguments(terms)
 	qu.where = &where
-	qu.arguments = append(qu.arguments, arguments...)
+	qu.whereArgs = append(qu.whereArgs, arguments...)
 	return qu
 }
 
 func (qu *updater) Exec() (sql.Result, error) {
-	return qu.builder.sess.StatementExec(qu.statement(), qu.arguments...)
+	return qu.builder.sess.StatementExec(qu.statement(), qu.Arguments()...)
 }
 
 func (qu *updater) Limit(limit int) Updater {