Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 8077ff4

Browse files
GarrettWuShuowei Li
and
Shuowei Li
authoredJan 6, 2025
feat: add max_retries to TextEmbeddingGenerator and Claude3TextGenerator (#1259)
* chore: fix wordings of Gemini max_retries * feat: add max_retries to TextEmbeddingGenerator and Claude3TextGenerator --------- Co-authored-by: Shuowei Li <shuowei.l@outlook.com>
1 parent 796fc3e commit 8077ff4

File tree

3 files changed

+369
-92
lines changed

3 files changed

+369
-92
lines changed
 

‎bigframes/ml/base.py

+57-7
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
"""
2323

2424
import abc
25-
from typing import cast, Optional, TypeVar
25+
from typing import Callable, cast, Mapping, Optional, TypeVar
26+
import warnings
2627

2728
import bigframes_vendored.sklearn.base
2829

@@ -77,6 +78,9 @@ def fit_transform(self, x_train: Union[DataFrame, Series], y_train: Union[DataFr
7778
...
7879
"""
7980

81+
def __init__(self):
82+
self._bqml_model: Optional[core.BqmlModel] = None
83+
8084
def __repr__(self):
8185
"""Print the estimator's constructor with all non-default parameter values."""
8286

@@ -95,9 +99,6 @@ def __repr__(self):
9599
class Predictor(BaseEstimator):
96100
"""A BigQuery DataFrames ML Model base class that can be used to predict outputs."""
97101

98-
def __init__(self):
99-
self._bqml_model: Optional[core.BqmlModel] = None
100-
101102
@abc.abstractmethod
102103
def predict(self, X):
103104
pass
@@ -213,12 +214,61 @@ def fit(
213214
return self._fit(X, y)
214215

215216

217+
class RetriableRemotePredictor(BaseEstimator):
218+
@property
219+
@abc.abstractmethod
220+
def _predict_func(self) -> Callable[[bpd.DataFrame, Mapping], bpd.DataFrame]:
221+
pass
222+
223+
@property
224+
@abc.abstractmethod
225+
def _status_col(self) -> str:
226+
pass
227+
228+
def _predict_and_retry(
229+
self, X: bpd.DataFrame, options: Mapping, max_retries: int
230+
) -> bpd.DataFrame:
231+
assert self._bqml_model is not None
232+
233+
df_result = bpd.DataFrame(session=self._bqml_model.session) # placeholder
234+
df_fail = X
235+
for _ in range(max_retries + 1):
236+
df = self._predict_func(df_fail, options)
237+
238+
success = df[self._status_col].str.len() == 0
239+
df_succ = df[success]
240+
df_fail = df[~success]
241+
242+
if df_succ.empty:
243+
if max_retries > 0:
244+
warnings.warn(
245+
"Can't make any progress, stop retrying.", RuntimeWarning
246+
)
247+
break
248+
249+
df_result = (
250+
bpd.concat([df_result, df_succ]) if not df_result.empty else df_succ
251+
)
252+
253+
if df_fail.empty:
254+
break
255+
256+
if not df_fail.empty:
257+
warnings.warn(
258+
f"Some predictions failed. Check column {self._status_col} for detailed status. You may want to filter the failed rows and retry.",
259+
RuntimeWarning,
260+
)
261+
262+
df_result = cast(
263+
bpd.DataFrame,
264+
bpd.concat([df_result, df_fail]) if not df_result.empty else df_fail,
265+
)
266+
return df_result
267+
268+
216269
class BaseTransformer(BaseEstimator):
217270
"""Transformer base class."""
218271

219-
def __init__(self):
220-
self._bqml_model: Optional[core.BqmlModel] = None
221-
222272
@abc.abstractmethod
223273
def _keys(self):
224274
pass

‎bigframes/ml/llm.py

+50-59
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import cast, Literal, Optional
19+
from typing import Callable, cast, Literal, Mapping, Optional
2020
import warnings
2121

2222
import bigframes_vendored.constants as constants
@@ -616,7 +616,7 @@ def to_gbq(
616616

617617

618618
@log_adapter.class_logger
619-
class TextEmbeddingGenerator(base.BaseEstimator):
619+
class TextEmbeddingGenerator(base.RetriableRemotePredictor):
620620
"""Text embedding generator LLM model.
621621
622622
Args:
@@ -715,18 +715,33 @@ def _from_bq(
715715
model._bqml_model = core.BqmlModel(session, bq_model)
716716
return model
717717

718-
def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
718+
@property
719+
def _predict_func(self) -> Callable[[bpd.DataFrame, Mapping], bpd.DataFrame]:
720+
return self._bqml_model.generate_embedding
721+
722+
@property
723+
def _status_col(self) -> str:
724+
return _ML_GENERATE_EMBEDDING_STATUS
725+
726+
def predict(self, X: utils.ArrayType, *, max_retries: int = 0) -> bpd.DataFrame:
719727
"""Predict the result from input DataFrame.
720728
721729
Args:
722730
X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
723731
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "content" column for prediction.
724732
733+
max_retries (int, default 0):
734+
Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry.
735+
Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result.
736+
725737
Returns:
726738
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
727739
"""
740+
if max_retries < 0:
741+
raise ValueError(
742+
f"max_retries must be larger than or equal to 0, but is {max_retries}."
743+
)
728744

729-
# Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
730745
(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)
731746

732747
if len(X.columns) == 1:
@@ -738,15 +753,7 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
738753
"flatten_json_output": True,
739754
}
740755

741-
df = self._bqml_model.generate_embedding(X, options)
742-
743-
if (df[_ML_GENERATE_EMBEDDING_STATUS] != "").any():
744-
warnings.warn(
745-
f"Some predictions failed. Check column {_ML_GENERATE_EMBEDDING_STATUS} for detailed status. You may want to filter the failed rows and retry.",
746-
RuntimeWarning,
747-
)
748-
749-
return df
756+
return self._predict_and_retry(X, options=options, max_retries=max_retries)
750757

751758
def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerator:
752759
"""Save the model to BigQuery.
@@ -765,7 +772,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerat
765772

766773

767774
@log_adapter.class_logger
768-
class GeminiTextGenerator(base.BaseEstimator):
775+
class GeminiTextGenerator(base.RetriableRemotePredictor):
769776
"""Gemini text generator LLM model.
770777
771778
Args:
@@ -891,6 +898,14 @@ def _bqml_options(self) -> dict:
891898
}
892899
return options
893900

901+
@property
902+
def _predict_func(self) -> Callable[[bpd.DataFrame, Mapping], bpd.DataFrame]:
903+
return self._bqml_model.generate_text
904+
905+
@property
906+
def _status_col(self) -> str:
907+
return _ML_GENERATE_TEXT_STATUS
908+
894909
def fit(
895910
self,
896911
X: utils.ArrayType,
@@ -1028,41 +1043,7 @@ def predict(
10281043
"ground_with_google_search": ground_with_google_search,
10291044
}
10301045

1031-
df_result = bpd.DataFrame(session=self._bqml_model.session) # placeholder
1032-
df_fail = X
1033-
for _ in range(max_retries + 1):
1034-
df = self._bqml_model.generate_text(df_fail, options)
1035-
1036-
success = df[_ML_GENERATE_TEXT_STATUS].str.len() == 0
1037-
df_succ = df[success]
1038-
df_fail = df[~success]
1039-
1040-
if df_succ.empty:
1041-
if max_retries > 0:
1042-
warnings.warn(
1043-
"Can't make any progress, stop retrying.", RuntimeWarning
1044-
)
1045-
break
1046-
1047-
df_result = (
1048-
bpd.concat([df_result, df_succ]) if not df_result.empty else df_succ
1049-
)
1050-
1051-
if df_fail.empty:
1052-
break
1053-
1054-
if not df_fail.empty:
1055-
warnings.warn(
1056-
f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.",
1057-
RuntimeWarning,
1058-
)
1059-
1060-
df_result = cast(
1061-
bpd.DataFrame,
1062-
bpd.concat([df_result, df_fail]) if not df_result.empty else df_fail,
1063-
)
1064-
1065-
return df_result
1046+
return self._predict_and_retry(X, options=options, max_retries=max_retries)
10661047

10671048
def score(
10681049
self,
@@ -1144,7 +1125,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> GeminiTextGenerator:
11441125

11451126

11461127
@log_adapter.class_logger
1147-
class Claude3TextGenerator(base.BaseEstimator):
1128+
class Claude3TextGenerator(base.RetriableRemotePredictor):
11481129
"""Claude3 text generator LLM model.
11491130
11501131
Go to Google Cloud Console -> Vertex AI -> Model Garden page to enabe the models before use. Must have the Consumer Procurement Entitlement Manager Identity and Access Management (IAM) role to enable the models.
@@ -1273,13 +1254,22 @@ def _bqml_options(self) -> dict:
12731254
}
12741255
return options
12751256

1257+
@property
1258+
def _predict_func(self) -> Callable[[bpd.DataFrame, Mapping], bpd.DataFrame]:
1259+
return self._bqml_model.generate_text
1260+
1261+
@property
1262+
def _status_col(self) -> str:
1263+
return _ML_GENERATE_TEXT_STATUS
1264+
12761265
def predict(
12771266
self,
12781267
X: utils.ArrayType,
12791268
*,
12801269
max_output_tokens: int = 128,
12811270
top_k: int = 40,
12821271
top_p: float = 0.95,
1272+
max_retries: int = 0,
12831273
) -> bpd.DataFrame:
12841274
"""Predict the result from input DataFrame.
12851275
@@ -1307,6 +1297,10 @@ def predict(
13071297
Specify a lower value for less random responses and a higher value for more random responses.
13081298
Default 0.95. Possible values [0.0, 1.0].
13091299
1300+
max_retries (int, default 0):
1301+
Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry.
1302+
Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result.
1303+
13101304
13111305
Returns:
13121306
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
@@ -1324,6 +1318,11 @@ def predict(
13241318
if top_p < 0.0 or top_p > 1.0:
13251319
raise ValueError(f"top_p must be [0.0, 1.0], but is {top_p}.")
13261320

1321+
if max_retries < 0:
1322+
raise ValueError(
1323+
f"max_retries must be larger than or equal to 0, but is {max_retries}."
1324+
)
1325+
13271326
(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)
13281327

13291328
if len(X.columns) == 1:
@@ -1338,15 +1337,7 @@ def predict(
13381337
"flatten_json_output": True,
13391338
}
13401339

1341-
df = self._bqml_model.generate_text(X, options)
1342-
1343-
if (df[_ML_GENERATE_TEXT_STATUS] != "").any():
1344-
warnings.warn(
1345-
f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.",
1346-
RuntimeWarning,
1347-
)
1348-
1349-
return df
1340+
return self._predict_and_retry(X, options=options, max_retries=max_retries)
13501341

13511342
def to_gbq(self, model_name: str, replace: bool = False) -> Claude3TextGenerator:
13521343
"""Save the model to BigQuery.
There was a problem loading the remainder of the diff.

0 commit comments

Comments
 (0)
Failed to load comments.