Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ab49350

Browse files
authoredDec 19, 2023
fix: fix DataFrameGroupby.agg() issue with as_index=False (#273)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #271 🦕
1 parent 5092215 commit ab49350

File tree

7 files changed

+70
-66
lines changed

7 files changed

+70
-66
lines changed
 

‎bigframes/core/block_transforms.py

-1
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,6 @@ def value_counts(
332332
by_column_ids=columns,
333333
aggregations=[(dummy, agg_ops.count_op)],
334334
dropna=dropna,
335-
as_index=True,
336335
)
337336
count_id = agg_ids[0]
338337
if normalize:

‎bigframes/core/blocks.py

+16-37
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
_MONOTONIC_DECREASING = "monotonic_decreasing"
6767

6868

69-
LevelType = typing.Union[str, int]
69+
LevelType = typing.Hashable
7070
LevelsType = typing.Union[LevelType, typing.Sequence[LevelType]]
7171

7272

@@ -941,7 +941,6 @@ def aggregate(
941941
by_column_ids: typing.Sequence[str] = (),
942942
aggregations: typing.Sequence[typing.Tuple[str, agg_ops.AggregateOp]] = (),
943943
*,
944-
as_index: bool = True,
945944
dropna: bool = True,
946945
) -> typing.Tuple[Block, typing.Sequence[str]]:
947946
"""
@@ -962,40 +961,21 @@ def aggregate(
962961
aggregate_labels = self._get_labels_for_columns(
963962
[agg[0] for agg in aggregations]
964963
)
965-
if as_index:
966-
names: typing.List[Label] = []
967-
for by_col_id in by_column_ids:
968-
if by_col_id in self.value_columns:
969-
names.append(self.col_id_to_label[by_col_id])
970-
else:
971-
names.append(self.col_id_to_index_name[by_col_id])
972-
return (
973-
Block(
974-
result_expr,
975-
index_columns=by_column_ids,
976-
column_labels=aggregate_labels,
977-
index_labels=names,
978-
),
979-
output_col_ids,
980-
)
981-
else: # as_index = False
982-
# If as_index=False, drop grouping levels, but keep grouping value columns
983-
by_value_columns = [
984-
col for col in by_column_ids if col in self.value_columns
985-
]
986-
by_column_labels = self._get_labels_for_columns(by_value_columns)
987-
labels = (*by_column_labels, *aggregate_labels)
988-
offsets_id = guid.generate_guid()
989-
result_expr_pruned = result_expr.select_columns(
990-
[*by_value_columns, *output_col_ids]
991-
).promote_offsets(offsets_id)
992-
993-
return (
994-
Block(
995-
result_expr_pruned, index_columns=[offsets_id], column_labels=labels
996-
),
997-
output_col_ids,
998-
)
964+
names: typing.List[Label] = []
965+
for by_col_id in by_column_ids:
966+
if by_col_id in self.value_columns:
967+
names.append(self.col_id_to_label[by_col_id])
968+
else:
969+
names.append(self.col_id_to_index_name[by_col_id])
970+
return (
971+
Block(
972+
result_expr,
973+
index_columns=by_column_ids,
974+
column_labels=aggregate_labels,
975+
index_labels=names,
976+
),
977+
output_col_ids,
978+
)
999979

1000980
def get_stat(self, column_id: str, stat: agg_ops.AggregateOp):
1001981
"""Gets aggregates immediately, and caches it"""
@@ -1324,7 +1304,6 @@ def pivot(
13241304
result_block, _ = block.aggregate(
13251305
by_column_ids=self.index_columns,
13261306
aggregations=aggregations,
1327-
as_index=True,
13281307
dropna=True,
13291308
)
13301309

‎bigframes/core/groupby/__init__.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,10 @@ def _agg_string(self, func: str) -> df.DataFrame:
263263
agg_block, _ = self._block.aggregate(
264264
by_column_ids=self._by_col_ids,
265265
aggregations=aggregations,
266-
as_index=self._as_index,
267266
dropna=self._dropna,
268267
)
269-
return df.DataFrame(agg_block)
268+
dataframe = df.DataFrame(agg_block)
269+
return dataframe if self._as_index else self._convert_index(dataframe)
270270

271271
def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
272272
aggregations: typing.List[typing.Tuple[str, agg_ops.AggregateOp]] = []
@@ -285,7 +285,6 @@ def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
285285
agg_block, _ = self._block.aggregate(
286286
by_column_ids=self._by_col_ids,
287287
aggregations=aggregations,
288-
as_index=self._as_index,
289288
dropna=self._dropna,
290289
)
291290
if want_aggfunc_level:
@@ -297,7 +296,8 @@ def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
297296
)
298297
else:
299298
agg_block = agg_block.with_column_labels(pd.Index(column_labels))
300-
return df.DataFrame(agg_block)
299+
dataframe = df.DataFrame(agg_block)
300+
return dataframe if self._as_index else self._convert_index(dataframe)
301301

302302
def _agg_list(self, func: typing.Sequence) -> df.DataFrame:
303303
aggregations = [
@@ -311,15 +311,15 @@ def _agg_list(self, func: typing.Sequence) -> df.DataFrame:
311311
agg_block, _ = self._block.aggregate(
312312
by_column_ids=self._by_col_ids,
313313
aggregations=aggregations,
314-
as_index=self._as_index,
315314
dropna=self._dropna,
316315
)
317316
agg_block = agg_block.with_column_labels(
318317
pd.MultiIndex.from_tuples(
319318
column_labels, names=[*self._block.column_labels.names, None]
320319
)
321320
)
322-
return df.DataFrame(agg_block)
321+
dataframe = df.DataFrame(agg_block)
322+
return dataframe if self._as_index else self._convert_index(dataframe)
323323

324324
def _agg_named(self, **kwargs) -> df.DataFrame:
325325
aggregations = []
@@ -339,11 +339,21 @@ def _agg_named(self, **kwargs) -> df.DataFrame:
339339
agg_block, _ = self._block.aggregate(
340340
by_column_ids=self._by_col_ids,
341341
aggregations=aggregations,
342-
as_index=self._as_index,
343342
dropna=self._dropna,
344343
)
345344
agg_block = agg_block.with_column_labels(column_labels)
346-
return df.DataFrame(agg_block)
345+
dataframe = df.DataFrame(agg_block)
346+
return dataframe if self._as_index else self._convert_index(dataframe)
347+
348+
def _convert_index(self, dataframe: df.DataFrame):
349+
"""Convert index levels to columns except where names conflict."""
350+
levels_to_drop = [
351+
level for level in dataframe.index.names if level in dataframe.columns
352+
]
353+
354+
if len(levels_to_drop) == dataframe.index.nlevels:
355+
return dataframe.reset_index(drop=True)
356+
return dataframe.droplevel(levels_to_drop).reset_index(drop=False)
347357

348358
aggregate = agg
349359

@@ -379,10 +389,10 @@ def _aggregate_all(
379389
result_block, _ = self._block.aggregate(
380390
by_column_ids=self._by_col_ids,
381391
aggregations=aggregations,
382-
as_index=self._as_index,
383392
dropna=self._dropna,
384393
)
385-
return df.DataFrame(result_block)
394+
dataframe = df.DataFrame(result_block)
395+
return dataframe if self._as_index else self._convert_index(dataframe)
386396

387397
def _apply_window_op(
388398
self,

‎bigframes/dataframe.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
# TODO(tbergeron): Convert to bytes-based limit
7373
MAX_INLINE_DF_SIZE = 5000
7474

75-
LevelType = typing.Union[str, int]
75+
LevelType = typing.Hashable
7676
LevelsType = typing.Union[LevelType, typing.Sequence[LevelType]]
7777
SingleItemValue = Union[bigframes.series.Series, int, float, Callable]
7878

@@ -1956,7 +1956,7 @@ def _stack_mono(self):
19561956

19571957
def _stack_multi(self, level: LevelsType = -1):
19581958
n_levels = self.columns.nlevels
1959-
if isinstance(level, int) or isinstance(level, str):
1959+
if not utils.is_list_like(level):
19601960
level = [level]
19611961
level_indices = []
19621962
for level_ref in level:
@@ -1966,7 +1966,7 @@ def _stack_multi(self, level: LevelsType = -1):
19661966
else:
19671967
level_indices.append(level_ref)
19681968
else: # str
1969-
level_indices.append(self.columns.names.index(level_ref))
1969+
level_indices.append(self.columns.names.index(level_ref)) # type: ignore
19701970

19711971
new_order = [
19721972
*[i for i in range(n_levels) if i not in level_indices],
@@ -1982,7 +1982,7 @@ def _stack_multi(self, level: LevelsType = -1):
19821982
return DataFrame(block)
19831983

19841984
def unstack(self, level: LevelsType = -1):
1985-
if isinstance(level, int) or isinstance(level, str):
1985+
if not utils.is_list_like(level):
19861986
level = [level]
19871987

19881988
block = self._block

‎bigframes/series.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,6 @@ def mode(self) -> Series:
841841
block, agg_ids = block.aggregate(
842842
by_column_ids=[self._value_column],
843843
aggregations=((self._value_column, agg_ops.count_op),),
844-
as_index=False,
845844
)
846845
value_count_col_id = agg_ids[0]
847846
block, max_value_count_col_id = block.apply_window_op(
@@ -855,14 +854,15 @@ def mode(self) -> Series:
855854
ops.eq_op,
856855
)
857856
block = block.filter(is_mode_col_id)
858-
mode_values_series = Series(
859-
block.select_column(self._value_column).assign_label(
860-
self._value_column, self.name
861-
)
862-
)
863-
return typing.cast(
864-
Series, mode_values_series.sort_values().reset_index(drop=True)
857+
# use temporary name for reset_index to avoid collision, restore after dropping extra columns
858+
block = (
859+
block.with_index_labels(["mode_temp_internal"])
860+
.order_by([OrderingColumnReference(self._value_column)])
861+
.reset_index(drop=False)
865862
)
863+
block = block.select_column(self._value_column).with_column_labels([self.name])
864+
mode_values_series = Series(block.select_column(self._value_column))
865+
return typing.cast(Series, mode_values_series)
866866

867867
def mean(self) -> float:
868868
return typing.cast(float, self._apply_aggregation(agg_ops.mean_op))

‎tests/system/small/test_groupby.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -122,23 +122,32 @@ def test_dataframe_groupby_agg_list(scalars_df_index, scalars_pandas_df_index):
122122
pd.testing.assert_frame_equal(pd_result, bf_result_computed, check_dtype=False)
123123

124124

125+
@pytest.mark.parametrize(
126+
("as_index"),
127+
[
128+
(True),
129+
(False),
130+
],
131+
)
125132
def test_dataframe_groupby_agg_dict_with_list(
126-
scalars_df_index, scalars_pandas_df_index
133+
scalars_df_index, scalars_pandas_df_index, as_index
127134
):
128135
col_names = ["int64_too", "float64_col", "int64_col", "bool_col", "string_col"]
129136
bf_result = (
130137
scalars_df_index[col_names]
131-
.groupby("string_col")
138+
.groupby("string_col", as_index=as_index)
132139
.agg({"int64_too": ["mean", "max"], "string_col": "count"})
133140
)
134141
pd_result = (
135142
scalars_pandas_df_index[col_names]
136-
.groupby("string_col")
143+
.groupby("string_col", as_index=as_index)
137144
.agg({"int64_too": ["mean", "max"], "string_col": "count"})
138145
)
139146
bf_result_computed = bf_result.to_pandas()
140147

141-
pd.testing.assert_frame_equal(pd_result, bf_result_computed, check_dtype=False)
148+
pd.testing.assert_frame_equal(
149+
pd_result, bf_result_computed, check_dtype=False, check_index_type=False
150+
)
142151

143152

144153
def test_dataframe_groupby_agg_dict_no_lists(scalars_df_index, scalars_pandas_df_index):

‎tests/system/small/test_multiindex.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -356,17 +356,24 @@ def test_multi_index_dataframe_groupby(scalars_df_index, scalars_pandas_df_index
356356
def test_multi_index_dataframe_groupby_level_aggregate(
357357
scalars_df_index, scalars_pandas_df_index, level, as_index
358358
):
359+
index_cols = ["int64_too", "bool_col"]
359360
bf_result = (
360-
scalars_df_index.set_index(["int64_too", "bool_col"])
361+
scalars_df_index.set_index(index_cols)
361362
.groupby(level=level, as_index=as_index)
362363
.mean(numeric_only=True)
363364
.to_pandas()
364365
)
365366
pd_result = (
366-
scalars_pandas_df_index.set_index(["int64_too", "bool_col"])
367+
scalars_pandas_df_index.set_index(index_cols)
367368
.groupby(level=level, as_index=as_index)
368369
.mean(numeric_only=True)
369370
)
371+
# For as_index=False, pandas will drop index levels used as groupings
372+
# In the future, it will include this in the result, bigframes already does this behavior
373+
if not as_index:
374+
for col in index_cols:
375+
if col in bf_result.columns:
376+
bf_result = bf_result.drop(col, axis=1)
370377

371378
# Pandas will have int64 index, while bigquery will have Int64 when resetting
372379
pandas.testing.assert_frame_equal(bf_result, pd_result, check_index_type=False)

0 commit comments

Comments
 (0)
Failed to load comments.