Skip to content

Commit

Permalink
Resolve google vertex ai deprecations in tests (#40506)
Browse files Browse the repository at this point in the history
* Resolve google vertex ai deprecations in tests

* Resolve google vertex ai deprecations in tests

* Update deprecations_ignore.yml
  • Loading branch information
dirrao committed Jul 4, 2024
1 parent 3583329 commit 3c1120a
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 31 deletions.
6 changes: 3 additions & 3 deletions airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,9 +656,9 @@ def execute(self, context: Context):
impersonation_chain=self.impersonation_chain,
)
try:
self.log.info("Deleting Auto ML training pipeline: %s", self.training_pipeline)
self.log.info("Deleting Auto ML training pipeline: %s", self.training_pipeline_id)
training_pipeline_operation = hook.delete_training_pipeline(
training_pipeline=self.training_pipeline,
training_pipeline=self.training_pipeline_id,
region=self.region,
project_id=self.project_id,
retry=self.retry,
Expand All @@ -668,7 +668,7 @@ def execute(self, context: Context):
hook.wait_for_operation(timeout=self.timeout, operation=training_pipeline_operation)
self.log.info("Training pipeline was deleted.")
except NotFound:
self.log.info("The Training Pipeline ID %s does not exist.", self.training_pipeline)
self.log.info("The Training Pipeline ID %s does not exist.", self.training_pipeline_id)


class ListAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,9 +1695,9 @@ def execute(self, context: Context):
impersonation_chain=self.impersonation_chain,
)
try:
self.log.info("Deleting custom training pipeline: %s", self.training_pipeline)
self.log.info("Deleting custom training pipeline: %s", self.training_pipeline_id)
training_pipeline_operation = hook.delete_training_pipeline(
training_pipeline=self.training_pipeline,
training_pipeline=self.training_pipeline_id,
region=self.region,
project_id=self.project_id,
retry=self.retry,
Expand All @@ -1707,11 +1707,11 @@ def execute(self, context: Context):
hook.wait_for_operation(timeout=self.timeout, operation=training_pipeline_operation)
self.log.info("Training pipeline was deleted.")
except NotFound:
self.log.info("The Training Pipeline ID %s does not exist.", self.training_pipeline)
self.log.info("The Training Pipeline ID %s does not exist.", self.training_pipeline_id)
try:
self.log.info("Deleting custom job: %s", self.custom_job)
self.log.info("Deleting custom job: %s", self.custom_job_id)
custom_job_operation = hook.delete_custom_job(
custom_job=self.custom_job,
custom_job=self.custom_job_id,
region=self.region,
project_id=self.project_id,
retry=self.retry,
Expand All @@ -1721,7 +1721,7 @@ def execute(self, context: Context):
hook.wait_for_operation(timeout=self.timeout, operation=custom_job_operation)
self.log.info("Custom job was deleted.")
except NotFound:
self.log.info("The Custom Job ID %s does not exist.", self.custom_job)
self.log.info("The Custom Job ID %s does not exist.", self.custom_job_id)


class ListCustomTrainingJobOperator(GoogleCloudBaseOperator):
Expand Down
7 changes: 0 additions & 7 deletions tests/deprecations_ignore.yml
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,6 @@
- tests/providers/google/cloud/operators/test_kubernetes_engine.py::TestGoogleCloudPlatformContainerOperator::test_create_execute_error_body
- tests/providers/google/cloud/operators/test_life_sciences.py::TestLifeSciencesRunPipelineOperator::test_executes
- tests/providers/google/cloud/operators/test_life_sciences.py::TestLifeSciencesRunPipelineOperator::test_executes_without_project_id
- tests/providers/google/cloud/operators/test_vertex_ai.py::TestVertexAICreateBatchPredictionJobOperator::test_execute
- tests/providers/google/cloud/operators/test_vertex_ai.py::TestVertexAICreateBatchPredictionJobOperator::test_execute_deferrable
- tests/providers/google/cloud/operators/test_vertex_ai.py::TestVertexAICreateHyperparameterTuningJobOperator::test_deferrable
- tests/providers/google/cloud/operators/test_vertex_ai.py::TestVertexAICreateHyperparameterTuningJobOperator::test_deferrable_sync_error
- tests/providers/google/cloud/operators/test_vertex_ai.py::TestVertexAICreateHyperparameterTuningJobOperator::test_execute
- tests/providers/google/cloud/operators/test_vertex_ai.py::TestVertexAIDeleteAutoMLTrainingJobOperator::test_execute
- tests/providers/google/cloud/operators/test_vertex_ai.py::TestVertexAIDeleteCustomTrainingJobOperator::test_execute
- tests/providers/google/cloud/sensors/test_gcs.py::TestTsFunction::test_should_support_cron
- tests/providers/google/cloud/sensors/test_gcs.py::TestTsFunction::test_should_support_datetime
- tests/providers/google/cloud/transfers/test_bigquery_to_postgres.py::TestBigQueryToPostgresOperator::test_execute_good_request_to_bq
Expand Down
66 changes: 51 additions & 15 deletions tests/providers/google/cloud/operators/test_vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@

TEST_TEMPLATE_PATH = "test_template_path"

SYNC_DEPRECATION_WARNING = "The 'sync' parameter is deprecated and will be removed after 01.10.2024."
SYNC_DEPRECATION_WARNING = "The 'sync' parameter is deprecated and will be removed after {}."


class TestVertexAICreateCustomContainerTrainingJobOperator:
Expand Down Expand Up @@ -235,7 +235,9 @@ def test_execute(self, mock_hook, mock_dataset):
dataset_id=TEST_DATASET_ID,
parent_model=TEST_PARENT_MODEL,
)
with pytest.warns(AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING):
with pytest.warns(
AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING.format("01.10.2024")
):
op.execute(context={"ti": mock.MagicMock()})
mock_dataset.assert_called_once_with(name=TEST_DATASET_ID)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
Expand Down Expand Up @@ -324,7 +326,9 @@ def test_execute__parent_model_version_index_is_removed(self, mock_hook, mock_da
dataset_id=TEST_DATASET_ID,
parent_model=VERSIONED_TEST_PARENT_MODEL,
)
with pytest.warns(AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING):
with pytest.warns(
AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING.format("01.10.2024")
):
op.execute(context={"ti": mock.MagicMock()})
mock_hook.return_value.create_custom_container_training_job.assert_called_once_with(
staging_bucket=STAGING_BUCKET,
Expand Down Expand Up @@ -406,7 +410,9 @@ def test_execute_enters_deferred_state(self, mock_hook):
)
mock_hook.return_value.exists.return_value = False
with pytest.raises(TaskDeferred) as exc:
with pytest.warns(AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING):
with pytest.warns(
AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING.format("01.10.2024")
):
task.execute(context={"ti": mock.MagicMock()})
assert isinstance(
exc.value.trigger, CustomContainerTrainingJobTrigger
Expand Down Expand Up @@ -505,7 +511,9 @@ def test_execute(self, mock_hook, mock_dataset):
dataset_id=TEST_DATASET_ID,
parent_model=TEST_PARENT_MODEL,
)
with pytest.warns(AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING):
with pytest.warns(
AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING.format("01.10.2024")
):
op.execute(context={"ti": mock.MagicMock()})
mock_dataset.assert_called_once_with(name=TEST_DATASET_ID)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
Expand Down Expand Up @@ -596,7 +604,9 @@ def test_execute__parent_model_version_index_is_removed(self, mock_hook, mock_da
dataset_id=TEST_DATASET_ID,
parent_model=VERSIONED_TEST_PARENT_MODEL,
)
with pytest.warns(AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING):
with pytest.warns(
AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING.format("01.10.2024")
):
op.execute(context={"ti": mock.MagicMock()})
mock_hook.return_value.create_custom_python_package_training_job.assert_called_once_with(
staging_bucket=STAGING_BUCKET,
Expand Down Expand Up @@ -680,7 +690,9 @@ def test_execute_enters_deferred_state(self, mock_hook):
)
mock_hook.return_value.exists.return_value = False
with pytest.raises(TaskDeferred) as exc:
with pytest.warns(AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING):
with pytest.warns(
AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING.format("01.10.2024")
):
task.execute(context={"ti": mock.MagicMock()})
assert isinstance(
exc.value.trigger, CustomPythonPackageTrainingJobTrigger
Expand Down Expand Up @@ -774,7 +786,9 @@ def test_execute(self, mock_hook, mock_dataset):
dataset_id=TEST_DATASET_ID,
parent_model=TEST_PARENT_MODEL,
)
with pytest.warns(AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING):
with pytest.warns(
AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING.format("01.10.2024")
):
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_dataset.assert_called_once_with(name=TEST_DATASET_ID)
Expand Down Expand Up @@ -858,7 +872,9 @@ def test_execute__parent_model_version_index_is_removed(self, mock_hook, mock_da
dataset_id=TEST_DATASET_ID,
parent_model=VERSIONED_TEST_PARENT_MODEL,
)
with pytest.warns(AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING):
with pytest.warns(
AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING.format("01.10.2024")
):
op.execute(context={"ti": mock.MagicMock()})
mock_hook.return_value.create_custom_training_job.assert_called_once_with(
staging_bucket=STAGING_BUCKET,
Expand Down Expand Up @@ -935,7 +951,9 @@ def test_execute_enters_deferred_state(self, mock_hook):
)
mock_hook.return_value.exists.return_value = False
with pytest.raises(TaskDeferred) as exc:
with pytest.warns(AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING):
with pytest.warns(
AirflowProviderDeprecationWarning, match=SYNC_DEPRECATION_WARNING.format("01.10.2024")
):
task.execute(context={"ti": mock.MagicMock()})
assert isinstance(
exc.value.trigger, CustomTrainingJobTrigger
Expand Down Expand Up @@ -1922,7 +1940,11 @@ def test_execute(self, mock_hook, mock_link_persist):
batch_size=TEST_BATCH_SIZE,
)
context = {"ti": mock.MagicMock()}
op.execute(context=context)
with pytest.warns(
AirflowProviderDeprecationWarning,
match=SYNC_DEPRECATION_WARNING.format("28.08.2024"),
):
op.execute(context=context)

mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.submit_batch_prediction_job.assert_called_once_with(
Expand Down Expand Up @@ -1976,7 +1998,10 @@ def test_execute_deferrable(self, mock_hook, mock_link_persist):
deferrable=True,
)
context = {"ti": mock.MagicMock()}
with pytest.raises(TaskDeferred) as exception_info:
with pytest.raises(TaskDeferred) as exception_info, pytest.warns(
AirflowProviderDeprecationWarning,
match=SYNC_DEPRECATION_WARNING.format("28.08.2024"),
):
op.execute(context=context)

mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
Expand Down Expand Up @@ -2302,7 +2327,11 @@ def test_execute(self, mock_hook, to_dict_mock):
max_trial_count=15,
parallel_trial_count=3,
)
op.execute(context={"ti": mock.MagicMock()})
with pytest.warns(
AirflowProviderDeprecationWarning,
match=SYNC_DEPRECATION_WARNING.format("01.09.2024"),
):
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.create_hyperparameter_tuning_job.assert_called_once_with(
project_id=GCP_PROJECT,
Expand Down Expand Up @@ -2353,7 +2382,11 @@ def test_deferrable(self, mock_hook, mock_defer):
parallel_trial_count=3,
deferrable=True,
)
op.execute(context={"ti": mock.MagicMock()})
with pytest.warns(
AirflowProviderDeprecationWarning,
match=SYNC_DEPRECATION_WARNING.format("01.09.2024"),
):
op.execute(context={"ti": mock.MagicMock()})
mock_defer.assert_called_once()

@pytest.mark.db_test
Expand All @@ -2374,7 +2407,10 @@ def test_deferrable_sync_error(self):
parallel_trial_count=3,
deferrable=True,
)
with pytest.raises(AirflowException):
with pytest.raises(AirflowException), pytest.warns(
AirflowProviderDeprecationWarning,
match=SYNC_DEPRECATION_WARNING.format("01.09.2024"),
):
op.execute(context={"ti": mock.MagicMock()})

@mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.HyperparameterTuningJobHook"))
Expand Down

0 comments on commit 3c1120a

Please sign in to comment.