Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 3068e19

Browse files
arwas11tswast
andauthoredDec 30, 2024
feat: add support for LinearRegression.predict_explain and LogisticRegression.predict_explain parameter, top_k_features (#1228)
* feat: add LogisticRegression.predict_explain() to generate ML.EXPLAIN_PREDICT columns * update tests * chore: add support for predict_explain paramater, top_k_features * update test * update logistic reg method with the new param * add and test new param's validation * Update bigframes/ml/linear_model.py --------- Co-authored-by: Tim Sweña (Swast) <swast@google.com>
1 parent cafd5e5 commit 3068e19

File tree

6 files changed

+100
-24
lines changed

6 files changed

+100
-24
lines changed
 

‎bigframes/ml/core.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,15 @@ def predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
123123
self._model_manipulation_sql_generator.ml_predict,
124124
)
125125

126-
def explain_predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
126+
def explain_predict(
127+
self, input_data: bpd.DataFrame, options: Mapping[str, int | float]
128+
) -> bpd.DataFrame:
127129
return self._apply_ml_tvf(
128130
input_data,
129-
self._model_manipulation_sql_generator.ml_explain_predict,
131+
lambda source_sql: self._model_manipulation_sql_generator.ml_explain_predict(
132+
source_sql=source_sql,
133+
struct_options=options,
134+
),
130135
)
131136

132137
def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame:

‎bigframes/ml/linear_model.py

