Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 0d8a16b

Browse files
authoredDec 9, 2024
feat: Series.isin supports bigframes.Series arg (#1195)
1 parent d638f7c commit 0d8a16b

File tree

4 files changed

+80
-2
lines changed

4 files changed

+80
-2
lines changed
 

‎bigframes/core/blocks.py

+37
Original file line numberDiff line numberDiff line change
@@ -2019,6 +2019,43 @@ def concat(
20192019
result_block = result_block.reset_index()
20202020
return result_block
20212021

2022+
def isin(self, other: Block):
2023+
# TODO: Support multiple other columns and match on label
2024+
# TODO: Model as explicit "IN" subquery/join to better allow db to optimize
2025+
assert len(other.value_columns) == 1
2026+
unique_other_values = other.expr.select_columns(
2027+
[other.value_columns[0]]
2028+
).aggregate((), by_column_ids=(other.value_columns[0],))
2029+
block = self
2030+
# for each original column, join with other
2031+
for i in range(len(self.value_columns)):
2032+
block = block._isin_inner(block.value_columns[i], unique_other_values)
2033+
return block
2034+
2035+
def _isin_inner(self: Block, col: str, unique_values: core.ArrayValue) -> Block:
2036+
unique_values, const = unique_values.create_constant(
2037+
True, dtype=bigframes.dtypes.BOOL_DTYPE
2038+
)
2039+
expr, (l_map, r_map) = self._expr.relational_join(
2040+
unique_values, ((col, unique_values.column_ids[0]),), type="left"
2041+
)
2042+
expr, matches = expr.project_to_id(
2043+
ops.eq_op.as_expr(ex.const(True), r_map[const])
2044+
)
2045+
2046+
new_index_cols = tuple(l_map[idx_col] for idx_col in self.index_columns)
2047+
new_value_cols = tuple(
2048+
l_map[val_col] if val_col != col else matches
2049+
for val_col in self.value_columns
2050+
)
2051+
expr = expr.select_columns((*new_index_cols, *new_value_cols))
2052+
return Block(
2053+
expr,
2054+
index_columns=new_index_cols,
2055+
column_labels=self.column_labels,
2056+
index_labels=self._index_labels,
2057+
)
2058+
20222059
def merge(
20232060
self,
20242061
other: Block,

‎bigframes/ml/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _get_only_column(input: ArrayType) -> Union[pd.Series, bpd.Series]:
9292
label = typing.cast(Hashable, input.columns.tolist()[0])
9393
if isinstance(input, pd.DataFrame):
9494
return typing.cast(pd.Series, input[label])
95-
return typing.cast(bpd.Series, input[label])
95+
return typing.cast(bpd.Series, input[label]) # type: ignore
9696

9797

9898
def parse_model_endpoint(model_endpoint: str) -> tuple[str, Optional[str]]:

‎bigframes/series.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -718,12 +718,13 @@ def nsmallest(self, n: int = 5, keep: str = "first") -> Series:
718718
)
719719

720720
def isin(self, values) -> "Series" | None:
721+
if isinstance(values, (Series,)):
722+
self._block.isin(values._block)
721723
if not _is_list_like(values):
722724
raise TypeError(
723725
"only list-like objects are allowed to be passed to "
724726
f"isin(), you passed a [{type(values).__name__}]"
725727
)
726-
727728
return self._apply_unary_op(
728729
ops.IsInOp(values=tuple(values), match_nulls=True)
729730
).fillna(value=False)

‎tests/system/small/test_series.py

+40
Original file line numberDiff line numberDiff line change
@@ -1200,6 +1200,46 @@ def test_isin(scalars_dfs, col_name, test_set):
12001200
)
12011201

12021202

1203+
@pytest.mark.parametrize(
1204+
(
1205+
"col_name",
1206+
"test_set",
1207+
),
1208+
[
1209+
(
1210+
"int64_col",
1211+
[314159, 2.0, 3, pd.NA],
1212+
),
1213+
(
1214+
"int64_col",
1215+
[2, 55555, 4],
1216+
),
1217+
(
1218+
"float64_col",
1219+
[-123.456, 1.25, pd.NA],
1220+
),
1221+
(
1222+
"int64_too",
1223+
[1, 2, pd.NA],
1224+
),
1225+
(
1226+
"string_col",
1227+
["Hello, World!", "Hi", "こんにちは"],
1228+
),
1229+
],
1230+
)
1231+
def test_isin_bigframes_values(scalars_dfs, col_name, test_set, session):
1232+
scalars_df, scalars_pandas_df = scalars_dfs
1233+
bf_result = (
1234+
scalars_df[col_name].isin(series.Series(test_set, session=session)).to_pandas()
1235+
)
1236+
pd_result = scalars_pandas_df[col_name].isin(test_set).astype("boolean")
1237+
pd.testing.assert_series_equal(
1238+
pd_result,
1239+
bf_result,
1240+
)
1241+
1242+
12031243
def test_isnull(scalars_dfs):
12041244
scalars_df, scalars_pandas_df = scalars_dfs
12051245
col_name = "float64_col"

0 commit comments

Comments
 (0)
Failed to load comments.