Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 27f8631

Browse files
authoredJul 10, 2024
feat: add stratify param support to ml.model_selection.train_test_split method (#815)
* feat: add stratify param to ml.model_selection.train_test_split * fix mypy * add notes for limit
1 parent eaa1db0 commit 27f8631

File tree

2 files changed

+100
-2
lines changed

2 files changed

+100
-2
lines changed
 

‎bigframes/ml/model_selection.py

+38-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
import typing
21-
from typing import List, Union
21+
from typing import cast, List, Union
2222

2323
from bigframes.ml import utils
2424
import bigframes.pandas as bpd
@@ -29,6 +29,7 @@ def train_test_split(
2929
test_size: Union[float, None] = None,
3030
train_size: Union[float, None] = None,
3131
random_state: Union[int, None] = None,
32+
stratify: Union[bpd.Series, None] = None,
3233
) -> List[Union[bpd.DataFrame, bpd.Series]]:
3334
"""Splits dataframes or series into random train and test subsets.
3435
@@ -46,6 +47,10 @@ def train_test_split(
4647
random_state (default None):
4748
A seed to use for randomly choosing the rows of the split. If not
4849
set, a random split will be generated each time.
50+
stratify: (bigframes.series.Series or None, default None):
51+
If not None, data is split in a stratified fashion, using this as the class labels. Each split has the same distribution of the class labels with the original dataset.
52+
Default to None.
53+
Note: By setting the stratify parameter, the memory consumption and generated SQL will be linear to the unique values in the Series. May return errors if the unique values size is too large.
4954
5055
Returns:
5156
List[Union[bigframes.dataframe.DataFrame, bigframes.series.Series]]: A list of BigQuery DataFrames or Series.
@@ -76,7 +81,38 @@ def train_test_split(
7681

7782
dfs = list(utils.convert_to_dataframe(*arrays))
7883

79-
split_dfs = dfs[0]._split(fracs=(train_size, test_size), random_state=random_state)
84+
def _stratify_split(df: bpd.DataFrame, stratify: bpd.Series) -> List[bpd.DataFrame]:
85+
"""Split a single DF accoding to the stratify Series."""
86+
stratify = stratify.rename("bigframes_stratify_col") # avoid name conflicts
87+
merged_df = df.join(stratify.to_frame(), how="outer")
88+
89+
train_dfs, test_dfs = [], []
90+
uniq = stratify.unique()
91+
for value in uniq:
92+
cur = merged_df[merged_df["bigframes_stratify_col"] == value]
93+
train, test = train_test_split(
94+
cur,
95+
test_size=test_size,
96+
train_size=train_size,
97+
random_state=random_state,
98+
)
99+
train_dfs.append(train)
100+
test_dfs.append(test)
101+
102+
train_df = cast(
103+
bpd.DataFrame, bpd.concat(train_dfs).drop(columns="bigframes_stratify_col")
104+
)
105+
test_df = cast(
106+
bpd.DataFrame, bpd.concat(test_dfs).drop(columns="bigframes_stratify_col")
107+
)
108+
return [train_df, test_df]
109+
110+
if stratify is None:
111+
split_dfs = dfs[0]._split(
112+
fracs=(train_size, test_size), random_state=random_state
113+
)
114+
else:
115+
split_dfs = _stratify_split(dfs[0], stratify)
80116
train_index = split_dfs[0].index
81117
test_index = split_dfs[1].index
82118

‎tests/system/small/ml/test_model_selection.py

+62
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,65 @@ def test_train_test_split_value_error(penguins_df_default_index, train_size, tes
234234
model_selection.train_test_split(
235235
X, y, train_size=train_size, test_size=test_size
236236
)
237+
238+
239+
def test_train_test_split_stratify(penguins_df_default_index):
240+
X = penguins_df_default_index[
241+
[
242+
"species",
243+
"island",
244+
"culmen_length_mm",
245+
]
246+
]
247+
y = penguins_df_default_index[["species"]]
248+
X_train, X_test, y_train, y_test = model_selection.train_test_split(
249+
X, y, stratify=penguins_df_default_index["species"]
250+
)
251+
252+
# Original distribution is [152, 124, 68]. All the categories follow 75/25 split
253+
train_counts = pd.Series(
254+
[114, 93, 51],
255+
index=pd.Index(
256+
[
257+
"Adelie Penguin (Pygoscelis adeliae)",
258+
"Gentoo penguin (Pygoscelis papua)",
259+
"Chinstrap penguin (Pygoscelis antarctica)",
260+
],
261+
name="species",
262+
),
263+
dtype="Int64",
264+
name="count",
265+
)
266+
test_counts = pd.Series(
267+
[38, 31, 17],
268+
index=pd.Index(
269+
[
270+
"Adelie Penguin (Pygoscelis adeliae)",
271+
"Gentoo penguin (Pygoscelis papua)",
272+
"Chinstrap penguin (Pygoscelis antarctica)",
273+
],
274+
name="species",
275+
),
276+
dtype="Int64",
277+
name="count",
278+
)
279+
pd.testing.assert_series_equal(
280+
X_train["species"].value_counts().to_pandas(),
281+
train_counts,
282+
check_index_type=False,
283+
)
284+
pd.testing.assert_series_equal(
285+
X_test["species"].value_counts().to_pandas(),
286+
test_counts,
287+
check_index_type=False,
288+
)
289+
pd.testing.assert_series_equal(
290+
y_train["species"].value_counts().to_pandas(),
291+
train_counts,
292+
check_index_type=False,
293+
)
294+
pd.testing.assert_series_equal(
295+
y_test["species"].value_counts().to_pandas(),
296+
test_counts,
297+
check_index_type=False,
298+
)

0 commit comments

Comments
 (0)
Failed to load comments.