diff --git a/src/pg_schema_column.go b/src/pg_schema_column.go index 2d4cd9e..7ab7921 100644 --- a/src/pg_schema_column.go +++ b/src/pg_schema_column.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "strconv" "strings" "time" @@ -102,6 +103,11 @@ func (pgSchemaColumn *PgSchemaColumn) FormatParquetValue(value string) *string { timestamp := strconv.FormatInt(parsedTime.UnixMilli(), 10) return ×tamp } + case "date": + parsedTime, err := time.Parse("2006-01-02", value) + PanicIfError(err) + date := fmt.Sprintf("%d", parsedTime.Unix()/86400) + return &date default: if strings.HasPrefix(pgSchemaColumn.UdtName, "_") { switch strings.TrimLeft(pgSchemaColumn.UdtName, "_") { @@ -141,7 +147,7 @@ func (pgSchemaColumn PgSchemaColumn) toIcebergSchemaField() IcebergSchemaField { } switch pgSchemaColumn.UdtName { - case "varchar", "char", "text", "jsonb", "json", "uuid": + case "varchar", "char", "text", "jsonb", "json", "uuid", "bpchar": icebergSchemaField.Type = "string" case "int2", "int4", "int8": icebergSchemaField.Type = "int" @@ -187,7 +193,7 @@ func (pgSchemaColumn *PgSchemaColumn) toParquetSchemaField() ParquetSchemaField } switch pgSchemaColumn.UdtName { - case "varchar", "char", "text", "bytea", "jsonb", "json": + case "varchar", "char", "text", "bytea", "jsonb", "json", "bpchar": parquetSchemaField.Type = "BYTE_ARRAY" parquetSchemaField.ConvertedType = "UTF8" case "int2", "int4", "int8": @@ -196,10 +202,14 @@ func (pgSchemaColumn *PgSchemaColumn) toParquetSchemaField() ParquetSchemaField parquetSchemaField.Type = "FLOAT" case "numeric": parquetSchemaField.Type = "FIXED_LEN_BYTE_ARRAY" - parquetSchemaField.Length = pgSchemaColumn.NumericPrecision + "." + pgSchemaColumn.NumericScale parquetSchemaField.ConvertedType = "DECIMAL" parquetSchemaField.Scale = pgSchemaColumn.NumericScale parquetSchemaField.Precision = pgSchemaColumn.NumericPrecision + scale, err := strconv.Atoi(pgSchemaColumn.NumericScale) + PanicIfError(err) + precision, err := strconv.Atoi(pgSchemaColumn.NumericPrecision) + PanicIfError(err) + parquetSchemaField.Length = strconv.Itoa(scale + precision) case "bool": parquetSchemaField.Type = "BOOLEAN" case "uuid": diff --git a/src/proxy.go b/src/proxy.go index f3f471a..caf923d 100644 --- a/src/proxy.go +++ b/src/proxy.go @@ -4,9 +4,13 @@ import ( "context" "database/sql" "errors" + "fmt" + "strconv" + "time" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" + duckDB "github.com/marcboeker/go-duckdb" pgQuery "github.com/pganalyze/pg_query_go/v5" ) @@ -106,11 +110,27 @@ func (proxy *Proxy) generateRowDescription(cols []*sql.ColumnType) *pgproto3.Row } func (proxy *Proxy) generateDataRow(rows *sql.Rows, cols []*sql.ColumnType) (*pgproto3.DataRow, error) { - values := make([][]byte, len(cols)) valuePtrs := make([]interface{}, len(cols)) - - for i := range values { - valuePtrs[i] = &values[i] + for i, col := range cols { + switch col.ScanType().Name() { + case "int32": + var value int32 + valuePtrs[i] = &value + case "int64": + var value int64 + valuePtrs[i] = &value + case "string": + var value string + valuePtrs[i] = &value + case "Time": + var value time.Time + valuePtrs[i] = &value + case "Decimal": + var value duckDB.Decimal + valuePtrs[i] = &value + default: + panic("Unsupported type") + } } err := rows.Scan(valuePtrs...) @@ -118,11 +138,30 @@ func (proxy *Proxy) generateDataRow(rows *sql.Rows, cols []*sql.ColumnType) (*pg return nil, err } - dataRow := pgproto3.DataRow{Values: values} - // Convert values to text format - for i := range values { - dataRow.Values[i] = []byte(string(values[i])) + var values [][]byte + for i, valuePtr := range valuePtrs { + switch value := valuePtr.(type) { + case *int32: + values = append(values, []byte(strconv.Itoa(int(*value)))) + case *int64: + values = append(values, []byte(strconv.Itoa(int(*value)))) + case *string: + values = append(values, []byte(*value)) + case *time.Time: + switch cols[i].DatabaseTypeName() { + case "DATE": + values = append(values, []byte(value.Format("2006-01-02"))) + default: + panic("Unsupported type") + } + case *duckDB.Decimal: + float64Value := (*value).Float64() + values = append(values, []byte(fmt.Sprintf("%v", float64Value))) + default: + panic("Unsupported type") + } } + dataRow := pgproto3.DataRow{Values: values} return &dataRow, nil }