Skip to content

Commit

Permalink
Merge pull request #277 from upper/issue-276
Browse files Browse the repository at this point in the history
Add test for Where() before Set() on update
  • Loading branch information
José Carlos authored Oct 20, 2016
2 parents 69048c2 + f0f283b commit 9e3ad92
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 24 deletions.
16 changes: 16 additions & 0 deletions lib/sqlbuilder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,3 +530,19 @@ var (
_ = Builder(&sqlBuilder{})
_ = exprDB(&exprProxy{})
)

func joinArguments(args ...[]interface{}) []interface{} {
total := 0
for i := range args {
total += len(args[i])
}
if total == 0 {
return nil
}

flatten := make([]interface{}, 0, total)
for i := range args {
flatten = append(flatten, args[i]...)
}
return flatten
}
46 changes: 46 additions & 0 deletions lib/sqlbuilder/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,45 @@ func TestUpdate(t *testing.T) {
b.Update("artist").Set("name", "Artist").String(),
)

{
idSlice := []int64{8, 7, 6}
q := b.Update("artist").Set(db.Cond{"some_column": 10}).Where(db.Cond{"id": 1}, db.Cond{"another_val": idSlice})
assert.Equal(
`UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IN ($3, $4, $5))`,
q.String(),
)
assert.Equal(
[]interface{}{10, 1, int64(8), int64(7), int64(6)},
q.Arguments(),
)
}

{
idSlice := []int64{}
q := b.Update("artist").Set(db.Cond{"some_column": 10}).Where(db.Cond{"id": 1}, db.Cond{"another_val": idSlice})
assert.Equal(
`UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IS NULL)`,
q.String(),
)
assert.Equal(
[]interface{}{10, 1},
q.Arguments(),
)
}

{
idSlice := []int64{}
q := b.Update("artist").Where(db.Cond{"id": 1}, db.Cond{"another_val": idSlice}).Set(db.Cond{"some_column": 10})
assert.Equal(
`UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IS NULL)`,
q.String(),
)
assert.Equal(
[]interface{}{10, 1},
q.Arguments(),
)
}

assert.Equal(
`UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`,
b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(),
Expand All @@ -671,6 +710,13 @@ func TestUpdate(t *testing.T) {
}{"Artist"}).Where(db.Cond{"id <": 5}).String(),
)

assert.Equal(
`UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`,
b.Update("artist").Where(db.Cond{"id <": 5}).Set(struct {
Nombre string `db:"name"`
}{"Artist"}).String(),
)

assert.Equal(
`UPDATE "artist" SET "name" = $1, "last_name" = $2 WHERE ("id" < $3)`,
b.Update("artist").Set(struct {
Expand Down
4 changes: 4 additions & 0 deletions lib/sqlbuilder/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ func (qd *deleter) Limit(limit int) Deleter {
return qd
}

func (qd *deleter) Arguments() []interface{} {
return qd.arguments
}

func (qd *deleter) Exec() (sql.Result, error) {
return qd.builder.sess.StatementExec(qd.statement(), qd.arguments...)
}
Expand Down
6 changes: 6 additions & 0 deletions lib/sqlbuilder/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ type Deleter interface {
// fmt.Stringer provides `String() string`, you can use `String()` to compile
// the `Inserter` into a string.
fmt.Stringer

// Arguments returns the arguments that are prepared for this query.
Arguments() []interface{}
}

// Updater represents an UPDATE statement.
Expand All @@ -388,6 +391,9 @@ type Updater interface {
// fmt.Stringer provides `String() string`, you can use `String()` to compile
// the `Inserter` into a string.
fmt.Stringer

// Arguments returns the arguments that are prepared for this query.
Arguments() []interface{}
}

// Execer provides methods for executing statements that do not return results.
Expand Down
20 changes: 8 additions & 12 deletions lib/sqlbuilder/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,14 @@ func (qs *selector) Arguments() []interface{} {
qs.mu.Lock()
defer qs.mu.Unlock()

total := len(qs.tableArgs) + len(qs.columnsArgs) + len(qs.whereArgs) + len(qs.joinsArgs) + len(qs.groupByArgs) + len(qs.orderByArgs)
if total == 0 {
return nil
}
args := make([]interface{}, 0, total)
args = append(args, qs.tableArgs...)
args = append(args, qs.columnsArgs...)
args = append(args, qs.joinsArgs...)
args = append(args, qs.whereArgs...)
args = append(args, qs.groupByArgs...)
args = append(args, qs.orderByArgs...)
return args
return joinArguments(
qs.tableArgs,
qs.columnsArgs,
qs.joinsArgs,
qs.whereArgs,
qs.groupByArgs,
qs.orderByArgs,
)
}

func (qs *selector) GroupBy(columns ...interface{}) Selector {
Expand Down
39 changes: 27 additions & 12 deletions lib/sqlbuilder/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,25 @@ package sqlbuilder

import (
"database/sql"
"sync"

"upper.io/db.v2/internal/sqladapter/exql"
)

type updater struct {
*stringer
builder *sqlBuilder
table string
columnValues *exql.ColumnValues
limit int
where *exql.Where
arguments []interface{}
builder *sqlBuilder
table string

columnValues *exql.ColumnValues
columnValuesArgs []interface{}

limit int

where *exql.Where
whereArgs []interface{}

mu sync.Mutex
}

func (qu *updater) Set(terms ...interface{}) Updater {
Expand All @@ -36,28 +43,36 @@ func (qu *updater) Set(terms ...interface{}) Updater {
cvs = append(cvs, cv)
}

args = append(args, qu.arguments...)

qu.columnValues.Insert(cvs...)
qu.arguments = append(qu.arguments, args...)
qu.columnValuesArgs = append(qu.columnValuesArgs, args...)
} else if len(terms) > 1 {
cv, arguments := qu.builder.t.ToColumnValues(terms)
qu.columnValues.Insert(cv.ColumnValues...)
qu.arguments = append(qu.arguments, arguments...)
qu.columnValuesArgs = append(qu.columnValuesArgs, arguments...)
}

return qu
}

func (qu *updater) Arguments() []interface{} {
qu.mu.Lock()
defer qu.mu.Unlock()

return joinArguments(
qu.columnValuesArgs,
qu.whereArgs,
)
}

func (qu *updater) Where(terms ...interface{}) Updater {
where, arguments := qu.builder.t.ToWhereWithArguments(terms)
qu.where = &where
qu.arguments = append(qu.arguments, arguments...)
qu.whereArgs = append(qu.whereArgs, arguments...)
return qu
}

func (qu *updater) Exec() (sql.Result, error) {
return qu.builder.sess.StatementExec(qu.statement(), qu.arguments...)
return qu.builder.sess.StatementExec(qu.statement(), qu.Arguments()...)
}

func (qu *updater) Limit(limit int) Updater {
Expand Down

0 comments on commit 9e3ad92

Please sign in to comment.