Skip to content

Commit

Permalink
Move the try outside the loop when this is possible in Google provider (
Browse files Browse the repository at this point in the history
#33976)

* Move the try outside the loop when this is possible in Google provider


---------

Co-authored-by: Tzu-ping Chung <[email protected]>
  • Loading branch information
hussein-awala and uranusjr authored Sep 3, 2023
1 parent 8918b43 commit 4f20b0f
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 118 deletions.
66 changes: 30 additions & 36 deletions airflow/providers/google/cloud/triggers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Gets current job execution status and yields a TriggerEvent."""
"""Gets current job execution status and yields a TriggerEvent."""
hook = self._get_async_hook()
while True:
try:
try:
while True:
job_status = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
if job_status == "success":
yield TriggerEvent(
Expand All @@ -95,10 +95,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"Bigquery job status is %s. Sleeping for %s seconds.", job_status, self.poll_interval
)
await asyncio.sleep(self.poll_interval)
except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})
return
except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})

def _get_async_hook(self) -> BigQueryAsyncHook:
return BigQueryAsyncHook(gcp_conn_id=self.conn_id)
Expand All @@ -124,8 +123,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Gets current job execution status and yields a TriggerEvent."""
hook = self._get_async_hook()
while True:
try:
try:
while True:
# Poll for job execution status
job_status = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
if job_status == "success":
Expand Down Expand Up @@ -160,10 +159,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"Bigquery job status is %s. Sleeping for %s seconds.", job_status, self.poll_interval
)
await asyncio.sleep(self.poll_interval)
except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})
return
except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})


class BigQueryGetDataTrigger(BigQueryInsertJobTrigger):
Expand Down Expand Up @@ -196,8 +194,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Gets current job execution status and yields a TriggerEvent with response data."""
hook = self._get_async_hook()
while True:
try:
try:
while True:
# Poll for job execution status
job_status = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
if job_status == "success":
Expand All @@ -220,10 +218,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"Bigquery job status is %s. Sleeping for %s seconds.", job_status, self.poll_interval
)
await asyncio.sleep(self.poll_interval)
except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})
return
except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})


class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
Expand Down Expand Up @@ -302,8 +299,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Gets current job execution status and yields a TriggerEvent."""
hook = self._get_async_hook()
while True:
try:
try:
while True:
first_job_response_from_hook = await hook.get_job_status(
job_id=self.first_job_id, project_id=self.project_id
)
Expand Down Expand Up @@ -365,10 +362,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
)
return

except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})
return
except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})


