From ffb012000ea7b8bbf155722667fe3fce1dfc5ee6 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Sat, 14 Dec 2024 00:34:06 +0000 Subject: [PATCH] perf: Prevent deferring filters with remote functions --- bigframes/core/compile/compiled.py | 11 +++++++++++ bigframes/core/expression.py | 9 +++++++++ 2 files changed, 20 insertions(+) diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index d4c814145b..8f9514bc1c 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -48,6 +48,7 @@ import bigframes.core.sql from bigframes.core.window_spec import RangeWindowBounds, RowsWindowBounds, WindowSpec import bigframes.dtypes +import bigframes.operations import bigframes.operations.aggregations as agg_ops ORDER_ID_COLUMN = "bigframes_ordering_id" @@ -132,12 +133,22 @@ def projection( expression_id_pairs: typing.Tuple[typing.Tuple[ex.Expression, str], ...], ) -> T: """Apply an expression to the ArrayValue and assign the output to a column.""" + + # Remote ops are expensive so force reprojection before to ensure filters get applied first + any_remote_op = any( + ex.contains_op(expr, bigframes.operations.RemoteFunctionOp) + for expr, _ in expression_id_pairs + ) + if any_remote_op and len(self._predicates) > 0: + return self._reproject_to_table().projection(expression_id_pairs) + bindings = {col: self._get_ibis_column(col) for col in self.column_ids} new_values = [ op_compiler.compile_expression(expression, bindings).name(id) for expression, id in expression_id_pairs ] result = self._select(tuple([*self._columns, *new_values])) # type: ignore + return result def selection( diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index 3b7828bbf0..429d2d42b3 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -40,6 +40,15 @@ def free_var(id: str) -> UnboundVariableExpression: return UnboundVariableExpression(id) +def contains_op(expr: Expression, op: type[bigframes.operations.ScalarOp]): + if isinstance(expr, OpExpression): + if isinstance(expr.op, op): + return True + else: + return any(contains_op(subexpr, op) for subexpr in expr.inputs) + return False + + @dataclasses.dataclass(frozen=True) class Aggregation(abc.ABC): """Represents windowing or aggregation over a column."""