diff --git a/generator/metadata/column_meta_data.go b/generator/metadata/column_meta_data.go index 1502719c..fd421b7a 100644 --- a/generator/metadata/column_meta_data.go +++ b/generator/metadata/column_meta_data.go @@ -42,4 +42,5 @@ type DataType struct { Name string Kind DataTypeKind IsUnsigned bool + Dimensions int // The number of array dimensions } diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index abb21bad..2d86b50e 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -65,6 +65,7 @@ select not attr.attnotnull as "column.isNullable", attr.attgenerated = 's' as "column.isGenerated", attr.atthasdef as "column.hasDefault", + attr.attndims as "dataType.dimensions", (case when tp.typtype = 'b' AND tp.typcategory <> 'A' then 'base' when tp.typtype = 'b' AND tp.typcategory = 'A' then 'array' diff --git a/generator/template/model_template.go b/generator/template/model_template.go index f89ebd1b..d3277216 100644 --- a/generator/template/model_template.go +++ b/generator/template/model_template.go @@ -6,6 +6,7 @@ import ( "github.com/go-jet/jet/v2/internal/utils/dbidentifier" "github.com/google/uuid" "github.com/jackc/pgtype" + "github.com/lib/pq" "path" "reflect" "strings" @@ -249,7 +250,7 @@ func getUserDefinedType(column metadata.Column) string { switch column.DataType.Kind { case metadata.EnumType: return dbidentifier.ToGoIdentifier(column.DataType.Name) - case metadata.UserDefinedType, metadata.ArrayType: + case metadata.UserDefinedType: return "string" } @@ -268,6 +269,11 @@ func getGoType(column metadata.Column) interface{} { // toGoType returns model type for column info. func toGoType(column metadata.Column) interface{} { + // We don't support multi-dimensional arrays + if column.DataType.Dimensions > 1 { + return "" + } + switch strings.ToLower(column.DataType.Name) { case "user-defined", "enum": return "" @@ -333,6 +339,16 @@ func toGoType(column metadata.Column) interface{} { return pgtype.Int8range{} case "numrange": return pgtype.Numrange{} + case "bool[]", "boolean[]": + return pq.BoolArray{} + case "integer[]", "int4[]": + return pq.Int32Array{} + case "bigint[]", "int8[]": + return pq.Int64Array{} + case "bytea[]": + return pq.ByteaArray{} + case "text[]", "jsonb[]", "json[]": + return pq.StringArray{} default: fmt.Println("- [Model ] Unsupported sql column '" + column.Name + " " + column.DataType.Name + "', using string instead.") return "" diff --git a/generator/template/sql_builder_template.go b/generator/template/sql_builder_template.go index fe7fba59..8431a8ca 100644 --- a/generator/template/sql_builder_template.go +++ b/generator/template/sql_builder_template.go @@ -145,53 +145,101 @@ func DefaultTableSQLBuilderColumn(columnMetaData metadata.Column) TableSQLBuilde // getSqlBuilderColumnType returns type of jet sql builder column func getSqlBuilderColumnType(columnMetaData metadata.Column) string { if columnMetaData.DataType.Kind != metadata.BaseType && - columnMetaData.DataType.Kind != metadata.RangeType { + columnMetaData.DataType.Kind != metadata.RangeType && + columnMetaData.DataType.Kind != metadata.ArrayType { return "String" } - switch strings.ToLower(columnMetaData.DataType.Name) { + typeName := columnMetaData.DataType.Name + columnName := columnMetaData.Name + + var columnType string + var supported bool + + if columnMetaData.DataType.Kind == metadata.ArrayType { + if columnMetaData.DataType.Dimensions > 1 { + fmt.Println("- [SQL Builder] Unsupported sql array with multiple dimensions column '" + columnName + " " + typeName + "', using StringColumn instead.") + return "String" + } + + columnType, supported = sqlArrayToColumnType(strings.TrimSuffix(typeName, "[]")) + } else { + columnType, supported = sqlToColumnType(typeName) + } + + if !supported { + fmt.Printf("- [SQL Builder] Unsupported SQL column '" + columnName + " " + typeName + "', using StringColumn instead.\n") + return "String" + } + + return columnType +} + +// sqlArrayToColumnType maps the type of an SQL array column type to a go jet sql builder column. Note that you don't +// pass the brackets `[]`, signifying an SQL array type, into this function. The second return value returns whether the +// given type is supported +func sqlArrayToColumnType(typeName string) (string, bool) { + switch strings.ToLower(typeName) { + case "user-defined", "enum", "text", "character", "character varying", "bytea", "uuid", + "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY", + "char", "varchar", "nvarchar", "binary", "varbinary", "bpchar", "varbit", + "tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL + return "StringArray", true + case "smallint", "integer", "bigint", "int2", "int4", "int8", + "tinyint", "mediumint", "int", "year": //MySQL + return "IntegerArray", true case "boolean", "bool": - return "Bool" + return "BoolArray", true + default: + return "", false + } +} + +// sqlToColumnType maps the type of a SQL column type to a go jet sql builder column. The second return value returns +// whether the given type is supported. +func sqlToColumnType(typeName string) (string, bool) { + switch strings.ToLower(typeName) { + case "boolean", "bool": + return "Bool", true case "smallint", "integer", "bigint", "int2", "int4", "int8", "tinyint", "mediumint", "int", "year": //MySQL - return "Integer" + return "Integer", true case "date": - return "Date" + return "Date", true case "timestamp without time zone", "timestamp", "datetime": //MySQL: - return "Timestamp" + return "Timestamp", true case "timestamp with time zone", "timestamptz": - return "Timestampz" + return "Timestampz", true case "time without time zone", "time": //MySQL - return "Time" + return "Time", true case "time with time zone", "timetz": - return "Timez" + return "Timez", true case "interval": - return "Interval" + return "Interval", true case "user-defined", "enum", "text", "character", "character varying", "bytea", "uuid", "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY", "char", "varchar", "nvarchar", "binary", "varbinary", "bpchar", "varbit", "tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL - return "String" + return "String", true case "real", "numeric", "decimal", "double precision", "float", "float4", "float8", "double": // MySQL - return "Float" + return "Float", true case "daterange": - return "DateRange" + return "DateRange", true case "tsrange": - return "TimestampRange" + return "TimestampRange", true case "tstzrange": - return "TimestampzRange" + return "TimestampzRange", true case "int4range": - return "Int4Range" + return "Int4Range", true case "int8range": - return "Int8Range" + return "Int8Range", true case "numrange": - return "NumericRange" + return "NumericRange", true default: - fmt.Println("- [SQL Builder] Unsupported sql column '" + columnMetaData.Name + " " + columnMetaData.DataType.Name + "', using StringColumn instead.") - return "String" + return "", false } } diff --git a/internal/jet/array_expression.go b/internal/jet/array_expression.go new file mode 100644 index 00000000..9ac562c6 --- /dev/null +++ b/internal/jet/array_expression.go @@ -0,0 +1,93 @@ +package jet + +// Array interface +type Array[E Expression] interface { + Expression + + EQ(rhs Array[E]) BoolExpression + NOT_EQ(rhs Array[E]) BoolExpression + LT(rhs Array[E]) BoolExpression + GT(rhs Array[E]) BoolExpression + LT_EQ(rhs Array[E]) BoolExpression + GT_EQ(rhs Array[E]) BoolExpression + + CONTAINS(rhs Array[E]) BoolExpression + IS_CONTAINED_BY(rhs Array[E]) BoolExpression + OVERLAP(rhs Array[E]) BoolExpression + CONCAT(rhs Array[E]) Array[E] + CONCAT_ELEMENT(E) Array[E] + + AT(expression IntegerExpression) E +} + +type arrayInterfaceImpl[E Expression] struct { + parent Array[E] +} + +type BinaryBoolOp func(Expression, Expression) BoolExpression + +func (a arrayInterfaceImpl[E]) EQ(rhs Array[E]) BoolExpression { + return Eq(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) NOT_EQ(rhs Array[E]) BoolExpression { + return NotEq(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) LT(rhs Array[E]) BoolExpression { + return Lt(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) GT(rhs Array[E]) BoolExpression { + return Gt(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) LT_EQ(rhs Array[E]) BoolExpression { + return LtEq(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) GT_EQ(rhs Array[E]) BoolExpression { + return GtEq(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) CONTAINS(rhs Array[E]) BoolExpression { + return Contains(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) IS_CONTAINED_BY(rhs Array[E]) BoolExpression { + return IsContainedBy(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) OVERLAP(rhs Array[E]) BoolExpression { + return Overlap(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) CONCAT(rhs Array[E]) Array[E] { + return ArrayExp[E](NewBinaryOperatorExpression(a.parent, rhs, "||")) +} + +func (a arrayInterfaceImpl[E]) CONCAT_ELEMENT(rhs E) Array[E] { + return ArrayExp[E](NewBinaryOperatorExpression(a.parent, rhs, "||")) +} + +func (a arrayInterfaceImpl[E]) AT(expression IntegerExpression) E { + return arrayElementTypeCaster[E](a.parent, arraySubscriptExpr(a.parent, expression)) +} + +type arrayExpressionWrapper[E Expression] struct { + arrayInterfaceImpl[E] + Expression +} + +func newArrayExpressionWrap[E Expression](expression Expression) Array[E] { + arrayExpressionWrapper := arrayExpressionWrapper[E]{Expression: expression} + arrayExpressionWrapper.arrayInterfaceImpl.parent = &arrayExpressionWrapper + return &arrayExpressionWrapper +} + +// ArrayExp is array expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as array expression. +// Does not add sql cast to generated sql builder output. +func ArrayExp[E Expression](expression Expression) Array[E] { + return newArrayExpressionWrap[E](expression) +} diff --git a/internal/jet/array_expression_test.go b/internal/jet/array_expression_test.go new file mode 100644 index 00000000..508b5f96 --- /dev/null +++ b/internal/jet/array_expression_test.go @@ -0,0 +1,59 @@ +package jet + +import ( + "github.com/lib/pq" + "testing" +) + +func TestArrayExpressionEQ(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.EQ(table2ColArray), "(table1.col_array_string = table2.col_array_string)") +} + +func TestArrayExpressionNOT_EQ(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.NOT_EQ(table2ColArray), "(table1.col_array_string != table2.col_array_string)") + assertClauseSerialize(t, table1ColStringArray.NOT_EQ(StringArray([]string{"x"})), "(table1.col_array_string != $1)", pq.StringArray{"x"}) +} + +func TestArrayExpressionLT(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.LT(table2ColArray), "(table1.col_array_string < table2.col_array_string)") +} + +func TestArrayExpressionGT(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.GT(table2ColArray), "(table1.col_array_string > table2.col_array_string)") +} + +func TestArrayExpressionLT_EQ(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.LT_EQ(table2ColArray), "(table1.col_array_string <= table2.col_array_string)") +} + +func TestArrayExpressionGT_EQ(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.GT_EQ(table2ColArray), "(table1.col_array_string >= table2.col_array_string)") +} + +func TestArrayExpressionCONTAINS(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.CONTAINS(table2ColArray), "(table1.col_array_string @> table2.col_array_string)") + assertClauseSerialize(t, table1ColStringArray.CONTAINS(StringArray([]string{"x"})), "(table1.col_array_string @> $1)", pq.StringArray{"x"}) +} + +func TestArrayExpressionCONTAINED_BY(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.IS_CONTAINED_BY(table2ColArray), "(table1.col_array_string <@ table2.col_array_string)") + assertClauseSerialize(t, table1ColStringArray.IS_CONTAINED_BY(StringArray([]string{"x"})), "(table1.col_array_string <@ $1)", pq.StringArray{"x"}) +} + +func TestArrayExpressionOVERLAP(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.OVERLAP(table2ColArray), "(table1.col_array_string && table2.col_array_string)") +} + +func TestArrayExpressionCONCAT(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.CONCAT(table2ColArray), "(table1.col_array_string || table2.col_array_string)") + assertClauseSerialize(t, table1ColStringArray.CONCAT(StringArray([]string{"x"})), "(table1.col_array_string || $1)", pq.StringArray{"x"}) +} + +func TestArrayExpressionCONCAT_ELEMENT(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.CONCAT_ELEMENT(StringExp(table2ColArray.AT(Int(1)))), "(table1.col_array_string || table2.col_array_string[$1])", int64(1)) + assertClauseSerialize(t, table1ColStringArray.CONCAT_ELEMENT(String("x")), "(table1.col_array_string || $1)", "x") +} + +func TestArrayExpressionAT(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.AT(Int(1)), "table1.col_array_string[$1]", int64(1)) +} diff --git a/internal/jet/column_types.go b/internal/jet/column_types.go index a7320615..2c47b103 100644 --- a/internal/jet/column_types.go +++ b/internal/jet/column_types.go @@ -121,6 +121,46 @@ func IntegerColumn(name string) ColumnInteger { //------------------------------------------------------// +type ColumnArray[E Expression] interface { + Array[E] + Column + + From(subQuery SelectTable) ColumnArray[E] + SET(stringExp Array[E]) ColumnAssigment +} + +type arrayColumnImpl[E Expression] struct { + arrayInterfaceImpl[E] + + ColumnExpressionImpl +} + +func (a arrayColumnImpl[E]) From(subQuery SelectTable) ColumnArray[E] { + newArrayColumn := ArrayColumn[E](a.name) + newArrayColumn.setTableName(a.tableName) + newArrayColumn.setSubQuery(subQuery) + + return newArrayColumn +} + +func (a *arrayColumnImpl[E]) SET(stringExp Array[E]) ColumnAssigment { + return columnAssigmentImpl{ + column: a, + expression: stringExp, + } +} + +// StringColumn creates named string column. +func ArrayColumn[E Expression](name string) ColumnArray[E] { + arrayColumn := &arrayColumnImpl[E]{} + arrayColumn.arrayInterfaceImpl.parent = arrayColumn + arrayColumn.ColumnExpressionImpl = NewColumnImpl(name, "", arrayColumn) + + return arrayColumn +} + +//------------------------------------------------------// + // ColumnString is interface for SQL text, character, character varying // bytea, uuid columns and enums types. type ColumnString interface { diff --git a/internal/jet/column_types_test.go b/internal/jet/column_types_test.go index 059d722d..38d9e96b 100644 --- a/internal/jet/column_types_test.go +++ b/internal/jet/column_types_test.go @@ -1,6 +1,7 @@ package jet import ( + "github.com/lib/pq" "testing" ) @@ -8,6 +9,42 @@ var subQuery = &selectTableImpl{ alias: "sub_query", } +func TestNewArrayColumnString(t *testing.T) { + stringArrayColumn := ArrayColumn[StringExpression]("colArray").From(subQuery) + assertClauseSerialize(t, stringArrayColumn, `sub_query."colArray"`) + assertClauseSerialize(t, stringArrayColumn.EQ(StringArray([]string{"X"})), `(sub_query."colArray" = $1)`, pq.StringArray{"X"}) + assertProjectionSerialize(t, stringArrayColumn, `sub_query."colArray" AS "colArray"`) + + arrayColumn2 := table1ColStringArray.From(subQuery) + assertClauseSerialize(t, arrayColumn2, `sub_query."table1.col_array_string"`) + assertClauseSerialize(t, arrayColumn2.EQ(StringArray([]string{"X"})), `(sub_query."table1.col_array_string" = $1)`, pq.StringArray{"X"}) + assertProjectionSerialize(t, arrayColumn2, `sub_query."table1.col_array_string" AS "table1.col_array_string"`) +} + +func TestNewArrayColumnBool(t *testing.T) { + boolArrayColumn := ArrayColumn[BoolExpression]("colArrayBool").From(subQuery) + assertClauseSerialize(t, boolArrayColumn, `sub_query."colArrayBool"`) + assertClauseSerialize(t, boolArrayColumn.EQ(BoolArray([]bool{true})), `(sub_query."colArrayBool" = $1)`, pq.BoolArray{true}) + assertProjectionSerialize(t, boolArrayColumn, `sub_query."colArrayBool" AS "colArrayBool"`) + + arrayColumn2 := table1ColBoolArray.From(subQuery) + assertClauseSerialize(t, arrayColumn2, `sub_query."table1.col_array_bool"`) + assertClauseSerialize(t, arrayColumn2.EQ(BoolArray([]bool{true})), `(sub_query."table1.col_array_bool" = $1)`, pq.BoolArray{true}) + assertProjectionSerialize(t, arrayColumn2, `sub_query."table1.col_array_bool" AS "table1.col_array_bool"`) +} + +func TestNewArrayColumnInteger(t *testing.T) { + intArrayColumn := ArrayColumn[IntegerExpression]("colArrayInt").From(subQuery) + assertClauseSerialize(t, intArrayColumn, `sub_query."colArrayInt"`) + assertClauseSerialize(t, intArrayColumn.EQ(Int32Array([]int32{42})), `(sub_query."colArrayInt" = $1)`, pq.Int32Array{42}) + assertProjectionSerialize(t, intArrayColumn, `sub_query."colArrayInt" AS "colArrayInt"`) + + arrayColumn2 := table1ColIntArray.From(subQuery) + assertClauseSerialize(t, arrayColumn2, `sub_query."table1.col_array_int"`) + assertClauseSerialize(t, arrayColumn2.EQ(Int32Array([]int32{42})), `(sub_query."table1.col_array_int" = $1)`, pq.Int32Array{42}) + assertProjectionSerialize(t, arrayColumn2, `sub_query."table1.col_array_int" AS "table1.col_array_int"`) +} + func TestNewBoolColumn(t *testing.T) { boolColumn := BoolColumn("colBool").From(subQuery) assertClauseSerialize(t, boolColumn, `sub_query."colBool"`) diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 05b1797f..e1ec13bf 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -316,6 +316,29 @@ func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder, } } +//type arraySubscriptExpression struct { +// ExpressionInterfaceImpl +// array Expression +// subscript IntegerExpression +//} +// +//func (a arraySubscriptExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +// if !contains(options, NoWrap) { +// out.WriteString("(") +// } +// a.array.serialize(statement, out, FallTrough(options)...) // FallTrough here because complexExpression is just a wrapper +// out.WriteString("[") +// a.subscript.serialize(statement, out, FallTrough(options)...) // FallTrough here because complexExpression is just a wrapper +// out.WriteString("]") +// if !contains(options, NoWrap) { +// out.WriteString(")") +// } +//} + +func arraySubscriptExpr(array Expression, subscript IntegerExpression) Expression { + return CustomExpression(array, Token("["), subscript, Token("]")) +} + type skipParenthesisWrap struct { Expression } diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 7e498806..25e00732 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -651,6 +651,76 @@ func LEAST(value Expression, values ...Expression) Expression { return NewFunc("LEAST", allValues, nil) } +// -------------------- Array Expressions Functions ------------------// + +// ANY should be used in combination with a boolean operator. The result of ANY is "true" if any true result is obtained +func ANY[E Expression](arr Array[E]) E { + return arrayElementTypeCaster(arr, Func("ANY", arr)) +} + +// ALL should be used in combination with a boolean operator. TThe result of ALL is “true” if all comparisons yield true +func ALL[E Expression](arr Array[E]) E { + return arrayElementTypeCaster(arr, Func("ALL", arr)) +} + +func ARRAY_APPEND[E Expression](arr Array[E], el E) Array[E] { + return arrayTypeCaster[E](arr, Func("array_append", arr, el)) +} + +func ARRAY_CAT[E Expression](arr1, arr2 Array[E]) Array[E] { + return arrayTypeCaster[E](arr1, Func("array_cat", arr1, arr2)) +} + +func ARRAY_PREPEND[E Expression](el E, arr Array[E]) Array[E] { + return ArrayExp[E](Func("array_prepend", el, arr)) +} + +func ARRAY_LENGTH[E Expression](arr Array[E], el IntegerExpression) IntegerExpression { + return IntExp(Func("array_length", arr, el)) +} + +func arrayTypeCaster[E Expression](arrayExp Expression, exp Expression) Array[E] { + var i Expression + switch arrayExp.(type) { + case Array[StringExpression]: + i = ArrayExp[StringExpression](exp) + case Array[Int4Expression]: + i = ArrayExp[Int4Expression](exp) + case Array[Int8Expression]: + i = ArrayExp[Int8Expression](exp) + case Array[IntegerExpression]: + i = ArrayExp[IntegerExpression](exp) + case Array[BoolExpression]: + i = ArrayExp[BoolExpression](exp) + } + return i.(Array[E]) +} + +func arrayElementTypeCaster[E Expression](arrayExp Array[E], exp Expression) E { + var i Expression + switch arrayExp.(type) { + case Array[StringExpression]: + i = StringExp(exp) + case Array[IntegerExpression]: + i = IntExp(exp) + case Array[BoolExpression]: + i = BoolExp(exp) + } + + return i.(E) +} + +func ARRAY[E Expression](elems ...E) Array[E] { + var args = make([]Serializer, len(elems)) + for i, each := range elems { + args[i] = each + } + return ArrayExp[E](CustomExpression(Token("ARRAY["), ListSerializer{ + Serializers: args, + Separator: ",", + }, Token("]"))) +} + //--------------------------------------------------------------------// type funcExpressionImpl struct { diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index d6f0b415..f4bae09d 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -2,6 +2,7 @@ package jet import ( "fmt" + "github.com/lib/pq" "time" ) @@ -160,6 +161,24 @@ func Decimal(value string) FloatExpression { return &floatLiteral } +// ---------------------------------------------------// + +func BoolArray(values []bool) Array[BoolExpression] { + return ArrayExp[BoolExpression](literal(pq.BoolArray(values))) +} + +func Int64Array(values []int64) Array[IntegerExpression] { + return ArrayExp[IntegerExpression](literal(pq.Int64Array(values))) +} + +func Int32Array(values []int32) Array[IntegerExpression] { + return ArrayExp[IntegerExpression](literal(pq.Int32Array(values))) +} + +func StringArray(values []string) Array[StringExpression] { + return ArrayExp[StringExpression](literal(pq.StringArray(values))) +} + // ---------------------------------------------------// type stringLiteral struct { stringInterfaceImpl diff --git a/internal/jet/operators.go b/internal/jet/operators.go index c453c3e0..46b36ec3 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -74,6 +74,11 @@ func Contains(lhs Expression, rhs Expression) BoolExpression { return newBinaryBoolOperatorExpression(lhs, rhs, "@>") } +// IsContainedBy returns a representation of "a <@ b" +func IsContainedBy(lhs Expression, rhs Expression) BoolExpression { + return newBinaryBoolOperatorExpression(lhs, rhs, "<@") +} + // Overlap returns a representation of "a && b" func Overlap(lhs, rhs Expression) BoolExpression { return newBinaryBoolOperatorExpression(lhs, rhs, "&&") diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index 46f47ad4..dce93968 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -7,6 +7,7 @@ import ( "github.com/go-jet/jet/v2/internal/3rdparty/pq" "github.com/go-jet/jet/v2/internal/utils/is" "github.com/google/uuid" + pq2 "github.com/lib/pq" "reflect" "sort" "strconv" @@ -81,11 +82,11 @@ func (s *SQLBuilder) write(data []byte) { } func isPreSeparator(b byte) bool { - return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':' + return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':' || b == '[' } func isPostSeparator(b byte) bool { - return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':' + return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':' || b == '[' || b == ']' } // WriteAlias is used to add alias to output SQL @@ -226,6 +227,8 @@ func argToString(value interface{}) string { case string: return stringQuote(bindVal) + case []string: + return stringArrayQuote(bindVal) case []byte: return stringQuote(string(bindVal)) case uuid.UUID: @@ -253,6 +256,13 @@ func argToString(value interface{}) string { } } +func stringArrayQuote(val []string) string { + // We'll rely on the internals of pq2.StringArray here. We know it will never return an error, and the returned + // value is a string + dv, _ := pq2.StringArray(val).Value() + return dv.(string) +} + func integerTypesToString(value interface{}) string { switch bindVal := value.(type) { case int: @@ -301,3 +311,7 @@ func shouldQuoteIdentifier(identifier string) bool { func stringQuote(value string) string { return `'` + strings.Replace(value, "'", "''", -1) + `'` } + +func stringDoubleQuote(value string) string { + return `"` + strings.Replace(value, `"`, `""`, -1) + `"` +} diff --git a/internal/jet/string_expression_test.go b/internal/jet/string_expression_test.go index 0f461acc..19837d33 100644 --- a/internal/jet/string_expression_test.go +++ b/internal/jet/string_expression_test.go @@ -76,6 +76,14 @@ func TestStringNOT_REGEXP_LIKE(t *testing.T) { assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 NOT REGEXP $1)", "JOHN") } +func TestStringANY_EQ(t *testing.T) { + assertClauseSerialize(t, table2ColStr.EQ(ANY[StringExpression](table1ColStringArray)), "(table2.col_str = ANY(table1.col_array_string))") +} + +func TestStringALL_EQ(t *testing.T) { + assertClauseSerialize(t, table2ColStr.EQ(ALL[StringExpression](table1ColStringArray)), "(table2.col_str = ALL(table1.col_array_string))") +} + func TestStringExp(t *testing.T) { assertClauseSerialize(t, StringExp(table2ColFloat), "table2.col_float") assertClauseSerialize(t, StringExp(table2ColFloat).NOT_LIKE(String("abc")), "(table2.col_float NOT LIKE $1)", "abc") diff --git a/internal/jet/testutils.go b/internal/jet/testutils.go index 70b21c77..0f4ff8a6 100644 --- a/internal/jet/testutils.go +++ b/internal/jet/testutils.go @@ -15,19 +15,22 @@ var defaultDialect = NewDialect(DialectParams{ // just for tests }) var ( - table1Col1 = IntegerColumn("col1") - table1ColInt = IntegerColumn("col_int") - table1ColFloat = FloatColumn("col_float") - table1Col3 = IntegerColumn("col3") - table1ColTime = TimeColumn("col_time") - table1ColTimez = TimezColumn("col_timez") - table1ColTimestamp = TimestampColumn("col_timestamp") - table1ColTimestampz = TimestampzColumn("col_timestampz") - table1ColBool = BoolColumn("col_bool") - table1ColDate = DateColumn("col_date") - table1ColRange = RangeColumn[Int8Expression]("col_range") + table1Col1 = IntegerColumn("col1") + table1ColInt = IntegerColumn("col_int") + table1ColFloat = FloatColumn("col_float") + table1Col3 = IntegerColumn("col3") + table1ColTime = TimeColumn("col_time") + table1ColTimez = TimezColumn("col_timez") + table1ColTimestamp = TimestampColumn("col_timestamp") + table1ColTimestampz = TimestampzColumn("col_timestampz") + table1ColBool = BoolColumn("col_bool") + table1ColDate = DateColumn("col_date") + table1ColRange = RangeColumn[Int8Expression]("col_range") + table1ColStringArray = ArrayColumn[StringExpression]("col_array_string") + table1ColBoolArray = ArrayColumn[BoolExpression]("col_array_bool") + table1ColIntArray = ArrayColumn[IntegerExpression]("col_array_int") ) -var table1 = NewTable("db", "table1", "", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColRange, table1ColTimestamp, table1ColTimestampz) +var table1 = NewTable("db", "table1", "", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColRange, table1ColTimestamp, table1ColTimestampz, table1ColStringArray, table1ColBoolArray, table1ColIntArray) var ( table2Col3 = IntegerColumn("col3") @@ -42,8 +45,9 @@ var ( table2ColTimestampz = TimestampzColumn("col_timestampz") table2ColDate = DateColumn("col_date") table2ColRange = RangeColumn[Int8Expression]("col_range") + table2ColArray = ArrayColumn[StringExpression]("col_array_string") ) -var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColRange, table2ColTimestamp, table2ColTimestampz) +var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColRange, table2ColTimestamp, table2ColTimestampz, table2ColArray) var ( table3Col1 = IntegerColumn("col1") diff --git a/postgres/columns.go b/postgres/columns.go index 819da380..6d3239a8 100644 --- a/postgres/columns.go +++ b/postgres/columns.go @@ -101,6 +101,24 @@ type ColumnInt8Range jet.ColumnRange[jet.Int8Expression] // Int8RangeColumn creates named range with range column var Int8RangeColumn = jet.RangeColumn[jet.Int8Expression] +// ColumnStringArray is interface of column +type ColumnStringArray jet.ColumnArray[StringExpression] + +// StringArrayColumn creates named string array column +var StringArrayColumn = jet.ArrayColumn[StringExpression] + +// ColumnIntegerArray is interface of column +type ColumnIntegerArray jet.ColumnArray[IntegerExpression] + +// IntegerArrayColumn creates named integer array column +var IntegerArrayColumn = jet.ArrayColumn[IntegerExpression] + +// ColumnBoolArray is interface of column +type ColumnBoolArray jet.ColumnArray[BoolExpression] + +// BoolArrayColumn creates named bool array column +var BoolArrayColumn = jet.ArrayColumn[BoolExpression] + //------------------------------------------------------// // ColumnInterval is interface of PostgreSQL interval columns. diff --git a/postgres/expressions.go b/postgres/expressions.go index 98729100..8fa67258 100644 --- a/postgres/expressions.go +++ b/postgres/expressions.go @@ -9,15 +9,24 @@ type Expression = jet.Expression // BoolExpression interface type BoolExpression = jet.BoolExpression +// BoolArrayExpression interface +type BoolArrayExpression = jet.Array[BoolExpression] + // StringExpression interface type StringExpression = jet.StringExpression +// StringArrayExpression interface +type StringArrayExpression = jet.Array[StringExpression] + // NumericExpression interface type NumericExpression = jet.NumericExpression // IntegerExpression interface type IntegerExpression = jet.IntegerExpression +// IntegerArrayExpression interface +type IntegerArrayExpression = jet.Array[IntegerExpression] + // FloatExpression is interface type FloatExpression = jet.FloatExpression diff --git a/postgres/functions.go b/postgres/functions.go index 7b6d1e16..aee52dc2 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -265,7 +265,7 @@ var TO_ASCII = jet.TO_ASCII // TO_HEX converts number to its equivalent hexadecimal representation var TO_HEX = jet.TO_HEX -//----------Data Type Formatting Functions ----------------------// +//---------- Range Functions ----------------------// // LOWER_BOUND returns range expressions lower bound func LOWER_BOUND[T Expression](expression jet.Range[T]) T { @@ -277,7 +277,40 @@ func UPPER_BOUND[T Expression](expression jet.Range[T]) T { return jet.UPPER_BOUND[T](expression) } -//----------Data Type Formatting Functions ----------------------// +// ---------- Array Functions ----------------------// + +// ANY should be used in combination with a boolean operator. The result of ANY is "true" if any true result is obtained +func ANY[T Expression](expression jet.Array[T]) T { + return jet.ANY[T](expression) +} + +// ALL should be used in combination with a boolean operator. TThe result of ALL is “true” if all comparisons yield true +func ALL[T Expression](expression jet.Array[T]) T { + return jet.ALL[T](expression) +} + +func ARRAY_APPEND[T Expression](arr jet.Array[T], el T) jet.Array[T] { + return jet.ARRAY_APPEND(arr, el) +} + +func ARRAY_CAT[T Expression](arr1, arr2 jet.Array[T]) jet.Array[T] { + return jet.ARRAY_CAT(arr1, arr2) +} + +func ARRAY_LENGTH[T Expression](expression jet.Array[T], dim IntegerExpression) IntegerExpression { + return jet.ARRAY_LENGTH(expression, dim) +} + +func ARRAY_PREPEND[T Expression](el T, arr jet.Array[T]) jet.Array[T] { + return jet.ARRAY_PREPEND(el, arr) +} + +// ARRAY constructor +func ARRAY[T Expression](elems ...T) jet.Array[T] { + return jet.ARRAY[T](elems...) +} + +//---------- Data Type Formatting Functions ----------------------// // TO_CHAR converts expression to string with format var TO_CHAR = jet.TO_CHAR diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index 25300c27..4aa9c503 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -175,27 +175,30 @@ RETURNING table1.col1 AS "table1.col1", } func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) { - stmt := table1.INSERT(table1Col1, table1ColBool). - VALUES("one", "two"). - VALUES("1", "2"). + stmt := table1.INSERT(table1Col1, table1ColBool, table1ColStringArray). + VALUES("one", "two", "three"). + VALUES("1", "2", "3"). ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE( SET(table1ColBool.SET(Bool(false)), table2ColInt.SET(Int(1)), - ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))), + table1ColStringArray.SET(StringArray([]string{"one"})), + ColumnList{table1Col1, table1ColBool, table1ColStringArray}.SET(jet.ROW(Int(2), String("two"), StringArray([]string{"two"}))), ).WHERE(table1Col1.GT(Int(2))), ). - RETURNING(table1Col1, table1ColBool) + RETURNING(table1Col1, table1ColBool, table1ColStringArray) assertDebugStatementSql(t, stmt, ` -INSERT INTO db.table1 (col1, col_bool) -VALUES ('one', 'two'), - ('1', '2') +INSERT INTO db.table1 (col1, col_bool, col_string_array) +VALUES ('one', 'two', 'three'), + ('1', '2', '3') ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE SET col_bool = FALSE::boolean, col_int = 1, - (col1, col_bool) = ROW(2, 'two'::text) + col_string_array = '{"one"}', + (col1, col_bool, col_string_array) = ROW(2, 'two'::text, '{"two"}') WHERE table1.col1 > 2 RETURNING table1.col1 AS "table1.col1", - table1.col_bool AS "table1.col_bool"; + table1.col_bool AS "table1.col_bool", + table1.col_string_array AS "table1.col_string_array"; `) } diff --git a/postgres/literal.go b/postgres/literal.go index e3a95b3b..30190a58 100644 --- a/postgres/literal.go +++ b/postgres/literal.go @@ -11,6 +11,11 @@ func Bool(value bool) BoolExpression { return CAST(jet.Bool(value)).AS_BOOL() } +// BoolArray creates new bool array literal expression +func BoolArray(elements []bool) BoolArrayExpression { + return jet.BoolArray(elements) +} + // Int is constructor for 64 bit signed integer expressions literals. var Int = jet.Int @@ -29,11 +34,21 @@ func Int32(value int32) IntegerExpression { return CAST(jet.Int32(value)).AS_INTEGER() } +// Int32Array creates new 32 bit signed integer literal expression +func Int32Array(elements []int32) IntegerArrayExpression { + return jet.Int32Array(elements) +} + // Int64 is constructor for 64 bit signed integer expressions literals. func Int64(value int64) IntegerExpression { return CAST(jet.Int(value)).AS_BIGINT() } +// Int64Array creates new 64 bit signed integer literal expression +func Int64Array(elements []int64) IntegerArrayExpression { + return jet.Int64Array(elements) +} + // Uint8 is constructor for 8 bit unsigned integer expressions literals. func Uint8(value uint8) IntegerExpression { return CAST(jet.Uint8(value)).AS_SMALLINT() @@ -65,6 +80,11 @@ func String(value string) StringExpression { return CAST(jet.String(value)).AS_TEXT() } +// StringArray creates new string array literal expression +func StringArray(elements []string) StringArrayExpression { + return jet.StringArray(elements) +} + // Json creates new json literal expression func Json(value interface{}) StringExpression { switch value.(type) { diff --git a/postgres/utils_test.go b/postgres/utils_test.go index 96bb13b0..d89b17bd 100644 --- a/postgres/utils_test.go +++ b/postgres/utils_test.go @@ -18,6 +18,8 @@ var table1ColBool = BoolColumn("col_bool") var table1ColDate = DateColumn("col_date") var table1ColInterval = IntervalColumn("col_interval") var table1ColRange = Int8RangeColumn("col_range") +var table1ColStringArray = StringArrayColumn("col_string_array") +var table1ColIntArray = IntegerArrayColumn("col_int_array") var table1 = NewTable( "db", @@ -34,6 +36,8 @@ var table1 = NewTable( table1ColTimestampz, table1ColInterval, table1ColRange, + table1ColStringArray, + table1ColIntArray, ) var table2Col3 = IntegerColumn("col3") @@ -49,8 +53,10 @@ var table2ColTimestampz = TimestampzColumn("col_timestampz") var table2ColDate = DateColumn("col_date") var table2ColInterval = IntervalColumn("col_interval") var table2ColRange = Int8RangeColumn("col_range") +var table2ColStringArray = StringArrayColumn("col_string_array") +var table2ColIntArray = IntegerArrayColumn("col_int_array") -var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz, table2ColInterval, table2ColRange) +var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz, table2ColInterval, table2ColRange, table2ColStringArray, table2ColIntArray) var table3Col1 = IntegerColumn("col1") var table3ColInt = IntegerColumn("col_int") diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index 9b3af507..7c1c116a 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -1,4 +1,3 @@ -version: '3' services: postgres: image: postgres:14.1 @@ -13,7 +12,7 @@ services: - ./testdata/init/postgres:/docker-entrypoint-initdb.d mysql: - image: mysql:8.0.27 + image: mysql/mysql-server:8.0.27 command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1'] restart: always environment: diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index d41feee2..b15056f1 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -2,6 +2,7 @@ package postgres import ( "database/sql" + "github.com/lib/pq" "testing" "time" @@ -1361,11 +1362,11 @@ var allTypesRow0 = model.AllTypes{ JSON: `{"a": 1, "b": 3}`, JsonbPtr: testutils.StringPtr(`{"a": 1, "b": 3}`), Jsonb: `{"a": 1, "b": 3}`, - IntegerArrayPtr: testutils.StringPtr("{1,2,3}"), - IntegerArray: "{1,2,3}", - TextArrayPtr: testutils.StringPtr("{breakfast,consulting}"), - TextArray: "{breakfast,consulting}", - JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`, + IntegerArrayPtr: &pq.Int32Array{1, 2, 3}, + IntegerArray: pq.Int32Array{1, 2, 3}, + TextArrayPtr: &pq.StringArray{"breakfast", "consulting"}, + TextArray: pq.StringArray{"breakfast", "consulting"}, + JsonbArray: pq.StringArray{`{"a": 1, "b": 2}`, `{"a": 3, "b": 4}`}, TextMultiDimArrayPtr: testutils.StringPtr("{{meeting,lunch},{training,presentation}}"), TextMultiDimArray: "{{meeting,lunch},{training,presentation}}", MoodPtr: &moodSad, @@ -1430,10 +1431,10 @@ var allTypesRow1 = model.AllTypes{ JsonbPtr: nil, Jsonb: `{"a": 1, "b": 3}`, IntegerArrayPtr: nil, - IntegerArray: "{1,2,3}", + IntegerArray: pq.Int32Array{1, 2, 3}, TextArrayPtr: nil, - TextArray: "{breakfast,consulting}", - JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`, + TextArray: pq.StringArray{"breakfast", "consulting"}, + JsonbArray: pq.StringArray{`{"a": 1, "b": 2}`, `{"a": 3, "b": 4}`}, TextMultiDimArrayPtr: nil, TextMultiDimArray: "{{meeting,lunch},{training,presentation}}", MoodPtr: nil, diff --git a/tests/postgres/array_test.go b/tests/postgres/array_test.go new file mode 100644 index 00000000..17f95891 --- /dev/null +++ b/tests/postgres/array_test.go @@ -0,0 +1,277 @@ +package postgres + +import ( + "database/sql" + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/table" + "github.com/google/go-cmp/cmp" + "github.com/lib/pq" + "github.com/stretchr/testify/require" + "math/big" + "testing" +) + +func TestArrayTableSelect(t *testing.T) { + skipForCockroachDB(t) + + textArray := StringArray([]string{"a"}) + boolArray := BoolArray([]bool{true}) + int4Array := Int32Array([]int32{1, 2}) + int8Array := Int64Array([]int64{10, 11}) + + query := SELECT( + SampleArrays.AllColumns, + SampleArrays.TextArray.EQ(SampleArrays.TextArray).AS("sample.text_eq"), + SampleArrays.BoolArray.EQ(boolArray).AS("sample.bool_eq"), + SampleArrays.TextArray.NOT_EQ(textArray).AS("sample.text_neq"), + SampleArrays.Int4Array.LT(int4Array).IS_TRUE().AS("sample.int4_lt"), + SampleArrays.Int8Array.LT_EQ(int8Array).IS_FALSE().AS("sample.int8_lteq"), + SampleArrays.TextArray.GT(textArray).AS("sample.text_gt"), + SampleArrays.Int4Array.GT_EQ(int4Array).AS("sample.bool_gteq"), + Int32(22).EQ(ANY[IntegerExpression](SampleArrays.Int4Array)).AS("sample.int4_eq_any"), + Int32(22).NOT_EQ(ANY[IntegerExpression](SampleArrays.Int4Array)).AS("sample.int4_neq_any"), + Int32(22).EQ(ALL[IntegerExpression](SampleArrays.Int4Array)).AS("sample.int4_eq_all"), + SampleArrays.Int8Array.CONTAINS(Int64Array([]int64{75364})).AS("sample.int8cont"), + SampleArrays.Int8Array.IS_CONTAINED_BY(Int64Array([]int64{75364})).AS("sample.int8cont_by"), + SampleArrays.Int4Array.OVERLAP(int4Array).AS("sample.int4_overlap"), + SampleArrays.BoolArray.CONCAT(boolArray).AS("sample.bool_concat"), + SampleArrays.TextArray.CONCAT_ELEMENT(String("z")).AS("sample.text_concat_el"), + SampleArrays.TextArray.AT(Int32(1)).AS("sample.text_at"), + ARRAY_APPEND[StringExpression](SampleArrays.TextArray, String("after")).AS("sample.text_append"), + ARRAY_CAT[StringExpression](SampleArrays.TextArray, textArray).AS("sample.text_cat"), + ARRAY_LENGTH[StringExpression](SampleArrays.TextArray, Int32(1)).AS("sample.text_length"), + ARRAY_PREPEND[StringExpression](String("before"), SampleArrays.TextArray).AS("sample.text_prepend"), + ).FROM( + SampleArrays, + ).WHERE( + SampleArrays.BoolArray.CONTAINS(BoolArray([]bool{true})), + ) + + testutils.AssertStatementSql(t, query, ` +SELECT sample_arrays.text_array AS "sample_arrays.text_array", + sample_arrays.bool_array AS "sample_arrays.bool_array", + sample_arrays.int4_array AS "sample_arrays.int4_array", + sample_arrays.int8_array AS "sample_arrays.int8_array", + (sample_arrays.text_array = sample_arrays.text_array) AS "sample.text_eq", + (sample_arrays.bool_array = $1) AS "sample.bool_eq", + (sample_arrays.text_array != $2) AS "sample.text_neq", + (sample_arrays.int4_array < $3) IS TRUE AS "sample.int4_lt", + (sample_arrays.int8_array <= $4) IS FALSE AS "sample.int8_lteq", + (sample_arrays.text_array > $5) AS "sample.text_gt", + (sample_arrays.int4_array >= $6) AS "sample.bool_gteq", + ($7::integer = ANY(sample_arrays.int4_array)) AS "sample.int4_eq_any", + ($8::integer != ANY(sample_arrays.int4_array)) AS "sample.int4_neq_any", + ($9::integer = ALL(sample_arrays.int4_array)) AS "sample.int4_eq_all", + (sample_arrays.int8_array @> $10) AS "sample.int8cont", + (sample_arrays.int8_array <@ $11) AS "sample.int8cont_by", + (sample_arrays.int4_array && $12) AS "sample.int4_overlap", + (sample_arrays.bool_array || $13) AS "sample.bool_concat", + (sample_arrays.text_array || $14::text) AS "sample.text_concat_el", + sample_arrays.text_array[$15::integer] AS "sample.text_at", + array_append(sample_arrays.text_array, $16::text) AS "sample.text_append", + array_cat(sample_arrays.text_array, $17) AS "sample.text_cat", + array_length(sample_arrays.text_array, $18::integer) AS "sample.text_length", + array_prepend($19::text, sample_arrays.text_array) AS "sample.text_prepend" +FROM test_sample.sample_arrays +WHERE sample_arrays.bool_array @> $20; +`) + + type sample struct { + model.SampleArrays + TextEq bool + BoolEq bool + TextNeq bool + Int4Lt bool + Int8Lteq bool + TextGt bool + BoolGteq bool + Int4EqAny bool + Int4NeqAny bool + Int4EqAll bool + Int8Cont bool + Int8ContBy bool + Int4Overlap bool + BoolConcat pq.BoolArray + TextConcatEl pq.StringArray + TextAt string + TextAppend pq.StringArray + TextCat pq.StringArray + TextLength int32 + TextPrepend pq.StringArray + } + + var dest sample + err := query.Query(db, &dest) + require.NoError(t, err) + + expectedRow := sample{ + SampleArrays: sampleArrayRow, + TextEq: true, + BoolEq: true, + TextNeq: true, + Int4Lt: false, + Int8Lteq: true, + TextGt: true, + BoolGteq: true, + Int4EqAny: false, + Int4NeqAny: true, + Int4EqAll: false, + Int8Cont: false, + Int8ContBy: false, + Int4Overlap: true, + BoolConcat: pq.BoolArray{true, true}, + TextConcatEl: pq.StringArray{"a", "b", "z"}, + TextAt: "a", + TextAppend: pq.StringArray{"a", "b", "after"}, + TextCat: pq.StringArray{"a", "b", "a"}, + TextLength: 2, + TextPrepend: pq.StringArray{"before", "a", "b"}, + } + + testutils.AssertDeepEqual(t, dest, expectedRow, cmp.AllowUnexported(big.Int{})) + requireLogged(t, query) +} + +func TestArraySelectColumnsFromSubQuery(t *testing.T) { + skipForCockroachDB(t) + + subQuery := SELECT( + SampleArrays.AllColumns, + SampleArrays.Int4Array.AS("array4"), + ).FROM( + SampleArrays, + ).AsTable("sub_query") + + int4Array := IntegerArrayColumn("array4").From(subQuery) + + stmt := SELECT( + subQuery.AllColumns(), + int4Array, + ).FROM( + subQuery, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT sub_query."sample_arrays.text_array" AS "sample_arrays.text_array", + sub_query."sample_arrays.bool_array" AS "sample_arrays.bool_array", + sub_query."sample_arrays.int4_array" AS "sample_arrays.int4_array", + sub_query."sample_arrays.int8_array" AS "sample_arrays.int8_array", + sub_query.array4 AS "array4", + sub_query.array4 AS "array4" +FROM ( + SELECT sample_arrays.text_array AS "sample_arrays.text_array", + sample_arrays.bool_array AS "sample_arrays.bool_array", + sample_arrays.int4_array AS "sample_arrays.int4_array", + sample_arrays.int8_array AS "sample_arrays.int8_array", + sample_arrays.int4_array AS "array4" + FROM test_sample.sample_arrays + ) AS sub_query; +`) + + var dest struct { + model.SampleArrays + Array4 pq.Int32Array + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertDeepEqual(t, dest.SampleArrays.Int4Array, sampleArrayRow.Int4Array) + testutils.AssertDeepEqual(t, dest.SampleArrays.Int8Array, sampleArrayRow.Int8Array) + testutils.AssertDeepEqual(t, dest.Array4, sampleArrayRow.Int4Array) +} + +func TestArrayTable_InsertColumn(t *testing.T) { + skipForCockroachDB(t) + + insertQuery := SampleArrays.INSERT(SampleArrays.AllColumns). + VALUES( + ARRAY(String("A"), String("B")), + ARRAY(Bool(true)), + ARRAY(Int32(1)), + ARRAY(Int64(2)), + ). + MODEL( + sampleArrayRow, + ). + RETURNING(SampleArrays.AllColumns) + + expectedQuery := ` +INSERT INTO test_sample.sample_arrays (text_array, bool_array, int4_array, int8_array) +VALUES (ARRAY['A'::text,'B'::text], ARRAY[TRUE::boolean], ARRAY[1::integer], ARRAY[2::bigint]), + ('{"a","b"}', '{t}', '{1,2,3}', '{10,11,12}') +RETURNING sample_arrays.text_array AS "sample_arrays.text_array", + sample_arrays.bool_array AS "sample_arrays.bool_array", + sample_arrays.int4_array AS "sample_arrays.int4_array", + sample_arrays.int8_array AS "sample_arrays.int8_array"; +` + testutils.AssertDebugStatementSql(t, insertQuery, expectedQuery) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []model.SampleArrays + err := insertQuery.Query(tx, &dest) + require.NoError(t, err) + require.Len(t, dest, 2) + testutils.AssertDeepEqual(t, sampleArrayRow, dest[1], cmp.AllowUnexported(big.Int{})) + }) +} + +func TestArrayTableUpdate(t *testing.T) { + skipForCockroachDB(t) + + t.Run("using model", func(t *testing.T) { + stmt := SampleArrays.UPDATE(SampleArrays.AllColumns). + MODEL(sampleArrayRow). + WHERE(String("a").EQ(ANY[StringExpression](SampleArrays.TextArray))). + RETURNING(SampleArrays.AllColumns) + + testutils.AssertStatementSql(t, stmt, ` +UPDATE test_sample.sample_arrays +SET (text_array, bool_array, int4_array, int8_array) = ($1, $2, $3, $4) +WHERE $5::text = ANY(sample_arrays.text_array) +RETURNING sample_arrays.text_array AS "sample_arrays.text_array", + sample_arrays.bool_array AS "sample_arrays.bool_array", + sample_arrays.int4_array AS "sample_arrays.int4_array", + sample_arrays.int8_array AS "sample_arrays.int8_array"; +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []model.SampleArrays + + err := stmt.Query(tx, &dest) + require.NoError(t, err) + require.Len(t, dest, 1) + testutils.AssertDeepEqual(t, sampleArrayRow, dest[0], cmp.AllowUnexported(big.Int{})) + }) + }) + + t.Run("update using SET", func(t *testing.T) { + stmt := SampleArrays.UPDATE(). + SET( + SampleArrays.Int4Array.SET(ARRAY(Int32(-10), Int32(11))), + SampleArrays.Int8Array.SET(ARRAY(Int64(-1200), Int64(7800))), + ). + WHERE(String("a").EQ(ANY[StringExpression](SampleArrays.TextArray))) + + testutils.AssertDebugStatementSql(t, stmt, ` +UPDATE test_sample.sample_arrays +SET int4_array = ARRAY[-10::integer,11::integer], + int8_array = ARRAY[-1200::bigint,7800::bigint] +WHERE 'a'::text = ANY(sample_arrays.text_array); +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + testutils.AssertExec(t, stmt, tx, 1) + }) + }) + +} + +var sampleArrayRow = model.SampleArrays{ + TextArray: pq.StringArray([]string{"a", "b"}), + BoolArray: pq.BoolArray([]bool{true}), + Int4Array: pq.Int32Array([]int32{1, 2, 3}), + Int8Array: pq.Int64Array([]int64{10, 11, 12}), +} diff --git a/tests/postgres/generator_template_test.go b/tests/postgres/generator_template_test.go index 4d87295d..819a19bd 100644 --- a/tests/postgres/generator_template_test.go +++ b/tests/postgres/generator_template_test.go @@ -447,7 +447,7 @@ func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) { require.Contains(t, data, "\"database/sql\"") require.Contains(t, data, "Description sql.NullString") require.Contains(t, data, "ReleaseYear sql.NullInt32") - require.Contains(t, data, "SpecialFeatures sql.NullString") + require.Contains(t, data, "SpecialFeatures *pq.StringArray") } func TestGeneratorTemplate_SQLBuilder_ChangeColumnTypes(t *testing.T) { diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index fe1407f3..05254e6c 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -602,12 +602,12 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { testutils.AssertFileNamesEqual(t, modelDir, "all_types.go", "all_types_view.go", "employee.go", "link.go", "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go", "floats.go", "people.go", - "components.go", "vulnerabilities.go", "all_types_materialized_view.go", "sample_ranges.go") + "components.go", "vulnerabilities.go", "all_types_materialized_view.go", "sample_ranges.go", "sample_arrays.go") testutils.AssertFileContent(t, modelDir+"/all_types.go", allTypesModelContent) testutils.AssertFileNamesEqual(t, tableDir, "all_types.go", "employee.go", "link.go", "person.go", "person_phone.go", "weird_names_table.go", "user.go", "floats.go", "people.go", "table_use_schema.go", - "components.go", "vulnerabilities.go", "sample_ranges.go") + "components.go", "vulnerabilities.go", "sample_ranges.go", "sample_arrays.go") testutils.AssertFileContent(t, tableDir+"/all_types.go", allTypesTableContent) testutils.AssertFileContent(t, tableDir+"/sample_ranges.go", sampleRangeTableContent) @@ -677,6 +677,7 @@ package model import ( "github.com/google/uuid" + "github.com/lib/pq" "time" ) @@ -735,11 +736,11 @@ type AllTypes struct { JSON string JsonbPtr *string Jsonb string - IntegerArrayPtr *string - IntegerArray string - TextArrayPtr *string - TextArray string - JsonbArray string + IntegerArrayPtr *pq.Int32Array + IntegerArray pq.Int32Array + TextArrayPtr *pq.StringArray + TextArray pq.StringArray + JsonbArray pq.StringArray TextMultiDimArrayPtr *string TextMultiDimArray string MoodPtr *Mood @@ -821,11 +822,11 @@ type allTypesTable struct { JSON postgres.ColumnString JsonbPtr postgres.ColumnString Jsonb postgres.ColumnString - IntegerArrayPtr postgres.ColumnString - IntegerArray postgres.ColumnString - TextArrayPtr postgres.ColumnString - TextArray postgres.ColumnString - JsonbArray postgres.ColumnString + IntegerArrayPtr postgres.ColumnIntegerArray + IntegerArray postgres.ColumnIntegerArray + TextArrayPtr postgres.ColumnStringArray + TextArray postgres.ColumnStringArray + JsonbArray postgres.ColumnStringArray TextMultiDimArrayPtr postgres.ColumnString TextMultiDimArray postgres.ColumnString MoodPtr postgres.ColumnString @@ -924,11 +925,11 @@ func newAllTypesTableImpl(schemaName, tableName, alias string) allTypesTable { JSONColumn = postgres.StringColumn("json") JsonbPtrColumn = postgres.StringColumn("jsonb_ptr") JsonbColumn = postgres.StringColumn("jsonb") - IntegerArrayPtrColumn = postgres.StringColumn("integer_array_ptr") - IntegerArrayColumn = postgres.StringColumn("integer_array") - TextArrayPtrColumn = postgres.StringColumn("text_array_ptr") - TextArrayColumn = postgres.StringColumn("text_array") - JsonbArrayColumn = postgres.StringColumn("jsonb_array") + IntegerArrayPtrColumn = postgres.IntegerArrayColumn("integer_array_ptr") + IntegerArrayColumn = postgres.IntegerArrayColumn("integer_array") + TextArrayPtrColumn = postgres.StringArrayColumn("text_array_ptr") + TextArrayColumn = postgres.StringArrayColumn("text_array") + JsonbArrayColumn = postgres.StringArrayColumn("jsonb_array") TextMultiDimArrayPtrColumn = postgres.StringColumn("text_multi_dim_array_ptr") TextMultiDimArrayColumn = postgres.StringColumn("text_multi_dim_array") MoodPtrColumn = postgres.StringColumn("mood_ptr") diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index 24b5949d..e617c845 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "github.com/lib/pq" "github.com/volatiletech/null/v8" "testing" "time" @@ -967,10 +968,11 @@ func TestScanIntoCustomBaseTypes(t *testing.T) { ReplacementCost MyFloat64 Rating *model.MpaaRating LastUpdate MyTime - SpecialFeatures *MyString + SpecialFeatures pq.StringArray Fulltext MyString } + // We'll skip special features, because it's a slice and it does not implement sql.Scanner stmt := SELECT( Film.AllColumns, ).FROM( @@ -979,14 +981,12 @@ func TestScanIntoCustomBaseTypes(t *testing.T) { Film.FilmID.ASC(), ).LIMIT(3) - var films []model.Film - - err := stmt.Query(db, &films) - require.NoError(t, err) - var myFilms []film + err := stmt.Query(db, &myFilms) + require.NoError(t, err) - err = stmt.Query(db, &myFilms) + var films []model.Film + err = stmt.Query(db, &films) require.NoError(t, err) require.Equal(t, testutils.ToJSON(films), testutils.ToJSON(myFilms)) @@ -1160,7 +1160,7 @@ var film1 = model.Film{ ReplacementCost: 20.99, Rating: &pgRating, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: testutils.StringPtr("{\"Deleted Scenes\",\"Behind the Scenes\"}"), + SpecialFeatures: &pq.StringArray{"Deleted Scenes", "Behind the Scenes"}, Fulltext: "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17", } @@ -1176,7 +1176,7 @@ var film2 = model.Film{ ReplacementCost: 12.99, Rating: &gRating, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: testutils.StringPtr(`{Trailers,"Deleted Scenes"}`), + SpecialFeatures: &pq.StringArray{"Trailers", "Deleted Scenes"}, Fulltext: `'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14`, } diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 048db147..ccd5287b 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + "github.com/lib/pq" "testing" "time" @@ -1837,7 +1838,7 @@ ORDER BY film.film_id ASC; Rating: &gRating, RentalDuration: 3, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: testutils.StringPtr("{Trailers,\"Deleted Scenes\"}"), + SpecialFeatures: &pq.StringArray{"Trailers", "Deleted Scenes"}, Fulltext: "'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14", }) } @@ -2793,7 +2794,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC; err := stmt.Query(db, &dest) require.NoError(t, err) - //jsonSave("./testdata/quick-start-dest.json", dest) + //testutils.SaveJSONFile(dest, "./testdata/results/postgres/quick-start-dest.json") testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/quick-start-dest.json") var dest2 []struct { @@ -2806,7 +2807,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC; err = stmt.Query(db, &dest2) require.NoError(t, err) - //jsonSave("./testdata/quick-start-dest2.json", dest2) + //testutils.SaveJSONFile(dest, "./testdata/results/postgres/quick-start-dest2.json") testutils.AssertJSONFile(t, dest2, "./testdata/results/postgres/quick-start-dest2.json") } @@ -3382,7 +3383,10 @@ func TestRecursionScanNxM(t *testing.T) { "ReplacementCost": 20.99, "Rating": "PG", "LastUpdate": "2013-05-26T14:50:58.951Z", - "SpecialFeatures": "{\"Deleted Scenes\",\"Behind the Scenes\"}", + "SpecialFeatures": [ + "Deleted Scenes", + "Behind the Scenes" + ], "Fulltext": "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17", "Actors": [ { @@ -3406,7 +3410,10 @@ func TestRecursionScanNxM(t *testing.T) { "ReplacementCost": 9.99, "Rating": "R", "LastUpdate": "2013-05-26T14:50:58.951Z", - "SpecialFeatures": "{Trailers,\"Deleted Scenes\"}", + "SpecialFeatures": [ + "Trailers", + "Deleted Scenes" + ], "Fulltext": "'anaconda':1 'australia':18 'confess':2 'dentist':8,11 'display':5 'fight':14 'girl':16 'lacklustur':4 'must':13", "Actors": [ { @@ -3454,7 +3461,10 @@ func TestRecursionScanNxM(t *testing.T) { "ReplacementCost": 20.99, "Rating": "PG", "LastUpdate": "2013-05-26T14:50:58.951Z", - "SpecialFeatures": "{\"Deleted Scenes\",\"Behind the Scenes\"}", + "SpecialFeatures": [ + "Deleted Scenes", + "Behind the Scenes" + ], "Fulltext": "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17", "Actors": null }, @@ -3470,7 +3480,10 @@ func TestRecursionScanNxM(t *testing.T) { "ReplacementCost": 9.99, "Rating": "R", "LastUpdate": "2013-05-26T14:50:58.951Z", - "SpecialFeatures": "{Trailers,\"Deleted Scenes\"}", + "SpecialFeatures": [ + "Trailers", + "Deleted Scenes" + ], "Fulltext": "'anaconda':1 'australia':18 'confess':2 'dentist':8,11 'display':5 'fight':14 'girl':16 'lacklustur':4 'must':13", "Actors": null }