Skip to content

Commit

Permalink
Create dataproc serverless spark batches operator (#19248)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaksYermak authored Nov 26, 2021
1 parent c97a2e8 commit bf68b9a
Show file tree
Hide file tree
Showing 9 changed files with 898 additions and 3 deletions.
49 changes: 49 additions & 0 deletions airflow/providers/google/cloud/example_dags/example_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@
from airflow import models
from airflow.providers.google.cloud.operators.dataproc import (
ClusterGenerator,
DataprocCreateBatchOperator,
DataprocCreateClusterOperator,
DataprocCreateWorkflowTemplateOperator,
DataprocDeleteBatchOperator,
DataprocDeleteClusterOperator,
DataprocGetBatchOperator,
DataprocInstantiateWorkflowTemplateOperator,
DataprocListBatchesOperator,
DataprocSubmitJobOperator,
DataprocUpdateClusterOperator,
)
Expand Down Expand Up @@ -174,6 +178,13 @@
},
"jobs": [{"step_id": "pig_job_1", "pig_job": PIG_JOB["pig_job"]}],
}
BATCH_ID = "test-batch-id"
BATCH_CONFIG = {
"spark_batch": {
"jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
"main_class": "org.apache.spark.examples.SparkPi",
},
}


with models.DAG(
Expand Down Expand Up @@ -282,3 +293,41 @@

# Task dependency created via `XComArgs`:
# spark_task_async >> spark_task_async_sensor

with models.DAG(
"example_gcp_batch_dataproc",
schedule_interval='@once',
start_date=datetime(2021, 1, 1),
catchup=False,
) as dag_batch:
# [START how_to_cloud_dataproc_create_batch_operator]
create_batch = DataprocCreateBatchOperator(
task_id="create_batch",
project_id=PROJECT_ID,
region=REGION,
batch=BATCH_CONFIG,
batch_id=BATCH_ID,
)
# [END how_to_cloud_dataproc_create_batch_operator]

# [START how_to_cloud_dataproc_get_batch_operator]
get_batch = DataprocGetBatchOperator(
task_id="get_batch", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID
)
# [END how_to_cloud_dataproc_get_batch_operator]

# [START how_to_cloud_dataproc_list_batches_operator]
list_batches = DataprocListBatchesOperator(
task_id="list_batches",
project_id=PROJECT_ID,
region=REGION,
)
# [END how_to_cloud_dataproc_list_batches_operator]

# [START how_to_cloud_dataproc_delete_batch_operator]
delete_batch = DataprocDeleteBatchOperator(
task_id="delete_batch", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID
)
# [END how_to_cloud_dataproc_delete_batch_operator]

create_batch >> get_batch >> list_batches >> delete_batch
219 changes: 219 additions & 0 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union

from google.api_core.exceptions import ServerError
from google.api_core.operation import Operation
from google.api_core.retry import Retry
from google.cloud.dataproc_v1 import (
Batch,
BatchControllerClient,
Cluster,
ClusterControllerClient,
Job,
Expand Down Expand Up @@ -267,6 +270,34 @@ def get_job_client(
credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options
)

def get_batch_client(
self, region: Optional[str] = None, location: Optional[str] = None
) -> BatchControllerClient:
"""Returns BatchControllerClient"""
if location is not None:
warnings.warn(
"Parameter `location` will be deprecated. "
"Please provide value through `region` parameter instead.",
DeprecationWarning,
stacklevel=2,
)
region = location
client_options = None
if region and region != 'global':
client_options = {'api_endpoint': f'{region}-dataproc.googleapis.com:443'}

return BatchControllerClient(
credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options
)

def wait_for_operation(self, timeout: float, operation: Operation):
"""Waits for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@GoogleBaseHook.fallback_to_default_project_id
def create_cluster(
self,
Expand Down Expand Up @@ -1030,3 +1061,191 @@ def cancel_job(
metadata=metadata,
)
return job

@GoogleBaseHook.fallback_to_default_project_id
def create_batch(
self,
region: str,
project_id: str,
batch: Union[Dict, Batch],
batch_id: Optional[str] = None,
request_id: Optional[str] = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = "",
):
"""
Creates a batch workload.
:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
:type project_id: str
:param region: Required. The Cloud Dataproc region in which to handle the request.
:type region: str
:param batch: Required. The batch to create.
:type batch: google.cloud.dataproc_v1.types.Batch
:param batch_id: Optional. The ID to use for the batch, which will become the final component
of the batch's resource name.
This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/.
:type batch_id: str
:param request_id: Optional. A unique id used to identify the request. If the server receives two
``CreateBatchRequest`` requests with the same id, then the second request will be ignored and
the first ``google.longrunning.Operation`` created and stored in the backend is returned.
:type request_id: str
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:type retry: google.api_core.retry.Retry
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:type timeout: float
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_batch_client(region)
parent = f'projects/{project_id}/regions/{region}'

result = client.create_batch(
request={
'parent': parent,
'batch': batch,
'batch_id': batch_id,
'request_id': request_id,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result

@GoogleBaseHook.fallback_to_default_project_id
def delete_batch(
self,
batch_id: str,
region: str,
project_id: str,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = None,
):
"""
Deletes the batch workload resource.
:param batch_id: Required. The ID to use for the batch, which will become the final component
of the batch's resource name.
This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/.
:type batch_id: str
:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
:type project_id: str
:param region: Required. The Cloud Dataproc region in which to handle the request.
:type region: str
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:type retry: google.api_core.retry.Retry
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:type timeout: float
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_batch_client(region)
name = f"projects/{project_id}/regions/{region}/batches/{batch_id}"

result = client.delete_batch(
request={
'name': name,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result

@GoogleBaseHook.fallback_to_default_project_id
def get_batch(
self,
batch_id: str,
region: str,
project_id: str,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = None,
):
"""
Gets the batch workload resource representation.
:param batch_id: Required. The ID to use for the batch, which will become the final component
of the batch's resource name.
This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/.
:type batch_id: str
:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
:type project_id: str
:param region: Required. The Cloud Dataproc region in which to handle the request.
:type region: str
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:type retry: google.api_core.retry.Retry
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:type timeout: float
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_batch_client(region)
name = f"projects/{project_id}/regions/{region}/batches/{batch_id}"

result = client.get_batch(
request={
'name': name,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result

@GoogleBaseHook.fallback_to_default_project_id
def list_batches(
self,
region: str,
project_id: str,
page_size: Optional[int] = None,
page_token: Optional[str] = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = None,
):
"""
Lists batch workloads.
:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
:type project_id: str
:param region: Required. The Cloud Dataproc region in which to handle the request.
:type region: str
:param page_size: Optional. The maximum number of batches to return in each response. The service may
return fewer than this value. The default page size is 20; the maximum page size is 1000.
:type page_size: int
:param page_token: Optional. A page token received from a previous ``ListBatches`` call.
Provide this token to retrieve the subsequent page.
:type page_token: str
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:type retry: google.api_core.retry.Retry
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:type timeout: float
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_batch_client(region)
parent = f'projects/{project_id}/regions/{region}'

result = client.list_batches(
request={
'parent': parent,
'page_size': page_size,
'page_token': page_token,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result
Loading

0 comments on commit bf68b9a

Please sign in to comment.