Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 038139d

Browse files
authoredSep 20, 2024
fix: Fix miscasting issues with case_when (#1003)
1 parent 8520873 commit 038139d

File tree

4 files changed

+81
-48
lines changed

4 files changed

+81
-48
lines changed
 

‎bigframes/core/expression.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
import bigframes.operations.aggregations as agg_ops
2626

2727

28-
def const(value: typing.Hashable, dtype: dtypes.ExpressionType = None) -> Expression:
28+
def const(
29+
value: typing.Hashable, dtype: dtypes.ExpressionType = None
30+
) -> ScalarConstantExpression:
2931
return ScalarConstantExpression(value, dtype or dtypes.infer_literal_type(value))
3032

3133

@@ -141,6 +143,9 @@ class ScalarConstantExpression(Expression):
141143
def is_const(self) -> bool:
142144
return True
143145

146+
def rename(self, name_mapping: Mapping[str, str]) -> ScalarConstantExpression:
147+
return self
148+
144149
def output_type(
145150
self, input_types: dict[str, bigframes.dtypes.Dtype]
146151
) -> dtypes.ExpressionType:
@@ -167,7 +172,7 @@ class UnboundVariableExpression(Expression):
167172
def unbound_variables(self) -> typing.Tuple[str, ...]:
168173
return (self.id,)
169174

170-
def rename(self, name_mapping: Mapping[str, str]) -> Expression:
175+
def rename(self, name_mapping: Mapping[str, str]) -> UnboundVariableExpression:
171176
if self.id in name_mapping:
172177
return UnboundVariableExpression(name_mapping[self.id])
173178
else:

‎bigframes/operations/base.py

+57-17
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
import typing
18-
from typing import List, Sequence
18+
from typing import List, Sequence, Union
1919

2020
import bigframes_vendored.constants as constants
2121
import bigframes_vendored.pandas.pandas._typing as vendored_pandas_typing
@@ -180,9 +180,10 @@ def _apply_binary_op(
180180
(self_col, other_col, block) = self._align(other_series, how=alignment)
181181

182182
name = self._name
183+
# Drop name if both objects have name attr, but they don't match
183184
if (
184185
hasattr(other, "name")
185-
and other.name != self._name
186+
and other_series.name != self._name
186187
and alignment == "outer"
187188
):
188189
name = None
@@ -208,41 +209,78 @@ def _apply_nary_op(
208209
ignore_self=False,
209210
):
210211
"""Applies an n-ary operator to the series and others."""
211-
values, block = self._align_n(others, ignore_self=ignore_self)
212-
block, result_id = block.apply_nary_op(
213-
values,
214-
op,
215-
self._name,
212+
values, block = self._align_n(
213+
others, ignore_self=ignore_self, cast_scalars=False
216214
)
215+
block, result_id = block.project_expr(op.as_expr(*values))
217216
return series.Series(block.select_column(result_id))
218217

219218
def _apply_binary_aggregation(
220219
self, other: series.Series, stat: agg_ops.BinaryAggregateOp
221220
) -> float:
222221
(left, right, block) = self._align(other, how="outer")
222+
assert isinstance(left, ex.UnboundVariableExpression)
223+
assert isinstance(right, ex.UnboundVariableExpression)
224+
return block.get_binary_stat(left.id, right.id, stat)
225+
226+
AlignedExprT = Union[ex.ScalarConstantExpression, ex.UnboundVariableExpression]
223227

224-
return block.get_binary_stat(left, right, stat)
228+
@typing.overload
229+
def _align(
230+
self, other: series.Series, how="outer"
231+
) -> tuple[
232+
ex.UnboundVariableExpression,
233+
ex.UnboundVariableExpression,
234+
blocks.Block,
235+
]:
236+
...
225237

226-
def _align(self, other: series.Series, how="outer") -> tuple[str, str, blocks.Block]: # type: ignore
238+
@typing.overload
239+
def _align(
240+
self, other: typing.Union[series.Series, scalars.Scalar], how="outer"
241+
) -> tuple[ex.UnboundVariableExpression, AlignedExprT, blocks.Block,]:
242+
...
243+
244+
def _align(
245+
self, other: typing.Union[series.Series, scalars.Scalar], how="outer"
246+
) -> tuple[ex.UnboundVariableExpression, AlignedExprT, blocks.Block,]:
227247
"""Aligns the series value with another scalar or series object. Returns new left column id, right column id and joined tabled expression."""
228248
values, block = self._align_n(
229249
[
230250
other,
231251
],
232252
how,
233253
)
234-
return (values[0], values[1], block)
254+
return (typing.cast(ex.UnboundVariableExpression, values[0]), values[1], block)
255+
256+
def _align3(self, other1: series.Series | scalars.Scalar, other2: series.Series | scalars.Scalar, how="left") -> tuple[ex.UnboundVariableExpression, AlignedExprT, AlignedExprT, blocks.Block]: # type: ignore
257+
"""Aligns the series value with 2 other scalars or series objects. Returns new values and joined tabled expression."""
258+
values, index = self._align_n([other1, other2], how)
259+
return (
260+
typing.cast(ex.UnboundVariableExpression, values[0]),
261+
values[1],
262+
values[2],
263+
index,
264+
)
235265

236266
def _align_n(
237267
self,
238268
others: typing.Sequence[typing.Union[series.Series, scalars.Scalar]],
239269
how="outer",
240270
ignore_self=False,
241-
) -> tuple[typing.Sequence[str], blocks.Block]:
271+
cast_scalars: bool = True,
272+
) -> tuple[
273+
typing.Sequence[
274+
Union[ex.ScalarConstantExpression, ex.UnboundVariableExpression]
275+
],
276+
blocks.Block,
277+
]:
242278
if ignore_self:
243-
value_ids: List[str] = []
279+
value_ids: List[
280+
Union[ex.ScalarConstantExpression, ex.UnboundVariableExpression]
281+
] = []
244282
else:
245-
value_ids = [self._value_column]
283+
value_ids = [ex.free_var(self._value_column)]
246284

247285
block = self._block
248286
for other in others:
@@ -252,14 +290,16 @@ def _align_n(
252290
get_column_right,
253291
) = block.join(other._block, how=how)
254292
value_ids = [
255-
*[get_column_left[value] for value in value_ids],
256-
get_column_right[other._value_column],
293+
*[value.rename(get_column_left) for value in value_ids],
294+
ex.free_var(get_column_right[other._value_column]),
257295
]
258296
else:
259297
# Will throw if can't interpret as scalar.
260298
dtype = typing.cast(bigframes.dtypes.Dtype, self._dtype)
261-
block, constant_col_id = block.create_constant(other, dtype=dtype)
262-
value_ids = [*value_ids, constant_col_id]
299+
value_ids = [
300+
*value_ids,
301+
ex.const(other, dtype=dtype if cast_scalars else None),
302+
]
263303
return (value_ids, block)
264304

265305
def _throw_if_null_index(self, opname: str):

‎bigframes/series.py

+9-24
Original file line numberDiff line numberDiff line change
@@ -445,23 +445,13 @@ def between(self, left, right, inclusive="both"):
445445
)
446446

447447
def case_when(self, caselist) -> Series:
448+
cases = list(itertools.chain(*caselist, (True, self)))
448449
return self._apply_nary_op(
449450
ops.case_when_op,
450-
tuple(
451-
itertools.chain(
452-
itertools.chain(*caselist),
453-
# Fallback to current value if no other matches.
454-
(
455-
# We make a Series with a constant value to avoid casts to
456-
# types other than boolean.
457-
Series(True, index=self.index, dtype=pandas.BooleanDtype()),
458-
self,
459-
),
460-
),
461-
),
451+
cases,
462452
# Self is already included in "others".
463453
ignore_self=True,
464-
)
454+
).rename(self.name)
465455

