Skip to content

Commit

Permalink
Add DataflowStartFlexTemplateOperator (#8550)
Browse files Browse the repository at this point in the history
  • Loading branch information
mik-laj authored Oct 16, 2020
1 parent 45d6083 commit 3c10ca6
Show file tree
Hide file tree
Showing 7 changed files with 566 additions and 67 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Example Airflow DAG for Google Cloud Dataflow service
"""
import os

from airflow import models
from airflow.providers.google.cloud.operators.dataflow import DataflowStartFlexTemplateOperator
from airflow.utils.dates import days_ago

GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")

DATAFLOW_FLEX_TEMPLATE_JOB_NAME = os.environ.get('DATAFLOW_FLEX_TEMPLATE_JOB_NAME', "dataflow-flex-template")

# For simplicity we use the same topic name as the subscription name.
PUBSUB_FLEX_TEMPLATE_TOPIC = os.environ.get('DATAFLOW_PUBSUB_FLEX_TEMPLATE_TOPIC', "dataflow-flex-template")
PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION = PUBSUB_FLEX_TEMPLATE_TOPIC
GCS_FLEX_TEMPLATE_TEMPLATE_PATH = os.environ.get(
'DATAFLOW_GCS_FLEX_TEMPLATE_TEMPLATE_PATH',
"gs://test-airflow-dataflow-flex-template/samples/dataflow/templates/streaming-beam-sql.json",
)
BQ_FLEX_TEMPLATE_DATASET = os.environ.get('DATAFLOW_BQ_FLEX_TEMPLATE_DATASET', 'airflow_dataflow_samples')
BQ_FLEX_TEMPLATE_LOCATION = os.environ.get('DATAFLOW_BQ_FLEX_TEMPLATE_LOCAATION>', 'us-west1')

with models.DAG(
dag_id="example_gcp_dataflow_flex_template_java",
start_date=days_ago(1),
schedule_interval=None, # Override to match your needs
) as dag_flex_template:
start_flex_template = DataflowStartFlexTemplateOperator(
task_id="start_flex_template_streaming_beam_sql",
body={
"launchParameter": {
"containerSpecGcsPath": GCS_FLEX_TEMPLATE_TEMPLATE_PATH,
"jobName": DATAFLOW_FLEX_TEMPLATE_JOB_NAME,
"parameters": {
"inputSubscription": PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION,
"outputTable": f"{GCP_PROJECT_ID}:{BQ_FLEX_TEMPLATE_DATASET}.streaming_beam_sql",
},
}
},
do_xcom_push=True,
location=BQ_FLEX_TEMPLATE_LOCATION,
)
115 changes: 90 additions & 25 deletions airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,32 @@ def inner_wrapper(self: "DataflowHook", *args, **kwargs):
class DataflowJobStatus:
"""
Helper class with Dataflow job statuses.
Reference: https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.jobs#Job.JobState
"""

JOB_STATE_DONE = "JOB_STATE_DONE"
JOB_STATE_UNKNOWN = "JOB_STATE_UNKNOWN"
JOB_STATE_STOPPED = "JOB_STATE_STOPPED"
JOB_STATE_RUNNING = "JOB_STATE_RUNNING"
JOB_STATE_FAILED = "JOB_STATE_FAILED"
JOB_STATE_CANCELLED = "JOB_STATE_CANCELLED"
JOB_STATE_UPDATED = "JOB_STATE_UPDATED"
JOB_STATE_DRAINING = "JOB_STATE_DRAINING"
JOB_STATE_DRAINED = "JOB_STATE_DRAINED"
JOB_STATE_PENDING = "JOB_STATE_PENDING"
JOB_STATE_CANCELLING = "JOB_STATE_CANCELLING"
JOB_STATE_QUEUED = "JOB_STATE_QUEUED"
FAILED_END_STATES = {JOB_STATE_FAILED, JOB_STATE_CANCELLED}
SUCCEEDED_END_STATES = {JOB_STATE_DONE}
END_STATES = SUCCEEDED_END_STATES | FAILED_END_STATES
SUCCEEDED_END_STATES = {JOB_STATE_DONE, JOB_STATE_UPDATED, JOB_STATE_DRAINED}
TERMINAL_STATES = SUCCEEDED_END_STATES | FAILED_END_STATES
AWAITING_STATES = {
JOB_STATE_RUNNING,
JOB_STATE_PENDING,
JOB_STATE_QUEUED,
JOB_STATE_CANCELLING,
JOB_STATE_DRAINING,
JOB_STATE_STOPPED,
}


class DataflowJobType:
Expand Down Expand Up @@ -170,7 +186,7 @@ def is_job_running(self) -> bool:
return False

for job in self._jobs:
if job['currentState'] not in DataflowJobStatus.END_STATES:
if job['currentState'] not in DataflowJobStatus.TERMINAL_STATES:
return True
return False

Expand Down Expand Up @@ -261,10 +277,7 @@ def _check_dataflow_job_state(self, job) -> bool:
and DataflowJobType.JOB_TYPE_STREAMING == job['type']
):
return True
elif job['currentState'] in {
DataflowJobStatus.JOB_STATE_RUNNING,
DataflowJobStatus.JOB_STATE_PENDING,
}:
elif job['currentState'] in DataflowJobStatus.AWAITING_STATES:
return False
self.log.debug("Current job: %s", str(job))
raise Exception(
Expand All @@ -282,14 +295,14 @@ def wait_for_done(self) -> None:
time.sleep(self._poll_sleep)
self._refresh_jobs()

def get_jobs(self) -> List[dict]:
def get_jobs(self, refresh=False) -> List[dict]:
"""
Returns Dataflow jobs.
:return: list of jobs
:rtype: list
"""
if not self._jobs:
if not self._jobs or refresh:
self._refresh_jobs()
if not self._jobs:
raise ValueError("Could not read _jobs")
Expand All @@ -300,23 +313,26 @@ def cancel(self) -> None:
"""
Cancels current job
"""
jobs = self._get_current_jobs()
batch = self._dataflow.new_batch_http_request()
job_ids = [job['id'] for job in jobs]
self.log.info("Canceling jobs: %s", ", ".join(job_ids))
for job_id in job_ids:
batch.add(
self._dataflow.projects()
.locations()
.jobs()
.update(
projectId=self._project_number,
location=self._job_location,
jobId=job_id,
body={"requestedState": DataflowJobStatus.JOB_STATE_CANCELLED},
jobs = self.get_jobs()
job_ids = [job['id'] for job in jobs if job['currentState'] not in DataflowJobStatus.TERMINAL_STATES]
if job_ids:
batch = self._dataflow.new_batch_http_request()
self.log.info("Canceling jobs: %s", ", ".join(job_ids))
for job_id in job_ids:
batch.add(
self._dataflow.projects()
.locations()
.jobs()
.update(
projectId=self._project_number,
location=self._job_location,
jobId=job_id,
body={"requestedState": DataflowJobStatus.JOB_STATE_CANCELLED},
)
)
)
batch.execute()
batch.execute()
else:
self.log.info("No jobs to cancel")


class _DataflowRunner(LoggingMixin):
Expand Down Expand Up @@ -631,6 +647,52 @@ def start_template_dataflow(
jobs_controller.wait_for_done()
return response["job"]

@GoogleBaseHook.fallback_to_default_project_id
def start_flex_template(
self,
body: dict,
location: str,
project_id: str,
on_new_job_id_callback: Optional[Callable[[str], None]] = None,
):
"""
Starts flex templates with the Dataflow pipeline.
:param body: The request body. See:
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
:param location: The location of the Dataflow job (for example europe-west1)
:type location: str
:param project_id: The ID of the GCP project that owns the job.
If set to ``None`` or missing, the default project_id from the GCP connection is used.
:type project_id: Optional[str]
:param on_new_job_id_callback: A callback that is called when a Job ID is detected.
:return: the Job
"""
service = self.get_conn()
request = (
service.projects() # pylint: disable=no-member
.locations()
.flexTemplates()
.launch(projectId=project_id, body=body, location=location)
)
response = request.execute(num_retries=self.num_retries)
job_id = response['job']['id']

if on_new_job_id_callback:
on_new_job_id_callback(job_id)

jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
job_id=job_id,
location=location,
poll_sleep=self.poll_sleep,
num_retries=self.num_retries,
)
jobs_controller.wait_for_done()

return jobs_controller.get_jobs(refresh=True)[0]

@_fallback_to_location_from_variables
@_fallback_to_project_id_from_variables
@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -659,6 +721,9 @@ def start_python_dataflow( # pylint: disable=too-many-arguments
:type dataflow: str
:param py_options: Additional options.
:type py_options: List[str]
:param project_id: The ID of the GCP project that owns the job.
If set to ``None`` or missing, the default project_id from the GCP connection is used.
:type project_id: Optional[str]
:param py_interpreter: Python version of the beam pipeline.
If None, this defaults to the python3.
To track python versions supported by beam and related
Expand Down
66 changes: 66 additions & 0 deletions airflow/providers/google/cloud/operators/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,72 @@ def on_kill(self) -> None:
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)


class DataflowStartFlexTemplateOperator(BaseOperator):
"""
Starts flex templates with the Dataflow pipeline.
:param body: The request body. See:
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
:param location: The location of the Dataflow job (for example europe-west1)
:type location: str
:param project_id: The ID of the GCP project that owns the job.
If set to ``None`` or missing, the default project_id from the GCP connection is used.
:type project_id: Optional[str]
:param gcp_conn_id: The connection ID to use connecting to Google Cloud
Platform.
:type gcp_conn_id: str
:param delegate_to: The account to impersonate, if any.
For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
"""

template_fields = ["body", 'location', 'project_id', 'gcp_conn_id']

@apply_defaults
def __init__(
self,
body: Dict,
location: str,
project_id: Optional[str] = None,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.body = body
self.location = location
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.job_id = None
self.hook: Optional[DataflowHook] = None

def execute(self, context):
self.hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
)

def set_current_job_id(job_id):
self.job_id = job_id

job = self.hook.start_flex_template(
body=self.body,
location=self.location,
project_id=self.project_id,
on_new_job_id_callback=set_current_job_id,
)

return job

def on_kill(self) -> None:
self.log.info("On kill.")
if self.job_id:
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)


class DataflowCreatePythonJobOperator(BaseOperator):
"""
Launching Cloud Dataflow jobs written in python. Note that both
Expand Down
Loading

0 comments on commit 3c10ca6

Please sign in to comment.