Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ARIMA_EVAULATE options in forecasting models #336

Merged
merged 9 commits into from
Jan 24, 2024
7 changes: 7 additions & 0 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
@@ -136,6 +136,13 @@ def evaluate(self, input_data: Optional[bpd.DataFrame] = None):

return self._session.read_gbq(sql)

def arima_evaluate(self, show_all_candidate_models: bool = False):
sql = self._model_manipulation_sql_generator.ml_arima_evaluate(
show_all_candidate_models
)

return self._session.read_gbq(sql)

def centroids(self) -> bpd.DataFrame:
assert self._model.model_type == "KMEANS"

25 changes: 25 additions & 0 deletions bigframes/ml/forecasting.py
Original file line number Diff line number Diff line change
@@ -151,6 +151,31 @@ def score(
input_data = X.join(y, how="outer")
return self._bqml_model.evaluate(input_data)

def summary(
self,
show_all_candidate_models: bool = False,
) -> bpd.DataFrame:
"""Summary of the evaluation metrics of the time series model.

.. note::

Output matches that of the BigQuery ML.ARIMA_EVALUATE function.
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-arima-evaluate
for the outputs relevant to this model type.

Args:
show_all_candidate_models (bool, default to False):
Whether to show evaluation metrics or an error message for either
all candidate models or for only the best model with the lowest
AIC. Default to False.

Returns:
bigframes.dataframe.DataFrame: A DataFrame as evaluation result.
"""
if not self._bqml_model:
raise RuntimeError("A model must be fitted before score")
return self._bqml_model.arima_evaluate(show_all_candidate_models)

def to_gbq(self, model_name: str, replace: bool = False) -> ARIMAPlus:
"""Save the model to BigQuery.

6 changes: 6 additions & 0 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
@@ -260,6 +260,12 @@ def ml_evaluate(self, source_df: Optional[bpd.DataFrame] = None) -> str:
return f"""SELECT * FROM ML.EVALUATE(MODEL `{self._model_name}`,
({source_sql}))"""

# ML evaluation TVFs
def ml_arima_evaluate(self, show_all_candidate_models: bool = False) -> str:
"""Encode ML.ARMIA_EVALUATE for BQML"""
return f"""SELECT * FROM ML.ARIMA_EVALUATE(MODEL `{self._model_name}`,
STRUCT({show_all_candidate_models} AS show_all_candidate_models))"""

def ml_centroids(self) -> str:
"""Encode ML.CENTROIDS for BQML"""
return f"""SELECT * FROM ML.CENTROIDS(MODEL `{self._model_name}`)"""
35 changes: 33 additions & 2 deletions tests/system/large/ml/test_forecasting.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,20 @@

from bigframes.ml import forecasting

ARIMA_EVALUATE_OUTPUT_COL = [
"non_seasonal_p",
"non_seasonal_d",
"non_seasonal_q",
"log_likelihood",
"AIC",
"variance",
"seasonal_periods",
"has_holiday_effect",
"has_spikes_and_dips",
"has_step_changes",
"error_message",
]


def test_arima_plus_model_fit_score(
time_series_df_default_index, dataset_id, new_time_series_df
@@ -42,7 +56,24 @@ def test_arima_plus_model_fit_score(
pd.testing.assert_frame_equal(result, expected, check_exact=False, rtol=0.1)

# save, load to ensure configuration was kept
reloaded_model = model.to_gbq(f"{dataset_id}.temp_configured_model", replace=True)
reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True)
assert (
f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name
)


def test_arima_plus_model_fit_summary(time_series_df_default_index, dataset_id):
model = forecasting.ARIMAPlus()
X_train = time_series_df_default_index[["parsed_date"]]
y_train = time_series_df_default_index[["total_visits"]]
model.fit(X_train, y_train)

result = model.summary()
assert result.shape == (1, 12)
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)

# save, load to ensure configuration was kept
reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True)
assert (
f"{dataset_id}.temp_configured_model" in reloaded_model._bqml_model.model_name
f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name
)
40 changes: 40 additions & 0 deletions tests/system/small/ml/test_forecasting.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,20 @@

from bigframes.ml import forecasting

ARIMA_EVALUATE_OUTPUT_COL = [
"non_seasonal_p",
"non_seasonal_d",
"non_seasonal_q",
"log_likelihood",
"AIC",
"variance",
"seasonal_periods",
"has_holiday_effect",
"has_spikes_and_dips",
"has_step_changes",
"error_message",
]


def test_model_predict_default(time_series_arima_plus_model: forecasting.ARIMAPlus):
utc = pytz.utc
@@ -104,6 +118,24 @@ def test_model_score(
)


def test_model_summary(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
result = time_series_arima_plus_model.summary()
assert result.shape == (1, 12)
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)


def test_model_summary_show_all_candidates(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
result = time_series_arima_plus_model.summary(
show_all_candidate_models=True,
)
assert result.shape[0] > 1
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)


def test_model_score_series(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
@@ -126,3 +158,11 @@ def test_model_score_series(
rtol=0.1,
check_index_type=False,
)


def test_model_summary_series(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
result = time_series_arima_plus_model.summary()
assert result.shape == (1, 12)
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)
13 changes: 13 additions & 0 deletions tests/unit/ml/test_sql.py
Original file line number Diff line number Diff line change
@@ -273,6 +273,19 @@ def test_ml_evaluate_produces_correct_sql(
)


def test_ml_arima_evaluate_produces_correct_sql(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
):
sql = model_manipulation_sql_generator.ml_arima_evaluate(
show_all_candidate_models=True
)
assert (
sql
== """SELECT * FROM ML.ARIMA_EVALUATE(MODEL `my_project_id.my_dataset_id.my_model_id`,
STRUCT(True AS show_all_candidate_models))"""
)


def test_ml_evaluate_no_source_produces_correct_sql(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
):