Skip to content

Commit

Permalink
Asynchronous execution of Dataproc jobs with a Sensor (#10673)
Browse files Browse the repository at this point in the history
  • Loading branch information
varundhussa authored Sep 5, 2020
1 parent 527ea81 commit ece685b
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 6 deletions.
16 changes: 16 additions & 0 deletions airflow/providers/google/cloud/example_dags/example_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
DataprocSubmitJobOperator,
DataprocUpdateClusterOperator,
)
from airflow.providers.google.cloud.sensors.dataproc import DataprocJobSensor
from airflow.utils.dates import days_ago

PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id")
Expand Down Expand Up @@ -170,6 +171,20 @@
task_id="spark_task", job=SPARK_JOB, location=REGION, project_id=PROJECT_ID
)

# [START cloud_dataproc_async_submit_sensor]
spark_task_async = DataprocSubmitJobOperator(
task_id="spark_task_async", job=SPARK_JOB, location=REGION, project_id=PROJECT_ID, asynchronous=True
)

spark_task_async_sensor = DataprocJobSensor(
task_id='spark_task_async_sensor_task',
location=REGION,
project_id=PROJECT_ID,
dataproc_job_id="{{task_instance.xcom_pull(task_ids='spark_task_async')}}",
poke_interval=10,
)
# [END cloud_dataproc_async_submit_sensor]

