Skip to content

Commit

Permalink
Fix cancel_on_kill after execution timeout for DataprocSubmitJobOpera…
Browse files Browse the repository at this point in the history
…tor (#22955)

Synchronous tasks killed by execution timeout weren't canceled
due to wrong assignment of job_id property.
  • Loading branch information
tauzen authored Apr 14, 2022
1 parent 52b724e commit ea1ae19
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
16 changes: 9 additions & 7 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,19 +1858,21 @@ def execute(self, context: 'Context'):
timeout=self.timeout,
metadata=self.metadata,
)
job_id = job_object.reference.job_id
self.log.info('Job %s submitted successfully.', job_id)
new_job_id: str = job_object.reference.job_id
self.log.info('Job %s submitted successfully.', new_job_id)
# Save data required by extra links no matter what the job status will be
DataprocLink.persist(context=context, task_instance=self, url=DATAPROC_JOB_LOG_LINK, resource=job_id)
DataprocLink.persist(
context=context, task_instance=self, url=DATAPROC_JOB_LOG_LINK, resource=new_job_id
)

self.job_id = new_job_id
if not self.asynchronous:
self.log.info('Waiting for job %s to complete', job_id)
self.log.info('Waiting for job %s to complete', new_job_id)
self.hook.wait_for_job(
job_id=job_id, region=self.region, project_id=self.project_id, timeout=self.wait_timeout
job_id=new_job_id, region=self.region, project_id=self.project_id, timeout=self.wait_timeout
)
self.log.info('Job %s completed successfully.', job_id)
self.log.info('Job %s completed successfully.', new_job_id)

self.job_id = job_id
return self.job_id

def on_kill(self):
Expand Down
29 changes: 29 additions & 0 deletions tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from google.api_core.retry import Retry

from airflow import AirflowException
from airflow.exceptions import AirflowTaskTimeout
from airflow.models import DAG, DagBag
from airflow.providers.google.cloud.operators.dataproc import (
DATAPROC_CLUSTER_LINK,
Expand Down Expand Up @@ -877,6 +878,34 @@ def test_on_kill(self, mock_hook):
project_id=GCP_PROJECT, region=GCP_LOCATION, job_id=job_id
)

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_on_kill_after_execution_timeout(self, mock_hook):
job = {}
job_id = "job_id"
mock_hook.return_value.wait_for_job.side_effect = AirflowTaskTimeout()
mock_hook.return_value.submit_job.return_value.reference.job_id = job_id

op = DataprocSubmitJobOperator(
task_id=TASK_ID,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
job=job,
gcp_conn_id=GCP_CONN_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
request_id=REQUEST_ID,
impersonation_chain=IMPERSONATION_CHAIN,
cancel_on_kill=True,
)
with pytest.raises(AirflowTaskTimeout):
op.execute(context=self.mock_context)

op.on_kill()
mock_hook.return_value.cancel_job.assert_called_once_with(
project_id=GCP_PROJECT, region=GCP_LOCATION, job_id=job_id
)

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_location_deprecation_warning(self, mock_hook):
xcom_push_call = call.ti.xcom_push(execution_date=None, key='conf', value=DATAPROC_JOB_CONF_EXPECTED)
Expand Down

0 comments on commit ea1ae19

Please sign in to comment.