Skip to content

Commit

Permalink
fix: AUTOINCREMENT flag cannot apply with PRIMARY KEY (go-gorm#167)
Browse files Browse the repository at this point in the history
* fix: AUTOINCREMENT flag cannot apply with PRIMARY KEY

* fix: migrator use ddl parser instead of regexp
  • Loading branch information
samuelncui authored Oct 8, 2023
1 parent af1b822 commit 139bd30
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 55 deletions.
54 changes: 54 additions & 0 deletions ddlmod.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,18 @@ func parseDDL(strs ...string) (*ddl, error) {
return &result, nil
}

func (d *ddl) clone() *ddl {
copied := new(ddl)
*copied = *d

copied.fields = make([]string, len(d.fields))
copy(copied.fields, d.fields)
copied.columns = make([]migrator.ColumnType, len(d.columns))
copy(copied.columns, d.columns)

return copied
}

func (d *ddl) compile() string {
if len(d.fields) == 0 {
return d.head
Expand All @@ -183,6 +195,21 @@ func (d *ddl) compile() string {
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
}

func (d *ddl) renameTable(dst, src string) error {
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + regexp.QuoteMeta(src) + "\\b('|`|\")?\\s*")
if err != nil {
return err
}

replaced := tableReg.ReplaceAllString(d.head, fmt.Sprintf(" `%s` ", dst))
if replaced == d.head {
return fmt.Errorf("failed to look up tablename `%s` from DDL head '%s'", src, d.head)
}

d.head = replaced
return nil
}

func (d *ddl) addConstraint(name string, sql string) {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")

Expand Down Expand Up @@ -240,3 +267,30 @@ func (d *ddl) getColumns() []string {
}
return res
}

func (d *ddl) alterColumn(name, sql string) bool {
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")

for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields[i] = sql
return false
}
}

d.fields = append(d.fields, sql)
return true
}

func (d *ddl) removeColumn(name string) bool {
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")

for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields = append(d.fields[:i], d.fields[i+1:]...)
return true
}
}

return false
}
83 changes: 29 additions & 54 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package sqlite
import (
"database/sql"
"fmt"
"regexp"
"strings"

"gorm.io/gorm"
Expand Down Expand Up @@ -78,23 +77,16 @@ func (m Migrator) HasColumn(value interface{}, name string) bool {

func (m Migrator) AlterColumn(value interface{}, name string) error {
return m.RunWithoutForeignKey(func() error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
if field := stmt.Schema.LookUpField(name); field != nil {
// lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)`
reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)")
if err != nil {
return "", nil, err
if ddl.alterColumn(field.DBName, fmt.Sprintf("`%s` ?", field.DBName)) {
return nil, nil, fmt.Errorf("field `%s` not found in origin ddl, ddl= '%s'", name, ddl.compile())
}

createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName))

if createSQL == rawDDL {
return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL)
}

return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil
return ddl, []interface{}{m.FullDataTypeOf(field)}, nil
}
return "", nil, fmt.Errorf("failed to alter field with name %v", name)

return nil, nil, fmt.Errorf("failed to alter field with name `%s`", name)
})
})
}
Expand Down Expand Up @@ -149,19 +141,13 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
}

func (m Migrator) DropColumn(value interface{}, name string) error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}

reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,")
if err != nil {
return "", nil, err
}

createSQL := reg.ReplaceAllString(rawDDL, "")

return createSQL, nil, nil
ddl.removeColumn(name)
return ddl, nil, nil
})
}

Expand All @@ -170,7 +156,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)

return m.recreateTable(value, &table,
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
var (
constraintName string
constraintSql string
Expand All @@ -185,17 +171,11 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
constraintSql = "CONSTRAINT ? CHECK (?)"
constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
} else {
return "", nil, nil
return nil, nil, nil
}

createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
}
createDDL.addConstraint(constraintName, constraintSql)
createSQL := createDDL.compile()

return createSQL, constraintValues, nil
ddl.addConstraint(constraintName, constraintSql)
return ddl, constraintValues, nil
})
})
}
Expand All @@ -210,15 +190,9 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
}

return m.recreateTable(value, &table,
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
}
createDDL.removeConstraint(name)
createSQL := createDDL.compile()

return createSQL, nil, nil
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
ddl.removeConstraint(name)
return ddl, nil, nil
})
})
}
Expand Down Expand Up @@ -375,8 +349,10 @@ func (m Migrator) getRawDDL(table string) (string, error) {
return createSQL, nil
}

func (m Migrator) recreateTable(value interface{}, tablePtr *string,
getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error {
func (m Migrator) recreateTable(
value interface{}, tablePtr *string,
getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error),
) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
table := stmt.Table
if tablePtr != nil {
Expand All @@ -388,27 +364,26 @@ func (m Migrator) recreateTable(value interface{}, tablePtr *string,
return err
}

newTableName := table + "__temp"

createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt)
originDDL, err := parseDDL(rawDDL)
if err != nil {
return err
}
if createSQL == "" {
return nil
}

tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt)
if err != nil {
return err
}
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
if createDDL == nil {
return nil
}

createDDL, err := parseDDL(createSQL)
if err != nil {
newTableName := table + "__temp"
if err := createDDL.renameTable(newTableName, table); err != nil {
return err
}

columns := createDDL.getColumns()
createSQL := createDDL.compile()

return m.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {
Expand Down
3 changes: 2 additions & 1 deletion sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
case schema.Bool:
return "numeric"
case schema.Int, schema.Uint:
if field.AutoIncrement && !field.PrimaryKey {
if field.AutoIncrement {
// doesn't check `PrimaryKey`, to keep backward compatibility
// https://www.sqlite.org/autoinc.html
return "integer PRIMARY KEY AUTOINCREMENT"
} else {
Expand Down

0 comments on commit 139bd30

Please sign in to comment.