Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1c934c2

Browse files
authoredMar 13, 2025
perf: eliminate count queries in llm retry (#1489)
* performance: eliminate count queries in llm retry * fix tests
1 parent 2029d08 commit 1c934c2

File tree

1 file changed

+13
-22
lines changed

1 file changed

+13
-22
lines changed
 

‎bigframes/ml/base.py

+13-22
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"""
2323

2424
import abc
25-
from typing import Callable, cast, Mapping, Optional, TypeVar
25+
from typing import Callable, cast, Mapping, Optional, TypeVar, Union
2626
import warnings
2727

2828
import bigframes_vendored.sklearn.base
@@ -259,38 +259,29 @@ def _predict_and_retry(
259259
) -> bpd.DataFrame:
260260
assert self._bqml_model is not None
261261

262-
df_result = bpd.DataFrame(session=self._bqml_model.session) # placeholder
263-
df_fail = X
264-
for _ in range(max_retries + 1):
262+
df_result: Union[bpd.DataFrame, None] = None # placeholder
263+
df_succ = df_fail = X
264+
for i in range(max_retries + 1):
265+
if i > 0 and df_fail.empty:
266+
break
267+
if i > 0 and df_succ.empty:
268+
msg = bfe.format_message("Can't make any progress, stop retrying.")
269+
warnings.warn(msg, category=RuntimeWarning)
270+
break
271+
265272
df = self._predict_func(df_fail, options)
266273

267274
success = df[self._status_col].str.len() == 0
268275
df_succ = df[success]
269276
df_fail = df[~success]
270277

271-
if df_succ.empty:
272-
if max_retries > 0:
273-
msg = bfe.format_message("Can't make any progress, stop retrying.")
274-
warnings.warn(msg, category=RuntimeWarning)
275-
break
276-
277278
df_result = (
278-
bpd.concat([df_result, df_succ]) if not df_result.empty else df_succ
279-
)
280-
281-
if df_fail.empty:
282-
break
283-
284-
if not df_fail.empty:
285-
msg = bfe.format_message(
286-
f"Some predictions failed. Check column {self._status_col} for detailed "
287-
"status. You may want to filter the failed rows and retry."
279+
bpd.concat([df_result, df_succ]) if df_result is not None else df_succ
288280
)
289-
warnings.warn(msg, category=RuntimeWarning)
290281

291282
df_result = cast(
292283
bpd.DataFrame,
293-
bpd.concat([df_result, df_fail]) if not df_result.empty else df_fail,
284+
bpd.concat([df_result, df_fail]) if df_result is not None else df_fail,
294285
)
295286
return df_result
296287

0 commit comments

Comments
 (0)
Failed to load comments.