You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardexpand all lines: bigframes/ml/model_selection.py
+38-2
Original file line number
Diff line number
Diff line change
@@ -18,7 +18,7 @@
18
18
19
19
20
20
importtyping
21
-
fromtypingimportList, Union
21
+
fromtypingimportcast, List, Union
22
22
23
23
frombigframes.mlimportutils
24
24
importbigframes.pandasasbpd
@@ -29,6 +29,7 @@ def train_test_split(
29
29
test_size: Union[float, None] =None,
30
30
train_size: Union[float, None] =None,
31
31
random_state: Union[int, None] =None,
32
+
stratify: Union[bpd.Series, None] =None,
32
33
) ->List[Union[bpd.DataFrame, bpd.Series]]:
33
34
"""Splits dataframes or series into random train and test subsets.
34
35
@@ -46,6 +47,10 @@ def train_test_split(
46
47
random_state (default None):
47
48
A seed to use for randomly choosing the rows of the split. If not
48
49
set, a random split will be generated each time.
50
+
stratify: (bigframes.series.Series or None, default None):
51
+
If not None, data is split in a stratified fashion, using this as the class labels. Each split has the same distribution of the class labels with the original dataset.
52
+
Default to None.
53
+
Note: By setting the stratify parameter, the memory consumption and generated SQL will be linear to the unique values in the Series. May return errors if the unique values size is too large.
49
54
50
55
Returns:
51
56
List[Union[bigframes.dataframe.DataFrame, bigframes.series.Series]]: A list of BigQuery DataFrames or Series.
0 commit comments