diff --git a/go/logic/inspect.go b/go/logic/inspect.go index bfdafae17..ea8c3adca 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -628,6 +628,7 @@ func (this *Inspector) applyColumnTypes(databaseName, tableName string, columnsL columnName := m.GetString("COLUMN_NAME") columnType := m.GetString("COLUMN_TYPE") columnOctetLength := m.GetUint("CHARACTER_OCTET_LENGTH") + extra := m.GetString("EXTRA") for _, columnsList := range columnsLists { column := columnsList.GetColumn(columnName) if column == nil { @@ -660,6 +661,9 @@ func (this *Inspector) applyColumnTypes(databaseName, tableName string, columnsL column.Type = sql.BinaryColumnType column.BinaryOctetLength = columnOctetLength } + if strings.Contains(extra, " GENERATED") { + column.IsVirtual = true + } if charset := m.GetString("CHARACTER_SET_NAME"); charset != "" { column.Charset = charset } diff --git a/go/sql/builder.go b/go/sql/builder.go index 5384e5117..332aef100 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -546,6 +546,12 @@ func NewDMLUpdateQueryBuilder(databaseName, tableName string, tableColumns, shar if uniqueKeyColumns.Len() == 0 { return nil, fmt.Errorf("no unique key columns found in NewDMLUpdateQueryBuilder") } + // If unique key contains virtual columns, those column won't be in sharedColumns + // which only contains non-virtual columns + nonVirtualUniqueKeyColumns := uniqueKeyColumns.FilterBy(func(column Column) bool { return !column.IsVirtual }) + if !nonVirtualUniqueKeyColumns.IsSubsetOf(sharedColumns) { + return nil, fmt.Errorf("unique key columns is not a subset of shared columns in NewDMLUpdateQueryBuilder") + } databaseName = EscapeName(databaseName) tableName = EscapeName(tableName) setClause, err := BuildSetPreparedClause(mappedSharedColumns) @@ -580,11 +586,6 @@ func NewDMLUpdateQueryBuilder(databaseName, tableName string, tableColumns, shar // BuildQuery builds the arguments array for a DML event UPDATE query. // It returns the query string, the shared arguments array, and the unique key arguments array. func (b *DMLUpdateQueryBuilder) BuildQuery(valueArgs, whereArgs []interface{}) (string, []interface{}, []interface{}, error) { - // TODO: move this check back to `NewDMLUpdateQueryBuilder()`, needs fix on generated columns. - if !b.uniqueKeyColumns.IsSubsetOf(b.sharedColumns) { - return "", nil, nil, fmt.Errorf("unique key columns is not a subset of shared columns in DMLUpdateQueryBuilder") - } - sharedArgs := make([]interface{}, 0, b.sharedColumns.Len()) for _, column := range b.sharedColumns.Columns() { tableOrdinal := b.tableColumns.Ordinals[column.Name] diff --git a/go/sql/builder_test.go b/go/sql/builder_test.go index 37964586d..d43f65056 100644 --- a/go/sql/builder_test.go +++ b/go/sql/builder_test.go @@ -688,9 +688,7 @@ func TestBuildDMLUpdateQuery(t *testing.T) { { sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) uniqueKeyColumns := NewColumnList([]string{"age", "surprise"}) - builder, err := NewDMLUpdateQueryBuilder(databaseName, tableName, tableColumns, sharedColumns, sharedColumns, uniqueKeyColumns) - require.NoError(t, err) - _, _, _, err = builder.BuildQuery(valueArgs, whereArgs) + _, err := NewDMLUpdateQueryBuilder(databaseName, tableName, tableColumns, sharedColumns, sharedColumns, uniqueKeyColumns) require.Error(t, err) } { diff --git a/go/sql/types.go b/go/sql/types.go index 3be1a44ca..f7aac5f5f 100644 --- a/go/sql/types.go +++ b/go/sql/types.go @@ -40,6 +40,7 @@ type CharacterSetConversion struct { type Column struct { Name string IsUnsigned bool + IsVirtual bool Charset string Type ColumnType EnumValues string @@ -244,6 +245,16 @@ func (this *ColumnList) IsSubsetOf(other *ColumnList) bool { return true } +func (this *ColumnList) FilterBy(f func(Column) bool) *ColumnList { + filteredCols := make([]Column, 0, len(this.columns)) + for _, column := range this.columns { + if f(column) { + filteredCols = append(filteredCols, column) + } + } + return &ColumnList{Ordinals: this.Ordinals, columns: filteredCols} +} + func (this *ColumnList) Len() int { return len(this.columns) } diff --git a/localtests/generated-columns-unique/create.sql b/localtests/generated-columns-unique/create.sql index 7a63dd984..83afc807d 100644 --- a/localtests/generated-columns-unique/create.sql +++ b/localtests/generated-columns-unique/create.sql @@ -3,6 +3,7 @@ create table gh_ost_test ( id int auto_increment, `idb` varchar(36) CHARACTER SET utf8mb4 GENERATED ALWAYS AS (json_unquote(json_extract(`jsonobj`,_utf8mb4'$._id'))) STORED NOT NULL, `jsonobj` json NOT NULL, + updated datetime DEFAULT NULL, PRIMARY KEY (`id`,`idb`) ) auto_increment=1; @@ -25,6 +26,11 @@ begin insert into gh_ost_test (id, jsonobj) values (null, '{"_id":13}'); insert into gh_ost_test (id, jsonobj) values (null, '{"_id":17}'); insert into gh_ost_test (id, jsonobj) values (null, '{"_id":19}'); - insert into gh_ost_test (id, jsonobj) values (null, '{"_id":23}'); - insert into gh_ost_test (id, jsonobj) values (null, '{"_id":27}'); + + update gh_ost_test set updated=NOW() where idb=5; + update gh_ost_test set updated=NOW() where idb=7; + update gh_ost_test set updated=NOW() where idb=11; + update gh_ost_test set updated=NOW() where idb=13; + update gh_ost_test set updated=NOW() where idb=17; + update gh_ost_test set updated=NOW() where idb=19; end ;; diff --git a/localtests/generated-columns/create.sql b/localtests/generated-columns/create.sql index e244ca3c0..357d4a3b4 100644 --- a/localtests/generated-columns/create.sql +++ b/localtests/generated-columns/create.sql @@ -27,4 +27,6 @@ begin insert into gh_ost_test (id, a, b) values (null, 2,0); insert into gh_ost_test (id, a, b) values (null, 2,1); insert into gh_ost_test (id, a, b) values (null, 2,2); + update gh_ost_test set b=b+1 where id < 5; + update gh_ost_test set b=b-1 where id >= 5; end ;;