Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 9d8970a

Browse files
authoredNov 20, 2024
docs: add snippet for evaluating a boosted tree model (#1154)
* must verify evaluation_data and edit Output * line edit * edit desc * correct model info
1 parent a972668 commit 9d8970a

File tree

1 file changed

+28
-5
lines changed

1 file changed

+28
-5
lines changed
 

‎samples/snippets/classification_boosted_tree_model_test.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,42 @@ def test_boosted_tree_model(random_model_id: str) -> None:
4848
y = training_data["income_bracket"]
4949

5050
# create and train the model
51-
census_model = ensemble.XGBClassifier(
51+
tree_model = ensemble.XGBClassifier(
5252
n_estimators=1,
5353
booster="gbtree",
5454
tree_method="hist",
5555
max_iterations=1, # For a more accurate model, try 50 iterations.
5656
subsample=0.85,
5757
)
58-
census_model.fit(X, y)
58+
tree_model.fit(X, y)
5959

60-
census_model.to_gbq(
61-
your_model_id, # For example: "your-project.census.census_model"
60+
tree_model.to_gbq(
61+
your_model_id, # For example: "your-project.bqml_tutorial.tree_model"
6262
replace=True,
6363
)
6464
# [END bigquery_dataframes_bqml_boosted_tree_create]
65+
# [START bigquery_dataframes_bqml_boosted_tree_explain]
66+
# Select model you'll use for predictions. `read_gbq_model` loads model
67+
# data from BigQuery, but you could also use the `tree_model` object
68+
# from the previous step.
69+
tree_model = bpd.read_gbq_model(
70+
your_model_id, # For example: "your-project.bqml_tutorial.tree_model"
71+
)
72+
73+
# input_data is defined in an earlier step.
74+
evaluation_data = input_data[input_data["dataframe"] == "evaluation"]
75+
X = evaluation_data.drop(columns=["income_bracket", "dataframe"])
76+
y = evaluation_data["income_bracket"]
77+
78+
# The score() method evaluates how the model performs compared to the
79+
# actual data. Output DataFrame matches that of ML.EVALUATE().
80+
score = tree_model.score(X, y)
81+
score.peek()
82+
# Output:
83+
# precision recall accuracy f1_score log_loss roc_auc
84+
# 0 0.671924 0.578804 0.839429 0.621897 0.344054 0.887335
85+
# [END bigquery_dataframes_bqml_boosted_tree_explain]
86+
assert tree_model is not None
87+
assert evaluation_data is not None
88+
assert score is not None
6589
assert input_data is not None
66-
assert census_model is not None

0 commit comments

Comments
 (0)
Failed to load comments.