Skip to content

Commit

Permalink
Fix parent_model parameter in GCP Vertex AI AutoML and Custom Job o…
Browse files Browse the repository at this point in the history
…perators (#38417)
  • Loading branch information
shahar1 authored Mar 23, 2024
1 parent 8520778 commit 3ac0aaf
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 24 deletions.
5 changes: 0 additions & 5 deletions airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None
model, training_id = self.hook.create_auto_ml_forecasting_training_job(
project_id=self.project_id,
region=self.region,
Expand Down Expand Up @@ -284,7 +283,6 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None
model, training_id = self.hook.create_auto_ml_image_training_job(
project_id=self.project_id,
region=self.region,
Expand Down Expand Up @@ -393,7 +391,6 @@ def execute(self, context: Context):
impersonation_chain=self.impersonation_chain,
)
credentials, _ = self.hook.get_credentials_and_project_id()
self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None
model, training_id = self.hook.create_auto_ml_tabular_training_job(
project_id=self.project_id,
region=self.region,
Expand Down Expand Up @@ -488,7 +485,6 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None
model, training_id = self.hook.create_auto_ml_text_training_job(
project_id=self.project_id,
region=self.region,
Expand Down Expand Up @@ -565,7 +561,6 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None
model, training_id = self.hook.create_auto_ml_video_training_job(
project_id=self.project_id,
region=self.region,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,6 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None

model, training_id, custom_job_id = self.hook.create_custom_container_training_job(
project_id=self.project_id,
region=self.region,
Expand Down Expand Up @@ -850,7 +848,6 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None
model, training_id, custom_job_id = self.hook.create_custom_python_package_training_job(
project_id=self.project_id,
region=self.region,
Expand Down Expand Up @@ -1234,8 +1231,6 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None

model, training_id, custom_job_id = self.hook.create_custom_training_job(
project_id=self.project_id,
region=self.region,
Expand Down
46 changes: 32 additions & 14 deletions tests/providers/google/cloud/operators/test_vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
"metadata": "test-image-dataset",
}
TEST_DATASET_ID = "test-dataset-id"
TEST_PARENT_MODEL = "test-parent-model"
TEST_EXPORT_CONFIG = {
"annotationsFilter": "test-filter",
"gcs_destination": {"output_uri_prefix": "airflow-system-tests-data"},
Expand Down Expand Up @@ -190,8 +191,9 @@


class TestVertexAICreateCustomContainerTrainingJobOperator:
@mock.patch(VERTEX_AI_PATH.format("custom_job.Dataset"))
@mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
def test_execute(self, mock_hook):
def test_execute(self, mock_hook, mock_dataset):
mock_hook.return_value.create_custom_container_training_job.return_value = (
None,
"training_id",
Expand All @@ -217,8 +219,11 @@ def test_execute(self, mock_hook):
test_fraction_split=TEST_FRACTION_SPLIT,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
dataset_id=TEST_DATASET_ID,
parent_model=TEST_PARENT_MODEL,
)
op.execute(context={"ti": mock.MagicMock()})
mock_dataset.assert_called_once_with(name=TEST_DATASET_ID)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.create_custom_container_training_job.assert_called_once_with(
staging_bucket=STAGING_BUCKET,
Expand All @@ -227,7 +232,7 @@ def test_execute(self, mock_hook):
container_uri=CONTAINER_URI,
model_serving_container_image_uri=CONTAINER_URI,
command=COMMAND_2,
dataset=None,
dataset=mock_dataset.return_value,
model_display_name=DISPLAY_NAME_2,
replica_count=REPLICA_COUNT,
machine_type=MACHINE_TYPE,
Expand All @@ -238,7 +243,7 @@ def test_execute(self, mock_hook):
test_fraction_split=TEST_FRACTION_SPLIT,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
parent_model=None,
parent_model=TEST_PARENT_MODEL,
model_serving_container_predict_route=None,
model_serving_container_health_route=None,
model_serving_container_command=None,
Expand Down Expand Up @@ -276,8 +281,9 @@ def test_execute(self, mock_hook):


class TestVertexAICreateCustomPythonPackageTrainingJobOperator:
@mock.patch(VERTEX_AI_PATH.format("custom_job.Dataset"))
@mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
def test_execute(self, mock_hook):
def test_execute(self, mock_hook, mock_dataset):
mock_hook.return_value.create_custom_python_package_training_job.return_value = (
None,
"training_id",
Expand All @@ -304,8 +310,11 @@ def test_execute(self, mock_hook):
test_fraction_split=TEST_FRACTION_SPLIT,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
dataset_id=TEST_DATASET_ID,
parent_model=TEST_PARENT_MODEL,
)
op.execute(context={"ti": mock.MagicMock()})
mock_dataset.assert_called_once_with(name=TEST_DATASET_ID)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.create_custom_python_package_training_job.assert_called_once_with(
staging_bucket=STAGING_BUCKET,
Expand All @@ -315,7 +324,7 @@ def test_execute(self, mock_hook):
model_serving_container_image_uri=CONTAINER_URI,
python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI,
python_module_name=PYTHON_MODULE_NAME,
dataset=None,
dataset=mock_dataset.return_value,
model_display_name=DISPLAY_NAME_2,
replica_count=REPLICA_COUNT,
machine_type=MACHINE_TYPE,
Expand All @@ -326,7 +335,7 @@ def test_execute(self, mock_hook):
test_fraction_split=TEST_FRACTION_SPLIT,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
parent_model=None,
parent_model=TEST_PARENT_MODEL,
is_default_version=None,
model_version_aliases=None,
model_version_description=None,
Expand Down Expand Up @@ -364,8 +373,9 @@ def test_execute(self, mock_hook):


class TestVertexAICreateCustomTrainingJobOperator:
@mock.patch(VERTEX_AI_PATH.format("custom_job.Dataset"))
@mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
def test_execute(self, mock_hook):
def test_execute(self, mock_hook, mock_dataset):
mock_hook.return_value.create_custom_training_job.return_value = (
None,
"training_id",
Expand All @@ -385,9 +395,12 @@ def test_execute(self, mock_hook):
replica_count=1,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
dataset_id=TEST_DATASET_ID,
parent_model=TEST_PARENT_MODEL,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_dataset.assert_called_once_with(name=TEST_DATASET_ID)
mock_hook.return_value.create_custom_training_job.assert_called_once_with(
staging_bucket=STAGING_BUCKET,
display_name=DISPLAY_NAME,
Expand All @@ -396,7 +409,7 @@ def test_execute(self, mock_hook):
model_serving_container_image_uri=CONTAINER_URI,
script_path=PYTHON_PACKAGE,
requirements=[],
dataset=None,
dataset=mock_dataset.return_value,
model_display_name=None,
replica_count=REPLICA_COUNT,
machine_type=MACHINE_TYPE,
Expand All @@ -405,7 +418,7 @@ def test_execute(self, mock_hook):
training_fraction_split=None,
validation_fraction_split=None,
test_fraction_split=None,
parent_model=None,
parent_model=TEST_PARENT_MODEL,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
model_serving_container_predict_route=None,
Expand Down Expand Up @@ -751,6 +764,7 @@ def test_execute(self, mock_hook, mock_dataset):
sync=True,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
parent_model=TEST_PARENT_MODEL,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
Expand All @@ -768,7 +782,7 @@ def test_execute(self, mock_hook, mock_dataset):
forecast_horizon=TEST_TRAINING_FORECAST_HORIZON,
data_granularity_unit=TEST_TRAINING_DATA_GRANULARITY_UNIT,
data_granularity_count=TEST_TRAINING_DATA_GRANULARITY_COUNT,
parent_model=None,
parent_model=TEST_PARENT_MODEL,
optimization_objective=None,
column_specs=None,
column_transformations=None,
Expand Down Expand Up @@ -814,6 +828,7 @@ def test_execute(self, mock_hook, mock_dataset):
sync=True,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
parent_model=TEST_PARENT_MODEL,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
Expand All @@ -824,7 +839,7 @@ def test_execute(self, mock_hook, mock_dataset):
display_name=DISPLAY_NAME,
dataset=mock_dataset.return_value,
prediction_type="classification",
parent_model=None,
parent_model=TEST_PARENT_MODEL,
multi_label=False,
model_type="CLOUD",
base_model=None,
Expand Down Expand Up @@ -869,6 +884,7 @@ def test_execute(self, mock_hook, mock_dataset):
sync=True,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
parent_model=TEST_PARENT_MODEL,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
Expand All @@ -880,7 +896,7 @@ def test_execute(self, mock_hook, mock_dataset):
region=GCP_LOCATION,
display_name=DISPLAY_NAME,
dataset=mock_dataset.return_value,
parent_model=None,
parent_model=TEST_PARENT_MODEL,
target_column=None,
optimization_prediction_type=None,
optimization_objective=None,
Expand Down Expand Up @@ -928,6 +944,7 @@ def test_execute(self, mock_hook, mock_dataset):
sync=True,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
parent_model=TEST_PARENT_MODEL,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
Expand All @@ -937,7 +954,7 @@ def test_execute(self, mock_hook, mock_dataset):
region=GCP_LOCATION,
display_name=DISPLAY_NAME,
dataset=mock_dataset.return_value,
parent_model=None,
parent_model=TEST_PARENT_MODEL,
prediction_type=None,
multi_label=False,
sentiment_max=10,
Expand Down Expand Up @@ -975,6 +992,7 @@ def test_execute(self, mock_hook, mock_dataset):
sync=True,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
parent_model=TEST_PARENT_MODEL,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
Expand All @@ -984,7 +1002,7 @@ def test_execute(self, mock_hook, mock_dataset):
region=GCP_LOCATION,
display_name=DISPLAY_NAME,
dataset=mock_dataset.return_value,
parent_model=None,
parent_model=TEST_PARENT_MODEL,
prediction_type="classification",
model_type="CLOUD",
labels=None,
Expand Down

0 comments on commit 3ac0aaf

Please sign in to comment.