Skip to content

Commit

Permalink
Fix TestTranslationLegacyModelPredictLink dataset_id error (#42463)
Browse files Browse the repository at this point in the history
- Add dataset_id parameter to let TestTranslationLegacyModelPredictLink
work with the translation model.

Co-authored-by: Oleg Kachur <[email protected]>
  • Loading branch information
olegkachur-e and Oleg Kachur authored Oct 17, 2024
1 parent e286bd7 commit aeb7e90
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,14 @@ def persist(
task_instance,
model_id: str,
project_id: str,
dataset_id: str,
):
task_instance.xcom_push(
context,
key=TranslationLegacyModelPredictLink.key,
value={
"location": task_instance.location,
"dataset_id": task_instance.model.dataset_id,
"dataset_id": dataset_id,
"model_id": model_id,
"project_id": project_id,
},
Expand Down
49 changes: 31 additions & 18 deletions providers/src/airflow/providers/google/cloud/operators/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import ast
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Sequence, Tuple
from typing import TYPE_CHECKING, Sequence, Tuple, cast

from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.automl_v1beta1 import (
Expand Down Expand Up @@ -280,17 +280,22 @@ def hook(self) -> CloudAutoMLHook | PredictionServiceHook:
impersonation_chain=self.impersonation_chain,
)

@cached_property
def model(self) -> Model | None:
if self.model_id:
hook = cast(CloudAutoMLHook, self.hook)
return hook.get_model(
model_id=self.model_id,
location=self.location,
project_id=self.project_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)
return None

def _check_model_type(self):
hook = self.hook
model = hook.get_model(
model_id=self.model_id,
location=self.location,
project_id=self.project_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)
if not hasattr(model, "translation_model_metadata"):
if not hasattr(self.model, "translation_model_metadata"):
raise AirflowException(
"AutoMLPredictOperator for text, image, and video prediction has been deprecated. "
"Please use endpoint_id param instead of model_id param."
Expand Down Expand Up @@ -329,11 +334,13 @@ def execute(self, context: Context):
)

project_id = self.project_id or hook.project_id
if project_id and self.model_id:
dataset_id: str | None = self.model.dataset_id if self.model else None
if project_id and self.model_id and dataset_id:
TranslationLegacyModelPredictLink.persist(
context=context,
task_instance=self,
model_id=self.model_id,
dataset_id=dataset_id,
project_id=project_id,
)
return PredictResponse.to_dict(result)
Expand Down Expand Up @@ -431,12 +438,16 @@ def __init__(
self.input_config = input_config
self.output_config = output_config

def execute(self, context: Context):
hook = CloudAutoMLHook(
@cached_property
def hook(self) -> CloudAutoMLHook:
return CloudAutoMLHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.model: Model = hook.get_model(

@cached_property
def model(self) -> Model:
return self.hook.get_model(
model_id=self.model_id,
location=self.location,
project_id=self.project_id,
Expand All @@ -445,6 +456,7 @@ def execute(self, context: Context):
metadata=self.metadata,
)

def execute(self, context: Context):
if not hasattr(self.model, "translation_model_metadata"):
_raise_exception_for_deprecated_operator(
self.__class__.__name__,
Expand All @@ -456,7 +468,7 @@ def execute(self, context: Context):
],
)
self.log.info("Fetch batch prediction.")
operation = hook.batch_predict(
operation = self.hook.batch_predict(
model_id=self.model_id,
input_config=self.input_config,
output_config=self.output_config,
Expand All @@ -467,16 +479,17 @@ def execute(self, context: Context):
timeout=self.timeout,
metadata=self.metadata,
)
operation_result = hook.wait_for_operation(timeout=self.timeout, operation=operation)
operation_result = self.hook.wait_for_operation(timeout=self.timeout, operation=operation)
result = BatchPredictResult.to_dict(operation_result)
self.log.info("Batch prediction is ready.")
project_id = self.project_id or hook.project_id
project_id = self.project_id or self.hook.project_id
if project_id:
TranslationLegacyModelPredictLink.persist(
context=context,
task_instance=self,
model_id=self.model_id,
project_id=project_id,
dataset_id=self.model.dataset_id,
)
return result

Expand Down
8 changes: 7 additions & 1 deletion providers/tests/google/cloud/links/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ def test_get_link(self, create_task_instance_of_operator, session):
ti.task.model = Model(dataset_id=DATASET, display_name=MODEL)
session.add(ti)
session.commit()
link.persist(context={"ti": ti}, task_instance=ti.task, model_id=MODEL, project_id=GCP_PROJECT_ID)
link.persist(
context={"ti": ti},
task_instance=ti.task,
model_id=MODEL,
project_id=GCP_PROJECT_ID,
dataset_id=DATASET,
)
actual_url = link.get_link(operator=ti.task, ti_key=ti.key)
assert actual_url == expected_url
4 changes: 4 additions & 0 deletions providers/tests/google/cloud/operators/test_automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def test_execute(self, mock_hook, mock_link_persist):
mock_hook.return_value.batch_predict.return_value.result.return_value = BatchPredictResult()
mock_hook.return_value.extract_object_id = extract_object_id
mock_hook.return_value.wait_for_operation.return_value = BatchPredictResult()
mock_hook.return_value.get_model.return_value = mock.MagicMock(**MODEL)
mock_context = {"ti": mock.MagicMock()}
with pytest.warns(AirflowProviderDeprecationWarning):
op = AutoMLBatchPredictOperator(
Expand Down Expand Up @@ -175,6 +176,7 @@ def test_execute(self, mock_hook, mock_link_persist):
task_instance=op,
model_id=MODEL_ID,
project_id=GCP_PROJECT_ID,
dataset_id=DATASET_ID,
)

@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
Expand Down Expand Up @@ -243,6 +245,7 @@ class TestAutoMLPredictOperator:
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
def test_execute(self, mock_hook, mock_link_persist):
mock_hook.return_value.predict.return_value = PredictResponse()
mock_hook.return_value.get_model.return_value = mock.MagicMock(**MODEL)
mock_context = {"ti": mock.MagicMock()}
op = AutoMLPredictOperator(
model_id=MODEL_ID,
Expand All @@ -268,6 +271,7 @@ def test_execute(self, mock_hook, mock_link_persist):
task_instance=op,
model_id=MODEL_ID,
project_id=GCP_PROJECT_ID,
dataset_id=DATASET_ID,
)

@pytest.mark.db_test
Expand Down

0 comments on commit aeb7e90

Please sign in to comment.