From 92335417d881c01b0d2ef77ad254f3f3b491df4c Mon Sep 17 00:00:00 2001 From: Alberto Costa Date: Sun, 17 Dec 2023 23:32:27 +0100 Subject: [PATCH] Add use_glob to GCSObjectExistenceSensor (#34137) --------- Co-authored-by: Alberto Costa --- airflow/providers/google/cloud/sensors/gcs.py | 11 +++++- .../providers/google/cloud/triggers/gcs.py | 16 ++++++-- airflow/providers/google/provider.yaml | 2 +- docs/apache-airflow/img/airflow_erd.sha256 | 2 +- docs/apache-airflow/img/airflow_erd.svg | 4 +- .../google/cloud/sensors/test_gcs.py | 22 +++++++++++ .../google/cloud/triggers/test_gcs.py | 38 +++++++++++++++++++ 7 files changed, 87 insertions(+), 8 deletions(-) diff --git a/airflow/providers/google/cloud/sensors/gcs.py b/airflow/providers/google/cloud/sensors/gcs.py index 453bb3bf44000..c5a80e2d55f68 100644 --- a/airflow/providers/google/cloud/sensors/gcs.py +++ b/airflow/providers/google/cloud/sensors/gcs.py @@ -50,6 +50,7 @@ class GCSObjectExistenceSensor(BaseSensorOperator): :param bucket: The Google Cloud Storage bucket where the object is. :param object: The name of the object to check in the Google cloud storage bucket. + :param use_glob: When set to True the object parameter is interpreted as glob :param google_cloud_conn_id: The connection ID to use when connecting to Google Cloud Storage. :param impersonation_chain: Optional service account to impersonate using short-term @@ -75,6 +76,7 @@ def __init__( *, bucket: str, object: str, + use_glob: bool = False, google_cloud_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, retry: Retry = DEFAULT_RETRY, @@ -84,7 +86,9 @@ def __init__( super().__init__(**kwargs) self.bucket = bucket self.object = object + self.use_glob = use_glob self.google_cloud_conn_id = google_cloud_conn_id + self._matches: list[str] = [] self.impersonation_chain = impersonation_chain self.retry = retry @@ -96,7 +100,11 @@ def poke(self, context: Context) -> bool: gcp_conn_id=self.google_cloud_conn_id, impersonation_chain=self.impersonation_chain, ) - return hook.exists(self.bucket, self.object, self.retry) + if self.use_glob: + self._matches = hook.list(self.bucket, match_glob=self.object) + return bool(self._matches) + else: + return hook.exists(self.bucket, self.object, self.retry) def execute(self, context: Context) -> None: """Airflow runs this method on the worker and defers using the trigger.""" @@ -109,6 +117,7 @@ def execute(self, context: Context) -> None: trigger=GCSBlobTrigger( bucket=self.bucket, object_name=self.object, + use_glob=self.use_glob, poke_interval=self.poke_interval, google_cloud_conn_id=self.google_cloud_conn_id, hook_params={ diff --git a/airflow/providers/google/cloud/triggers/gcs.py b/airflow/providers/google/cloud/triggers/gcs.py index a5d181cd0566e..f801e5ae9a78d 100644 --- a/airflow/providers/google/cloud/triggers/gcs.py +++ b/airflow/providers/google/cloud/triggers/gcs.py @@ -35,6 +35,7 @@ class GCSBlobTrigger(BaseTrigger): :param bucket: the bucket in the google cloud storage where the objects are residing. :param object_name: the file or folder present in the bucket + :param use_glob: if true object_name is interpreted as glob :param google_cloud_conn_id: reference to the Google Connection :param poke_interval: polling period in seconds to check for file/folder :param hook_params: Extra config params to be passed to the underlying hook. @@ -45,6 +46,7 @@ def __init__( self, bucket: str, object_name: str, + use_glob: bool, poke_interval: float, google_cloud_conn_id: str, hook_params: dict[str, Any], @@ -52,6 +54,7 @@ def __init__( super().__init__() self.bucket = bucket self.object_name = object_name + self.use_glob = use_glob self.poke_interval = poke_interval self.google_cloud_conn_id: str = google_cloud_conn_id self.hook_params = hook_params @@ -63,6 +66,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: { "bucket": self.bucket, "object_name": self.object_name, + "use_glob": self.use_glob, "poke_interval": self.poke_interval, "google_cloud_conn_id": self.google_cloud_conn_id, "hook_params": self.hook_params, @@ -98,9 +102,14 @@ async def _object_exists(self, hook: GCSAsyncHook, bucket_name: str, object_name async with ClientSession() as s: client = await hook.get_storage_client(s) bucket = client.get_bucket(bucket_name) - object_response = await bucket.blob_exists(blob_name=object_name) - if object_response: - return "success" + if self.use_glob: + list_blobs_response = await bucket.list_blobs(match_glob=object_name) + if len(list_blobs_response) > 0: + return "success" + else: + blob_exists_response = await bucket.blob_exists(blob_name=object_name) + if blob_exists_response: + return "success" return "pending" @@ -234,6 +243,7 @@ def __init__( poke_interval=poke_interval, google_cloud_conn_id=google_cloud_conn_id, hook_params=hook_params, + use_glob=False, ) self.prefix = prefix diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 1c6859696ed4c..1d7bfd317e829 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -87,7 +87,7 @@ dependencies: - asgiref>=3.5.2 - gcloud-aio-auth>=4.0.0,<5.0.0 - gcloud-aio-bigquery>=6.1.2 - - gcloud-aio-storage + - gcloud-aio-storage>=9.0.0 - gcsfs>=2023.10.0 - google-ads>=22.1.0 - google-api-core>=2.11.0 diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256 index ded2722d373cc..eb2a21ae340d0 100644 --- a/docs/apache-airflow/img/airflow_erd.sha256 +++ b/docs/apache-airflow/img/airflow_erd.sha256 @@ -1 +1 @@ -a5677b0b603e8835f92da4b8b061ec268ce7257ef6b446f12593743ecf90710a \ No newline at end of file +194706fc390025f473f73ce934bfe4b394b50ce76748e5df33ae643e38538357 \ No newline at end of file diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg index 497ef76975f7d..8e85b5fa0cab8 100644 --- a/docs/apache-airflow/img/airflow_erd.svg +++ b/docs/apache-airflow/img/airflow_erd.svg @@ -1342,14 +1342,14 @@ task_instance--xcom -0..N +1 1 task_instance--xcom -1 +0..N 1 diff --git a/tests/providers/google/cloud/sensors/test_gcs.py b/tests/providers/google/cloud/sensors/test_gcs.py index 1d4bbcec876e0..37697ff58d735 100644 --- a/tests/providers/google/cloud/sensors/test_gcs.py +++ b/tests/providers/google/cloud/sensors/test_gcs.py @@ -94,6 +94,7 @@ def test_should_pass_argument_to_hook(self, mock_hook): task_id="task-id", bucket=TEST_BUCKET, object=TEST_OBJECT, + use_glob=False, google_cloud_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) @@ -108,6 +109,27 @@ def test_should_pass_argument_to_hook(self, mock_hook): ) mock_hook.return_value.exists.assert_called_once_with(TEST_BUCKET, TEST_OBJECT, DEFAULT_RETRY) + @mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook") + def test_should_pass_argument_to_hook_using_glob(self, mock_hook): + task = GCSObjectExistenceSensor( + task_id="task-id", + bucket=TEST_BUCKET, + object=TEST_OBJECT, + use_glob=True, + google_cloud_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + mock_hook.return_value.list.return_value = [mock.MagicMock()] + + result = task.poke(mock.MagicMock()) + + assert result is True + mock_hook.assert_called_once_with( + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET, match_glob=TEST_OBJECT) + @mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook") @mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceSensor.defer") def test_gcs_object_existence_sensor_finish_before_deferred(self, mock_defer, mock_hook): diff --git a/tests/providers/google/cloud/triggers/test_gcs.py b/tests/providers/google/cloud/triggers/test_gcs.py index 4fcde67711cb6..3c4bc9031a8de 100644 --- a/tests/providers/google/cloud/triggers/test_gcs.py +++ b/tests/providers/google/cloud/triggers/test_gcs.py @@ -55,6 +55,19 @@ def trigger(): return GCSBlobTrigger( bucket=TEST_BUCKET, object_name=TEST_OBJECT, + use_glob=False, + poke_interval=TEST_POLLING_INTERVAL, + google_cloud_conn_id=TEST_GCP_CONN_ID, + hook_params=TEST_HOOK_PARAMS, + ) + + +@pytest.fixture +def trigger_using_glob(): + return GCSBlobTrigger( + bucket=TEST_BUCKET, + object_name=TEST_OBJECT, + use_glob=True, poke_interval=TEST_POLLING_INTERVAL, google_cloud_conn_id=TEST_GCP_CONN_ID, hook_params=TEST_HOOK_PARAMS, @@ -73,6 +86,7 @@ def test_gcs_blob_trigger_serialization(self, trigger): assert kwargs == { "bucket": TEST_BUCKET, "object_name": TEST_OBJECT, + "use_glob": False, "poke_interval": TEST_POLLING_INTERVAL, "google_cloud_conn_id": TEST_GCP_CONN_ID, "hook_params": TEST_HOOK_PARAMS, @@ -141,6 +155,30 @@ async def test_object_exists(self, exists, response, trigger): assert res == response bucket.blob_exists.assert_called_once_with(blob_name=TEST_OBJECT) + @pytest.mark.asyncio + @pytest.mark.parametrize( + "blob_list,response", + [ + ([TEST_OBJECT], "success"), + ([], "pending"), + ], + ) + async def test_object_exists_using_glob(self, blob_list, response, trigger_using_glob): + """ + Tests to check if a particular object in Google Cloud Storage + is found or not + """ + hook = AsyncMock(GCSAsyncHook) + storage = AsyncMock(Storage) + hook.get_storage_client.return_value = storage + bucket = AsyncMock(Bucket) + storage.get_bucket.return_value = bucket + bucket.list_blobs.return_value = blob_list + + res = await trigger_using_glob._object_exists(hook, TEST_BUCKET, TEST_OBJECT) + assert res == response + bucket.list_blobs.assert_called_once_with(match_glob=TEST_OBJECT) + class TestGCSPrefixBlobTrigger: TRIGGER = GCSPrefixBlobTrigger(