@@ -155,14 +155,15 @@ def _fit(
155
155
def predict (self , X : utils .ArrayType ) -> bpd .DataFrame :
156
156
if not self ._bqml_model :
157
157
raise RuntimeError ("A model must be fitted before predict" )
158
-
159
- (X ,) = utils .batch_convert_to_dataframe (X )
158
+ (X ,) = utils .batch_convert_to_dataframe (X , session = self ._bqml_model .session )
160
159
161
160
return self ._bqml_model .predict (X )
162
161
163
162
def predict_explain (
164
163
self ,
165
164
X : utils .ArrayType ,
165
+ * ,
166
+ top_k_features : int = 5 ,
166
167
) -> bpd .DataFrame :
167
168
"""
168
169
Explain predictions for a linear regression model.
@@ -175,18 +176,32 @@ def predict_explain(
175
176
X (bigframes.dataframe.DataFrame or bigframes.series.Series or
176
177
pandas.core.frame.DataFrame or pandas.core.series.Series):
177
178
Series or a DataFrame to explain its predictions.
179
+ top_k_features (int, default 5):
180
+ an INT64 value that specifies how many top feature attribution
181
+ pairs are generated for each row of input data. The features are
182
+ ranked by the absolute values of their attributions.
183
+
184
+ By default, top_k_features is set to 5. If its value is greater
185
+ than the number of features in the training data, the
186
+ attributions of all features are returned.
178
187
179
188
Returns:
180
189
bigframes.pandas.DataFrame:
181
190
The predicted DataFrames with explanation columns.
182
191
"""
183
- # TODO(b/377366612): Add support for `top_k_features` parameter
192
+ if top_k_features < 1 :
193
+ raise ValueError (
194
+ f"top_k_features must be at least 1, but is { top_k_features } ."
195
+ )
196
+
184
197
if not self ._bqml_model :
185
198
raise RuntimeError ("A model must be fitted before predict" )
186
199
187
200
(X ,) = utils .batch_convert_to_dataframe (X , session = self ._bqml_model .session )
188
201
189
- return self ._bqml_model .explain_predict (X )
202
+ return self ._bqml_model .explain_predict (
203
+ X , options = {"top_k_features" : top_k_features }
204
+ )
190
205
191
206
def score (
192
207
self ,
@@ -356,6 +371,8 @@ def predict(
356
371
def predict_explain (
357
372
self ,
358
373
X : utils .ArrayType ,
374
+ * ,
375
+ top_k_features : int = 5 ,
359
376
) -> bpd .DataFrame :
360
377
"""
361
378
Explain predictions for a logistic regression model.
@@ -368,18 +385,32 @@ def predict_explain(
368
385
X (bigframes.dataframe.DataFrame or bigframes.series.Series or
369
386
pandas.core.frame.DataFrame or pandas.core.series.Series):
370
387
Series or a DataFrame to explain its predictions.
388
+ top_k_features (int, default 5):
389
+ an INT64 value that specifies how many top feature attribution
390
+ pairs are generated for each row of input data. The features are
391
+ ranked by the absolute values of their attributions.
392
+
393
+ By default, top_k_features is set to 5. If its value is greater
394
+ than the number of features in the training data, the
395
+ attributions of all features are returned.
371
396
372
397
Returns:
373
398
bigframes.pandas.DataFrame:
374
399
The predicted DataFrames with explanation columns.
375
400
"""
376
- # TODO(b/377366612): Add support for `top_k_features` parameter
401
+ if top_k_features < 1 :
402
+ raise ValueError (
403
+ f"top_k_features must be at least 1, but is { top_k_features } ."
404
+ )
405
+
377
406
if not self ._bqml_model :
378
407
raise RuntimeError ("A model must be fitted before predict" )
379
408
380
409
(X ,) = utils .batch_convert_to_dataframe (X , session = self ._bqml_model .session )
381
410
382
- return self ._bqml_model .explain_predict (X )
411
+ return self ._bqml_model .explain_predict (
412
+ X , options = {"top_k_features" : top_k_features }
413
+ )
383
414
384
415
def score (
385
416
self ,
0 commit comments