From 1176ba446758e26db0002fd4e5259472b7273e47 Mon Sep 17 00:00:00 2001 From: David Schneider Date: Sun, 12 May 2024 14:55:03 +0200 Subject: [PATCH] fix(sqlite): Fix ADD COLUMN without typename Use type name 'any' for ALTER TABLE t1 ADD COLUMN c1 where no type name for c1 is provided. This is the same logic as for CREATE TABLE. Fixes #3375 --- internal/engine/sqlite/catalog_test.go | 24 ++++++++++++++++++++++++ internal/engine/sqlite/convert.go | 19 ++++++++++--------- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/internal/engine/sqlite/catalog_test.go b/internal/engine/sqlite/catalog_test.go index bf6dcd8316..a4f2861838 100644 --- a/internal/engine/sqlite/catalog_test.go +++ b/internal/engine/sqlite/catalog_test.go @@ -82,6 +82,30 @@ func TestUpdate(t *testing.T) { }, }, }, + { + ` + CREATE TABLE foo (bar text); + ALTER TABLE foo ADD COLUMN baz; + `, + &catalog.Schema{ + Name: "main", + Tables: []*catalog.Table{ + { + Rel: &ast.TableName{Name: "foo"}, + Columns: []*catalog.Column{ + { + Name: "bar", + Type: ast.TypeName{Name: "text"}, + }, + { + Name: "baz", + Type: ast.TypeName{Name: "any"}, + }, + }, + }, + }, + }, + }, { ` CREATE TABLE foo (bar text); diff --git a/internal/engine/sqlite/convert.go b/internal/engine/sqlite/convert.go index 02d80bc48c..d97977bbb0 100644 --- a/internal/engine/sqlite/convert.go +++ b/internal/engine/sqlite/convert.go @@ -35,6 +35,13 @@ func identifier(id string) string { return strings.ToLower(id) } +func getTypeName(t parser.IType_nameContext) string { + if t == nil { + return "any" + } + return t.GetText() +} + func NewIdentifier(t string) *ast.String { return &ast.String{Str: identifier(t)} } @@ -72,10 +79,8 @@ func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) a Name: &name, Subtype: ast.AT_AddColumn, Def: &ast.ColumnDef{ - Colname: name, - TypeName: &ast.TypeName{ - Name: def.Type_name().GetText(), - }, + Colname: name, + TypeName: &ast.TypeName{Name: getTypeName(def.Type_name())}, IsNotNull: hasNotNullConstraint(def.AllColumn_constraint()), }, }) @@ -113,14 +118,10 @@ func (c *cc) convertCreate_table_stmtContext(n *parser.Create_table_stmtContext) } for _, idef := range n.AllColumn_def() { if def, ok := idef.(*parser.Column_defContext); ok { - typeName := "any" - if def.Type_name() != nil { - typeName = def.Type_name().GetText() - } stmt.Cols = append(stmt.Cols, &ast.ColumnDef{ Colname: identifier(def.Column_name().GetText()), IsNotNull: hasNotNullConstraint(def.AllColumn_constraint()), - TypeName: &ast.TypeName{Name: typeName}, + TypeName: &ast.TypeName{Name: getTypeName(def.Type_name())}, }) } }