Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically ref names of CTETables in DELETE and UPDATE statements #179

Merged
merged 1 commit into from
Nov 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
refs #176: Automatically ref names of CTETables in DELETE and UPDATE …
…statement
huandu committed Nov 2, 2024
commit 96c9b25b7d730743dca3fb83a0c1b113cfd4328f
8 changes: 4 additions & 4 deletions cte.go
Original file line number Diff line number Diff line change
@@ -138,12 +138,12 @@ func (cteb *CTEBuilder) TableNames() []string {
return tableNames
}

// tableNamesForSelect returns a list of table names which should be automatically added to FROM clause.
// It's not public, as this feature is designed only for SelectBuilder right now.
func (cteb *CTEBuilder) tableNamesForSelect() []string {
// tableNamesForFrom returns a list of table names which should be automatically added to FROM clause.
// It's not public, as this feature is designed only for SelectBuilder/UpdateBuilder/DeleteBuilder right now.
func (cteb *CTEBuilder) tableNamesForFrom() []string {
cnt := 0

// It's rare that the ShouldAddToTableList() returns true.
// ShouldAddToTableList() unlikely returns true.
// Count it before allocating any memory for better performance.
for _, query := range cteb.queries {
if query.ShouldAddToTableList() {
37 changes: 37 additions & 0 deletions cte_test.go
Original file line number Diff line number Diff line change
@@ -82,6 +82,43 @@ func ExampleCTEBuilder() {
// [users valid_users]
}

func ExampleCTEBuilder_update() {
builder := With(
CTETable("users", "user_id").As(
Select("user_id").From("vip_users"),
),
).Update("orders").Set(
"orders.transport_fee = 0",
).Where(
"users.user_id = orders.user_id",
)

sqlForMySQL, _ := builder.BuildWithFlavor(MySQL)
sqlForPostgreSQL, _ := builder.BuildWithFlavor(PostgreSQL)

fmt.Println(sqlForMySQL)
fmt.Println(sqlForPostgreSQL)

// Output:
// WITH users (user_id) AS (SELECT user_id FROM vip_users) UPDATE orders, users SET orders.transport_fee = 0 WHERE users.user_id = orders.user_id
// WITH users (user_id) AS (SELECT user_id FROM vip_users) UPDATE orders FROM users SET orders.transport_fee = 0 WHERE users.user_id = orders.user_id
}

func ExampleCTEBuilder_delete() {
sql := With(
CTETable("users", "user_id").As(
Select("user_id").From("cheaters"),
),
).DeleteFrom("awards").Where(
"users.user_id = awards.user_id",
).String()

fmt.Println(sql)

// Output:
// WITH users (user_id) AS (SELECT user_id FROM cheaters) DELETE FROM awards, users WHERE users.user_id = awards.user_id
}

func TestCTEBuilder(t *testing.T) {
a := assert.New(t)
cteb := newCTEBuilder()
51 changes: 40 additions & 11 deletions delete.go
Original file line number Diff line number Diff line change
@@ -45,8 +45,10 @@ type DeleteBuilder struct {
whereClauseProxy *whereClauseProxy
whereClauseExpr string

cteBuilder string
table string
cteBuilderVar string
cteBuilder *CTEBuilder

tables []string
orderByCols []string
order string
limit int
@@ -60,24 +62,48 @@ type DeleteBuilder struct {
var _ Builder = new(DeleteBuilder)

// DeleteFrom sets table name in DELETE.
func DeleteFrom(table string) *DeleteBuilder {
return DefaultFlavor.NewDeleteBuilder().DeleteFrom(table)
func DeleteFrom(table ...string) *DeleteBuilder {
return DefaultFlavor.NewDeleteBuilder().DeleteFrom(table...)
}

// With sets WITH clause (the Common Table Expression) before DELETE.
func (db *DeleteBuilder) With(builder *CTEBuilder) *DeleteBuilder {
db.marker = deleteMarkerAfterWith
db.cteBuilder = db.Var(builder)
db.cteBuilderVar = db.Var(builder)
db.cteBuilder = builder
return db
}

// DeleteFrom sets table name in DELETE.
func (db *DeleteBuilder) DeleteFrom(table string) *DeleteBuilder {
db.table = Escape(table)
func (db *DeleteBuilder) DeleteFrom(table ...string) *DeleteBuilder {
db.tables = table
db.marker = deleteMarkerAfterDeleteFrom
return db
}

// TableNames returns all table names in this DELETE statement.
func (db *DeleteBuilder) TableNames() []string {
var additionalTableNames []string

if db.cteBuilder != nil {
additionalTableNames = db.cteBuilder.tableNamesForFrom()
}

var tableNames []string

if len(db.tables) > 0 && len(additionalTableNames) > 0 {
tableNames = make([]string, len(db.tables)+len(additionalTableNames))
copy(tableNames, db.tables)
copy(tableNames[len(db.tables):], additionalTableNames)
} else if len(db.tables) > 0 {
tableNames = db.tables
} else if len(additionalTableNames) > 0 {
tableNames = additionalTableNames
}

return tableNames
}

// Where sets expressions of WHERE in DELETE.
func (db *DeleteBuilder) Where(andExpr ...string) *DeleteBuilder {
if len(andExpr) == 0 || estimateStringsBytes(andExpr) == 0 {
@@ -146,17 +172,20 @@ func (db *DeleteBuilder) Build() (sql string, args []interface{}) {
// BuildWithFlavor returns compiled DELETE string and args with flavor and initial args.
// They can be used in `DB#Query` of package `database/sql` directly.
func (db *DeleteBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {

buf := newStringBuilder()
db.injection.WriteTo(buf, deleteMarkerInit)

if db.cteBuilder != "" {
buf.WriteLeadingString(db.cteBuilder)
if db.cteBuilder != nil {
buf.WriteLeadingString(db.cteBuilderVar)
db.injection.WriteTo(buf, deleteMarkerAfterWith)
}

if len(db.table) > 0 {
tableNames := db.TableNames()

if len(tableNames) > 0 {
buf.WriteLeadingString("DELETE FROM ")
buf.WriteString(db.table)
buf.WriteStrings(tableNames, ", ")
}

db.injection.WriteTo(buf, deleteMarkerAfterDeleteFrom)
4 changes: 2 additions & 2 deletions select.go
Original file line number Diff line number Diff line change
@@ -96,12 +96,12 @@ func Select(col ...string) *SelectBuilder {
return DefaultFlavor.NewSelectBuilder().Select(col...)
}

// TableNames returns all table names in a SELECT.
// TableNames returns all table names in this SELECT statement.
func (sb *SelectBuilder) TableNames() []string {
var additionalTableNames []string

if sb.cteBuilder != nil {
additionalTableNames = sb.cteBuilder.tableNamesForSelect()
additionalTableNames = sb.cteBuilder.tableNamesForFrom()
}

var tableNames []string
70 changes: 58 additions & 12 deletions update.go
Original file line number Diff line number Diff line change
@@ -47,8 +47,10 @@ type UpdateBuilder struct {
whereClauseProxy *whereClauseProxy
whereClauseExpr string

cteBuilder string
table string
cteBuilderVar string
cteBuilder *CTEBuilder

tables []string
assignments []string
orderByCols []string
order string
@@ -63,24 +65,46 @@ type UpdateBuilder struct {
var _ Builder = new(UpdateBuilder)

// Update sets table name in UPDATE.
func Update(table string) *UpdateBuilder {
return DefaultFlavor.NewUpdateBuilder().Update(table)
func Update(table ...string) *UpdateBuilder {
return DefaultFlavor.NewUpdateBuilder().Update(table...)
}

// With sets WITH clause (the Common Table Expression) before UPDATE.
func (ub *UpdateBuilder) With(builder *CTEBuilder) *UpdateBuilder {
ub.marker = updateMarkerAfterWith
ub.cteBuilder = ub.Var(builder)
ub.cteBuilderVar = ub.Var(builder)
ub.cteBuilder = builder
return ub
}

// Update sets table name in UPDATE.
func (ub *UpdateBuilder) Update(table string) *UpdateBuilder {
ub.table = Escape(table)
func (ub *UpdateBuilder) Update(table ...string) *UpdateBuilder {
ub.tables = table
ub.marker = updateMarkerAfterUpdate
return ub
}

// TableNames returns all table names in this UPDATE statement.
func (ub *UpdateBuilder) TableNames() (tableNames []string) {
var additionalTableNames []string

if ub.cteBuilder != nil {
additionalTableNames = ub.cteBuilder.tableNamesForFrom()
}

if len(ub.tables) > 0 && len(additionalTableNames) > 0 {
tableNames = make([]string, len(ub.tables)+len(additionalTableNames))
copy(tableNames, ub.tables)
copy(tableNames[len(ub.tables):], additionalTableNames)
} else if len(ub.tables) > 0 {
tableNames = ub.tables
} else if len(additionalTableNames) > 0 {
tableNames = additionalTableNames
}

return tableNames
}

// Set sets the assignments in SET.
func (ub *UpdateBuilder) Set(assignment ...string) *UpdateBuilder {
ub.assignments = assignment
@@ -212,14 +236,36 @@ func (ub *UpdateBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{
buf := newStringBuilder()
ub.injection.WriteTo(buf, updateMarkerInit)

if ub.cteBuilder != "" {
buf.WriteLeadingString(ub.cteBuilder)
if ub.cteBuilder != nil {
buf.WriteLeadingString(ub.cteBuilderVar)
ub.injection.WriteTo(buf, updateMarkerAfterWith)
}

if len(ub.table) > 0 {
buf.WriteLeadingString("UPDATE ")
buf.WriteString(ub.table)
switch flavor {
case MySQL:
// CTE table names should be written after UPDATE keyword in MySQL.
tableNames := ub.TableNames()

if len(tableNames) > 0 {
buf.WriteLeadingString("UPDATE ")
buf.WriteStrings(tableNames, ", ")
}

default:
if len(ub.tables) > 0 {
buf.WriteLeadingString("UPDATE ")
buf.WriteStrings(ub.tables, ", ")

// For ISO SQL, CTE table names should be written after FROM keyword.
if ub.cteBuilder != nil {
cteTableNames := ub.cteBuilder.tableNamesForFrom()

if len(cteTableNames) > 0 {
buf.WriteLeadingString("FROM ")
buf.WriteStrings(cteTableNames, ", ")
}
}
}
}

ub.injection.WriteTo(buf, updateMarkerAfterUpdate)