# [START how_to_cloud_dataproc_submit_job_to_cluster_operator]
pyspark_task = DataprocSubmitJobOperator(
task_id="pyspark_task", job=PYSPARK_JOB, location=REGION, project_id=PROJECT_ID
Expand Down Expand Up @@ -199,6 +214,7 @@
scale_cluster >> pig_task >> delete_cluster
scale_cluster >> spark_sql_task >> delete_cluster
scale_cluster >> spark_task >> delete_cluster
scale_cluster >> spark_task_async >> spark_task_async_sensor >> delete_cluster
scale_cluster >> pyspark_task >> delete_cluster
scale_cluster >> sparkr_task >> delete_cluster
scale_cluster >> hadoop_task >> delete_cluster
32 changes: 27 additions & 5 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,10 @@ class DataprocJobBaseOperator(BaseOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:type impersonation_chain: Union[str, Sequence[str]]
:param asynchronous: Flag to return after submitting the job to the Dataproc API.
This is useful for submitting long running jobs and
waiting on them asynchronously using the DataprocJobSensor
:type asynchronous: bool
:var dataproc_job_id: The actual "jobId" as submitted to the Dataproc API.
This is useful for identifying or linking to the job in the Google Cloud Console
Expand All @@ -930,6 +934,7 @@ def __init__(
region: str = 'global',
job_error_states: Optional[Set[str]] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
asynchronous: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -949,6 +954,7 @@ def __init__(
self.job_template = None
self.job = None
self.dataproc_job_id = None
self.asynchronous = asynchronous

def create_job_template(self):
"""
Expand Down Expand Up @@ -980,8 +986,13 @@ def execute(self, context):
project_id=self.project_id, job=self.job["job"], location=self.region,
)
job_id = job_object.reference.job_id
self.hook.wait_for_job(job_id=job_id, location=self.region, project_id=self.project_id)
self.log.info('Job executed correctly.')
self.log.info('Job %s submitted successfully.', job_id)

if not self.asynchronous:
self.log.info('Waiting for job %s to complete', job_id)
self.hook.wait_for_job(job_id=job_id, location=self.region, project_id=self.project_id)
self.log.info('Job %s completed successfully.', job_id)
return job_id
else:
raise AirflowException("Create a job template before")

Expand Down Expand Up @@ -1785,6 +1796,10 @@ class DataprocSubmitJobOperator(BaseOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:type impersonation_chain: Union[str, Sequence[str]]
:param asynchronous: Flag to return after submitting the job to the Dataproc API.
This is useful for submitting long running jobs and
waiting on them asynchronously using the DataprocJobSensor
:type asynchronous: bool
"""

template_fields = (
Expand All @@ -1807,6 +1822,7 @@ def __init__(
metadata: Optional[Sequence[Tuple[str, str]]] = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
asynchronous: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -1819,6 +1835,7 @@ def __init__(
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.asynchronous = asynchronous

def execute(self, context: Dict):
self.log.info("Submitting job")
Expand All @@ -1833,9 +1850,14 @@ def execute(self, context: Dict):
metadata=self.metadata,
)
job_id = job_object.reference.job_id
self.log.info("Waiting for job %s to complete", job_id)
hook.wait_for_job(job_id=job_id, project_id=self.project_id, location=self.location)
self.log.info("Job completed successfully.")
self.log.info('Job %s submitted successfully.', job_id)

if not self.asynchronous:
self.log.info('Waiting for job %s to complete', job_id)
hook.wait_for_job(job_id=job_id, location=self.location, project_id=self.project_id)
self.log.info('Job %s completed successfully.', job_id)

return job_id


class DataprocUpdateClusterOperator(BaseOperator):
Expand Down
81 changes: 81 additions & 0 deletions airflow/providers/google/cloud/sensors/dataproc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module contains a Dataproc Job sensor.
"""
# pylint: disable=C0302

from google.cloud.dataproc_v1beta2.types import JobStatus

from airflow.providers.google.cloud.hooks.dataproc import DataprocHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
from airflow.exceptions import AirflowException


class DataprocJobSensor(BaseSensorOperator):
"""
Check for the state of a previously submitted Dataproc job.
:param project_id: The ID of the google cloud project in which
to create the cluster. (templated)
:type project_id: str
:param dataproc_job_id: The Dataproc job ID to poll. (templated)
:type dataproc_job_id: str
:param location: Required. The Cloud Dataproc region in which to handle the request. (templated)
:type location: str
:param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform.
:type gcp_conn_id: str
"""

template_fields = ('project_id', 'location', 'dataproc_job_id')
ui_color = '#f0eee4'

@apply_defaults
def __init__(
self,
*,
project_id: str,
dataproc_job_id: str,
location: str,
gcp_conn_id: str = 'google_cloud_default',
**kwargs,
) -> None:
super().__init__(**kwargs)
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.dataproc_job_id = dataproc_job_id
self.location = location

def poke(self, context):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
job = hook.get_job(job_id=self.dataproc_job_id, location=self.location, project_id=self.project_id)
state = job.status.state

if state == JobStatus.ERROR:
raise AirflowException('Job failed:\n{}'.format(job))
elif state in {JobStatus.CANCELLED, JobStatus.CANCEL_PENDING, JobStatus.CANCEL_STARTED}:
raise AirflowException('Job was cancelled:\n{}'.format(job))
elif JobStatus.DONE == state:
self.log.debug("Job %s completed successfully.", self.dataproc_job_id)
return True
elif JobStatus.ATTEMPT_FAILURE == state:
self.log.debug("Job %s attempt has failed.", self.dataproc_job_id)

self.log.info("Waiting for job %s to complete.", self.dataproc_job_id)
return False
2 changes: 1 addition & 1 deletion docs/operators-and-hooks-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ These integrations allow you to perform various operations within the Google Clo
- :doc:`How to use <howto/operator/google/cloud/dataproc>`
- :mod:`airflow.providers.google.cloud.hooks.dataproc`
- :mod:`airflow.providers.google.cloud.operators.dataproc`
-
- :mod:`airflow.providers.google.cloud.sensors.dataproc`

* - `Datastore <https://cloud.google.com/datastore/>`__
- :doc:`How to use <howto/operator/google/cloud/datastore>`
Expand Down
36 changes: 36 additions & 0 deletions tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,42 @@ def test_execute(self, mock_hook):
job_id=job_id, project_id=GCP_PROJECT, location=GCP_LOCATION
)

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_async(self, mock_hook):
job = {}
job_id = "job_id"
mock_hook.return_value.wait_for_job.return_value = None
mock_hook.return_value.submit_job.return_value.reference.job_id = job_id

op = DataprocSubmitJobOperator(
task_id=TASK_ID,
location=GCP_LOCATION,
project_id=GCP_PROJECT,
job=job,
gcp_conn_id=GCP_CONN_ID,
retry=RETRY,
asynchronous=True,
timeout=TIMEOUT,
metadata=METADATA,
request_id=REQUEST_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={})

mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.return_value.submit_job.assert_called_once_with(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
job=job,
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)
mock_hook.return_value.wait_for_job.assert_not_called()


class TestDataprocUpdateClusterOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
Expand Down
128 changes: 128 additions & 0 deletions tests/providers/google/cloud/sensors/test_dataproc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import unittest
from unittest import mock

from google.cloud.dataproc_v1beta2.types import JobStatus
from airflow import AirflowException
from airflow.providers.google.cloud.sensors.dataproc import DataprocJobSensor

from airflow.version import version as airflow_version

AIRFLOW_VERSION = "v" + airflow_version.replace(".", "-").replace("+", "-")

DATAPROC_PATH = "airflow.providers.google.cloud.sensors.dataproc.{}"

TASK_ID = "task-id"
GCP_PROJECT = "test-project"
GCP_LOCATION = "test-location"
GCP_CONN_ID = "test-conn"
TIMEOUT = 120


class TestDataprocJobSensor(unittest.TestCase):
def create_job(self, state: int):
job = mock.Mock()
job.status = mock.Mock()
job.status.state = state
return job

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_done(self, mock_hook):
job = self.create_job(JobStatus.DONE)
job_id = "job_id"
mock_hook.return_value.get_job.return_value = job

sensor = DataprocJobSensor(
task_id=TASK_ID,
location=GCP_LOCATION,
project_id=GCP_PROJECT,
dataproc_job_id=job_id,
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
)
ret = sensor.poke(context={})

mock_hook.return_value.get_job.assert_called_once_with(
job_id=job_id, location=GCP_LOCATION, project_id=GCP_PROJECT
)
self.assertTrue(ret)

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_error(self, mock_hook):
job = self.create_job(JobStatus.ERROR)
job_id = "job_id"
mock_hook.return_value.get_job.return_value = job

sensor = DataprocJobSensor(
task_id=TASK_ID,
location=GCP_LOCATION,
project_id=GCP_PROJECT,
dataproc_job_id=job_id,
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
)

with self.assertRaisesRegex(AirflowException, "Job failed"):
sensor.poke(context={})

mock_hook.return_value.get_job.assert_called_once_with(
job_id=job_id, location=GCP_LOCATION, project_id=GCP_PROJECT
)

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_wait(self, mock_hook):
job = self.create_job(JobStatus.RUNNING)
job_id = "job_id"
mock_hook.return_value.get_job.return_value = job

sensor = DataprocJobSensor(
task_id=TASK_ID,
location=GCP_LOCATION,
project_id=GCP_PROJECT,
dataproc_job_id=job_id,
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
)
ret = sensor.poke(context={})

mock_hook.return_value.get_job.assert_called_once_with(
job_id=job_id, location=GCP_LOCATION, project_id=GCP_PROJECT
)
self.assertFalse(ret)

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_cancelled(self, mock_hook):
job = self.create_job(JobStatus.CANCELLED)
job_id = "job_id"
mock_hook.return_value.get_job.return_value = job

sensor = DataprocJobSensor(
task_id=TASK_ID,
location=GCP_LOCATION,
project_id=GCP_PROJECT,
dataproc_job_id=job_id,
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
)
with self.assertRaisesRegex(AirflowException, "Job was cancelled"):
sensor.poke(context={})

mock_hook.return_value.get_job.assert_called_once_with(
job_id=job_id, location=GCP_LOCATION, project_id=GCP_PROJECT
)

0 comments on commit ece685b

Please sign in to comment.