Skip to content

Commit

Permalink
Use AsyncClient for Composer Operators in deferrable mode (#25951)
Browse files Browse the repository at this point in the history
  • Loading branch information
Łukasz Wyszomirski authored Aug 29, 2022
1 parent 57fc3e9 commit da8f133
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 8 deletions.
127 changes: 126 additions & 1 deletion airflow/providers/google/cloud/hooks/cloud_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@
from google.api_core.client_options import ClientOptions
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.api_core.operation import Operation
from google.api_core.operation_async import AsyncOperation
from google.api_core.retry import Retry
from google.cloud.orchestration.airflow.service_v1 import EnvironmentsClient, ImageVersionsClient
from google.cloud.orchestration.airflow.service_v1 import (
EnvironmentsAsyncClient,
EnvironmentsClient,
ImageVersionsClient,
)
from google.cloud.orchestration.airflow.service_v1.services.environments.pagers import ListEnvironmentsPager
from google.cloud.orchestration.airflow.service_v1.services.image_versions.pagers import (
ListImageVersionsPager,
Expand Down Expand Up @@ -275,3 +280,123 @@ def list_image_versions(
metadata=metadata,
)
return result


class CloudComposerAsyncHook(GoogleBaseHook):
"""Hook for Google Cloud Composer async APIs."""

client_options = ClientOptions(api_endpoint='composer.googleapis.com:443')

def get_environment_client(self) -> EnvironmentsAsyncClient:
"""Retrieves client library object that allow access Environments service."""
return EnvironmentsAsyncClient(
credentials=self.get_credentials(),
client_info=CLIENT_INFO,
client_options=self.client_options,
)

def get_environment_name(self, project_id, region, environment_id):
return f'projects/{project_id}/locations/{region}/environments/{environment_id}'

def get_parent(self, project_id, region):
return f'projects/{project_id}/locations/{region}'

async def get_operation(self, operation_name):
return await self.get_environment_client().transport.operations_client.get_operation(
name=operation_name
)

@GoogleBaseHook.fallback_to_default_project_id
async def create_environment(
self,
project_id: str,
region: str,
environment: Union[Environment, Dict],
retry: Union[Retry, _MethodDefault] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> AsyncOperation:
"""
Create a new environment.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
:param environment: The environment to create. This corresponds to the ``environment`` field on the
``request`` instance; if ``request`` is provided, this should not be set.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
client = self.get_environment_client()
return await client.create_environment(
request={'parent': self.get_parent(project_id, region), 'environment': environment},
retry=retry,
timeout=timeout,
metadata=metadata,
)

@GoogleBaseHook.fallback_to_default_project_id
async def delete_environment(
self,
project_id: str,
region: str,
environment_id: str,
retry: Union[Retry, _MethodDefault] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> AsyncOperation:
"""
Delete an environment.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
:param environment_id: Required. The ID of the Google Cloud environment that the service belongs to.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
client = self.get_environment_client()
name = self.get_environment_name(project_id, region, environment_id)
return await client.delete_environment(
request={"name": name}, retry=retry, timeout=timeout, metadata=metadata
)

@GoogleBaseHook.fallback_to_default_project_id
async def update_environment(
self,
project_id: str,
region: str,
environment_id: str,
environment: Union[Environment, Dict],
update_mask: Union[Dict, FieldMask],
retry: Union[Retry, _MethodDefault] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> AsyncOperation:
r"""
Update an environment.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
:param environment_id: Required. The ID of the Google Cloud environment that the service belongs to.
:param environment: A patch environment. Fields specified by the ``updateMask`` will be copied from
the patch environment into the environment under update.
This corresponds to the ``environment`` field on the ``request`` instance; if ``request`` is
provided, this should not be set.
:param update_mask: Required. A comma-separated list of paths, relative to ``Environment``, of fields
to update. If a dict is provided, it must be of the same form as the protobuf message
:class:`~google.protobuf.field_mask_pb2.FieldMask`
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
client = self.get_environment_client()
name = self.get_environment_name(project_id, region, environment_id)

return await client.update_environment(
request={"name": name, "environment": environment, "update_mask": update_mask},
retry=retry,
timeout=timeout,
metadata=metadata,
)
6 changes: 3 additions & 3 deletions airflow/providers/google/cloud/triggers/cloud_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Union

from airflow import AirflowException
from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerHook
from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerAsyncHook

try:
from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(

self.pooling_period_seconds = pooling_period_seconds

self.gcp_hook = CloudComposerHook(
self.gcp_hook = CloudComposerAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
delegate_to=self.delegate_to,
Expand All @@ -80,7 +80,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:

async def run(self):
while True:
operation = self.gcp_hook.get_operation(operation_name=self.operation_name)
operation = await self.gcp_hook.get_operation(operation_name=self.operation_name)
if operation.done:
break
elif operation.error.message:
Expand Down
81 changes: 80 additions & 1 deletion tests/providers/google/cloud/hooks/test_cloud_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
import unittest
from unittest import mock

import pytest
from google.api_core.gapic_v1.method import DEFAULT

from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerHook
from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerAsyncHook, CloudComposerHook

TEST_GCP_REGION = "global"
TEST_GCP_PROJECT = "test-project"
Expand Down Expand Up @@ -193,3 +194,81 @@ def test_list_image_versions(self, mock_client) -> None:
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)


class TestCloudComposerAsyncHook(unittest.TestCase):
def setUp(
self,
) -> None:
with mock.patch(BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_init):
self.hook = CloudComposerAsyncHook(gcp_conn_id="test")

@pytest.mark.asyncio
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client"))
async def test_create_environment(self, mock_client) -> None:
await self.hook.create_environment(
project_id=TEST_GCP_PROJECT,
region=TEST_GCP_REGION,
environment=TEST_ENVIRONMENT,
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)
mock_client.assert_called_once()
mock_client.return_value.create_environment.assert_called_once_with(
request={
'parent': self.hook.get_parent(TEST_GCP_PROJECT, TEST_GCP_REGION),
'environment': TEST_ENVIRONMENT,
},
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)

@pytest.mark.asyncio
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client"))
async def test_delete_environment(self, mock_client) -> None:
await self.hook.delete_environment(
project_id=TEST_GCP_PROJECT,
region=TEST_GCP_REGION,
environment_id=TEST_ENVIRONMENT_ID,
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)
mock_client.assert_called_once()
mock_client.return_value.delete_environment.assert_called_once_with(
request={
"name": self.hook.get_environment_name(TEST_GCP_PROJECT, TEST_GCP_REGION, TEST_ENVIRONMENT_ID)
},
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)

@pytest.mark.asyncio
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client"))
async def test_update_environment(self, mock_client) -> None:
await self.hook.update_environment(
project_id=TEST_GCP_PROJECT,
region=TEST_GCP_REGION,
environment_id=TEST_ENVIRONMENT_ID,
environment=TEST_UPDATED_ENVIRONMENT,
update_mask=TEST_UPDATE_MASK,
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)
mock_client.assert_called_once()
mock_client.return_value.update_environment.assert_called_once_with(
request={
"name": self.hook.get_environment_name(
TEST_GCP_PROJECT, TEST_GCP_REGION, TEST_ENVIRONMENT_ID
),
"environment": TEST_UPDATED_ENVIRONMENT,
"update_mask": TEST_UPDATE_MASK,
},
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)
6 changes: 3 additions & 3 deletions tests/providers/google/cloud/operators/test_cloud_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_execute(self, mock_hook, to_dict_mode) -> None:

@mock.patch(COMPOSER_STRING.format("Environment.to_dict"))
@mock.patch(COMPOSER_STRING.format("CloudComposerHook"))
@mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerHook"))
@mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerAsyncHook"))
def test_execute_deferrable(self, mock_trigger_hook, mock_hook, to_dict_mode):
op = CloudComposerCreateEnvironmentOperator(
task_id=TASK_ID,
Expand Down Expand Up @@ -145,7 +145,7 @@ def test_execute(self, mock_hook) -> None:
)

@mock.patch(COMPOSER_STRING.format("CloudComposerHook"))
@mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerHook"))
@mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerAsyncHook"))
def test_execute_deferrable(self, mock_trigger_hook, mock_hook):
op = CloudComposerDeleteEnvironmentOperator(
task_id=TASK_ID,
Expand Down Expand Up @@ -200,7 +200,7 @@ def test_execute(self, mock_hook, to_dict_mode) -> None:

@mock.patch(COMPOSER_STRING.format("Environment.to_dict"))
@mock.patch(COMPOSER_STRING.format("CloudComposerHook"))
@mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerHook"))
@mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerAsyncHook"))
def test_execute_deferrable(self, mock_trigger_hook, mock_hook, to_dict_mode):
op = CloudComposerUpdateEnvironmentOperator(
task_id=TASK_ID,
Expand Down

0 comments on commit da8f133

Please sign in to comment.