16
16
17
17
from __future__ import annotations
18
18
19
- from typing import cast , Literal , Optional
19
+ from typing import Callable , cast , Literal , Mapping , Optional
20
20
import warnings
21
21
22
22
import bigframes_vendored .constants as constants
@@ -616,7 +616,7 @@ def to_gbq(
616
616
617
617
618
618
@log_adapter .class_logger
619
- class TextEmbeddingGenerator (base .BaseEstimator ):
619
+ class TextEmbeddingGenerator (base .RetriableRemotePredictor ):
620
620
"""Text embedding generator LLM model.
621
621
622
622
Args:
@@ -715,18 +715,33 @@ def _from_bq(
715
715
model ._bqml_model = core .BqmlModel (session , bq_model )
716
716
return model
717
717
718
- def predict (self , X : utils .ArrayType ) -> bpd .DataFrame :
718
+ @property
719
+ def _predict_func (self ) -> Callable [[bpd .DataFrame , Mapping ], bpd .DataFrame ]:
720
+ return self ._bqml_model .generate_embedding
721
+
722
+ @property
723
+ def _status_col (self ) -> str :
724
+ return _ML_GENERATE_EMBEDDING_STATUS
725
+
726
+ def predict (self , X : utils .ArrayType , * , max_retries : int = 0 ) -> bpd .DataFrame :
719
727
"""Predict the result from input DataFrame.
720
728
721
729
Args:
722
730
X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
723
731
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "content" column for prediction.
724
732
733
+ max_retries (int, default 0):
734
+ Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry.
735
+ Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result.
736
+
725
737
Returns:
726
738
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
727
739
"""
740
+ if max_retries < 0 :
741
+ raise ValueError (
742
+ f"max_retries must be larger than or equal to 0, but is { max_retries } ."
743
+ )
728
744
729
- # Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
730
745
(X ,) = utils .batch_convert_to_dataframe (X , session = self ._bqml_model .session )
731
746
732
747
if len (X .columns ) == 1 :
@@ -738,15 +753,7 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
738
753
"flatten_json_output" : True ,
739
754
}
740
755
741
- df = self ._bqml_model .generate_embedding (X , options )
742
-
743
- if (df [_ML_GENERATE_EMBEDDING_STATUS ] != "" ).any ():
744
- warnings .warn (
745
- f"Some predictions failed. Check column { _ML_GENERATE_EMBEDDING_STATUS } for detailed status. You may want to filter the failed rows and retry." ,
746
- RuntimeWarning ,
747
- )
748
-
749
- return df
756
+ return self ._predict_and_retry (X , options = options , max_retries = max_retries )
750
757
751
758
def to_gbq (self , model_name : str , replace : bool = False ) -> TextEmbeddingGenerator :
752
759
"""Save the model to BigQuery.
@@ -765,7 +772,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerat
765
772
766
773
767
774
@log_adapter .class_logger
768
- class GeminiTextGenerator (base .BaseEstimator ):
775
+ class GeminiTextGenerator (base .RetriableRemotePredictor ):
769
776
"""Gemini text generator LLM model.
770
777
771
778
Args:
@@ -891,6 +898,14 @@ def _bqml_options(self) -> dict:
891
898
}
892
899
return options
893
900
901
+ @property
902
+ def _predict_func (self ) -> Callable [[bpd .DataFrame , Mapping ], bpd .DataFrame ]:
903
+ return self ._bqml_model .generate_text
904
+
905
+ @property
906
+ def _status_col (self ) -> str :
907
+ return _ML_GENERATE_TEXT_STATUS
908
+
894
909
def fit (
895
910
self ,
896
911
X : utils .ArrayType ,
@@ -1028,41 +1043,7 @@ def predict(
1028
1043
"ground_with_google_search" : ground_with_google_search ,
1029
1044
}
1030
1045
1031
- df_result = bpd .DataFrame (session = self ._bqml_model .session ) # placeholder
1032
- df_fail = X
1033
- for _ in range (max_retries + 1 ):
1034
- df = self ._bqml_model .generate_text (df_fail , options )
1035
-
1036
- success = df [_ML_GENERATE_TEXT_STATUS ].str .len () == 0
1037
- df_succ = df [success ]
1038
- df_fail = df [~ success ]
1039
-
1040
- if df_succ .empty :
1041
- if max_retries > 0 :
1042
- warnings .warn (
1043
- "Can't make any progress, stop retrying." , RuntimeWarning
1044
- )
1045
- break
1046
-
1047
- df_result = (
1048
- bpd .concat ([df_result , df_succ ]) if not df_result .empty else df_succ
1049
- )
1050
-
1051
- if df_fail .empty :
1052
- break
1053
-
1054
- if not df_fail .empty :
1055
- warnings .warn (
1056
- f"Some predictions failed. Check column { _ML_GENERATE_TEXT_STATUS } for detailed status. You may want to filter the failed rows and retry." ,
1057
- RuntimeWarning ,
1058
- )
1059
-
1060
- df_result = cast (
1061
- bpd .DataFrame ,
1062
- bpd .concat ([df_result , df_fail ]) if not df_result .empty else df_fail ,
1063
- )
1064
-
1065
- return df_result
1046
+ return self ._predict_and_retry (X , options = options , max_retries = max_retries )
1066
1047
1067
1048
def score (
1068
1049
self ,
@@ -1144,7 +1125,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> GeminiTextGenerator:
1144
1125
1145
1126
1146
1127
@log_adapter .class_logger
1147
- class Claude3TextGenerator (base .BaseEstimator ):
1128
+ class Claude3TextGenerator (base .RetriableRemotePredictor ):
1148
1129
"""Claude3 text generator LLM model.
1149
1130
1150
1131
Go to Google Cloud Console -> Vertex AI -> Model Garden page to enabe the models before use. Must have the Consumer Procurement Entitlement Manager Identity and Access Management (IAM) role to enable the models.
@@ -1273,13 +1254,22 @@ def _bqml_options(self) -> dict:
1273
1254
}
1274
1255
return options
1275
1256
1257
+ @property
1258
+ def _predict_func (self ) -> Callable [[bpd .DataFrame , Mapping ], bpd .DataFrame ]:
1259
+ return self ._bqml_model .generate_text
1260
+
1261
+ @property
1262
+ def _status_col (self ) -> str :
1263
+ return _ML_GENERATE_TEXT_STATUS
1264
+
1276
1265
def predict (
1277
1266
self ,
1278
1267
X : utils .ArrayType ,
1279
1268
* ,
1280
1269
max_output_tokens : int = 128 ,
1281
1270
top_k : int = 40 ,
1282
1271
top_p : float = 0.95 ,
1272
+ max_retries : int = 0 ,
1283
1273
) -> bpd .DataFrame :
1284
1274
"""Predict the result from input DataFrame.
1285
1275
@@ -1307,6 +1297,10 @@ def predict(
1307
1297
Specify a lower value for less random responses and a higher value for more random responses.
1308
1298
Default 0.95. Possible values [0.0, 1.0].
1309
1299
1300
+ max_retries (int, default 0):
1301
+ Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry.
1302
+ Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result.
1303
+
1310
1304
1311
1305
Returns:
1312
1306
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
@@ -1324,6 +1318,11 @@ def predict(
1324
1318
if top_p < 0.0 or top_p > 1.0 :
1325
1319
raise ValueError (f"top_p must be [0.0, 1.0], but is { top_p } ." )
1326
1320
1321
+ if max_retries < 0 :
1322
+ raise ValueError (
1323
+ f"max_retries must be larger than or equal to 0, but is { max_retries } ."
1324
+ )
1325
+
1327
1326
(X ,) = utils .batch_convert_to_dataframe (X , session = self ._bqml_model .session )
1328
1327
1329
1328
if len (X .columns ) == 1 :
@@ -1338,15 +1337,7 @@ def predict(
1338
1337
"flatten_json_output" : True ,
1339
1338
}
1340
1339
1341
- df = self ._bqml_model .generate_text (X , options )
1342
-
1343
- if (df [_ML_GENERATE_TEXT_STATUS ] != "" ).any ():
1344
- warnings .warn (
1345
- f"Some predictions failed. Check column { _ML_GENERATE_TEXT_STATUS } for detailed status. You may want to filter the failed rows and retry." ,
1346
- RuntimeWarning ,
1347
- )
1348
-
1349
- return df
1340
+ return self ._predict_and_retry (X , options = options , max_retries = max_retries )
1350
1341
1351
1342
def to_gbq (self , model_name : str , replace : bool = False ) -> Claude3TextGenerator :
1352
1343
"""Save the model to BigQuery.
0 commit comments