From af1b82215d2ae9ae9dd38e4f3a7a1ed5ee6102c9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 16 Aug 2023 17:47:56 +0800 Subject: [PATCH 1/4] Fix QuoteTo --- sqlite.go | 53 ++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/sqlite.go b/sqlite.go index 8617f00..4f0da2e 100644 --- a/sqlite.go +++ b/sqlite.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "strconv" - "strings" "gorm.io/gorm/callbacks" @@ -143,19 +142,51 @@ func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, } func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { - writer.WriteByte('`') - if strings.Contains(str, ".") { - for idx, str := range strings.Split(str, ".") { - if idx > 0 { - writer.WriteString(".`") + var ( + underQuoted, selfQuoted bool + continuousBacktick int8 + shiftDelimiter int8 + ) + + for _, v := range []byte(str) { + switch v { + case '`': + continuousBacktick++ + if continuousBacktick == 2 { + writer.WriteString("``") + continuousBacktick = 0 + } + case '.': + if continuousBacktick > 0 || !selfQuoted { + shiftDelimiter = 0 + underQuoted = false + continuousBacktick = 0 + writer.WriteString("`") + } + writer.WriteByte(v) + continue + default: + if shiftDelimiter-continuousBacktick <= 0 && !underQuoted { + writer.WriteString("`") + underQuoted = true + if selfQuoted = continuousBacktick > 0; selfQuoted { + continuousBacktick -= 1 + } + } + + for ; continuousBacktick > 0; continuousBacktick -= 1 { + writer.WriteString("``") } - writer.WriteString(str) - writer.WriteByte('`') + + writer.WriteByte(v) } - } else { - writer.WriteString(str) - writer.WriteByte('`') + shiftDelimiter++ + } + + if continuousBacktick > 0 && !selfQuoted { + writer.WriteString("``") } + writer.WriteString("`") } func (dialector Dialector) Explain(sql string, vars ...interface{}) string { From 139bd307e5272318b407f578ed99ecc726cab544 Mon Sep 17 00:00:00 2001 From: Samuel N Cui Date: Sun, 8 Oct 2023 10:47:18 +0800 Subject: [PATCH 2/4] fix: AUTOINCREMENT flag cannot apply with PRIMARY KEY (#167) * fix: AUTOINCREMENT flag cannot apply with PRIMARY KEY * fix: migrator use ddl parser instead of regexp --- ddlmod.go | 54 ++++++++++++++++++++++++++++++++++ migrator.go | 83 +++++++++++++++++++---------------------------------- sqlite.go | 3 +- 3 files changed, 85 insertions(+), 55 deletions(-) diff --git a/ddlmod.go b/ddlmod.go index 39cc13a..50e9655 100644 --- a/ddlmod.go +++ b/ddlmod.go @@ -175,6 +175,18 @@ func parseDDL(strs ...string) (*ddl, error) { return &result, nil } +func (d *ddl) clone() *ddl { + copied := new(ddl) + *copied = *d + + copied.fields = make([]string, len(d.fields)) + copy(copied.fields, d.fields) + copied.columns = make([]migrator.ColumnType, len(d.columns)) + copy(copied.columns, d.columns) + + return copied +} + func (d *ddl) compile() string { if len(d.fields) == 0 { return d.head @@ -183,6 +195,21 @@ func (d *ddl) compile() string { return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ",")) } +func (d *ddl) renameTable(dst, src string) error { + tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + regexp.QuoteMeta(src) + "\\b('|`|\")?\\s*") + if err != nil { + return err + } + + replaced := tableReg.ReplaceAllString(d.head, fmt.Sprintf(" `%s` ", dst)) + if replaced == d.head { + return fmt.Errorf("failed to look up tablename `%s` from DDL head '%s'", src, d.head) + } + + d.head = replaced + return nil +} + func (d *ddl) addConstraint(name string, sql string) { reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]") @@ -240,3 +267,30 @@ func (d *ddl) getColumns() []string { } return res } + +func (d *ddl) alterColumn(name, sql string) bool { + reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$") + + for i := 0; i < len(d.fields); i++ { + if reg.MatchString(d.fields[i]) { + d.fields[i] = sql + return false + } + } + + d.fields = append(d.fields, sql) + return true +} + +func (d *ddl) removeColumn(name string) bool { + reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$") + + for i := 0; i < len(d.fields); i++ { + if reg.MatchString(d.fields[i]) { + d.fields = append(d.fields[:i], d.fields[i+1:]...) + return true + } + } + + return false +} diff --git a/migrator.go b/migrator.go index fd2eeb4..1b85d6d 100644 --- a/migrator.go +++ b/migrator.go @@ -3,7 +3,6 @@ package sqlite import ( "database/sql" "fmt" - "regexp" "strings" "gorm.io/gorm" @@ -78,23 +77,16 @@ func (m Migrator) HasColumn(value interface{}, name string) bool { func (m Migrator) AlterColumn(value interface{}, name string) error { return m.RunWithoutForeignKey(func() error { - return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { + return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { if field := stmt.Schema.LookUpField(name); field != nil { - // lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)` - reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)") - if err != nil { - return "", nil, err + if ddl.alterColumn(field.DBName, fmt.Sprintf("`%s` ?", field.DBName)) { + return nil, nil, fmt.Errorf("field `%s` not found in origin ddl, ddl= '%s'", name, ddl.compile()) } - createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName)) - - if createSQL == rawDDL { - return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL) - } - - return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil + return ddl, []interface{}{m.FullDataTypeOf(field)}, nil } - return "", nil, fmt.Errorf("failed to alter field with name %v", name) + + return nil, nil, fmt.Errorf("failed to alter field with name `%s`", name) }) }) } @@ -149,19 +141,13 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { } func (m Migrator) DropColumn(value interface{}, name string) error { - return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { + return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { if field := stmt.Schema.LookUpField(name); field != nil { name = field.DBName } - reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,") - if err != nil { - return "", nil, err - } - - createSQL := reg.ReplaceAllString(rawDDL, "") - - return createSQL, nil, nil + ddl.removeColumn(name) + return ddl, nil, nil }) } @@ -170,7 +156,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { constraint, chk, table := m.GuessConstraintAndTable(stmt, name) return m.recreateTable(value, &table, - func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { + func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { var ( constraintName string constraintSql string @@ -185,17 +171,11 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { constraintSql = "CONSTRAINT ? CHECK (?)" constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}} } else { - return "", nil, nil + return nil, nil, nil } - createDDL, err := parseDDL(rawDDL) - if err != nil { - return "", nil, err - } - createDDL.addConstraint(constraintName, constraintSql) - createSQL := createDDL.compile() - - return createSQL, constraintValues, nil + ddl.addConstraint(constraintName, constraintSql) + return ddl, constraintValues, nil }) }) } @@ -210,15 +190,9 @@ func (m Migrator) DropConstraint(value interface{}, name string) error { } return m.recreateTable(value, &table, - func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { - createDDL, err := parseDDL(rawDDL) - if err != nil { - return "", nil, err - } - createDDL.removeConstraint(name) - createSQL := createDDL.compile() - - return createSQL, nil, nil + func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { + ddl.removeConstraint(name) + return ddl, nil, nil }) }) } @@ -375,8 +349,10 @@ func (m Migrator) getRawDDL(table string) (string, error) { return createSQL, nil } -func (m Migrator) recreateTable(value interface{}, tablePtr *string, - getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error { +func (m Migrator) recreateTable( + value interface{}, tablePtr *string, + getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error), +) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { table := stmt.Table if tablePtr != nil { @@ -388,27 +364,26 @@ func (m Migrator) recreateTable(value interface{}, tablePtr *string, return err } - newTableName := table + "__temp" - - createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt) + originDDL, err := parseDDL(rawDDL) if err != nil { return err } - if createSQL == "" { - return nil - } - tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*") + createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt) if err != nil { return err } - createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) + if createDDL == nil { + return nil + } - createDDL, err := parseDDL(createSQL) - if err != nil { + newTableName := table + "__temp" + if err := createDDL.renameTable(newTableName, table); err != nil { return err } + columns := createDDL.getColumns() + createSQL := createDDL.compile() return m.DB.Transaction(func(tx *gorm.DB) error { if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil { diff --git a/sqlite.go b/sqlite.go index 4f0da2e..abcb3ae 100644 --- a/sqlite.go +++ b/sqlite.go @@ -198,7 +198,8 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { case schema.Bool: return "numeric" case schema.Int, schema.Uint: - if field.AutoIncrement && !field.PrimaryKey { + if field.AutoIncrement { + // doesn't check `PrimaryKey`, to keep backward compatibility // https://www.sqlite.org/autoinc.html return "integer PRIMARY KEY AUTOINCREMENT" } else { From 8172ddb5129575927e716c13a70a572e9f9de6a5 Mon Sep 17 00:00:00 2001 From: Franco Liberali Date: Sun, 8 Oct 2023 04:47:42 +0200 Subject: [PATCH 3/4] add from as valid clause for update (#166) --- sqlite.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlite.go b/sqlite.go index abcb3ae..dc76d11 100644 --- a/sqlite.go +++ b/sqlite.go @@ -55,7 +55,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { if compareVersion(version, "3.35.0") >= 0 { callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, - UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, + UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE", "RETURNING"}, DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, LastInsertIDReversed: true, }) From 74475fc966dda86870919e65deced0c6c0c89b5e Mon Sep 17 00:00:00 2001 From: ChrisPortman Date: Sun, 8 Oct 2023 13:52:14 +1100 Subject: [PATCH 4/4] Issue #158 SQLite fields are nullible unless `NOT NULL` (#159) This change sets the default value of ColumnType.Nullable() to true. If the column is explicitly `MOT NULL` then it will be set false. This is consistent with the SQLite documentation: https://www.sqlitetutorial.net/sqlite-not-null-constraint/ Co-authored-by: Chris Carter --- ddlmod.go | 2 +- ddlmod_test.go | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ddlmod.go b/ddlmod.go index 50e9655..5357061 100644 --- a/ddlmod.go +++ b/ddlmod.go @@ -125,7 +125,7 @@ func parseDDL(strs ...string) (*ddl, error) { ColumnTypeValue: sql.NullString{String: matches[2], Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, - NullableValue: sql.NullBool{Valid: true}, + NullableValue: sql.NullBool{Bool: true, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, } diff --git a/ddlmod_test.go b/ddlmod_test.go index 763c3ce..963c271 100644 --- a/ddlmod_test.go +++ b/ddlmod_test.go @@ -20,16 +20,16 @@ func TestParseDDL(t *testing.T) { "CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)", }, 6, []migrator.ColumnType{ {NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}}, - {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, - {NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, - {NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, }, }, {"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{ {NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, {NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, - {NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, - {NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, }}, {"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{ {NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, @@ -39,7 +39,7 @@ func TestParseDDL(t *testing.T) { {"with_special_characters", []string{ "CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")", }, 1, []migrator.ColumnType{ - {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, }, }, { @@ -122,7 +122,7 @@ func TestParseDDL_Whitespaces(t *testing.T) { NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, - NullableValue: sql.NullBool{Bool: false, Valid: true}, + NullableValue: sql.NullBool{Bool: true, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, @@ -131,7 +131,7 @@ func TestParseDDL_Whitespaces(t *testing.T) { NameValue: sql.NullString{String: "dark_mode", Valid: true}, DataTypeValue: sql.NullString{String: "numeric", Valid: true}, ColumnTypeValue: sql.NullString{String: "numeric", Valid: true}, - NullableValue: sql.NullBool{Valid: true}, + NullableValue: sql.NullBool{Bool: true, Valid: true}, DefaultValueValue: sql.NullString{String: "true", Valid: true}, UniqueValue: sql.NullBool{Bool: false, Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},