Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 37f8c32

Browse files
authoredNov 13, 2024
fix: dataframe fillna with scalar. (#1132)
* fix: dataframe fillna with string scalar. * update type supports * remove case that pandas has issue * update annotation
1 parent 8d4da15 commit 37f8c32

File tree

3 files changed

+36
-7
lines changed

3 files changed

+36
-7
lines changed
 

‎bigframes/dataframe.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ def _apply_binop(
734734
how: str = "outer",
735735
reverse: bool = False,
736736
):
737-
if isinstance(other, (float, int, bool)):
737+
if isinstance(other, bigframes.dtypes.LOCAL_SCALAR_TYPES):
738738
return self._apply_scalar_binop(other, op, reverse=reverse)
739739
elif isinstance(other, DataFrame):
740740
return self._apply_dataframe_binop(other, op, how=how, reverse=reverse)
@@ -752,7 +752,10 @@ def _apply_binop(
752752
)
753753

754754
def _apply_scalar_binop(
755-
self, other: float | int, op: ops.BinaryOp, reverse: bool = False
755+
self,
756+
other: bigframes.dtypes.LOCAL_SCALAR_TYPE,
757+
op: ops.BinaryOp,
758+
reverse: bool = False,
756759
) -> DataFrame:
757760
if reverse:
758761
expr = op.as_expr(

‎bigframes/dtypes.py

+19
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,25 @@
5959
# Used when storing Null expressions
6060
DEFAULT_DTYPE = FLOAT_DTYPE
6161

62+
LOCAL_SCALAR_TYPE = Union[
63+
bool,
64+
np.bool_,
65+
int,
66+
np.integer,
67+
float,
68+
np.floating,
69+
decimal.Decimal,
70+
str,
71+
np.str_,
72+
bytes,
73+
np.bytes_,
74+
datetime.datetime,
75+
pd.Timestamp,
76+
datetime.date,
77+
datetime.time,
78+
]
79+
LOCAL_SCALAR_TYPES = typing.get_args(LOCAL_SCALAR_TYPE)
80+
6281

6382
# Will have a few dtype variants: simple(eg. int, string, bool), complex (eg. list, struct), and virtual (eg. micro intervals, categorical)
6483
@dataclass(frozen=True)

‎tests/system/small/test_dataframe.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -1020,13 +1020,20 @@ def test_df_interpolate(scalars_dfs):
10201020
)
10211021

10221022

1023-
def test_df_fillna(scalars_dfs):
1023+
@pytest.mark.parametrize(
1024+
"col, fill_value",
1025+
[
1026+
(["int64_col", "float64_col"], 3),
1027+
(["string_col"], "A"),
1028+
(["datetime_col"], pd.Timestamp("2023-01-01")),
1029+
],
1030+
)
1031+
def test_df_fillna(scalars_dfs, col, fill_value):
10241032
scalars_df, scalars_pandas_df = scalars_dfs
1025-
df = scalars_df[["int64_col", "float64_col"]].fillna(3)
1026-
bf_result = df.to_pandas()
1027-
pd_result = scalars_pandas_df[["int64_col", "float64_col"]].fillna(3)
1033+
bf_result = scalars_df[col].fillna(fill_value).to_pandas()
1034+
pd_result = scalars_pandas_df[col].fillna(fill_value)
10281035

1029-
pandas.testing.assert_frame_equal(bf_result, pd_result)
1036+
pd.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
10301037

10311038

10321039
def test_df_replace_scalar_scalar(scalars_dfs):

0 commit comments

Comments
 (0)
Failed to load comments.