Skip to content

Commit

Permalink
Change dataprep system tests assets (#26488)
Browse files Browse the repository at this point in the history
  • Loading branch information
MrGeorgeOwl authored Nov 10, 2022
1 parent 34e21ea commit 59e3198
Show file tree
Hide file tree
Showing 13 changed files with 1,166 additions and 110 deletions.
79 changes: 0 additions & 79 deletions airflow/providers/google/cloud/example_dags/example_dataprep.py

This file was deleted.

82 changes: 78 additions & 4 deletions airflow/providers/google/cloud/hooks/dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from __future__ import annotations

import json
import os
from enum import Enum
from typing import Any
from urllib.parse import urljoin

import requests
from requests import HTTPError
Expand All @@ -43,6 +44,17 @@ def _get_field(extras: dict, field_name: str):
return extras.get(prefixed_name) or None


class JobGroupStatuses(str, Enum):
"""Types of job group run statuses."""

CREATED = "Created"
UNDEFINED = "undefined"
IN_PROGRESS = "InProgress"
COMPLETE = "Complete"
FAILED = "Failed"
CANCELED = "Canceled"


class GoogleDataprepHook(BaseHook):
"""
Hook for connection with Dataprep API.
Expand Down Expand Up @@ -82,7 +94,7 @@ def get_jobs_for_job_group(self, job_id: int) -> dict[str, Any]:
:param job_id: The ID of the job that will be fetched
"""
endpoint_path = f"v4/jobGroups/{job_id}/jobs"
url: str = os.path.join(self._base_url, endpoint_path)
url: str = urljoin(self._base_url, endpoint_path)
response = requests.get(url, headers=self._headers)
self._raise_for_status(response)
return response.json()
Expand All @@ -99,7 +111,7 @@ def get_job_group(self, job_group_id: int, embed: str, include_deleted: bool) ->
"""
params: dict[str, Any] = {"embed": embed, "includeDeleted": include_deleted}
endpoint_path = f"v4/jobGroups/{job_group_id}"
url: str = os.path.join(self._base_url, endpoint_path)
url: str = urljoin(self._base_url, endpoint_path)
response = requests.get(url, headers=self._headers, params=params)
self._raise_for_status(response)
return response.json()
Expand All @@ -115,11 +127,73 @@ def run_job_group(self, body_request: dict) -> dict[str, Any]:
:param body_request: The identifier for the recipe you would like to run.
"""
endpoint_path = "v4/jobGroups"
url: str = os.path.join(self._base_url, endpoint_path)
url: str = urljoin(self._base_url, endpoint_path)
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()

@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
def copy_flow(
self, *, flow_id: int, name: str = "", description: str = "", copy_datasources: bool = False
) -> dict:
"""
Create a copy of the provided flow id, as well as all contained recipes.
:param flow_id: ID of the flow to be copied
:param name: Name for the copy of the flow
:param description: Description of the copy of the flow
:param copy_datasources: Bool value to define should copies of data inputs be made or not.
"""
endpoint_path = f"v4/flows/{flow_id}/copy"
url: str = urljoin(self._base_url, endpoint_path)
body_request = {
"name": name,
"description": description,
"copyDatasources": copy_datasources,
}
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()

@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
def delete_flow(self, *, flow_id: int) -> None:
"""
Delete the flow with the provided id.
:param flow_id: ID of the flow to be copied
"""
endpoint_path = f"v4/flows/{flow_id}"
url: str = urljoin(self._base_url, endpoint_path)
response = requests.delete(url, headers=self._headers)
self._raise_for_status(response)

@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
def run_flow(self, *, flow_id: int, body_request: dict) -> dict:
"""
Runs the flow with the provided id copy of the provided flow id.
:param flow_id: ID of the flow to be copied
:param body_request: Body of the POST request to be sent.
"""
endpoint = f"v4/flows/{flow_id}/run"
url: str = urljoin(self._base_url, endpoint)
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()

@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
def get_job_group_status(self, *, job_group_id: int) -> JobGroupStatuses:
"""
Check the status of the Dataprep task to be finished.
:param job_group_id: ID of the job group to check
"""
endpoint = f"/v4/jobGroups/{job_group_id}/status"
url: str = urljoin(self._base_url, endpoint)
response = requests.get(url, headers=self._headers)
self._raise_for_status(response)
return response.json()

def _raise_for_status(self, response: requests.models.Response) -> None:
try:
response.raise_for_status()
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/links/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ def get_link(
conf = XCom.get_value(key=self.key, ti_key=ti_key)
if not conf:
return ""
if self.format_str.startswith(BASE_LINK):
if self.format_str.startswith("http"):
return self.format_str.format(**conf)
return BASE_LINK + self.format_str.format(**conf)
63 changes: 63 additions & 0 deletions airflow/providers/google/cloud/links/dataprep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.providers.google.cloud.links.base import BaseGoogleLink

if TYPE_CHECKING:
from airflow.utils.context import Context

BASE_LINK = "https://clouddataprep.com"
DATAPREP_FLOW_LINK = BASE_LINK + "/flows/{flow_id}?projectId={project_id}"
DATAPREP_JOB_GROUP_LINK = BASE_LINK + "/jobs/{job_group_id}?projectId={project_id}"


class DataprepFlowLink(BaseGoogleLink):
"""Helper class for constructing Dataprep flow link."""

name = "Flow details page"
key = "dataprep_flow_page"
format_str = DATAPREP_FLOW_LINK

@staticmethod
def persist(context: Context, task_instance, project_id: str, flow_id: int):
task_instance.xcom_push(
context=context,
key=DataprepFlowLink.key,
value={"project_id": project_id, "flow_id": flow_id},
)


class DataprepJobGroupLink(BaseGoogleLink):
"""Helper class for constructing Dataprep job group link."""

name = "Job group details page"
key = "dataprep_job_group_page"
format_str = DATAPREP_JOB_GROUP_LINK

@staticmethod
def persist(context: Context, task_instance, project_id: str, job_group_id: int):
task_instance.xcom_push(
context=context,
key=DataprepJobGroupLink.key,
value={
"project_id": project_id,
"job_group_id": job_group_id,
},
)
Loading

0 comments on commit 59e3198

Please sign in to comment.