Skip to content

Commit

Permalink
Fix DataprocSubmitJobOperator in deferrable mode=True when task is ma…
Browse files Browse the repository at this point in the history
…rked as failed. (#39230)

When a user cancels a task for `DataprocSubmitJobOperator` in deferrable mode, the job should be cancelled in the same way as non-deferrable mode. This PR intends to fix this behaviour.
  • Loading branch information
sunank200 authored Apr 26, 2024
1 parent 7683344 commit 2a913b6
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 13 deletions.
1 change: 1 addition & 0 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2592,6 +2592,7 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
cancel_on_kill=self.cancel_on_kill,
),
method_name="execute_complete",
)
Expand Down
45 changes: 33 additions & 12 deletions airflow/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
polling_interval_seconds: int = 30,
cancel_on_kill: bool = True,
delete_on_error: bool = True,
):
super().__init__()
Expand All @@ -52,6 +53,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.polling_interval_seconds = polling_interval_seconds
self.cancel_on_kill = cancel_on_kill
self.delete_on_error = delete_on_error

def get_async_hook(self):
Expand All @@ -63,8 +65,8 @@ def get_async_hook(self):
def get_sync_hook(self):
# The synchronous hook is utilized to delete the cluster when a task is cancelled.
# This is because the asynchronous hook deletion is not awaited when the trigger task
# is cancelled. The call for deleting the cluster through the sync hook is not a blocking
# call, which means it does not wait until the cluster is deleted.
# is cancelled. The call for deleting the cluster or job through the sync hook is not a blocking
# call, which means it does not wait until the cluster or job is deleted.
return DataprocHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down Expand Up @@ -104,20 +106,39 @@ def serialize(self):
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"polling_interval_seconds": self.polling_interval_seconds,
"cancel_on_kill": self.cancel_on_kill,
},
)

async def run(self):
while True:
job = await self.get_async_hook().get_job(
project_id=self.project_id, region=self.region, job_id=self.job_id
)
state = job.status.state
self.log.info("Dataproc job: %s is in state: %s", self.job_id, state)
if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR):
break
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": job})
try:
while True:
job = await self.get_async_hook().get_job(
project_id=self.project_id, region=self.region, job_id=self.job_id
)
state = job.status.state
self.log.info("Dataproc job: %s is in state: %s", self.job_id, state)
if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR):
break
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": job})
except asyncio.CancelledError:
self.log.info("Task got cancelled.")
try:
if self.job_id and self.cancel_on_kill:
self.log.info("Cancelling the job: %s", self.job_id)
# The synchronous hook is utilized to delete the cluster when a task is cancelled. This
# is because the asynchronous hook deletion is not awaited when the trigger task is
# cancelled. The call for deleting the cluster or job through the sync hook is not a
# blocking call, which means it does not wait until the cluster or job is deleted.
self.get_sync_hook().cancel_job(
job_id=self.job_id, project_id=self.project_id, region=self.region
)
self.log.info("Job: %s is cancelled", self.job_id)
yield TriggerEvent({"job_id": self.job_id, "job_state": ClusterStatus.State.DELETING})
except Exception as e:
self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e))
raise e


class DataprocClusterTrigger(DataprocBaseTrigger):
Expand Down
106 changes: 105 additions & 1 deletion tests/providers/google/cloud/triggers/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
from unittest import mock

import pytest
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
from google.protobuf.any_pb2 import Any
from google.rpc.status_pb2 import Status

from airflow.providers.google.cloud.triggers.dataproc import (
DataprocBatchTrigger,
DataprocClusterTrigger,
DataprocOperationTrigger,
DataprocSubmitTrigger,
)
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
from airflow.triggers.base import TriggerEvent
Expand All @@ -47,6 +48,7 @@
TEST_POLL_INTERVAL = 5
TEST_GCP_CONN_ID = "google_cloud_default"
TEST_OPERATION_NAME = "name"
TEST_JOB_ID = "test-job-id"


@pytest.fixture
Expand Down Expand Up @@ -113,6 +115,17 @@ def func(**kwargs):
return func


@pytest.fixture
def submit_trigger():
return DataprocSubmitTrigger(
job_id=TEST_JOB_ID,
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
gcp_conn_id=TEST_GCP_CONN_ID,
polling_interval_seconds=TEST_POLL_INTERVAL,
)


@pytest.fixture
def async_get_batch():
def func(**kwargs):
Expand Down Expand Up @@ -472,3 +485,94 @@ async def test_async_operation_triggers_on_error(self, mock_hook, operation_trig
)
actual_event = await operation_trigger.run().asend(None)
assert expected_event == actual_event


@pytest.mark.db_test
class TestDataprocSubmitTrigger:
def test_submit_trigger_serialization(self, submit_trigger):
"""Test that the trigger serializes its configuration correctly."""
classpath, kwargs = submit_trigger.serialize()
assert classpath == "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger"
assert kwargs == {
"job_id": TEST_JOB_ID,
"project_id": TEST_PROJECT_ID,
"region": TEST_REGION,
"gcp_conn_id": TEST_GCP_CONN_ID,
"polling_interval_seconds": TEST_POLL_INTERVAL,
"cancel_on_kill": True,
"impersonation_chain": None,
}

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook")
async def test_submit_trigger_run_success(self, mock_get_async_hook, submit_trigger):
"""Test the trigger correctly handles a job completion."""
mock_hook = mock_get_async_hook.return_value
mock_hook.get_job = mock.AsyncMock(
return_value=mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.DONE))
)

async_gen = submit_trigger.run()
event = await async_gen.asend(None)
expected_event = TriggerEvent(
{"job_id": TEST_JOB_ID, "job_state": JobStatus.State.DONE, "job": mock_hook.get_job.return_value}
)
assert event.payload == expected_event.payload

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook")
async def test_submit_trigger_run_error(self, mock_get_async_hook, submit_trigger):
"""Test the trigger correctly handles a job error."""
mock_hook = mock_get_async_hook.return_value
mock_hook.get_job = mock.AsyncMock(
return_value=mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.ERROR))
)

async_gen = submit_trigger.run()
event = await async_gen.asend(None)
expected_event = TriggerEvent(
{"job_id": TEST_JOB_ID, "job_state": JobStatus.State.ERROR, "job": mock_hook.get_job.return_value}
)
assert event.payload == expected_event.payload

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_sync_hook")
async def test_submit_trigger_run_cancelled(
self, mock_get_sync_hook, mock_get_async_hook, submit_trigger
):
"""Test the trigger correctly handles an asyncio.CancelledError."""
mock_async_hook = mock_get_async_hook.return_value
mock_async_hook.get_job.side_effect = asyncio.CancelledError

mock_sync_hook = mock_get_sync_hook.return_value
mock_sync_hook.cancel_job = mock.MagicMock()

async_gen = submit_trigger.run()

try:
await async_gen.asend(None)
# Should raise StopAsyncIteration if no more items to yield
await async_gen.asend(None)
except asyncio.CancelledError:
# Handle the cancellation as expected
pass
except StopAsyncIteration:
# The generator should be properly closed after handling the cancellation
pass
except Exception as e:
# Catch any other exceptions that should not occur
pytest.fail(f"Unexpected exception raised: {e}")

# Check if cancel_job was correctly called
if submit_trigger.cancel_on_kill:
mock_sync_hook.cancel_job.assert_called_once_with(
job_id=submit_trigger.job_id,
project_id=submit_trigger.project_id,
region=submit_trigger.region,
)
else:
mock_sync_hook.cancel_job.assert_not_called()

# Clean up the generator
await async_gen.aclose()

0 comments on commit 2a913b6

Please sign in to comment.