class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
Expand Down Expand Up @@ -430,8 +426,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Gets current job execution status and yields a TriggerEvent."""
hook = self._get_async_hook()
while True:
try:
try:
while True:
# Poll for job execution status
response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
if response_from_hook == "success":
Expand All @@ -448,10 +444,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
else:
yield TriggerEvent({"status": "error", "message": response_from_hook, "records": None})
return
except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})
return
except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})


class BigQueryTableExistenceTrigger(BaseTrigger):
Expand Down Expand Up @@ -501,8 +496,8 @@ def _get_async_hook(self) -> BigQueryTableAsyncHook:

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Will run until the table exists in the Google Big Query."""
while True:
try:
try:
while True:
hook = self._get_async_hook()
response = await self._table_exists(
hook=hook, dataset=self.dataset_id, table_id=self.table_id, project_id=self.project_id
Expand All @@ -511,10 +506,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
yield TriggerEvent({"status": "success", "message": "success"})
return
await asyncio.sleep(self.poll_interval)
except Exception as e:
self.log.exception("Exception occurred while checking for Table existence")
yield TriggerEvent({"status": "error", "message": str(e)})
return
except Exception as e:
self.log.exception("Exception occurred while checking for Table existence")
yield TriggerEvent({"status": "error", "message": str(e)})

async def _table_exists(
self, hook: BigQueryTableAsyncHook, dataset: str, table_id: str, project_id: str
Expand Down
19 changes: 9 additions & 10 deletions airflow/providers/google/cloud/triggers/bigquery_dts.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
async def run(self) -> AsyncIterator[TriggerEvent]:
"""If the Transfer Run is in a terminal state, then yield TriggerEvent object."""
hook = self._get_async_hook()
while True:
try:
try:
while True:
transfer_run: TransferRun = await hook.get_transfer_run(
project_id=self.project_id,
config_id=self.config_id,
Expand Down Expand Up @@ -129,14 +129,13 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
self.log.info("Job is still working...")
self.log.info("Waiting for %s seconds", self.poll_interval)
await asyncio.sleep(self.poll_interval)
except Exception as e:
yield TriggerEvent(
{
"status": "failed",
"message": f"Trigger failed with exception: {e}",
}
)
return
except Exception as e:
yield TriggerEvent(
{
"status": "failed",
"message": f"Trigger failed with exception: {e}",
}
)

def _get_async_hook(self) -> AsyncBiqQueryDataTransferServiceHook:
return AsyncBiqQueryDataTransferServiceHook(
Expand Down
14 changes: 6 additions & 8 deletions airflow/providers/google/cloud/triggers/cloud_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
"""
timeout = self.timeout
hook = self._get_async_hook()
while timeout is None or timeout > 0:

try:
try:
while timeout is None or timeout > 0:
job: Job = await hook.get_batch_job(job_name=self.job_name)

status: JobStatus.State = job.status.state
Expand Down Expand Up @@ -134,10 +133,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
if timeout is None or timeout > 0:
await asyncio.sleep(self.polling_period_seconds)

except Exception as e:
self.log.exception("Exception occurred while checking for job completion.")
yield TriggerEvent({"status": "error", "message": str(e)})
return
except Exception as e:
self.log.exception("Exception occurred while checking for job completion.")
yield TriggerEvent({"status": "error", "message": str(e)})
return

self.log.exception(f"Job with name [{self.job_name}] timed out")
yield TriggerEvent(
Expand All @@ -147,7 +146,6 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
"message": f"Batch job with name {self.job_name} timed out",
}
)
return

def _get_async_hook(self) -> CloudBatchAsyncHook:
return CloudBatchAsyncHook(
Expand Down
11 changes: 5 additions & 6 deletions airflow/providers/google/cloud/triggers/cloud_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Gets current build execution status and yields a TriggerEvent."""
hook = self._get_async_hook()
while True:
try:
try:
while True:
# Poll for job execution status
cloud_build_instance = await hook.get_cloud_build(
id_=self.id_,
Expand Down Expand Up @@ -119,10 +119,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
)
return

except Exception as e:
self.log.exception("Exception occurred while checking for Cloud Build completion")
yield TriggerEvent({"status": "error", "message": str(e)})
return
except Exception as e:
self.log.exception("Exception occurred while checking for Cloud Build completion")
yield TriggerEvent({"status": "error", "message": str(e)})

def _get_async_hook(self) -> CloudBuildAsyncHook:
return CloudBuildAsyncHook(gcp_conn_id=self.gcp_conn_id)
20 changes: 10 additions & 10 deletions airflow/providers/google/cloud/triggers/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def serialize(self):
)

async def run(self):
while True:
try:
try:
while True:
operation = await self.hook.get_operation(
project_id=self.project_id, operation_name=self.operation_name
)
Expand Down Expand Up @@ -93,11 +93,11 @@ async def run(self):
self.poke_interval,
)
await asyncio.sleep(self.poke_interval)
except Exception as e:
self.log.exception("Exception occurred while checking operation status.")
yield TriggerEvent(
{
"status": "failed",
"message": str(e),
}
)
except Exception as e:
self.log.exception("Exception occurred while checking operation status.")
yield TriggerEvent(
{
"status": "failed",
"message": str(e),
}
)
11 changes: 5 additions & 6 deletions airflow/providers/google/cloud/triggers/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ async def run(self):
amount of time stored in self.poll_sleep variable.
"""
hook = self._get_async_hook()
while True:
try:
try:
while True:
status = await hook.get_job_status(
project_id=self.project_id,
job_id=self.job_id,
Expand Down Expand Up @@ -129,10 +129,9 @@ async def run(self):
self.log.info("Current job status is: %s", status)
self.log.info("Sleeping for %s seconds.", self.poll_sleep)
await asyncio.sleep(self.poll_sleep)
except Exception as e:
self.log.exception("Exception occurred while checking for job completion.")
yield TriggerEvent({"status": "error", "message": str(e)})
return
except Exception as e:
self.log.exception("Exception occurred while checking for job completion.")
yield TriggerEvent({"status": "error", "message": str(e)})

def _get_async_hook(self) -> AsyncDataflowHook:
return AsyncDataflowHook(
Expand Down
11 changes: 5 additions & 6 deletions airflow/providers/google/cloud/triggers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Gets current pipeline status and yields a TriggerEvent."""
hook = self._get_async_hook()
while True:
try:
try:
while True:
# Poll for job execution status
response_from_hook = await hook.get_pipeline_status(
success_states=self.success_states,
Expand All @@ -109,10 +109,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
else:
yield TriggerEvent({"status": "error", "message": response_from_hook})
return
except Exception as e:
self.log.exception("Exception occurred while checking for pipeline state")
yield TriggerEvent({"status": "error", "message": str(e)})
return
except Exception as e:
self.log.exception("Exception occurred while checking for pipeline state")
yield TriggerEvent({"status": "error", "message": str(e)})

def _get_async_hook(self) -> DataFusionAsyncHook:
return DataFusionAsyncHook(
Expand Down
38 changes: 18 additions & 20 deletions airflow/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Wait until cluster is deleted completely."""
while self.end_time > time.time():
try:
try:
while self.end_time > time.time():
cluster = await self.get_async_hook().get_cluster(
region=self.region, # type: ignore[arg-type]
cluster_name=self.cluster_name,
Expand All @@ -277,13 +277,12 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
self.polling_interval_seconds,
)
await asyncio.sleep(self.polling_interval_seconds)
except NotFound:
yield TriggerEvent({"status": "success", "message": ""})
return
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
return
yield TriggerEvent({"status": "error", "message": "Timeout"})
except NotFound:
yield TriggerEvent({"status": "success", "message": ""})
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
else:
yield TriggerEvent({"status": "error", "message": "Timeout"})


class DataprocWorkflowTrigger(DataprocBaseTrigger):
Expand Down Expand Up @@ -312,8 +311,8 @@ def serialize(self):

async def run(self) -> AsyncIterator[TriggerEvent]:
hook = self.get_async_hook()
while True:
try:
try:
while True:
operation = await hook.get_operation(region=self.region, operation_name=self.name)
if operation.done:
if operation.error.message:
Expand All @@ -338,12 +337,11 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
else:
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
await asyncio.sleep(self.polling_interval_seconds)
except Exception as e:
self.log.exception("Exception occurred while checking operation status.")
yield TriggerEvent(
{
"status": "failed",
"message": str(e),
}
)
return
except Exception as e:
self.log.exception("Exception occurred while checking operation status.")
yield TriggerEvent(
{
"status": "failed",
"message": str(e),
}
)
Loading

0 comments on commit 4f20b0f

Please sign in to comment.