Skip to content

Commit

Permalink
Refactor unneeded jumps in providers (#33833)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro authored Sep 1, 2023
1 parent 0e1c106 commit 875387a
Show file tree
Hide file tree
Showing 17 changed files with 108 additions and 142 deletions.
25 changes: 9 additions & 16 deletions airflow/providers/amazon/aws/hooks/datasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,25 +301,18 @@ def wait_for_task_execution(self, task_execution_arn: str, max_iterations: int =
if not task_execution_arn:
raise AirflowBadRequest("task_execution_arn not specified")

status = None
iterations = max_iterations
while status is None or status in self.TASK_EXECUTION_INTERMEDIATE_STATES:
for _ in range(max_iterations):
task_execution = self.get_conn().describe_task_execution(TaskExecutionArn=task_execution_arn)
status = task_execution["Status"]
self.log.info("status=%s", status)
iterations -= 1
if status in self.TASK_EXECUTION_FAILURE_STATES:
break
if status in self.TASK_EXECUTION_SUCCESS_STATES:
break
if iterations <= 0:
break
return True
elif status in self.TASK_EXECUTION_FAILURE_STATES:
return False
elif status is None or status in self.TASK_EXECUTION_INTERMEDIATE_STATES:
time.sleep(self.wait_interval_seconds)
else:
raise AirflowException(f"Unknown status: {status}") # Should never happen
time.sleep(self.wait_interval_seconds)

if status in self.TASK_EXECUTION_SUCCESS_STATES:
return True
if status in self.TASK_EXECUTION_FAILURE_STATES:
return False
if iterations <= 0:
else:
raise AirflowTaskTimeout("Max iterations exceeded!")
raise AirflowException(f"Unknown status: {status}") # Should never happen
12 changes: 6 additions & 6 deletions airflow/providers/amazon/aws/hooks/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,12 @@ def multi_stream_iter(self, log_group: str, streams: list, positions=None) -> Ge
]
events: list[Any | None] = []
for event_stream in event_iters:
if not event_stream:
events.append(None)
continue
try:
events.append(next(event_stream))
except StopIteration:
if event_stream:
try:
events.append(next(event_stream))
except StopIteration:
events.append(None)
else:
events.append(None)

while any(events):
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/amazon/aws/sensors/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,12 @@ def poke(self, context: Context):

if "Successful" not in response:
raise AirflowException(f"Delete SQS Messages failed {response} for messages {messages}")
if not message_batch:
if message_batch:
context["ti"].xcom_push(key="messages", value=message_batch)
return True
else:
return False

context["ti"].xcom_push(key="messages", value=message_batch)
return True

@deprecated(reason="use `hook` property instead.")
def get_hook(self) -> SqsHook:
"""Create and return an SqsHook."""
Expand Down
12 changes: 4 additions & 8 deletions airflow/providers/amazon/aws/utils/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,9 @@ def filter_messages_jsonpath(messages, message_filtering_match_values, message_f
# Body is a string, deserialize to an object and then parse
body = json.loads(body)
results = jsonpath_expr.find(body)
if not results:
continue
if message_filtering_match_values is None:
if results and (
message_filtering_match_values is None
or any(result.value in message_filtering_match_values for result in results)
):
filtered_messages.append(message)
continue
for result in results:
if result.value in message_filtering_match_values:
filtered_messages.append(message)
break
return filtered_messages
4 changes: 1 addition & 3 deletions airflow/providers/cncf/kubernetes/utils/delete_from.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def delete_from_yaml(
**kwargs,
):
for yml_document in yaml_objects:
if yml_document is None:
continue
else:
if yml_document is not None:
delete_from_dict(
k8s_client=k8s_client,
data=yml_document,
Expand Down
36 changes: 17 additions & 19 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2206,27 +2206,25 @@ def run_query(
if param_name == "schemaUpdateOptions" and param:
self.log.info("Adding experimental 'schemaUpdateOptions': %s", schema_update_options)

if param_name != "destinationTable":
continue

for key in ["projectId", "datasetId", "tableId"]:
if key not in configuration["query"]["destinationTable"]:
raise ValueError(
"Not correct 'destinationTable' in "
"api_resource_configs. 'destinationTable' "
"must be a dict with {'projectId':'', "
"'datasetId':'', 'tableId':''}"
if param_name == "destinationTable":
for key in ["projectId", "datasetId", "tableId"]:
if key not in configuration["query"]["destinationTable"]:
raise ValueError(
"Not correct 'destinationTable' in "
"api_resource_configs. 'destinationTable' "
"must be a dict with {'projectId':'', "
"'datasetId':'', 'tableId':''}"
)
else:
configuration["query"].update(
{
"allowLargeResults": allow_large_results,
"flattenResults": flatten_results,
"writeDisposition": write_disposition,
"createDisposition": create_disposition,
}
)

configuration["query"].update(
{
"allowLargeResults": allow_large_results,
"flattenResults": flatten_results,
"writeDisposition": write_disposition,
"createDisposition": create_disposition,
}
)

if (
"useLegacySql" in configuration["query"]
and configuration["query"]["useLegacySql"]
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/google/cloud/hooks/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,12 +371,12 @@ def delete_pipeline(
self._check_response_status_and_data(
response, f"Deleting a pipeline failed with code {response.status}: {response.data}"
)
if response.status == 200:
break
except ConflictException as exc:
self.log.info(exc)
sleep(time_to_wait)
continue
else:
if response.status == 200:
break

def list_pipelines(
self,
Expand Down
14 changes: 4 additions & 10 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,6 @@ def download(
# Wait with exponential backoff scheme before retrying.
timeout_seconds = 2 ** (num_file_attempts - 1)
time.sleep(timeout_seconds)
continue

def download_as_byte_array(
self,
Expand Down Expand Up @@ -508,28 +507,23 @@ def _call_with_retry(f: Callable[[], None]) -> None:
:param f: Callable that should be retried.
"""
num_file_attempts = 0

while num_file_attempts < num_max_attempts:
for attempt in range(1, 1 + num_max_attempts):
try:
num_file_attempts += 1
f()

except GoogleCloudError as e:
if num_file_attempts == num_max_attempts:
if attempt == num_max_attempts:
self.log.error(
"Upload attempt of object: %s from %s has failed. Attempt: %s, max %s.",
object_name,
object_name,
num_file_attempts,
attempt,
num_max_attempts,
)
raise e

# Wait with exponential backoff scheme before retrying.
timeout_seconds = 2 ** (num_file_attempts - 1)
timeout_seconds = 2 ** (attempt - 1)
time.sleep(timeout_seconds)
continue

client = self.get_conn()
bucket = client.bucket(bucket_name, user_project=user_project)
Expand Down
4 changes: 1 addition & 3 deletions airflow/providers/google/cloud/log/gcs_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,7 @@ def gcs_write(self, log, remote_log_location) -> bool:
old_log = blob.download_as_bytes().decode()
log = "\n".join([old_log, log]) if old_log else log
except Exception as e:
if self.no_log_found(e):
pass
else:
if not self.no_log_found(e):
log += self._add_message(
f"Error checking for previous log; if exists, may be overwritten: {e}"
)
Expand Down
45 changes: 21 additions & 24 deletions airflow/providers/google/cloud/operators/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,13 @@ def __init__(
def check_body_fields(self) -> None:
required_params = ["machine_type", "disks", "network_interfaces"]
for param in required_params:
if param in self.body:
continue
readable_param = param.replace("_", " ")
raise AirflowException(
f"The body '{self.body}' should contain at least {readable_param} for the new operator "
f"in the '{param}' field. Check (google.cloud.compute_v1.types.Instance) "
f"for more details about body fields description."
)
if param not in self.body:
readable_param = param.replace("_", " ")
raise AirflowException(
f"The body '{self.body}' should contain at least {readable_param} for the new operator "
f"in the '{param}' field. Check (google.cloud.compute_v1.types.Instance) "
f"for more details about body fields description."
)

def _validate_inputs(self) -> None:
super()._validate_inputs()
Expand Down Expand Up @@ -915,14 +914,13 @@ def __init__(
def check_body_fields(self) -> None:
required_params = ["machine_type", "disks", "network_interfaces"]
for param in required_params:
if param in self.body["properties"]:
continue
readable_param = param.replace("_", " ")
raise AirflowException(
f"The body '{self.body}' should contain at least {readable_param} for the new operator "
f"in the '{param}' field. Check (google.cloud.compute_v1.types.Instance) "
f"for more details about body fields description."
)
if param not in self.body["properties"]:
readable_param = param.replace("_", " ")
raise AirflowException(
f"The body '{self.body}' should contain at least {readable_param} for the new operator "
f"in the '{param}' field. Check (google.cloud.compute_v1.types.Instance) "
f"for more details about body fields description."
)

def _validate_all_body_fields(self) -> None:
if self._field_validator:
Expand Down Expand Up @@ -1500,14 +1498,13 @@ def __init__(
def check_body_fields(self) -> None:
required_params = ["base_instance_name", "target_size", "instance_template"]
for param in required_params:
if param in self.body:
continue
readable_param = param.replace("_", " ")
raise AirflowException(
f"The body '{self.body}' should contain at least {readable_param} for the new operator "
f"in the '{param}' field. Check (google.cloud.compute_v1.types.Instance) "
f"for more details about body fields description."
)
if param not in self.body:
readable_param = param.replace("_", " ")
raise AirflowException(
f"The body '{self.body}' should contain at least {readable_param} for the new operator "
f"in the '{param}' field. Check (google.cloud.compute_v1.types.Instance) "
f"for more details about body fields description."
)

def _validate_all_body_fields(self) -> None:
if self._field_validator:
Expand Down
10 changes: 4 additions & 6 deletions airflow/providers/google/cloud/operators/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,9 +797,8 @@ def execute(self, context: Context) -> list[str]:
num_max_attempts=self.download_num_attempts,
)
except GoogleCloudError:
if self.download_continue_on_fail:
continue
raise
if not self.download_continue_on_fail:
raise

self.log.info("Starting the transformation")
cmd = [self.transform_script] if isinstance(self.transform_script, str) else self.transform_script
Expand Down Expand Up @@ -847,9 +846,8 @@ def execute(self, context: Context) -> list[str]:
)
files_uploaded.append(str(upload_file_name))
except GoogleCloudError:
if self.upload_continue_on_fail:
continue
raise
if not self.upload_continue_on_fail:
raise

return files_uploaded

Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/microsoft/azure/hooks/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,8 +855,8 @@ def wait_for_pipeline_run_status(
except ServiceRequestError:
if executed_after_token_refresh:
self.refresh_conn()
continue
raise
else:
raise

return pipeline_run_status in expected_statuses

Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/microsoft/azure/triggers/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
if executed_after_token_refresh:
await hook.refresh_conn()
executed_after_token_refresh = False
continue
raise
else:
raise
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})

Expand Down Expand Up @@ -207,8 +207,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
if executed_after_token_refresh:
await hook.refresh_conn()
executed_after_token_refresh = False
continue
raise
else:
raise

yield TriggerEvent(
{
Expand Down
14 changes: 6 additions & 8 deletions airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,16 +208,14 @@ def _include_fields(self):
raise Exception("Don't use both includes and excludes.")
if self.includes:
for field in self.includes:
if field in self._fields or not hasattr(self.obj, field):
continue
setattr(self, field, getattr(self.obj, field))
self._fields.append(field)
if field not in self._fields and hasattr(self.obj, field):
setattr(self, field, getattr(self.obj, field))
self._fields.append(field)
else:
for field, val in self.obj.__dict__.items():
if field in self._fields or field in self.excludes or field in self.renames:
continue
setattr(self, field, val)
self._fields.append(field)
if field not in self._fields and field not in self.excludes and field not in self.renames:
setattr(self, field, val)
self._fields.append(field)


class DagInfo(InfoJsonEncodable):
Expand Down
24 changes: 12 additions & 12 deletions airflow/providers/smtp/hooks/smtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ def get_conn(self) -> SmtpHook:
try:
self.smtp_client = self._build_client()
except smtplib.SMTPServerDisconnected:
if attempt < self.smtp_retry_limit:
continue
raise AirflowException("Unable to connect to smtp server")
if self.smtp_starttls:
self.smtp_client.starttls()
if self.smtp_user and self.smtp_password:
self.smtp_client.login(self.smtp_user, self.smtp_password)
break
if attempt == self.smtp_retry_limit:
raise AirflowException("Unable to connect to smtp server")
else:
if self.smtp_starttls:
self.smtp_client.starttls()
if self.smtp_user and self.smtp_password:
self.smtp_client.login(self.smtp_user, self.smtp_password)
break

return self

Expand Down Expand Up @@ -234,10 +234,10 @@ def send_email_smtp(
from_addr=from_email, to_addrs=recipients, msg=mime_msg.as_string()
)
except smtplib.SMTPServerDisconnected as e:
if attempt < self.smtp_retry_limit:
continue
raise e
break
if attempt == self.smtp_retry_limit:
raise e
else:
break

def _build_mime_message(
self,
Expand Down
Loading

0 comments on commit 875387a

Please sign in to comment.