Skip to content

Commit

Permalink
Refactor parsing of expressions and propagate column names
Browse files Browse the repository at this point in the history
  • Loading branch information
amol- committed Apr 17, 2024
1 parent d7522de commit 9986d09
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 118 deletions.
7 changes: 5 additions & 2 deletions src/substrait/sql/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
catalog.load_standard_extensions(
pathlib.Path(__file__).parent.parent.parent.parent / "third_party" / "substrait" / "extensions",
)

# TODO: Turn this into a command line tool to test more queries.
# We can probably have a quick way to declare schema using command line args.
# like first_name=String,surname=String,age=I32 etc...
schema = proto.NamedStruct(
names=["first_name", "surname", "age"],
struct=proto.Type.Struct(
Expand Down Expand Up @@ -38,5 +42,4 @@
print("---- PROJECTION ----")
print(projection_expr)
print("---- FILTER ----")
print(filter_expr)
# parse_extended_expression("INSERT INTO table VALUES(1, 2, 3)")
print(filter_expr)
208 changes: 92 additions & 116 deletions src/substrait/sql/extended_expression.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools

import sqlglot

from substrait import proto
Expand All @@ -19,21 +21,28 @@ def parse_sql_extended_expression(catalog, schema, sql):
if not isinstance(select, sqlglot.expressions.Select):
raise ValueError("a SELECT statement was expected")

invoked_functions_projection, projections = _substrait_projection_from_sqlglot(
catalog, schema, select.expressions
)
sqlglot_parser = SQLGlotParser(catalog, schema)

# Handle the projections in the SELECT statemenent.
project_expressions = []
projection_invoked_functions = set()
for sqlexpr in select.expressions:
invoked_functions, output_name, expr = sqlglot_parser.expression_from_sqlglot(sqlexpr)
projection_invoked_functions.update(invoked_functions)
project_expressions.append(proto.ExpressionReference(expression=expr, output_names=[output_name]))
extension_uris, extensions = catalog.extensions_for_functions(
invoked_functions_projection
projection_invoked_functions
)
projection_extended_expr = proto.ExtendedExpression(
extension_uris=extension_uris,
extensions=extensions,
base_schema=schema,
referred_expr=projections,
referred_expr=project_expressions,
)

invoked_functions_filter, filter_expr = _substrait_expression_from_sqlglot(
catalog, schema, select.find(sqlglot.expressions.Where).this
# Handle WHERE clause in the SELECT statement.
invoked_functions_filter, _, filter_expr = sqlglot_parser.expression_from_sqlglot(
select.find(sqlglot.expressions.Where).this
)
extension_uris, extensions = catalog.extensions_for_functions(
invoked_functions_filter
Expand All @@ -48,122 +57,89 @@ def parse_sql_extended_expression(catalog, schema, sql):
return projection_extended_expr, filter_extended_expr


def _substrait_projection_from_sqlglot(catalog, schema, expressions):
if not expressions:
return set(), []
class SQLGlotParser:
def __init__(self, functions_catalog, schema):
self._functions_catalog = functions_catalog
self._schema = schema
self._counter = itertools.count()

# My understanding of ExtendedExpressions is that they are meant to directly
# point to the Expression that ProjectRel would contain, so we don't actually
# need a ProjectRel at all.
"""
projection_sub = proto.ProjectRel(
input=proto.Rel(
read=proto.ReadRel(
named_table=proto.ReadRel.NamedTable(names=["__table__"]),
base_schema=schema,
)
),
expressions=[],
)
"""

substrait_expressions = []
invoked_functions = set()
for sqlexpr in expressions:
output_names = []
if isinstance(sqlexpr, sqlglot.expressions.Alias):
output_names = [sqlexpr.output_name]
sqlexpr = sqlexpr.this
_, substrait_expr = _parse_expression(
catalog, schema, sqlexpr, invoked_functions
)
substrait_expr_reference = proto.ExpressionReference(
expression=substrait_expr, output_names=output_names
def expression_from_sqlglot(self, sqlglot_node):
invoked_functions = set()
output_name, _, substrait_expr = self._parse_expression(
sqlglot_node, invoked_functions
)
substrait_expressions.append(substrait_expr_reference)

return invoked_functions, substrait_expressions

return invoked_functions, output_name, substrait_expr

def _substrait_expression_from_sqlglot(catalog, schema, sqlglot_node):
if not sqlglot_node:
return set(), None

invoked_functions = set()
_, substrait_expr = _parse_expression(
catalog, schema, sqlglot_node, invoked_functions
)
return invoked_functions, substrait_expr


def _parse_expression(catalog, schema, expr, invoked_functions):
# TODO: Propagate up column names (output_names) so that the projections _always_ have an output_name
if isinstance(expr, sqlglot.expressions.Literal):
if expr.is_string:
return proto.Type(string=proto.Type.String()), proto.Expression(
literal=proto.Expression.Literal(string=expr.text)
def _parse_expression(self, expr, invoked_functions):
if isinstance(expr, sqlglot.expressions.Literal):
if expr.is_string:
return f"literal_{next(self._counter)}", proto.Type(string=proto.Type.String()), proto.Expression(
literal=proto.Expression.Literal(string=expr.text)
)
elif expr.is_int:
return f"literal_{next(self._counter)}", proto.Type(i32=proto.Type.I32()), proto.Expression(
literal=proto.Expression.Literal(i32=int(expr.name))
)
elif sqlglot.helper.is_float(expr.name):
return f"literal_{next(self._counter)}", proto.Type(fp32=proto.Type.FP32()), proto.Expression(
literal=proto.Expression.Literal(float=float(expr.name))
)
else:
raise ValueError(f"Unsupporter literal: {expr.text}")
elif isinstance(expr, sqlglot.expressions.Column):
column_name = expr.output_name
schema_field = list(self._schema.names).index(column_name)
schema_type = self._schema.struct.types[schema_field]
return column_name, schema_type, proto.Expression(
selection=proto.Expression.FieldReference(
direct_reference=proto.Expression.ReferenceSegment(
struct_field=proto.Expression.ReferenceSegment.StructField(
field=schema_field
)
)
)
)
elif expr.is_int:
return proto.Type(i32=proto.Type.I32()), proto.Expression(
literal=proto.Expression.Literal(i32=int(expr.name))
elif isinstance(expr, sqlglot.expressions.Alias):
_, aliased_type, aliased_expr = self._parse_expression(expr.this, invoked_functions)
return expr.output_name, aliased_type, aliased_expr
elif expr.key in SQL_BINARY_FUNCTIONS:
left_name, left_type, left = self._parse_expression(
expr.left, invoked_functions
)
elif sqlglot.helper.is_float(expr.name):
return proto.Type(fp32=proto.Type.FP32()), proto.Expression(
literal=proto.Expression.Literal(float=float(expr.name))
right_name, right_type, right = self._parse_expression(
expr.right, invoked_functions
)
function_name = SQL_BINARY_FUNCTIONS[expr.key]
signature, result_type, function_expression = self._parse_function_invokation(
function_name, left_type, left, right_type, right
)
invoked_functions.add(signature)
result_name = f"{left_name}_{function_name}_{right_name}_{next(self._counter)}"
return result_name, result_type, function_expression
else:
raise ValueError(f"Unsupporter literal: {expr.text}")
elif isinstance(expr, sqlglot.expressions.Column):
column_name = expr.output_name
schema_field = list(schema.names).index(column_name)
schema_type = schema.struct.types[schema_field]
return schema_type, proto.Expression(
selection=proto.Expression.FieldReference(
direct_reference=proto.Expression.ReferenceSegment(
struct_field=proto.Expression.ReferenceSegment.StructField(
field=schema_field
)
)
raise ValueError(
f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}"
)
)
elif expr.key in SQL_BINARY_FUNCTIONS:
left_type, left = _parse_expression(
catalog, schema, expr.left, invoked_functions
)
right_type, right = _parse_expression(
catalog, schema, expr.right, invoked_functions
)
function_name = SQL_BINARY_FUNCTIONS[expr.key]
signature, result_type, function_expression = _parse_function_invokation(
catalog, function_name, left_type, left, right_type, right
)
invoked_functions.add(signature)
return result_type, function_expression
else:
raise ValueError(
f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}"
)


def _parse_function_invokation(catalog, function_name, left_type, left, right_type, right):
signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}"
try:
function_anchor = catalog.function_anchor(signature)
except KeyError:
# not function found with the exact types, try any1_any1 version
signature = f"{function_name}:any1_any1"
function_anchor = catalog.function_anchor(signature)
return (
signature,
left_type,
proto.Expression(
scalar_function=proto.Expression.ScalarFunction(
function_reference=function_anchor,
arguments=[
proto.FunctionArgument(value=left),
proto.FunctionArgument(value=right),
],
)
),
)
def _parse_function_invokation(self, function_name, left_type, left, right_type, right):
signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}"
try:
function_anchor = self._functions_catalog.function_anchor(signature)
except KeyError:
# not function found with the exact types, try any1_any1 version
signature = f"{function_name}:any1_any1"
function_anchor = self._functions_catalog.function_anchor(signature)
return (
signature,
left_type,
proto.Expression(
scalar_function=proto.Expression.ScalarFunction(
function_reference=function_anchor,
arguments=[
proto.FunctionArgument(value=left),
proto.FunctionArgument(value=right),
],
)
),
)

0 comments on commit 9986d09

Please sign in to comment.