Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 30c8883

Browse files
sycaigcf-owl-bot[bot]
andauthoredOct 23, 2024
feat: add support for pandas series & data frames as inputs for ml models. (#1088)
* support pandas dataframes and series as model inputs. * polish code and add tests * clean up code * fix type hints * fix lint * fix a bug that was introduced in the last commit * use type alias in type hints and update docs * add session parameter in the converters * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * fix default parameter issue * fix type error --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 32fab96 commit 30c8883

29 files changed

+363
-238
lines changed
 

‎bigframes/ml/base.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
"""
2323

2424
import abc
25-
from typing import cast, Optional, TypeVar, Union
25+
from typing import cast, Optional, TypeVar
2626

2727
import bigframes_vendored.sklearn.base
2828

2929
from bigframes.ml import core
30+
import bigframes.ml.utils as utils
3031
import bigframes.pandas as bpd
3132

3233

@@ -157,8 +158,8 @@ class SupervisedTrainablePredictor(TrainablePredictor):
157158

158159
def fit(
159160
self: _T,
160-
X: Union[bpd.DataFrame, bpd.Series],
161-
y: Union[bpd.DataFrame, bpd.Series],
161+
X: utils.ArrayType,
162+
y: utils.ArrayType,
162163
) -> _T:
163164
return self._fit(X, y)
164165

@@ -172,8 +173,8 @@ class UnsupervisedTrainablePredictor(TrainablePredictor):
172173

173174
def fit(
174175
self: _T,
175-
X: Union[bpd.DataFrame, bpd.Series],
176-
y: Optional[Union[bpd.DataFrame, bpd.Series]] = None,
176+
X: utils.ArrayType,
177+
y: Optional[utils.ArrayType] = None,
177178
) -> _T:
178179
return self._fit(X, y)
179180

@@ -243,8 +244,8 @@ def transform(self, X):
243244

244245
def fit_transform(
245246
self,
246-
X: Union[bpd.DataFrame, bpd.Series],
247-
y: Optional[Union[bpd.DataFrame, bpd.Series]] = None,
247+
X: utils.ArrayType,
248+
y: Optional[utils.ArrayType] = None,
248249
) -> bpd.DataFrame:
249250
return self.fit(X, y).transform(X)
250251

@@ -264,6 +265,6 @@ def transform(self, y):
264265

265266
def fit_transform(
266267
self,
267-
y: Union[bpd.DataFrame, bpd.Series],
268+
y: utils.ArrayType,
268269
) -> bpd.DataFrame:
269270
return self.fit(y).transform(y)

‎bigframes/ml/cluster.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import bigframes_vendored.sklearn.cluster._kmeans
2323
from google.cloud import bigquery
24+
import pandas as pd
2425

2526
import bigframes
2627
from bigframes.core import log_adapter
@@ -101,7 +102,7 @@ def _bqml_options(self) -> dict:
101102

102103
def _fit(
103104
self,
104-
X: Union[bpd.DataFrame, bpd.Series],
105+
X: utils.ArrayType,
105106
y=None, # ignored
106107
transforms: Optional[List[str]] = None,
107108
) -> KMeans:
@@ -125,17 +126,20 @@ def cluster_centers_(self) -> bpd.DataFrame:
125126

126127
def predict(
127128
self,
128-
X: Union[bpd.DataFrame, bpd.Series],
129+
X: utils.ArrayType,
129130
) -> bpd.DataFrame:
130131
if not self._bqml_model:
131132
raise RuntimeError("A model must be fitted before predict")
132133

133-
(X,) = utils.convert_to_dataframe(X)
134+
(X,) = utils.convert_to_dataframe(X, session=self._bqml_model.session)
134135

135136
return self._bqml_model.predict(X)
136137

137138
def detect_anomalies(
138-
self, X: Union[bpd.DataFrame, bpd.Series], *, contamination: float = 0.1
139+
self,
140+
X: Union[bpd.DataFrame, bpd.Series, pd.DataFrame, pd.Series],
141+
*,
142+
contamination: float = 0.1,
139143
) -> bpd.DataFrame:
140144
"""Detect the anomaly data points of the input.
141145
@@ -156,7 +160,7 @@ def detect_anomalies(
156160
if not self._bqml_model:
157161
raise RuntimeError("A model must be fitted before detect_anomalies")
158162

159-
(X,) = utils.convert_to_dataframe(X)
163+
(X,) = utils.convert_to_dataframe(X, session=self._bqml_model.session)
160164

161165
return self._bqml_model.detect_anomalies(
162166
X, options={"contamination": contamination}
@@ -181,12 +185,12 @@ def to_gbq(self, model_name: str, replace: bool = False) -> KMeans:
181185

182186
def score(
183187
self,
184-
X: Union[bpd.DataFrame, bpd.Series],
188+
X: utils.ArrayType,
185189
y=None, # ignored
186190
) -> bpd.DataFrame:
187191
if not self._bqml_model:
188192
raise RuntimeError("A model must be fitted before score")
189193

190-
(X,) = utils.convert_to_dataframe(X)
194+
(X,) = utils.convert_to_dataframe(X, session=self._bqml_model.session)
191195

192196
return self._bqml_model.evaluate(X)

‎bigframes/ml/compose.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def _compile_to_sql(
332332

333333
def fit(
334334
self,
335-
X: Union[bpd.DataFrame, bpd.Series],
335+
X: utils.ArrayType,
336336
y=None, # ignored
337337
) -> ColumnTransformer:
338338
(X,) = utils.convert_to_dataframe(X)
@@ -347,11 +347,11 @@ def fit(
347347
self._extract_output_names()
348348
return self
349349

350-
def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
350+
def transform(self, X: utils.ArrayType) -> bpd.DataFrame:
351351
if not self._bqml_model:
352352
raise RuntimeError("Must be fitted before transform")
353353

354-
(X,) = utils.convert_to_dataframe(X)
354+
(X,) = utils.convert_to_dataframe(X, session=self._bqml_model.session)
355355

356356
df = self._bqml_model.transform(X)
357357
return typing.cast(

‎bigframes/ml/decomposition.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _bqml_options(self) -> dict:
8484

8585
def _fit(
8686
self,
87-
X: Union[bpd.DataFrame, bpd.Series],
87+
X: utils.ArrayType,
8888
y=None,
8989
transforms: Optional[List[str]] = None,
9090
) -> PCA:
@@ -129,16 +129,19 @@ def explained_variance_ratio_(self) -> bpd.DataFrame:
129129
["principal_component_id", "explained_variance_ratio"]
130130
]
131131

132-
def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
132+
def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
133133
if not self._bqml_model:
134134
raise RuntimeError("A model must be fitted before predict")
135135

136-
(X,) = utils.convert_to_dataframe(X)
136+
(X,) = utils.convert_to_dataframe(X, session=self._bqml_model.session)
137137

138138
return self._bqml_model.predict(X)
139139

140140
def detect_anomalies(
141-
self, X: Union[bpd.DataFrame, bpd.Series], *, contamination: float = 0.1
141+
self,
142+
X: utils.ArrayType,
143+
*,
144+
contamination: float = 0.1,
142145
) -> bpd.DataFrame:
143146
"""Detect the anomaly data points of the input.
144147
@@ -159,7 +162,7 @@ def detect_anomalies(
159162
if not self._bqml_model:
160163
raise RuntimeError("A model must be fitted before detect_anomalies")
161164

162-
(X,) = utils.convert_to_dataframe(X)
165+
(X,) = utils.convert_to_dataframe(X, session=self._bqml_model.session)
163166

164167
return self._bqml_model.detect_anomalies(
165168
X, options={"contamination": contamination}

‎bigframes/ml/ensemble.py

+30-30
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from __future__ import annotations
1919

20-
from typing import Dict, List, Literal, Optional, Union
20+
from typing import Dict, List, Literal, Optional
2121

2222
import bigframes_vendored.sklearn.ensemble._forest
2323
import bigframes_vendored.xgboost.sklearn
@@ -142,8 +142,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
142142

143143
def _fit(
144144
self,
145-
X: Union[bpd.DataFrame, bpd.Series],
146-
y: Union[bpd.DataFrame, bpd.Series],
145+
X: utils.ArrayType,
146+
y: utils.ArrayType,
147147
transforms: Optional[List[str]] = None,
148148
) -> XGBRegressor:
149149
X, y = utils.convert_to_dataframe(X, y)
@@ -158,24 +158,24 @@ def _fit(
158158

159159
def predict(
160160
self,
161-
X: Union[bpd.DataFrame, bpd.Series],
161+
X: utils.ArrayType,
162162
) -> bpd.DataFrame:
163163
if not self._bqml_model:
164164
raise RuntimeError("A model must be fitted before predict")
165-
(X,) = utils.convert_to_dataframe(X)
165+
(X,) = utils.convert_to_dataframe(X, session=self._bqml_model.session)
166166

167167
return self._bqml_model.predict(X)
168168

169169
def score(
170170
self,
171-
X: Union[bpd.DataFrame, bpd.Series],
172-
y: Union[bpd.DataFrame, bpd.Series],
171+
X: utils.ArrayType,
172+
y: utils.ArrayType,
173173
):
174-
X, y = utils.convert_to_dataframe(X, y)
175-
176174
if not self._bqml_model:
177175
raise RuntimeError("A model must be fitted before score")
178176

177+
X, y = utils.convert_to_dataframe(X, y, session=self._bqml_model.session)
178+
179179
input_data = (
180180
X.join(y, how="outer") if (X is not None) and (y is not None) else None
181181
)
@@ -291,8 +291,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
291291

292292
def _fit(
293293
self,
294-
X: Union[bpd.DataFrame, bpd.Series],
295-
y: Union[bpd.DataFrame, bpd.Series],
294+
X: utils.ArrayType,
295+
y: utils.ArrayType,
296296
transforms: Optional[List[str]] = None,
297297
) -> XGBClassifier:
298298
X, y = utils.convert_to_dataframe(X, y)
@@ -305,22 +305,22 @@ def _fit(
305305
)
306306
return self
307307

308-
def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
308+
def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
309309
if not self._bqml_model:
310310
raise RuntimeError("A model must be fitted before predict")
311-
(X,) = utils.convert_to_dataframe(X)
311+
(X,) = utils.convert_to_dataframe(X, session=self._bqml_model.session)
312312

313313
return self._bqml_model.predict(X)
314314

315315
def score(
316316
self,
317-
X: Union[bpd.DataFrame, bpd.Series],
318-
y: Union[bpd.DataFrame, bpd.Series],
317+
X: utils.ArrayType,
318+
y: utils.ArrayType,
319319
):
320320
if not self._bqml_model:
321321
raise RuntimeError("A model must be fitted before score")
322322

323-
X, y = utils.convert_to_dataframe(X, y)
323+
X, y = utils.convert_to_dataframe(X, y, session=self._bqml_model.session)
324324

325325
input_data = (
326326
X.join(y, how="outer") if (X is not None) and (y is not None) else None
@@ -427,8 +427,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
427427

428428
def _fit(
429429
self,
430-
X: Union[bpd.DataFrame, bpd.Series],
431-
y: Union[bpd.DataFrame, bpd.Series],
430+
X: utils.ArrayType,
431+
y: utils.ArrayType,
432432
transforms: Optional[List[str]] = None,
433433
) -> RandomForestRegressor:
434434
X, y = utils.convert_to_dataframe(X, y)
@@ -443,18 +443,18 @@ def _fit(
443443

444444
def predict(
445445
self,
446-
X: Union[bpd.DataFrame, bpd.Series],
446+
X: utils.ArrayType,
447447
) -> bpd.DataFrame:
448448
if not self._bqml_model:
449449
raise RuntimeError("A model must be fitted before predict")
450-
(X,) = utils.convert_to_dataframe(X)
450+
(X,) = utils.convert_to_dataframe(X, session=self._bqml_model.session)
451451

452452
return self._bqml_model.predict(X)
453453

454454
def score(
455455
self,
456-
X: Union[bpd.DataFrame, bpd.Series],
457-
y: Union[bpd.DataFrame, bpd.Series],
456+
X: utils.ArrayType,
457+
y: utils.ArrayType,
458458
):
459459
"""Calculate evaluation metrics of the model.
460460
@@ -476,7 +476,7 @@ def score(
476476
if not self._bqml_model:
477477
raise RuntimeError("A model must be fitted before score")
478478

479-
X, y = utils.convert_to_dataframe(X, y)
479+
X, y = utils.convert_to_dataframe(X, y, session=self._bqml_model.session)
480480

481481
input_data = (
482482
X.join(y, how="outer") if (X is not None) and (y is not None) else None
@@ -583,8 +583,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
583583

584584
def _fit(
585585
self,
586-
X: Union[bpd.DataFrame, bpd.Series],
587-
y: Union[bpd.DataFrame, bpd.Series],
586+
X: utils.ArrayType,
587+
y: utils.ArrayType,
588588
transforms: Optional[List[str]] = None,
589589
) -> RandomForestClassifier:
590590
X, y = utils.convert_to_dataframe(X, y)
@@ -599,18 +599,18 @@ def _fit(
599599

600600
def predict(
601601
self,
602-
X: Union[bpd.DataFrame, bpd.Series],
602+
X: utils.ArrayType,
603603
) -> bpd.DataFrame:
604604
if not self._bqml_model:
605605
raise RuntimeError("A model must be fitted before predict")
606-
(X,) = utils.convert_to_dataframe(X)
606+
(X,) = utils.convert_to_dataframe(X, session=self._bqml_model.session)
607607

608608
return self._bqml_model.predict(X)
609609

610610
def score(
611611
self,
612-
X: Union[bpd.DataFrame, bpd.Series],
613-
y: Union[bpd.DataFrame, bpd.Series],
612+
X: utils.ArrayType,
613+
y: utils.ArrayType,
614614
):
615615
"""Calculate evaluation metrics of the model.
616616
@@ -632,7 +632,7 @@ def score(
632632
if not self._bqml_model:
633633
raise RuntimeError("A model must be fitted before score")
634634

635-
X, y = utils.convert_to_dataframe(X, y)
635+
X, y = utils.convert_to_dataframe(X, y, session=self._bqml_model.session)
636636

637637
input_data = (
638638
X.join(y, how="outer") if (X is not None) and (y is not None) else None
There was a problem loading the remainder of the diff.

0 commit comments

Comments
 (0)
Failed to load comments.