Skip to content

Commit

Permalink
fix(projection): use emit and not project for column selection
Browse files Browse the repository at this point in the history
fix(emit): recompute output mapping for nested projection

Nested projections output_mapping refers to the column index based on
the number of columns in the previous emit, so we recompute the
`output_mapping` counter if there's already an emit present.
  • Loading branch information
gforsyth committed Oct 19, 2022
1 parent c3d791e commit 489beb8
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 3 deletions.
29 changes: 26 additions & 3 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import datetime
import decimal
import functools
import itertools
import operator
import uuid
from typing import Any, Mapping, MutableMapping, Sequence, TypeVar
Expand All @@ -20,6 +21,7 @@
import ibis.expr.types as ir
import toolz
from ibis import util
from ibis.util import to_op_dag

from ..proto.substrait.ibis import algebra_pb2 as stalg
from ..proto.substrait.ibis import type_pb2 as stt
Expand Down Expand Up @@ -637,6 +639,7 @@ def unbound_table(
read=stalg.ReadRel(
# TODO: filter,
# TODO: projection,
common=stalg.RelCommon(direct=stalg.RelCommon.Direct()),
base_schema=translate(op.schema),
named_table=stalg.ReadRel.NamedTable(names=[op.name]),
)
Expand Down Expand Up @@ -694,7 +697,6 @@ def selection(
child_rel_field_offsets=child_rel_field_offsets,
**kwargs,
)

# filter
if op.predicates:
relation = stalg.Rel(
Expand All @@ -709,18 +711,39 @@ def selection(
)
)

# projection
# projection / emit
selections = [
col
for sel in op.selections
for col in (
sel.get_columns(sel.columns) if isinstance(sel, ir.TableExpr) else [sel]
)
]
if op.selections:

# TODO: there has to be a better way to get a list of unbound tables
# underlying an expression
unbound_tables = {
t for t in to_op_dag(op.to_expr()).keys() if isinstance(t, ops.UnboundTable)
}
mapping_counter = itertools.count(sum(map(lambda t: len(t.schema), unbound_tables)))
if selections:

if relation.project.common.ListFields():
# if there is already an `emit` in RelCommon then we're stacking
# projections and we need to update the output_mapping to refer to
# the number of fields present in the most recent emit
mapping_counter = itertools.count(
len(relation.project.common.emit.output_mapping)
)

relation = stalg.Rel(
project=stalg.ProjectRel(
input=relation,
common=stalg.RelCommon(
emit=stalg.RelCommon.Emit(
output_mapping=[next(mapping_counter) for _ in selections]
)
),
expressions=[
translate(
selection,
Expand Down
99 changes: 99 additions & 0 deletions ibis_substrait/tests/compiler/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ def test_translate_table_expansion(compiler):
result = translate(expr, compiler)
expected = {
"project": {
"common": {"emit": {"outputMapping": [2, 3, 4]}},
"input": {
"read": {
"common": {"direct": {}},
"baseSchema": {
"names": ["a", "b"],
"struct": {
Expand Down Expand Up @@ -153,6 +155,103 @@ def test_translate_table_expansion(compiler):
assert to_dict(result) == expected


def test_emit_mutate_select_all(compiler):
t = ibis.table([("a", "int64"), ("b", "char"), ("c", "int32")], name="table0")
expr = t.mutate(d=t.a + 1)
result = translate(expr, compiler)
expected = {
"project": {
"common": {"emit": {"outputMapping": [3, 4, 5, 6]}},
"input": {
"read": {
"common": {"direct": {}},
"baseSchema": {
"names": ["a", "b", "c"],
"struct": {
"types": [
{"i64": {"nullability": "NULLABILITY_NULLABLE"}},
{"string": {"nullability": "NULLABILITY_NULLABLE"}},
{"i32": {"nullability": "NULLABILITY_NULLABLE"}},
],
"nullability": "NULLABILITY_REQUIRED",
},
},
"namedTable": {"names": ["table0"]},
}
},
"expressions": [
{
"selection": {
"directReference": {"structField": {}},
"rootReference": {},
}
},
{
"selection": {
"directReference": {"structField": {"field": 1}},
"rootReference": {},
}
},
{
"selection": {
"directReference": {"structField": {"field": 2}},
"rootReference": {},
}
},
{
"scalarFunction": {
"functionReference": 1,
"outputType": {"i64": {"nullability": "NULLABILITY_NULLABLE"}},
"arguments": [
{"enum": {"specified": "ERROR"}},
{
"value": {
"selection": {
"directReference": {"structField": {}},
"rootReference": {},
}
},
},
{"value": {"literal": {"i8": 1}}},
],
}
},
],
}
}

assert to_dict(result) == expected


def test_emit_nested_projection_output_mapping(compiler):
t = ibis.table(
[
("a", "int64"),
("b", "int64"),
("c", "int64"),
("d", "int64"),
],
name="t",
)
expr = t["a", "b", "c", "d"]
result = translate(expr, compiler)
# root table has 4 columns, so output mapping starts at index 4
# should have 4 entries
assert result.project.common.emit.output_mapping == [4, 5, 6, 7]

expr = expr["a", "b", "c"]
result = translate(expr, compiler)
# previous emit has 4 columns, so output mapping starts at index 4
# should have 3 entries
assert result.project.common.emit.output_mapping == [4, 5, 6]

expr = expr["a", "b"]
result = translate(expr, compiler)
# previous emit has 3 columns, so output mapping starts at index 3
# should have 2 entries
assert result.project.common.emit.output_mapping == [3, 4]


def test_ibis_schema_to_substrait_schema():
input = ibis.schema(
[
Expand Down

0 comments on commit 489beb8

Please sign in to comment.