Skip to content

Commit

Permalink
dialect/sql/sqljson: inline boolean values (ent#3570)
Browse files Browse the repository at this point in the history
Some drivers like mysql encodes them as 0/1
  • Loading branch information
a8m authored May 29, 2023
1 parent 633d021 commit a8851db
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
9 changes: 8 additions & 1 deletion dialect/sql/sqljson/sqljson.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package sqljson
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"unicode"

Expand Down Expand Up @@ -95,7 +96,13 @@ func ValueEQ(column string, arg any, opts ...Option) *sql.Predicate {
return sql.P(func(b *sql.Builder) {
opts = normalizePG(b, arg, opts)
valuePath(b, column, opts...)
b.WriteOp(sql.OpEQ).Arg(arg)
b.WriteOp(sql.OpEQ)
// Inline boolean values, as some drivers (e.g., MySQL) encode them as 0/1.
if v, ok := arg.(bool); ok {
b.WriteString(strconv.FormatBool(v))
} else {
b.Arg(arg)
}
})
}

Expand Down
7 changes: 7 additions & 0 deletions dialect/sql/sqljson/sqljson_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ func TestWritePath(t *testing.T) {
wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, '$.b.c[1].d') = ?",
wantArgs: []any{"a"},
},
{
input: sql.Dialect(dialect.MySQL).
Select("*").
From(sql.Table("users")).
Where(sqljson.ValueEQ("a", true, sqljson.DotPath("b.c[1].d"))),
wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, '$.b.c[1].d') = true",
},
{
input: sql.Dialect(dialect.MySQL).
Select("*").
Expand Down
16 changes: 16 additions & 0 deletions entc/integration/json/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,22 @@ func Predicates(t *testing.T, client *ent.Client) {
}).CountX(ctx)
require.Equal(t, 4, n)
})

t.Run("Boolean", func(t *testing.T) {
users := client.User.Query().
Where(func(s *sql.Selector) {
s.Where(sqljson.ValueEQ(user.FieldT, true, sqljson.Path("b")))
}).
AllX(ctx)
require.Empty(t, users)
client.User.Create().SetT(&schema.T{B: true}).ExecX(ctx)
u1 := client.User.Query().
Where(func(s *sql.Selector) {
s.Where(sqljson.ValueEQ(user.FieldT, true, sqljson.Path("b")))
}).
OnlyX(ctx)
require.True(t, u1.T.B)
})
}

func Order(t *testing.T, client *ent.Client) {
Expand Down

0 comments on commit a8851db

Please sign in to comment.