Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 8d39187

Browse files
authoredApr 3, 2024
fix: plot.scatter s parameter cannot accept float-like column (#563)
Fixes internal b/330574847 🦕
1 parent 2fce51f commit 8d39187

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed
 

‎bigframes/operations/_matplotlib/core.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import abc
1616
import typing
17-
import uuid
1817

1918
import pandas as pd
2019

@@ -115,6 +114,18 @@ def _compute_plot_data(self):
115114
if self._is_column_name(c, sample) and sample[c].dtype == dtypes.STRING_DTYPE:
116115
sample[c] = sample[c].astype("object")
117116

117+
# To avoid Matplotlib's automatic conversion of `Float64` or `Int64` columns
118+
# to `object` types (which breaks float-like behavior), this code proactively
119+
# converts the column to a compatible format.
120+
s = self.kwargs.get("s", None)
121+
if pd.core.dtypes.common.is_integer(s):
122+
s = self.data.columns[s]
123+
if self._is_column_name(s, sample):
124+
if sample[s].dtype == dtypes.INT_DTYPE:
125+
sample[s] = sample[s].astype("int64")
126+
elif sample[s].dtype == dtypes.FLOAT_DTYPE:
127+
sample[s] = sample[s].astype("float64")
128+
118129
return sample
119130

120131
def _is_sequence_arg(self, arg):
@@ -130,9 +141,3 @@ def _is_column_name(self, arg, data):
130141
and pd.core.dtypes.common.is_hashable(arg)
131142
and arg in data.columns
132143
)
133-
134-
def _generate_new_column_name(self, data):
135-
col_name = None
136-
while col_name is None or col_name in data.columns:
137-
col_name = f"plot_temp_{str(uuid.uuid4())[:8]}"
138-
return col_name

‎tests/system/small/operations/test_plotting.py

+26
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,32 @@ def test_scatter_args_c(c):
240240
)
241241

242242

243+
@pytest.mark.parametrize(
244+
("s"),
245+
[
246+
pytest.param([10, 34, 50], id="int"),
247+
pytest.param([1.0, 3.4, 5.0], id="float"),
248+
pytest.param(
249+
[True, True, False], id="bool", marks=pytest.mark.xfail(raises=ValueError)
250+
),
251+
],
252+
)
253+
def test_scatter_args_s(s):
254+
data = {
255+
"a": [1, 2, 3],
256+
"b": [1, 2, 3],
257+
}
258+
data["s"] = s
259+
df = bpd.DataFrame(data)
260+
pd_df = pd.DataFrame(data)
261+
262+
ax = df.plot.scatter(x="a", y="b", s="s")
263+
pd_ax = pd_df.plot.scatter(x="a", y="b", s="s")
264+
tm.assert_numpy_array_equal(
265+
ax.collections[0].get_sizes(), pd_ax.collections[0].get_sizes()
266+
)
267+
268+
243269
@pytest.mark.parametrize(
244270
("arg_name"),
245271
[

0 commit comments

Comments
 (0)
Failed to load comments.