Skip to content

Commit

Permalink
Create DataprocStartClusterOperator and DataprocStopClusterOperator (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
molcay authored Jan 26, 2024
1 parent bd06434 commit 0f2670e
Show file tree
Hide file tree
Showing 8 changed files with 672 additions and 0 deletions.
88 changes: 88 additions & 0 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,94 @@ def update_cluster(
)
return operation

@GoogleBaseHook.fallback_to_default_project_id
def start_cluster(
self,
region: str,
project_id: str,
cluster_name: str,
cluster_uuid: str | None = None,
request_id: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> Operation:
"""Start a cluster in a project.
:param region: Cloud Dataproc region to handle the request.
:param project_id: Google Cloud project ID that the cluster belongs to.
:param cluster_name: The cluster name.
:param cluster_uuid: The cluster UUID
:param request_id: A unique id used to identify the request. If the
server receives two *UpdateClusterRequest* requests with the same
ID, the second request will be ignored, and an operation created
for the first one and stored in the backend is returned.
:param retry: A retry object used to retry requests. If *None*, requests
will not be retried.
:param timeout: The amount of time, in seconds, to wait for the request
to complete. If *retry* is specified, the timeout applies to each
individual attempt.
:param metadata: Additional metadata that is provided to the method.
:return: An instance of ``google.api_core.operation.Operation``
"""
client = self.get_cluster_client(region=region)
return client.start_cluster(
request={
"project_id": project_id,
"region": region,
"cluster_name": cluster_name,
"cluster_uuid": cluster_uuid,
"request_id": request_id,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)

@GoogleBaseHook.fallback_to_default_project_id
def stop_cluster(
self,
region: str,
project_id: str,
cluster_name: str,
cluster_uuid: str | None = None,
request_id: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> Operation:
"""Start a cluster in a project.
:param region: Cloud Dataproc region to handle the request.
:param project_id: Google Cloud project ID that the cluster belongs to.
:param cluster_name: The cluster name.
:param cluster_uuid: The cluster UUID
:param request_id: A unique id used to identify the request. If the
server receives two *UpdateClusterRequest* requests with the same
ID, the second request will be ignored, and an operation created
for the first one and stored in the backend is returned.
:param retry: A retry object used to retry requests. If *None*, requests
will not be retried.
:param timeout: The amount of time, in seconds, to wait for the request
to complete. If *retry* is specified, the timeout applies to each
individual attempt.
:param metadata: Additional metadata that is provided to the method.
:return: An instance of ``google.api_core.operation.Operation``
"""
client = self.get_cluster_client(region=region)
return client.stop_cluster(
request={
"project_id": project_id,
"region": region,
"cluster_name": cluster_name,
"cluster_uuid": cluster_uuid,
"request_id": request_id,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)

@GoogleBaseHook.fallback_to_default_project_id
def create_workflow_template(
self,
Expand Down
197 changes: 197 additions & 0 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,17 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster:
cluster = self._get_cluster(hook)
return cluster

def _start_cluster(self, hook: DataprocHook):
op: operation.Operation = hook.start_cluster(
region=self.region,
project_id=self.project_id,
cluster_name=self.cluster_name,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)
return hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=op)

def execute(self, context: Context) -> dict:
self.log.info("Creating cluster: %s", self.cluster_name)
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
Expand Down Expand Up @@ -801,6 +812,9 @@ def execute(self, context: Context) -> dict:
# Create new cluster
cluster = self._create_cluster(hook)
self._handle_error_state(hook, cluster)
elif cluster.status.state == cluster.status.State.STOPPED:
# if the cluster exists and already stopped, then start the cluster
self._start_cluster(hook)

return Cluster.to_dict(cluster)

Expand Down Expand Up @@ -1082,6 +1096,189 @@ def _delete_cluster(self, hook: DataprocHook):
)


class _DataprocStartStopClusterBaseOperator(GoogleCloudBaseOperator):
"""Base class to start or stop a cluster in a project.
:param cluster_name: Required. Name of the cluster to create
:param region: Required. The specified region where the dataproc cluster is created.
:param project_id: Optional. The ID of the Google Cloud project the cluster belongs to.
:param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail
if cluster with specified UUID does not exist.
:param request_id: Optional. A unique id used to identify the request. If the server receives two
``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the
first ``google.longrunning.Operation`` created and stored in the backend is returned.
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:param metadata: Additional metadata that is provided to the method.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

template_fields = (
"cluster_name",
"region",
"project_id",
"request_id",
"impersonation_chain",
)

def __init__(
self,
*,
cluster_name: str,
region: str,
project_id: str | None = None,
cluster_uuid: str | None = None,
request_id: str | None = None,
retry: AsyncRetry | _MethodDefault = DEFAULT,
timeout: float = 1 * 60 * 60,
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.project_id = project_id
self.region = region
self.cluster_name = cluster_name
self.cluster_uuid = cluster_uuid
self.request_id = request_id
self.retry = retry
self.timeout = timeout
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self._hook: DataprocHook | None = None

@property
def hook(self):
if self._hook is None:
self._hook = DataprocHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
return self._hook

def _get_project_id(self) -> str:
return self.project_id or self.hook.project_id

def _get_cluster(self) -> Cluster:
"""Retrieve the cluster information.
:return: Instance of ``google.cloud.dataproc_v1.Cluster``` class
"""
return self.hook.get_cluster(
project_id=self._get_project_id(),
region=self.region,
cluster_name=self.cluster_name,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)

def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool, str | None]:
"""Implement this method in child class to return whether the cluster is in desired state or not.
If the cluster is in desired stated you can return a log message content as a second value
for the return tuple.
:param cluster: Required. Instance of ``google.cloud.dataproc_v1.Cluster``
class to interact with Dataproc API
:return: Tuple of (Boolean, Optional[str]) The first value of the tuple is whether the cluster is
in desired state or not. The second value of the tuple will use if you want to log something when
the cluster is in desired state already.
"""
raise NotImplementedError

def _get_operation(self) -> operation.Operation:
"""Implement this method in child class to call the related hook method and return its result.
:return: ``google.api_core.operation.Operation`` value whether the cluster is in desired state or not
"""
raise NotImplementedError

def execute(self, context: Context) -> dict | None:
cluster: Cluster = self._get_cluster()
is_already_desired_state, log_str = self._check_desired_cluster_state(cluster)
if is_already_desired_state:
self.log.info(log_str)
return None

op: operation.Operation = self._get_operation()
result = self.hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=op)
return Cluster.to_dict(result)


class DataprocStartClusterOperator(_DataprocStartStopClusterBaseOperator):
"""Start a cluster in a project."""

operator_extra_links = (DataprocClusterLink(),)

def execute(self, context: Context) -> dict | None:
self.log.info("Starting the cluster: %s", self.cluster_name)
cluster = super().execute(context)
DataprocClusterLink.persist(
context=context,
operator=self,
cluster_id=self.cluster_name,
project_id=self._get_project_id(),
region=self.region,
)
self.log.info("Cluster started")
return cluster

def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool, str | None]:
if cluster.status.state == cluster.status.State.RUNNING:
return True, f'The cluster "{self.cluster_name}" already running!'
return False, None

def _get_operation(self) -> operation.Operation:
return self.hook.start_cluster(
region=self.region,
project_id=self._get_project_id(),
cluster_name=self.cluster_name,
cluster_uuid=self.cluster_uuid,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)


class DataprocStopClusterOperator(_DataprocStartStopClusterBaseOperator):
"""Stop a cluster in a project."""

def execute(self, context: Context) -> dict | None:
self.log.info("Stopping the cluster: %s", self.cluster_name)
cluster = super().execute(context)
self.log.info("Cluster stopped")
return cluster

def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool, str | None]:
if cluster.status.state in [cluster.status.State.STOPPED, cluster.status.State.STOPPING]:
return True, f'The cluster "{self.cluster_name}" already stopped!'
return False, None

def _get_operation(self) -> operation.Operation:
return self.hook.stop_cluster(
region=self.region,
project_id=self._get_project_id(),
cluster_name=self.cluster_name,
cluster_uuid=self.cluster_uuid,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)


class DataprocJobBaseOperator(GoogleCloudBaseOperator):
"""Base class for operators that launch job on DataProc.
Expand Down
24 changes: 24 additions & 0 deletions docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,30 @@ You can use deferrable mode for this action in order to run the operator asynchr
:start-after: [START how_to_cloud_dataproc_update_cluster_operator_async]
:end-before: [END how_to_cloud_dataproc_update_cluster_operator_async]

Starting a cluster
---------------------------

To start a cluster you can use the
:class:`~airflow.providers.google.cloud.operators.dataproc.DataprocStartClusterOperator`:

.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_dataproc_start_cluster_operator]
:end-before: [END how_to_cloud_dataproc_start_cluster_operator]

Stopping a cluster
---------------------------

To stop a cluster you can use the
:class:`~airflow.providers.google.cloud.operators.dataproc.DataprocStartClusterOperator`:

.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_dataproc_stop_cluster_operator]
:end-before: [END how_to_cloud_dataproc_stop_cluster_operator]

Deleting a cluster
------------------

Expand Down
1 change: 1 addition & 0 deletions tests/always/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
"airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryToSqlBaseOperator",
"airflow.providers.google.cloud.operators.cloud_sql.CloudSQLBaseOperator",
"airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator",
"airflow.providers.google.cloud.operators.dataproc._DataprocStartStopClusterBaseOperator",
"airflow.providers.google.cloud.operators.vertex_ai.custom_job.CustomTrainingJobBaseOperator",
"airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator",
}
Expand Down
Loading

0 comments on commit 0f2670e

Please sign in to comment.