Skip to content

Commit

Permalink
Enable string normalization in python formatting - providers (#27205)
Browse files Browse the repository at this point in the history
  • Loading branch information
dstandish committed Oct 23, 2022
1 parent e789a2b commit 2a34dc9
Show file tree
Hide file tree
Showing 488 changed files with 7,830 additions and 7,823 deletions.
11 changes: 9 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,16 @@ repos:
rev: 22.3.0
hooks:
- id: black
name: Run Black (the uncompromising Python code formatter)
name: Run black (python formatter) on core
args: [--config=./pyproject.toml, --skip-string-normalization]
exclude: ^airflow/_vendor/|^airflow/contrib/|^airflow/providers/
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
name: Run black (python formatter) on providers
args: [--config=./pyproject.toml]
exclude: ^airflow/_vendor/|^airflow/contrib
files: ^airflow/providers/
- repo: https://github.com/asottile/blacken-docs
rev: v1.12.1
hooks:
Expand Down
3 changes: 2 additions & 1 deletion STATIC_CODE_CHECKS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ require Breeze Docker image to be build locally.
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| ID | Description | Image |
+========================================================+==================================================================+=========+
| black | Run Black (the uncompromising Python code formatter) | |
| black | * Run black (python formatter) on core | |
| | * Run black (python formatter) on providers | |
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| blacken-docs | Run black on python code blocks in documentation files | |
+--------------------------------------------------------+------------------------------------------------------------------+---------+
Expand Down
16 changes: 8 additions & 8 deletions airflow/providers/airbyte/hooks/airbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ class AirbyteHook(HttpHook):
:param api_version: Optional. Airbyte API version.
"""

conn_name_attr = 'airbyte_conn_id'
default_conn_name = 'airbyte_default'
conn_type = 'airbyte'
hook_name = 'Airbyte'
conn_name_attr = "airbyte_conn_id"
default_conn_name = "airbyte_default"
conn_type = "airbyte"
hook_name = "Airbyte"

RUNNING = "running"
SUCCEEDED = "succeeded"
Expand Down Expand Up @@ -121,19 +121,19 @@ def cancel_job(self, job_id: int) -> Any:

def test_connection(self):
"""Tests the Airbyte connection by hitting the health API"""
self.method = 'GET'
self.method = "GET"
try:
res = self.run(
endpoint=f"api/{self.api_version}/health",
headers={"accept": "application/json"},
extra_options={'check_response': False},
extra_options={"check_response": False},
)

if res.status_code == 200:
return True, 'Connection successfully tested'
return True, "Connection successfully tested"
else:
return False, res.text
except Exception as e:
return False, str(e)
finally:
self.method = 'POST'
self.method = "POST"
10 changes: 5 additions & 5 deletions airflow/providers/airbyte/operators/airbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class AirbyteTriggerSyncOperator(BaseOperator):
Only used when ``asynchronous`` is False.
"""

template_fields: Sequence[str] = ('connection_id',)
template_fields: Sequence[str] = ("connection_id",)

def __init__(
self,
Expand All @@ -71,18 +71,18 @@ def execute(self, context: Context) -> None:
"""Create Airbyte Job and wait to finish"""
self.hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version)
job_object = self.hook.submit_sync_connection(connection_id=self.connection_id)
self.job_id = job_object.json()['job']['id']
self.job_id = job_object.json()["job"]["id"]

self.log.info("Job %s was submitted to Airbyte Server", self.job_id)
if not self.asynchronous:
self.log.info('Waiting for job %s to complete', self.job_id)
self.log.info("Waiting for job %s to complete", self.job_id)
self.hook.wait_for_job(job_id=self.job_id, wait_seconds=self.wait_seconds, timeout=self.timeout)
self.log.info('Job %s completed successfully', self.job_id)
self.log.info("Job %s completed successfully", self.job_id)

return self.job_id

def on_kill(self):
"""Cancel the job if task is cancelled"""
if self.job_id:
self.log.info('on_kill: cancel the airbyte Job %s', self.job_id)
self.log.info("on_kill: cancel the airbyte Job %s", self.job_id)
self.hook.cancel_job(self.job_id)
8 changes: 4 additions & 4 deletions airflow/providers/airbyte/sensors/airbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ class AirbyteJobSensor(BaseSensorOperator):
:param api_version: Optional. Airbyte API version.
"""

template_fields: Sequence[str] = ('airbyte_job_id',)
ui_color = '#6C51FD'
template_fields: Sequence[str] = ("airbyte_job_id",)
ui_color = "#6C51FD"

def __init__(
self,
*,
airbyte_job_id: int,
airbyte_conn_id: str = 'airbyte_default',
airbyte_conn_id: str = "airbyte_default",
api_version: str = "v1",
**kwargs,
) -> None:
Expand All @@ -57,7 +57,7 @@ def __init__(
def poke(self, context: Context) -> bool:
hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version)
job = hook.get_job(job_id=self.airbyte_job_id)
status = job.json()['job']['status']
status = job.json()["job"]["status"]

if status == hook.FAILED:
raise AirflowException(f"Job failed: \n{job}")
Expand Down
44 changes: 22 additions & 22 deletions airflow/providers/alibaba/cloud/hooks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def provide_bucket_name(func: T) -> T:
def wrapper(*args, **kwargs) -> T:
bound_args = function_signature.bind(*args, **kwargs)
self = args[0]
if bound_args.arguments.get('bucket_name') is None and self.oss_conn_id:
if bound_args.arguments.get("bucket_name") is None and self.oss_conn_id:
connection = self.get_connection(self.oss_conn_id)
if connection.schema:
bound_args.arguments['bucket_name'] = connection.schema
bound_args.arguments["bucket_name"] = connection.schema

return func(*bound_args.args, **bound_args.kwargs)

Expand All @@ -67,13 +67,13 @@ def wrapper(*args, **kwargs) -> T:
bound_args = function_signature.bind(*args, **kwargs)

def get_key() -> str:
if 'key' in bound_args.arguments:
return 'key'
raise ValueError('Missing key parameter!')
if "key" in bound_args.arguments:
return "key"
raise ValueError("Missing key parameter!")

key_name = get_key()
if 'bucket_name' not in bound_args.arguments or bound_args.arguments['bucket_name'] is None:
bound_args.arguments['bucket_name'], bound_args.arguments['key'] = OSSHook.parse_oss_url(
if "bucket_name" not in bound_args.arguments or bound_args.arguments["bucket_name"] is None:
bound_args.arguments["bucket_name"], bound_args.arguments["key"] = OSSHook.parse_oss_url(
bound_args.arguments[key_name]
)

Expand All @@ -85,12 +85,12 @@ def get_key() -> str:
class OSSHook(BaseHook):
"""Interact with Alibaba Cloud OSS, using the oss2 library."""

conn_name_attr = 'alibabacloud_conn_id'
default_conn_name = 'oss_default'
conn_type = 'oss'
hook_name = 'OSS'
conn_name_attr = "alibabacloud_conn_id"
default_conn_name = "oss_default"
conn_type = "oss"
hook_name = "OSS"

def __init__(self, region: str | None = None, oss_conn_id='oss_default', *args, **kwargs) -> None:
def __init__(self, region: str | None = None, oss_conn_id="oss_default", *args, **kwargs) -> None:
self.oss_conn_id = oss_conn_id
self.oss_conn = self.get_connection(oss_conn_id)
self.region = self.get_default_region() if region is None else region
Expand All @@ -114,7 +114,7 @@ def parse_oss_url(ossurl: str) -> tuple:
raise AirflowException(f'Please provide a bucket_name instead of "{ossurl}"')

bucket_name = parsed_url.netloc
key = parsed_url.path.lstrip('/')
key = parsed_url.path.lstrip("/")

return bucket_name, key

Expand Down Expand Up @@ -146,7 +146,7 @@ def get_bucket(self, bucket_name: str | None = None) -> oss2.api.Bucket:
"""
auth = self.get_credential()
assert self.region is not None
return oss2.Bucket(auth, f'https://oss-{self.region}.aliyuncs.com', bucket_name)
return oss2.Bucket(auth, f"https://oss-{self.region}.aliyuncs.com", bucket_name)

@provide_bucket_name
@unify_bucket_name_and_key
Expand Down Expand Up @@ -337,7 +337,7 @@ def key_exist(self, bucket_name: str | None, key: str) -> bool:
:param key: oss bucket key
"""
# full_path = None
self.log.info('Looking up oss bucket %s for bucket key %s ...', bucket_name, key)
self.log.info("Looking up oss bucket %s for bucket key %s ...", bucket_name, key)
try:
return self.get_bucket(bucket_name).object_exists(key)
except Exception as e:
Expand All @@ -346,14 +346,14 @@ def key_exist(self, bucket_name: str | None, key: str) -> bool:

def get_credential(self) -> oss2.auth.Auth:
extra_config = self.oss_conn.extra_dejson
auth_type = extra_config.get('auth_type', None)
auth_type = extra_config.get("auth_type", None)
if not auth_type:
raise Exception("No auth_type specified in extra_config. ")

if auth_type != 'AK':
if auth_type != "AK":
raise Exception(f"Unsupported auth_type: {auth_type}")
oss_access_key_id = extra_config.get('access_key_id', None)
oss_access_key_secret = extra_config.get('access_key_secret', None)
oss_access_key_id = extra_config.get("access_key_id", None)
oss_access_key_secret = extra_config.get("access_key_secret", None)
if not oss_access_key_id:
raise Exception(f"No access_key_id is specified for connection: {self.oss_conn_id}")

Expand All @@ -364,14 +364,14 @@ def get_credential(self) -> oss2.auth.Auth:

def get_default_region(self) -> str | None:
extra_config = self.oss_conn.extra_dejson
auth_type = extra_config.get('auth_type', None)
auth_type = extra_config.get("auth_type", None)
if not auth_type:
raise Exception("No auth_type specified in extra_config. ")

if auth_type != 'AK':
if auth_type != "AK":
raise Exception(f"Unsupported auth_type: {auth_type}")

default_region = extra_config.get('region', None)
default_region = extra_config.get("region", None)
if not default_region:
raise Exception(f"No region is specified for connection: {self.oss_conn_id}")
return default_region
24 changes: 12 additions & 12 deletions airflow/providers/alibaba/cloud/log/oss_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@ def __init__(self, base_log_folder, oss_log_folder, filename_template=None):
self.log.info("Using oss_task_handler for remote logging...")
super().__init__(base_log_folder, filename_template)
(self.bucket_name, self.base_folder) = OSSHook.parse_oss_url(oss_log_folder)
self.log_relative_path = ''
self.log_relative_path = ""
self._hook = None
self.closed = False
self.upload_on_close = True

@cached_property
def hook(self):
remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID')
remote_conn_id = conf.get("logging", "REMOTE_LOG_CONN_ID")
self.log.info("remote_conn_id: %s", remote_conn_id)
try:
return OSSHook(oss_conn_id=remote_conn_id)
except Exception as e:
self.log.error(e, exc_info=True)
self.log.error(
'Could not create an OSSHook with connection id "%s". '
'Please make sure that airflow[oss] is installed and '
'the OSS connection exists.',
"Please make sure that airflow[oss] is installed and "
"the OSS connection exists.",
remote_conn_id,
)

Expand All @@ -70,7 +70,7 @@ def set_context(self, ti):
# Clear the file first so that duplicate data is not uploaded
# when re-using the same path (e.g. with rescheduled sensors)
if self.upload_on_close:
with open(self.handler.baseFilename, 'w'):
with open(self.handler.baseFilename, "w"):
pass

def close(self):
Expand Down Expand Up @@ -119,8 +119,8 @@ def _read(self, ti, try_number, metadata=None):
# local machine even if there are errors reading remote logs, as
# returned remote_log will contain error messages.
remote_log = self.oss_read(remote_loc, return_error=True)
log = f'*** Reading remote log from {remote_loc}.\n{remote_log}\n'
return log, {'end_of_log': True}
log = f"*** Reading remote log from {remote_loc}.\n{remote_log}\n"
return log, {"end_of_log": True}

def oss_log_exists(self, remote_log_location):
"""
Expand All @@ -129,7 +129,7 @@ def oss_log_exists(self, remote_log_location):
:param remote_log_location: log's location in remote storage
:return: True if location exists else False
"""
oss_remote_log_location = f'{self.base_folder}/{remote_log_location}'
oss_remote_log_location = f"{self.base_folder}/{remote_log_location}"
with contextlib.suppress(Exception):
return self.hook.key_exist(self.bucket_name, oss_remote_log_location)
return False
Expand All @@ -144,11 +144,11 @@ def oss_read(self, remote_log_location, return_error=False):
error occurs. Otherwise returns '' when an error occurs.
"""
try:
oss_remote_log_location = f'{self.base_folder}/{remote_log_location}'
oss_remote_log_location = f"{self.base_folder}/{remote_log_location}"
self.log.info("read remote log: %s", oss_remote_log_location)
return self.hook.read_key(self.bucket_name, oss_remote_log_location)
except Exception:
msg = f'Could not read logs from {oss_remote_log_location}'
msg = f"Could not read logs from {oss_remote_log_location}"
self.log.exception(msg)
# return error if needed
if return_error:
Expand All @@ -164,7 +164,7 @@ def oss_write(self, log, remote_log_location, append=True):
:param append: if False, any existing log file is overwritten. If True,
the new log is appended to any existing logs.
"""
oss_remote_log_location = f'{self.base_folder}/{remote_log_location}'
oss_remote_log_location = f"{self.base_folder}/{remote_log_location}"
pos = 0
if append and self.oss_log_exists(oss_remote_log_location):
head = self.hook.head_key(self.bucket_name, oss_remote_log_location)
Expand All @@ -175,7 +175,7 @@ def oss_write(self, log, remote_log_location, append=True):
self.hook.append_string(self.bucket_name, log, oss_remote_log_location, pos)
except Exception:
self.log.exception(
'Could not write logs to %s, log write pos is: %s, Append is %s',
"Could not write logs to %s, log write pos is: %s, Append is %s",
oss_remote_log_location,
str(pos),
str(append),
Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/alibaba/cloud/operators/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self,
region: str,
bucket_name: str | None = None,
oss_conn_id: str = 'oss_default',
oss_conn_id: str = "oss_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -66,7 +66,7 @@ def __init__(
self,
region: str,
bucket_name: str | None = None,
oss_conn_id: str = 'oss_default',
oss_conn_id: str = "oss_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(
file: str,
region: str,
bucket_name: str | None = None,
oss_conn_id: str = 'oss_default',
oss_conn_id: str = "oss_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -128,7 +128,7 @@ def __init__(
file: str,
region: str,
bucket_name: str | None = None,
oss_conn_id: str = 'oss_default',
oss_conn_id: str = "oss_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -158,7 +158,7 @@ def __init__(
keys: list,
region: str,
bucket_name: str | None = None,
oss_conn_id: str = 'oss_default',
oss_conn_id: str = "oss_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -187,7 +187,7 @@ def __init__(
key: str,
region: str,
bucket_name: str | None = None,
oss_conn_id: str = 'oss_default',
oss_conn_id: str = "oss_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down
Loading

0 comments on commit 2a34dc9

Please sign in to comment.