Skip to content

Commit

Permalink
Validate schema when parsing PG function calls, extract constants
Browse files Browse the repository at this point in the history
  • Loading branch information
exAspArk committed Jan 6, 2025
1 parent 1c5d54e commit ef143d9
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 85 deletions.
5 changes: 5 additions & 0 deletions src/custom_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,8 @@ func (pgSchemaTable PgSchemaTable) ToIcebergSchemaTable() IcebergSchemaTable {
Table: pgSchemaTable.Table,
}
}

type PgSchemaFunction struct {
Schema string
Function string
}
37 changes: 37 additions & 0 deletions src/pg_constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package main

const (
PG_SCHEMA_INFORMATION_SCHEMA = "information_schema"
PG_SCHEMA_PG_CATALOG = "pg_catalog"
PG_SCHEMA_PUBLIC = "public"

PG_FUNCTION_ARRAY_UPPER = "array_upper"
PG_FUNCTION_PG_GET_INDEXDEF = "pg_get_indexdef"
PG_FUNCTION_PG_GET_KEYWORDS = "pg_get_keywords"
PG_FUNCTION_PG_IS_IN_RECOVERY = "pg_is_in_recovery"
PG_FUNCTION_PG_SHOW_ALL_SETTINGS = "pg_show_all_settings"
PG_FUNCTION_QUOTE_INDENT = "quote_ident"
PG_FUNCTION_PG_GET_EXPR = "pg_get_expr"
PG_FUNCTION_SET_CONFIG = "set_config"
PG_FUNCTION_ROW_TO_JSON = "row_to_json"
PG_FUNCTION_ARRAY_TO_STRING = "array_to_string"
PG_FUNCTION_PG_EXPANDARRAY = "_pg_expandarray"

PG_TABLE_PG_AUTH_MEMBERS = "pg_auth_members"
PG_TABLE_PG_CLASS = "pg_class"
PG_TABLE_PG_DATABASE = "pg_database"
PG_TABLE_PG_EXTENSION = "pg_extension"
PG_TABLE_PG_INHERITS = "pg_inherits"
PG_TABLE_PG_MATVIEWS = "pg_matviews"
PG_TABLE_PG_NAMESPACE = "pg_namespace"
PG_TABLE_PG_REPLICATION_SLOTS = "pg_replication_slots"
PG_TABLE_PG_ROLES = "pg_roles"
PG_TABLE_PG_SHADOW = "pg_shadow"
PG_TABLE_PG_SHDESCRIPTION = "pg_shdescription"
PG_TABLE_PG_STATIO_USER_TABLES = "pg_statio_user_tables"
PG_TABLE_PG_STAT_ACTIVITY = "pg_stat_activity"
PG_TABLE_PG_STAT_GSSAPI = "pg_stat_gssapi"
PG_TABLE_PG_STAT_USER_TABLES = "pg_stat_user_tables"
PG_TABLE_PG_USER = "pg_user"
PG_TABLE_TABLES = "tables"
)
2 changes: 0 additions & 2 deletions src/pg_schema_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ const (
PG_TRUE = "YES"
PG_FALSE = "FALSE"

PG_SCHEMA_PG_CATALOG = "pg_catalog"

PG_DATA_TYPE_ARRAY = "ARRAY"

PARQUET_SCHEMA_REPETITION_TYPE_REQUIRED = "REQUIRED"
Expand Down
18 changes: 16 additions & 2 deletions src/query_parser_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,22 @@ func (parser *QueryParserSelect) NestedFunctionCalls(functionCall *pgQuery.FuncC
return nestedFunctionCalls
}

func (parser *QueryParserSelect) FunctionName(functionCall *pgQuery.FuncCall) string {
return functionCall.Funcname[len(functionCall.Funcname)-1].GetString_().Sval
func (parser *QueryParserSelect) SchemaFunction(functionCall *pgQuery.FuncCall) PgSchemaFunction {
if len(functionCall.Funcname) == 1 {
return PgSchemaFunction{
Schema: "",
Function: functionCall.Funcname[0].GetString_().Sval,
}
}

if len(functionCall.Funcname) == 2 {
return PgSchemaFunction{
Schema: functionCall.Funcname[0].GetString_().Sval,
Function: functionCall.Funcname[1].GetString_().Sval,
}
}

return PgSchemaFunction{}
}

// quote_ident(str) -> concat("\""+str+"\"")
Expand Down
33 changes: 5 additions & 28 deletions src/query_parser_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,6 @@ import (
pgQuery "github.com/pganalyze/pg_query_go/v5"
)

const (
// PG_SCHEMA_PG_CATALOG = "pg_catalog" Already defined in pg_schema_column.go
PG_SCHEMA_INFORMATION_SCHEMA = "information_schema"

PG_FUNCTION_PG_GET_KEYWORDS = "pg_get_keywords"
PG_FUNCTION_ARRAY_UPPER = "array_upper"
PG_FUNCTION_PG_SHOW_ALL_SETTINGS = "pg_show_all_settings"
PG_FUNCTION_PG_IS_IN_RECOVERY = "pg_is_in_recovery"
)

type QueryParserTable struct {
config *Config
utils *QueryParserUtils
Expand Down Expand Up @@ -247,26 +237,13 @@ func (parser *QueryParserTable) MakePgGetKeywordsNode(node *pgQuery.Node) *pgQue
return parser.utils.MakeSubselectWithRowsNode(PG_FUNCTION_PG_GET_KEYWORDS, columns, rows, alias)
}

// array_upper(array, 1)
func (parser *QueryParserTable) IsArrayUpperFunction(funcCallNode *pgQuery.FuncCall) bool {
if len(funcCallNode.Funcname) != 1 {
return false
}

funcName := funcCallNode.Funcname[0].GetString_().Sval

if funcName == PG_FUNCTION_ARRAY_UPPER {
dimension := funcCallNode.Args[1].GetAConst().GetIval().Ival
if dimension == 1 {
return true
}
}

return false
}

// array_upper(array, 1) -> len(array)
func (parser *QueryParserTable) MakeArrayUpperNode(funcCallNode *pgQuery.FuncCall) *pgQuery.FuncCall {
dimension := funcCallNode.Args[1].GetAConst().GetIval().Ival
if dimension != 1 {
return funcCallNode
}

return pgQuery.MakeFuncCallNode(
[]*pgQuery.Node{
pgQuery.MakeStrNode("len"),
Expand Down
41 changes: 16 additions & 25 deletions src/select_remapper_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,6 @@ import (
pgQuery "github.com/pganalyze/pg_query_go/v5"
)

const (
PG_FUNCTION_QUOTE_INDENT = "quote_ident"
PG_FUNCTION_PG_GET_EXPR = "pg_get_expr"
PG_FUNCTION_SET_CONFIG = "set_config"
PG_FUNCTION_ROW_TO_JSON = "row_to_json"
PG_FUNCTION_ARRAY_TO_STRING = "array_to_string"
PG_FUNCTION_PG_EXPANDARRAY = "_pg_expandarray"
)

var REMAPPED_CONSTANT_BY_PG_FUNCTION_NAME = map[string]string{
"version": "PostgreSQL " + PG_VERSION + ", compiled by Bemi",
"pg_get_userbyid": "bemidb",
Expand Down Expand Up @@ -53,18 +44,18 @@ func (remapper *SelectRemapperSelect) RemapSelect(targetNode *pgQuery.Node) *pgQ
return targetNode
}

originalFunctionName := remapper.parserSelect.FunctionName(functionCall)
schemaFunction := remapper.parserSelect.SchemaFunction(functionCall)

// set_config(setting_name, new_value, is_local) -> new_value
if originalFunctionName == PG_FUNCTION_SET_CONFIG {
if schemaFunction.Function == PG_FUNCTION_SET_CONFIG {
remapper.parserSelect.RemapSetConfigFunction(targetNode, functionCall)
return targetNode
}

renamedNameFunction := remapper.remappedFunctionName(functionCall)
if renamedNameFunction != nil {
functionCall = renamedNameFunction
remapper.parserSelect.SetDefaultTargetName(targetNode, originalFunctionName)
remapper.parserSelect.SetDefaultTargetName(targetNode, schemaFunction.Function)
}

remappedArgsFunction := remapper.remappedFunctionArgs(functionCall)
Expand All @@ -75,7 +66,7 @@ func (remapper *SelectRemapperSelect) RemapSelect(targetNode *pgQuery.Node) *pgQ
constantNode := remapper.remappedToConstant(functionCall)
if constantNode != nil {
remapper.parserSelect.OverrideTargetValue(targetNode, constantNode)
remapper.parserSelect.SetDefaultTargetName(targetNode, originalFunctionName)
remapper.parserSelect.SetDefaultTargetName(targetNode, schemaFunction.Function)
}

functionCall = remapper.remapNestedFunctionCalls(functionCall) // recursive
Expand All @@ -91,12 +82,12 @@ func (remapper *SelectRemapperSelect) remappedInderectionFunctionCall(targetNode
return nil
}

functionName := parser.FunctionName(functionCall)
schemaFunction := parser.SchemaFunction(functionCall)

switch functionName {
switch {

// (information_schema._pg_expandarray(array)).n -> unnest(anyarray) AS n
case PG_FUNCTION_PG_EXPANDARRAY:
case schemaFunction.Schema == PG_SCHEMA_INFORMATION_SCHEMA && schemaFunction.Function == PG_FUNCTION_PG_EXPANDARRAY:
inderectionColumnName := targetNode.GetResTarget().Val.GetAIndirection().Indirection[0].GetString_().Sval
newTargetNode := parser.RemapInderectionToFunctionCall(targetNode, parser.RemapPgExpandArray(functionCall))
remapper.parserSelect.SetDefaultTargetName(newTargetNode, inderectionColumnName)
Expand All @@ -108,20 +99,20 @@ func (remapper *SelectRemapperSelect) remappedInderectionFunctionCall(targetNode
}

func (remapper *SelectRemapperSelect) remappedFunctionName(functionCall *pgQuery.FuncCall) *pgQuery.FuncCall {
functionName := remapper.parserSelect.FunctionName(functionCall)
schemaFunction := remapper.parserSelect.SchemaFunction(functionCall)

switch functionName {
switch {

// quote_ident(str) -> concat("\""+str+"\"")
case PG_FUNCTION_QUOTE_INDENT:
case schemaFunction.Function == PG_FUNCTION_QUOTE_INDENT:
return remapper.parserSelect.RemapQuoteIdentToConcat(functionCall)

// array_to_string(array, separator) -> main.array_to_string(array, separator)
case PG_FUNCTION_ARRAY_TO_STRING:
case schemaFunction.Function == PG_FUNCTION_ARRAY_TO_STRING:
return remapper.parserSelect.RemapArrayToString(functionCall)

// row_to_json(col) -> to_json(col)
case PG_FUNCTION_ROW_TO_JSON:
case schemaFunction.Function == PG_FUNCTION_ROW_TO_JSON:
return remapper.parserSelect.RemapRowToJson(functionCall)

default:
Expand All @@ -130,10 +121,10 @@ func (remapper *SelectRemapperSelect) remappedFunctionName(functionCall *pgQuery
}

func (remapper *SelectRemapperSelect) remappedFunctionArgs(functionCall *pgQuery.FuncCall) *pgQuery.FuncCall {
functionName := remapper.parserSelect.FunctionName(functionCall)
schemaFunction := remapper.parserSelect.SchemaFunction(functionCall)

// pg_get_expr(pg_node_tree, relation_oid, pretty_bool) -> pg_get_expr(pg_node_tree, relation_oid)
if functionName == PG_FUNCTION_PG_GET_EXPR {
if schemaFunction.Schema == PG_SCHEMA_PG_CATALOG && schemaFunction.Function == PG_FUNCTION_PG_GET_EXPR {
return remapper.parserSelect.RemoveThirdArgumentFromPgGetExpr(functionCall)
}

Expand Down Expand Up @@ -165,8 +156,8 @@ func (remapper *SelectRemapperSelect) remapNestedFunctionCalls(functionCall *pgQ
}

func (remapper *SelectRemapperSelect) remappedToConstant(functionCall *pgQuery.FuncCall) *pgQuery.Node {
functionName := remapper.parserSelect.FunctionName(functionCall)
constant, ok := REMAPPED_CONSTANT_BY_PG_FUNCTION_NAME[functionName]
schemaFunction := remapper.parserSelect.SchemaFunction(functionCall)
constant, ok := REMAPPED_CONSTANT_BY_PG_FUNCTION_NAME[schemaFunction.Function]
if ok {
return pgQuery.MakeAConstStrNode(constant, 0)
}
Expand Down
40 changes: 12 additions & 28 deletions src/select_remapper_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,10 @@ import (
pgQuery "github.com/pganalyze/pg_query_go/v5"
)

const (
PG_SCHEMA_PUBLIC = "public"

PG_TABLE_PG_INHERITS = "pg_inherits"
PG_TABLE_PG_SHDESCRIPTION = "pg_shdescription"
PG_TABLE_PG_STATIO_USER_TABLES = "pg_statio_user_tables"
PG_TABLE_PG_SHADOW = "pg_shadow"
PG_TABLE_PG_NAMESPACE = "pg_namespace"
PG_TABLE_PG_ROLES = "pg_roles"
PG_TABLE_PG_CLASS = "pg_class"
PG_TABLE_PG_EXTENSION = "pg_extension"
PG_TABLE_PG_REPLICATION_SLOTS = "pg_replication_slots"
PG_TABLE_PG_DATABASE = "pg_database"
PG_TABLE_PG_STAT_GSSAPI = "pg_stat_gssapi"
PG_TABLE_PG_AUTH_MEMBERS = "pg_auth_members"
PG_TABLE_PG_USER = "pg_user"
PG_TABLE_PG_STAT_ACTIVITY = "pg_stat_activity"
PG_TABLE_PG_MATVIEWS = "pg_matviews"
PG_TABLE_PG_STAT_USER_TABLES = "pg_stat_user_tables"

PG_TABLE_TABLES = "tables"
)

type SelectRemapperTable struct {
parserTable *QueryParserTable
parserWhere *QueryParserWhere
parserSelect *QueryParserSelect
icebergSchemaTables []IcebergSchemaTable
icebergReader *IcebergReader
duckdb *Duckdb
Expand All @@ -42,6 +20,7 @@ func NewSelectRemapperTable(config *Config, icebergReader *IcebergReader, duckdb
remapper := &SelectRemapperTable{
parserTable: NewQueryParserTable(config),
parserWhere: NewQueryParserWhere(config),
parserSelect: NewQueryParserSelect(config),
icebergReader: icebergReader,
duckdb: duckdb,
config: config,
Expand Down Expand Up @@ -194,13 +173,18 @@ func (remapper *SelectRemapperTable) RemapTableFunction(node *pgQuery.Node) *pgQ
}

// FROM PG_FUNCTION(PG_NESTED_FUNCTION())
func (remapper *SelectRemapperTable) RemapNestedTableFunction(funcCallNode *pgQuery.FuncCall) *pgQuery.FuncCall {
func (remapper *SelectRemapperTable) RemapNestedTableFunction(functionCall *pgQuery.FuncCall) *pgQuery.FuncCall {
schemaFunction := remapper.parserSelect.SchemaFunction(functionCall)

switch {

// array_upper(values, 1) -> len(values)
if remapper.parserTable.IsArrayUpperFunction(funcCallNode) {
return remapper.parserTable.MakeArrayUpperNode(funcCallNode)
}
case schemaFunction.Function == PG_FUNCTION_ARRAY_UPPER:
return remapper.parserTable.MakeArrayUpperNode(functionCall)

return funcCallNode
default:
return functionCall
}
}

func (remapper *SelectRemapperTable) RemapWhereClauseForTable(qSchemaTable QuerySchemaTable, selectStatement *pgQuery.SelectStmt) *pgQuery.SelectStmt {
Expand Down

0 comments on commit ef143d9

Please sign in to comment.