|
22 | 22 | """
|
23 | 23 |
|
24 | 24 | import abc
|
25 |
| -from typing import Callable, cast, Mapping, Optional, TypeVar |
| 25 | +from typing import Callable, cast, Mapping, Optional, TypeVar, Union |
26 | 26 | import warnings
|
27 | 27 |
|
28 | 28 | import bigframes_vendored.sklearn.base
|
@@ -259,38 +259,29 @@ def _predict_and_retry(
|
259 | 259 | ) -> bpd.DataFrame:
|
260 | 260 | assert self._bqml_model is not None
|
261 | 261 |
|
262 |
| - df_result = bpd.DataFrame(session=self._bqml_model.session) # placeholder |
263 |
| - df_fail = X |
264 |
| - for _ in range(max_retries + 1): |
| 262 | + df_result: Union[bpd.DataFrame, None] = None # placeholder |
| 263 | + df_succ = df_fail = X |
| 264 | + for i in range(max_retries + 1): |
| 265 | + if i > 0 and df_fail.empty: |
| 266 | + break |
| 267 | + if i > 0 and df_succ.empty: |
| 268 | + msg = bfe.format_message("Can't make any progress, stop retrying.") |
| 269 | + warnings.warn(msg, category=RuntimeWarning) |
| 270 | + break |
| 271 | + |
265 | 272 | df = self._predict_func(df_fail, options)
|
266 | 273 |
|
267 | 274 | success = df[self._status_col].str.len() == 0
|
268 | 275 | df_succ = df[success]
|
269 | 276 | df_fail = df[~success]
|
270 | 277 |
|
271 |
| - if df_succ.empty: |
272 |
| - if max_retries > 0: |
273 |
| - msg = bfe.format_message("Can't make any progress, stop retrying.") |
274 |
| - warnings.warn(msg, category=RuntimeWarning) |
275 |
| - break |
276 |
| - |
277 | 278 | df_result = (
|
278 |
| - bpd.concat([df_result, df_succ]) if not df_result.empty else df_succ |
279 |
| - ) |
280 |
| - |
281 |
| - if df_fail.empty: |
282 |
| - break |
283 |
| - |
284 |
| - if not df_fail.empty: |
285 |
| - msg = bfe.format_message( |
286 |
| - f"Some predictions failed. Check column {self._status_col} for detailed " |
287 |
| - "status. You may want to filter the failed rows and retry." |
| 279 | + bpd.concat([df_result, df_succ]) if df_result is not None else df_succ |
288 | 280 | )
|
289 |
| - warnings.warn(msg, category=RuntimeWarning) |
290 | 281 |
|
291 | 282 | df_result = cast(
|
292 | 283 | bpd.DataFrame,
|
293 |
| - bpd.concat([df_result, df_fail]) if not df_result.empty else df_fail, |
| 284 | + bpd.concat([df_result, df_fail]) if df_result is not None else df_fail, |
294 | 285 | )
|
295 | 286 | return df_result
|
296 | 287 |
|
|
0 commit comments