+37-6
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,15 @@ def _fit(
155155
def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
156156
if not self._bqml_model:
157157
raise RuntimeError("A model must be fitted before predict")
158-
159-
(X,) = utils.batch_convert_to_dataframe(X)
158+
(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)
160159

161160
return self._bqml_model.predict(X)
162161

163162
def predict_explain(
164163
self,
165164
X: utils.ArrayType,
165+
*,
166+
top_k_features: int = 5,
166167
) -> bpd.DataFrame:
167168
"""
168169
Explain predictions for a linear regression model.
@@ -175,18 +176,32 @@ def predict_explain(
175176
X (bigframes.dataframe.DataFrame or bigframes.series.Series or
176177
pandas.core.frame.DataFrame or pandas.core.series.Series):
177178
Series or a DataFrame to explain its predictions.
179+
top_k_features (int, default 5):
180+
an INT64 value that specifies how many top feature attribution
181+
pairs are generated for each row of input data. The features are
182+
ranked by the absolute values of their attributions.
183+
184+
By default, top_k_features is set to 5. If its value is greater
185+
than the number of features in the training data, the
186+
attributions of all features are returned.
178187
179188
Returns:
180189
bigframes.pandas.DataFrame:
181190
The predicted DataFrames with explanation columns.
182191
"""
183-
# TODO(b/377366612): Add support for `top_k_features` parameter
192+
if top_k_features < 1:
193+
raise ValueError(
194+
f"top_k_features must be at least 1, but is {top_k_features}."
195+
)
196+
184197
if not self._bqml_model:
185198
raise RuntimeError("A model must be fitted before predict")
186199

187200
(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)
188201

189-
return self._bqml_model.explain_predict(X)
202+
return self._bqml_model.explain_predict(
203+
X, options={"top_k_features": top_k_features}
204+
)
190205

191206
def score(
192207
self,
@@ -356,6 +371,8 @@ def predict(
356371
def predict_explain(
357372
self,
358373
X: utils.ArrayType,
374+
*,
375+
top_k_features: int = 5,
359376
) -> bpd.DataFrame:
360377
"""
361378
Explain predictions for a logistic regression model.
@@ -368,18 +385,32 @@ def predict_explain(
368385
X (bigframes.dataframe.DataFrame or bigframes.series.Series or
369386
pandas.core.frame.DataFrame or pandas.core.series.Series):
370387
Series or a DataFrame to explain its predictions.
388+
top_k_features (int, default 5):
389+
an INT64 value that specifies how many top feature attribution
390+
pairs are generated for each row of input data. The features are
391+
ranked by the absolute values of their attributions.
392+
393+
By default, top_k_features is set to 5. If its value is greater
394+
than the number of features in the training data, the
395+
attributions of all features are returned.
371396
372397
Returns:
373398
bigframes.pandas.DataFrame:
374399
The predicted DataFrames with explanation columns.
375400
"""
376-
# TODO(b/377366612): Add support for `top_k_features` parameter
401+
if top_k_features < 1:
402+
raise ValueError(
403+
f"top_k_features must be at least 1, but is {top_k_features}."
404+
)
405+
377406
if not self._bqml_model:
378407
raise RuntimeError("A model must be fitted before predict")
379408

380409
(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)
381410

382-
return self._bqml_model.explain_predict(X)
411+
return self._bqml_model.explain_predict(
412+
X, options={"top_k_features": top_k_features}
413+
)
383414

384415
def score(
385416
self,

‎bigframes/ml/sql.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,13 @@ def ml_predict(self, source_sql: str) -> str:
304304
return f"""SELECT * FROM ML.PREDICT(MODEL {self._model_ref_sql()},
305305
({source_sql}))"""
306306

307-
def ml_explain_predict(self, source_sql: str) -> str:
307+
def ml_explain_predict(
308+
self, source_sql: str, struct_options: Mapping[str, Union[int, float]]
309+
) -> str:
308310
"""Encode ML.EXPLAIN_PREDICT for BQML"""
311+
struct_options_sql = self.struct_options(**struct_options)
309312
return f"""SELECT * FROM ML.EXPLAIN_PREDICT(MODEL {self._model_ref_sql()},
310-
({source_sql}))"""
313+
({source_sql}), {struct_options_sql})"""
311314

312315
def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str:
313316
"""Encode ML.FORECAST for BQML"""

‎tests/system/small/ml/test_core.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,9 @@ def test_model_predict(penguins_bqml_linear_model: core.BqmlModel, new_penguins_
263263
def test_model_predict_explain(
264264
penguins_bqml_linear_model: core.BqmlModel, new_penguins_df
265265
):
266+
options = {"top_k_features": 3}
266267
predictions = penguins_bqml_linear_model.explain_predict(
267-
new_penguins_df
268+
new_penguins_df, options
268269
).to_pandas()
269270
expected = pd.DataFrame(
270271
{
@@ -317,14 +318,15 @@ def test_model_predict_explain_with_unnamed_index(
317318
# need to persist through the call to ML.PREDICT
318319
new_penguins_df = new_penguins_df.reset_index()
319320

321+
options = {"top_k_features": 3}
320322
# remove the middle tag number to ensure we're really keeping the unnamed index
321323
new_penguins_df = typing.cast(
322324
bigframes.dataframe.DataFrame,
323325
new_penguins_df[new_penguins_df.tag_number != 1672],
324326
)
325327

326328
predictions = penguins_bqml_linear_model.explain_predict(
327-
new_penguins_df
329+
new_penguins_df, options
328330
).to_pandas()
329331

330332
expected = pd.DataFrame(

‎tests/system/small/ml/test_linear_model.py

+30
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import re
16+
1517
import google.api_core.exceptions
1618
import pandas
1719
import pytest
@@ -132,6 +134,20 @@ def test_linear_reg_model_predict_explain(penguins_linear_model, new_penguins_df
132134
)
133135

134136

137+
def test_linear_model_predict_explain_top_k_features(
138+
penguins_logistic_model: linear_model.LinearRegression, new_penguins_df
139+
):
140+
top_k_features = 0
141+
142+
with pytest.raises(
143+
ValueError,
144+
match=re.escape(f"top_k_features must be at least 1, but is {top_k_features}."),
145+
):
146+
penguins_logistic_model.predict_explain(
147+
new_penguins_df, top_k_features=top_k_features
148+
).to_pandas()
149+
150+
135151
def test_linear_reg_model_predict_params(
136152
penguins_linear_model: linear_model.LinearRegression, new_penguins_df
137153
):
@@ -307,6 +323,20 @@ def test_logistic_model_predict(penguins_logistic_model, new_penguins_df):
307323
)
308324

309325

326+
def test_logistic_model_predict_explain_top_k_features(
327+
penguins_logistic_model: linear_model.LogisticRegression, new_penguins_df
328+
):
329+
top_k_features = 0
330+
331+
with pytest.raises(
332+
ValueError,
333+
match=re.escape(f"top_k_features must be at least 1, but is {top_k_features}."),
334+
):
335+
penguins_logistic_model.predict_explain(
336+
new_penguins_df, top_k_features=top_k_features
337+
).to_pandas()
338+
339+
310340
def test_logistic_model_predict_params(
311341
penguins_logistic_model: linear_model.LogisticRegression, new_penguins_df
312342
):

‎tests/unit/ml/test_sql.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -342,18 +342,6 @@ def test_ml_predict_correct(
342342
)
343343

344344

345-
def test_ml_explain_predict_correct(
346-
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
347-
mock_df: bpd.DataFrame,
348-
):
349-
sql = model_manipulation_sql_generator.ml_explain_predict(source_sql=mock_df.sql)
350-
assert (
351-
sql
352-
== """SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`,
353-
(input_X_y_sql))"""
354-
)
355-
356-
357345
def test_ml_llm_evaluate_correct(
358346
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
359347
mock_df: bpd.DataFrame,
@@ -462,6 +450,23 @@ def test_ml_generate_embedding_correct(
462450
)
463451

464452

453+
def test_ml_explain_predict_correct(
454+
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
455+
mock_df: bpd.DataFrame,
456+
):
457+
sql = model_manipulation_sql_generator.ml_explain_predict(
458+
source_sql=mock_df.sql,
459+
struct_options={"option_key1": 1, "option_key2": 2.25},
460+
)
461+
assert (
462+
sql
463+
== """SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`,
464+
(input_X_y_sql), STRUCT(
465+
1 AS `option_key1`,
466+
2.25 AS `option_key2`))"""
467+
)
468+
469+
465470
def test_ml_detect_anomalies_correct_sql(
466471
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
467472
mock_df: bpd.DataFrame,

0 commit comments

Comments
 (0)
Failed to load comments.