Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 012081a

Browse files
authoredFeb 3, 2025
perf: Prevent inlining of remote ops (#1347)
1 parent 417de3a commit 012081a

File tree

5 files changed

+53
-5
lines changed

5 files changed

+53
-5
lines changed
 

‎bigframes/core/compile/compiled.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,19 @@ def projection(
110110
expression_id_pairs: typing.Tuple[typing.Tuple[ex.Expression, str], ...],
111111
) -> UnorderedIR:
112112
"""Apply an expression to the ArrayValue and assign the output to a column."""
113+
cannot_inline = any(expr.expensive for expr, _ in expression_id_pairs)
114+
113115
bindings = {col: self._get_ibis_column(col) for col in self.column_ids}
114116
new_values = [
115117
op_compiler.compile_expression(expression, bindings).name(id)
116118
for expression, id in expression_id_pairs
117119
]
118-
return UnorderedIR(self._table, (*self._columns, *new_values))
120+
result = UnorderedIR(self._table, (*self._columns, *new_values))
121+
if cannot_inline:
122+
return result._reproject_to_table()
123+
else:
124+
# Cheap ops can defer "SELECT" and inline into later ops
125+
return result
119126

120127
def selection(
121128
self,
@@ -174,13 +181,12 @@ def _to_ibis_expr(
174181
Returns:
175182
An ibis expression representing the data help by the ArrayValue object.
176183
"""
177-
columns = list(self._columns)
178184
# Special case for empty tables, since we can't create an empty
179185
# projection.
180-
if not columns:
186+
if not self._columns:
181187
return bigframes_vendored.ibis.memtable([])
182188

183-
table = self._table.select(columns)
189+
table = self._table.select(self._columns)
184190
if fraction is not None:
185191
table = table.filter(
186192
bigframes_vendored.ibis.random() < ibis_types.literal(fraction)

‎bigframes/core/expression.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import dataclasses
1919
import itertools
2020
import typing
21-
from typing import Mapping, TypeVar, Union
21+
from typing import Generator, Mapping, TypeVar, Union
2222

2323
import pandas as pd
2424

@@ -155,6 +155,16 @@ class Expression(abc.ABC):
155155
def free_variables(self) -> typing.Tuple[str, ...]:
156156
return ()
157157

158+
@property
159+
def children(self) -> typing.Tuple[Expression, ...]:
160+
return ()
161+
162+
@property
163+
def expensive(self) -> bool:
164+
return any(
165+
isinstance(ex, OpExpression) and ex.op.expensive for ex in self.walk()
166+
)
167+
158168
@property
159169
@abc.abstractmethod
160170
def column_references(self) -> typing.Tuple[ids.ColumnId, ...]:
@@ -216,6 +226,11 @@ def is_identity(self) -> bool:
216226
"""True for identity operation that does not transform input."""
217227
return False
218228

229+
def walk(self) -> Generator[Expression, None, None]:
230+
yield self
231+
for child in self.children:
232+
yield from child.children
233+
219234

220235
@dataclasses.dataclass(frozen=True)
221236
class ScalarConstantExpression(Expression):
@@ -389,6 +404,10 @@ def free_variables(self) -> typing.Tuple[str, ...]:
389404
def is_const(self) -> bool:
390405
return all(child.is_const for child in self.inputs)
391406

407+
@property
408+
def children(self):
409+
return self.inputs
410+
392411
def output_type(
393412
self, input_types: dict[ids.ColumnId, dtypes.ExpressionType]
394413
) -> dtypes.ExpressionType:

‎bigframes/operations/base_ops.py

+9
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ def deterministic(self) -> bool:
4848
"""Whether the operation is deterministic" (given deterministic inputs)"""
4949
...
5050

51+
@property
52+
def expensive(self) -> bool:
53+
"""Whether the operation is expensive to calculate. Such ops shouldn't be inlined if referenced multiple places."""
54+
...
55+
5156

5257
@dataclasses.dataclass(frozen=True)
5358
class ScalarOp:
@@ -73,6 +78,10 @@ def deterministic(self) -> bool:
7378
"""Whether the operation is deterministic" (given deterministic inputs)"""
7479
return True
7580

81+
@property
82+
def expensive(self) -> bool:
83+
return False
84+
7685

7786
@dataclasses.dataclass(frozen=True)
7887
class NaryOp(ScalarOp):

‎bigframes/operations/remote_function_ops.py

+12
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ class RemoteFunctionOp(base_ops.UnaryOp):
2525
func: typing.Callable
2626
apply_on_null: bool
2727

28+
@property
29+
def expensive(self) -> bool:
30+
return True
31+
2832
def output_type(self, *input_types):
2933
# This property should be set to a valid Dtype by the @remote_function decorator or read_gbq_function method
3034
if hasattr(self.func, "output_dtype"):
@@ -45,6 +49,10 @@ class BinaryRemoteFunctionOp(base_ops.BinaryOp):
4549
name: typing.ClassVar[str] = "binary_remote_function"
4650
func: typing.Callable
4751

52+
@property
53+
def expensive(self) -> bool:
54+
return True
55+
4856
def output_type(self, *input_types):
4957
# This property should be set to a valid Dtype by the @remote_function decorator or read_gbq_function method
5058
if hasattr(self.func, "output_dtype"):
@@ -65,6 +73,10 @@ class NaryRemoteFunctionOp(base_ops.NaryOp):
6573
name: typing.ClassVar[str] = "nary_remote_function"
6674
func: typing.Callable
6775

76+
@property
77+
def expensive(self) -> bool:
78+
return True
79+
6880
def output_type(self, *input_types):
6981
# This property should be set to a valid Dtype by the @remote_function decorator or read_gbq_function method
7082
if hasattr(self.func, "output_dtype"):

‎third_party/bigframes_vendored/ibis/backends/sql/rewrites.py

+2
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ def merge_select_select(_, **kwargs):
245245
ops.InSubquery,
246246
ops.Unnest,
247247
ops.Impure,
248+
# This is used for remote functions, which we don't want to copy
249+
ops.ScalarUDF,
248250
)
249251
if _.find_below(blocking, filter=ops.Value):
250252
return _

0 commit comments

Comments
 (0)
Failed to load comments.