466456
@validations.requires_ordering()
467457
def cumsum(self) -> Series:
@@ -1116,8 +1106,8 @@ def ne(self, other: object) -> Series:
11161106

11171107
def where(self, cond, other=None):
11181108
value_id, cond_id, other_id, block = self._align3(cond, other)
1119-
block, result_id = block.apply_ternary_op(
1120-
value_id, cond_id, other_id, ops.where_op
1109+
block, result_id = block.project_expr(
1110+
ops.where_op.as_expr(value_id, cond_id, other_id)
11211111
)
11221112
return Series(block.select_column(result_id).with_column_labels([self.name]))
11231113

@@ -1129,8 +1119,8 @@ def clip(self, lower, upper):
11291119
if upper is None:
11301120
return self._apply_binary_op(lower, ops.maximum_op, alignment="left")
11311121
value_id, lower_id, upper_id, block = self._align3(lower, upper)
1132-
block, result_id = block.apply_ternary_op(
1133-
value_id, lower_id, upper_id, ops.clip_op
1122+
block, result_id = block.project_expr(
1123+
ops.clip_op.as_expr(value_id, lower_id, upper_id),
11341124
)
11351125
return Series(block.select_column(result_id).with_column_labels([self.name]))
11361126

@@ -1242,8 +1232,8 @@ def __getitem__(self, indexer):
12421232
return self.iloc[indexer]
12431233
if isinstance(indexer, Series):
12441234
(left, right, block) = self._align(indexer, "left")
1245-
block = block.filter_by_id(right)
1246-
block = block.select_column(left)
1235+
block = block.filter(right)
1236+
block = block.select_column(left.id)
12471237
return Series(block)
12481238
return self.loc[indexer]
12491239

@@ -1262,11 +1252,6 @@ def __getattr__(self, key: str):
12621252
else:
12631253
raise AttributeError(key)
12641254

1265-
def _align3(self, other1: Series | scalars.Scalar, other2: Series | scalars.Scalar, how="left") -> tuple[str, str, str, blocks.Block]: # type: ignore
1266-
"""Aligns the series value with 2 other scalars or series objects. Returns new values and joined tabled expression."""
1267-
values, index = self._align_n([other1, other2], how)
1268-
return (values[0], values[1], values[2], index)
1269-
12701255
def _apply_aggregation(
12711256
self, op: agg_ops.UnaryAggregateOp | agg_ops.NullaryAggregateOp
12721257
) -> Any:

‎tests/system/small/test_series.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -2709,27 +2709,30 @@ def test_between(scalars_df_index, scalars_pandas_df_index, left, right, inclusi
27092709
)
27102710

27112711

2712-
def test_case_when(scalars_df_index, scalars_pandas_df_index):
2712+
def test_series_case_when(scalars_dfs_maybe_ordered):
27132713
pytest.importorskip(
27142714
"pandas",
27152715
minversion="2.2.0",
27162716
reason="case_when added in pandas 2.2.0",
27172717
)
2718+
scalars_df, scalars_pandas_df = scalars_dfs_maybe_ordered
27182719

2719-
bf_series = scalars_df_index["int64_col"]
2720-
pd_series = scalars_pandas_df_index["int64_col"]
2720+
bf_series = scalars_df["int64_col"]
2721+
pd_series = scalars_pandas_df["int64_col"]
27212722

27222723
# TODO(tswast): pandas case_when appears to assume True when a value is
27232724
# null. I suspect this should be considered a bug in pandas.
27242725
bf_result = bf_series.case_when(
27252726
[
2726-
((bf_series > 100).fillna(True), 1000),
2727+
((bf_series > 100).fillna(True), bf_series - 1),
2728+
((bf_series > 0).fillna(True), pd.NA),
27272729
((bf_series < -100).fillna(True), -1000),
27282730
]
27292731
).to_pandas()
27302732
pd_result = pd_series.case_when(
27312733
[
2732-
(pd_series > 100, 1000),
2734+
(pd_series > 100, pd_series - 1),
2735+
(pd_series > 0, pd.NA),
27332736
(pd_series < -100, -1000),
27342737
]
27352738
)

0 commit comments

Comments
 (0)
Failed to load comments.