Skip to content

Commit

Permalink
Add param to CloudDataTransferServiceOperator (#14118)
Browse files Browse the repository at this point in the history
When a one-time job is created with `CloudDataTransferServiceS3(GCS)ToGCSOperator`, the job remains on the GCP console even after the job is completed.

This is a specification of the data transfer service, but I would like to add this parameter because there are normally cases where don't want to leave a one-time job.
  • Loading branch information
ysktir authored Feb 10, 2021
1 parent ebd39ca commit 02288cf
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
HTTP_DATA_SOURCE,
MINUTES,
MONTH,
NAME,
OBJECT_CONDITIONS,
PROJECT_ID,
SCHEDULE,
Expand Down Expand Up @@ -798,7 +799,8 @@ class CloudDataTransferServiceS3ToGCSOperator(BaseOperator):
:param transfer_options: Optional transfer service transfer options; see
https://cloud.google.com/storage-transfer/docs/reference/rest/v1/TransferSpec
:type transfer_options: dict
:param wait: Wait for transfer to finish
:param wait: Wait for transfer to finish. It must be set to True, if
'delete_job_after_completion' is set to True.
:type wait: bool
:param timeout: Time to wait for the operation to end in seconds. Defaults to 60 seconds if not specified.
:type timeout: Optional[Union[float, timedelta]]
Expand All @@ -811,6 +813,9 @@ class CloudDataTransferServiceS3ToGCSOperator(BaseOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:type google_impersonation_chain: Union[str, Sequence[str]]
:param delete_job_after_completion: If True, delete the job after complete.
If set to True, 'wait' must be set to True.
:type delete_job_after_completion: bool
"""

template_fields = (
Expand Down Expand Up @@ -840,6 +845,7 @@ def __init__( # pylint: disable=too-many-arguments
wait: bool = True,
timeout: Optional[float] = None,
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
delete_job_after_completion: bool = False,
**kwargs,
) -> None:

Expand All @@ -857,6 +863,12 @@ def __init__( # pylint: disable=too-many-arguments
self.wait = wait
self.timeout = timeout
self.google_impersonation_chain = google_impersonation_chain
self.delete_job_after_completion = delete_job_after_completion
self._validate_inputs()

def _validate_inputs(self) -> None:
if self.delete_job_after_completion and not self.wait:
raise AirflowException("If 'delete_job_after_completion' is True, then 'wait' must also be True.")

def execute(self, context) -> None:
hook = CloudDataTransferServiceHook(
Expand All @@ -872,6 +884,8 @@ def execute(self, context) -> None:

if self.wait:
hook.wait_for_transfer_job(job, timeout=self.timeout)
if self.delete_job_after_completion:
hook.delete_transfer_job(job_name=job[NAME], project_id=self.project_id)

def _create_body(self) -> dict:
body = {
Expand Down Expand Up @@ -955,7 +969,8 @@ class CloudDataTransferServiceGCSToGCSOperator(BaseOperator):
:param transfer_options: Optional transfer service transfer options; see
https://cloud.google.com/storage-transfer/docs/reference/rest/v1/TransferSpec#TransferOptions
:type transfer_options: dict
:param wait: Wait for transfer to finish; defaults to `True`
:param wait: Wait for transfer to finish. It must be set to True, if
'delete_job_after_completion' is set to True.
:type wait: bool
:param timeout: Time to wait for the operation to end in seconds. Defaults to 60 seconds if not specified.
:type timeout: Optional[Union[float, timedelta]]
Expand All @@ -968,6 +983,9 @@ class CloudDataTransferServiceGCSToGCSOperator(BaseOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:type google_impersonation_chain: Union[str, Sequence[str]]
:param delete_job_after_completion: If True, delete the job after complete.
If set to True, 'wait' must be set to True.
:type delete_job_after_completion: bool
"""

template_fields = (
Expand Down Expand Up @@ -996,6 +1014,7 @@ def __init__( # pylint: disable=too-many-arguments
wait: bool = True,
timeout: Optional[float] = None,
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
delete_job_after_completion: bool = False,
**kwargs,
) -> None:

Expand All @@ -1012,6 +1031,12 @@ def __init__( # pylint: disable=too-many-arguments
self.wait = wait
self.timeout = timeout
self.google_impersonation_chain = google_impersonation_chain
self.delete_job_after_completion = delete_job_after_completion
self._validate_inputs()

def _validate_inputs(self) -> None:
if self.delete_job_after_completion and not self.wait:
raise AirflowException("If 'delete_job_after_completion' is True, then 'wait' must also be True.")

def execute(self, context) -> None:
hook = CloudDataTransferServiceHook(
Expand All @@ -1028,6 +1053,8 @@ def execute(self, context) -> None:

if self.wait:
hook.wait_for_transfer_job(job, timeout=self.timeout)
if self.delete_job_after_completion:
hook.delete_transfer_job(job_name=job[NAME], project_id=self.project_id)

def _create_body(self) -> dict:
body = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,7 @@ def test_execute(self, mock_aws_hook, mock_transfer_hook):
)

assert mock_transfer_hook.return_value.wait_for_transfer_job.called
assert not mock_transfer_hook.return_value.delete_transfer_job.called

@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
Expand All @@ -804,6 +805,63 @@ def test_execute_skip_wait(self, mock_aws_hook, mock_transfer_hook):
)

assert not mock_transfer_hook.return_value.wait_for_transfer_job.called
assert not mock_transfer_hook.return_value.delete_transfer_job.called

@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
)
@mock.patch('airflow.providers.google.cloud.operators.cloud_storage_transfer_service.AwsBaseHook')
def test_execute_delete_job_after_completion(self, mock_aws_hook, mock_transfer_hook):
mock_aws_hook.return_value.get_credentials.return_value = Credentials(
TEST_AWS_ACCESS_KEY_ID, TEST_AWS_ACCESS_SECRET, None
)

operator = CloudDataTransferServiceS3ToGCSOperator(
task_id=TASK_ID,
s3_bucket=AWS_BUCKET_NAME,
gcs_bucket=GCS_BUCKET_NAME,
description=DESCRIPTION,
schedule=SCHEDULE_DICT,
wait=True,
delete_job_after_completion=True,
)

operator.execute(None)

mock_transfer_hook.return_value.create_transfer_job.assert_called_once_with(
body=VALID_TRANSFER_JOB_AWS_RAW
)

assert mock_transfer_hook.return_value.wait_for_transfer_job.called
assert mock_transfer_hook.return_value.delete_transfer_job.called

@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
)
@mock.patch('airflow.providers.google.cloud.operators.cloud_storage_transfer_service.AwsBaseHook')
def test_execute_should_throw_ex_when_delete_job_without_wait(self, mock_aws_hook, mock_transfer_hook):
mock_aws_hook.return_value.get_credentials.return_value = Credentials(
TEST_AWS_ACCESS_KEY_ID, TEST_AWS_ACCESS_SECRET, None
)

with pytest.raises(AirflowException) as ctx:

operator = CloudDataTransferServiceS3ToGCSOperator(
task_id=TASK_ID,
s3_bucket=AWS_BUCKET_NAME,
gcs_bucket=GCS_BUCKET_NAME,
description=DESCRIPTION,
schedule=SCHEDULE_DICT,
wait=False,
delete_job_after_completion=True,
)

operator.execute(None)

err = ctx.value
assert "If 'delete_job_after_completion' is True, then 'wait' must also be True." in str(err)
mock_aws_hook.assert_not_called()
mock_transfer_hook.assert_not_called()


class TestGoogleCloudStorageToGoogleCloudStorageTransferOperator(unittest.TestCase):
Expand Down Expand Up @@ -873,6 +931,7 @@ def test_execute(self, mock_transfer_hook):
body=VALID_TRANSFER_JOB_GCS_RAW
)
assert mock_transfer_hook.return_value.wait_for_transfer_job.called
assert not mock_transfer_hook.return_value.delete_transfer_job.called

