Skip to content

Commit

Permalink
fix(ScalarFunction): add (currently) required error enum
Browse files Browse the repository at this point in the history
This is _probably_ going to be changed to an optional arg in substrait
but as of this commit it's required still.
  • Loading branch information
gforsyth committed Oct 19, 2022
1 parent 489beb8 commit 135320b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
17 changes: 16 additions & 1 deletion ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,12 +476,27 @@ def value_op(
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
error_args = []
# TODO(gforsyth): remove this brittle workaround after extension parsing is ready
# TODO(gforsyth): sending in `ERROR` for floating point ops doesn't match
# the substrait spec but is what pyarrow currently expects.
if (
isinstance(op, ops.BinaryOp)
and not isinstance(op, ops.Comparison)
and isinstance(op.left.type(), (dt.Integer, dt.Floating))
and isinstance(op.right.type(), (dt.Integer, dt.Floating))
):
error_args.append(
stalg.FunctionArgument(enum=stalg.FunctionArgument.Enum(specified="ERROR"))
)

# given the details of `op` -> function id
return stalg.Expression(
scalar_function=stalg.Expression.ScalarFunction(
function_reference=compiler.function_id(expr),
output_type=translate(expr.type()),
arguments=[
arguments=error_args
+ [
stalg.FunctionArgument(value=translate(arg, compiler, **kwargs))
for arg in op.args
if isinstance(arg, ir.Expr)
Expand Down
1 change: 1 addition & 0 deletions ibis_substrait/tests/compiler/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def test_translate_table_expansion(compiler):
"scalarFunction": {
"functionReference": 1,
"arguments": [
{"enum": {"specified": "ERROR"}},
{
"value": {
"selection": {
Expand Down

0 comments on commit 135320b

Please sign in to comment.