@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
Expand All @@ -893,3 +952,50 @@ def test_execute_skip_wait(self, mock_transfer_hook):
body=VALID_TRANSFER_JOB_GCS_RAW
)
assert not mock_transfer_hook.return_value.wait_for_transfer_job.called
assert not mock_transfer_hook.return_value.delete_transfer_job.called

@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
)
def test_execute_delete_job_after_completion(self, mock_transfer_hook):

operator = CloudDataTransferServiceGCSToGCSOperator(
task_id=TASK_ID,
source_bucket=GCS_BUCKET_NAME,
destination_bucket=GCS_BUCKET_NAME,
description=DESCRIPTION,
schedule=SCHEDULE_DICT,
wait=True,
delete_job_after_completion=True,
)

operator.execute(None)

mock_transfer_hook.return_value.create_transfer_job.assert_called_once_with(
body=VALID_TRANSFER_JOB_GCS_RAW
)
assert mock_transfer_hook.return_value.wait_for_transfer_job.called
assert mock_transfer_hook.return_value.delete_transfer_job.called

@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
)
def test_execute_should_throw_ex_when_delete_job_without_wait(self, mock_transfer_hook):

with pytest.raises(AirflowException) as ctx:

operator = CloudDataTransferServiceS3ToGCSOperator(
task_id=TASK_ID,
s3_bucket=AWS_BUCKET_NAME,
gcs_bucket=GCS_BUCKET_NAME,
description=DESCRIPTION,
schedule=SCHEDULE_DICT,
wait=False,
delete_job_after_completion=True,
)

operator.execute(None)

err = ctx.value
assert "If 'delete_job_after_completion' is True, then 'wait' must also be True." in str(err)
mock_transfer_hook.assert_not_called()

0 comments on commit 02288cf

Please sign in to comment.