Skip to content

Commit

Permalink
Replace urlparse with urlsplit (#27389)
Browse files Browse the repository at this point in the history
* Replace urlparse with urlsplit in s3 files

Co-authored-by: eladkal <[email protected]>
  • Loading branch information
westonkl and eladkal authored Nov 14, 2022
1 parent 93699a3 commit 00af5c0
Show file tree
Hide file tree
Showing 27 changed files with 64 additions and 69 deletions.
13 changes: 4 additions & 9 deletions airflow/cli/commands/connection_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import warnings
from pathlib import Path
from typing import Any
from urllib.parse import urlparse, urlunparse
from urllib.parse import urlsplit, urlunsplit

from sqlalchemy.orm import exc

Expand Down Expand Up @@ -133,12 +133,12 @@ def _is_stdout(fileio: io.TextIOWrapper) -> bool:

def _valid_uri(uri: str) -> bool:
"""Check if a URI is valid, by checking if both scheme and netloc are available"""
uri_parts = urlparse(uri)
uri_parts = urlsplit(uri)
return uri_parts.scheme != "" and uri_parts.netloc != ""


@cache
def _get_connection_types():
def _get_connection_types() -> list[str]:
"""Returns connection types available."""
_connection_types = ["fs", "mesos_framework-id", "email", "generic"]
providers_manager = ProvidersManager()
Expand All @@ -148,10 +148,6 @@ def _get_connection_types():
return _connection_types


def _valid_conn_type(conn_type: str) -> bool:
return conn_type in _get_connection_types()


def connections_export(args):
"""Exports all connections to a file"""
file_formats = [".yaml", ".json", ".env"]
Expand Down Expand Up @@ -269,15 +265,14 @@ def connections_add(args):
msg = msg.format(
conn_id=new_conn.conn_id,
uri=args.conn_uri
or urlunparse(
or urlunsplit(
(
new_conn.conn_type,
f"{new_conn.login or ''}:{'******' if new_conn.password else ''}"
f"@{new_conn.host or ''}:{new_conn.port or ''}",
new_conn.schema or "",
"",
"",
"",
)
),
)
Expand Down
6 changes: 3 additions & 3 deletions airflow/config_templates/airflow_local_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
from urllib.parse import urlsplit

from airflow.configuration import conf
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -221,7 +221,7 @@

DEFAULT_LOGGING_CONFIG["handlers"].update(S3_REMOTE_HANDLERS)
elif REMOTE_BASE_LOG_FOLDER.startswith("cloudwatch://"):
url_parts = urlparse(REMOTE_BASE_LOG_FOLDER)
url_parts = urlsplit(REMOTE_BASE_LOG_FOLDER)
CLOUDWATCH_REMOTE_HANDLERS: dict[str, dict[str, str | None]] = {
"task": {
"class": "airflow.providers.amazon.aws.log.cloudwatch_task_handler.CloudwatchTaskHandler",
Expand Down Expand Up @@ -264,7 +264,7 @@
elif REMOTE_BASE_LOG_FOLDER.startswith("stackdriver://"):
key_path = conf.get_mandatory_value("logging", "GOOGLE_KEY_PATH", fallback=None)
# stackdriver:///airflow-tasks => airflow-tasks
log_name = urlparse(REMOTE_BASE_LOG_FOLDER).path[1:]
log_name = urlsplit(REMOTE_BASE_LOG_FOLDER).path[1:]
STACKDRIVER_REMOTE_HANDLERS = {
"task": {
"class": "airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverTaskHandler",
Expand Down
4 changes: 2 additions & 2 deletions airflow/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from json.decoder import JSONDecodeError
from re import Pattern
from typing import IO, Any, Dict, Iterable, Tuple, Union
from urllib.parse import urlparse
from urllib.parse import urlsplit

from typing_extensions import overload

Expand Down Expand Up @@ -403,7 +403,7 @@ def _upgrade_postgres_metastore_conn(self):
old_value = self.get(section, key)
bad_schemes = ["postgres+psycopg2", "postgres"]
good_scheme = "postgresql"
parsed = urlparse(old_value)
parsed = urlsplit(old_value)
if parsed.scheme in bad_schemes:
warnings.warn(
f"Bad scheme in Airflow configuration core > sql_alchemy_conn: `{parsed.scheme}`. "
Expand Down
4 changes: 2 additions & 2 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

from typing import Any
from urllib.parse import urlparse
from urllib.parse import urlsplit

import attr

Expand All @@ -37,6 +37,6 @@ def _check_uri(self, attr, uri: str):
uri.encode("ascii")
except UnicodeEncodeError:
raise ValueError(f"{attr.name!r} must be ascii")
parsed = urlparse(uri)
parsed = urlsplit(uri)
if parsed.scheme and parsed.scheme.lower() == "airflow":
raise ValueError(f"{attr.name!r} scheme `airflow` is reserved")
4 changes: 2 additions & 2 deletions airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import warnings
from json import JSONDecodeError
from urllib.parse import parse_qsl, quote, unquote, urlencode, urlparse
from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit

from sqlalchemy import Boolean, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declared_attr
Expand Down Expand Up @@ -188,7 +188,7 @@ def _normalize_conn_type(conn_type):
return conn_type

def _parse_from_uri(self, uri: str):
uri_parts = urlparse(uri)
uri_parts = urlsplit(uri)
conn_type = uri_parts.scheme
self.conn_type = self._normalize_conn_type(conn_type)
self.host = _parse_netloc_to_hostname(uri_parts)
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from urllib.parse import urlparse
from urllib.parse import urlsplit

import sqlalchemy_jsonfield
from sqlalchemy import (
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(self, uri: str, **kwargs):
uri.encode("ascii")
except UnicodeEncodeError:
raise ValueError("URI must be ascii")
parsed = urlparse(uri)
parsed = urlsplit(uri)
if parsed.scheme and parsed.scheme.lower() == "airflow":
raise ValueError("Scheme `airflow` is reserved.")
super().__init__(uri=uri, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/alibaba/cloud/hooks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from functools import wraps
from inspect import signature
from typing import TYPE_CHECKING, Callable, TypeVar, cast
from urllib.parse import urlparse
from urllib.parse import urlsplit

import oss2
from oss2.exceptions import ClientError
Expand Down Expand Up @@ -108,7 +108,7 @@ def parse_oss_url(ossurl: str) -> tuple:
:param ossurl: The OSS Url to parse.
:return: the parsed bucket name and key
"""
parsed_url = urlparse(ossurl)
parsed_url = urlsplit(ossurl)

if not parsed_url.netloc:
raise AirflowException(f'Please provide a bucket_name instead of "{ossurl}"')
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/alibaba/cloud/sensors/oss_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence
from urllib.parse import urlparse
from urllib.parse import urlsplit

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -69,13 +69,13 @@ def poke(self, context: Context):
@returns True if the object exists, False otherwise
"""
if self.bucket_name is None:
parsed_url = urlparse(self.bucket_key)
parsed_url = urlsplit(self.bucket_key)
if parsed_url.netloc == "":
raise AirflowException("If key is a relative path from root, please provide a bucket_name")
self.bucket_name = parsed_url.netloc
self.bucket_key = parsed_url.path.lstrip("/")
else:
parsed_url = urlparse(self.bucket_key)
parsed_url = urlsplit(self.bucket_key)
if parsed_url.scheme != "" or parsed_url.netloc != "":
raise AirflowException(
"If bucket_name is provided, bucket_key"
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pathlib import Path
from tempfile import NamedTemporaryFile, gettempdir
from typing import Any, Callable, TypeVar, cast
from urllib.parse import urlparse
from urllib.parse import urlsplit
from uuid import uuid4

from boto3.s3.transfer import S3Transfer, TransferConfig
Expand Down Expand Up @@ -153,7 +153,7 @@ def parse_s3_url(s3url: str) -> tuple[str, str]:
"""
format = s3url.split("//")
if format[0].lower() == "s3:":
parsed_url = urlparse(s3url)
parsed_url = urlsplit(s3url)
if not parsed_url.netloc:
raise AirflowException(f'Please provide a bucket name using a valid format: "{s3url}"')

Expand Down Expand Up @@ -190,7 +190,7 @@ def get_s3_bucket_key(
if bucket is None:
return S3Hook.parse_s3_url(key)

parsed_url = urlparse(key)
parsed_url = urlsplit(key)
if parsed_url.scheme != "" or parsed_url.netloc != "":
raise TypeError(
f"If `{bucket_param_name}` is provided, {key_param_name} should be a relative path "
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/transfers/s3_to_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import warnings
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Sequence
from urllib.parse import urlparse
from urllib.parse import urlsplit

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
@staticmethod
def get_s3_key(s3_key: str) -> str:
"""This parses the correct format for S3 keys regardless of how the S3 url is passed."""
parsed_s3_key = urlparse(s3_key)
parsed_s3_key = urlsplit(s3_key)
return parsed_s3_key.path.lstrip("/")

def execute(self, context: Context) -> None:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/transfers/sftp_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Sequence
from urllib.parse import urlparse
from urllib.parse import urlsplit

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
@staticmethod
def get_s3_key(s3_key: str) -> str:
"""This parses the correct format for S3 keys regardless of how the S3 url is passed."""
parsed_s3_key = urlparse(s3_key)
parsed_s3_key = urlsplit(s3_key)
return parsed_s3_key.path.lstrip("/")

def execute(self, context: Context) -> None:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/databricks/hooks/databricks_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import platform
import time
from typing import Any
from urllib.parse import urlparse
from urllib.parse import urlsplit

import aiohttp
import requests
Expand Down Expand Up @@ -186,7 +186,7 @@ def _parse_host(host: str) -> str:
assert h._parse_host('xx.cloud.databricks.com') == 'xx.cloud.databricks.com'
"""
urlparse_host = urlparse(host).hostname
urlparse_host = urlsplit(host).hostname
if urlparse_host:
# In this case, host = https://xx.cloud.databricks.com
return urlparse_host
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/databricks/operators/databricks_repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import re
from typing import TYPE_CHECKING, Sequence
from urllib.parse import urlparse
from urllib.parse import urlsplit

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(
def __detect_repo_provider__(url):
provider = None
try:
netloc = urlparse(url).netloc
netloc = urlsplit(url).netloc
idx = netloc.rfind("@")
if idx != -1:
netloc = netloc[(idx + 1) :]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import os
from datetime import datetime
from typing import Callable
from urllib.parse import urlparse
from urllib.parse import urlsplit

from airflow import models
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -53,7 +53,7 @@
GCS_JAR = os.environ.get("GCP_DATAFLOW_JAR", "gs://INVALID BUCKET NAME/word-count-beam-bundled-0.1.jar")
GCS_PYTHON = os.environ.get("GCP_DATAFLOW_PYTHON", "gs://INVALID BUCKET NAME/wordcount_debugging.py")

GCS_JAR_PARTS = urlparse(GCS_JAR)
GCS_JAR_PARTS = urlsplit(GCS_JAR)
GCS_JAR_BUCKET_NAME = GCS_JAR_PARTS.netloc
GCS_JAR_OBJECT_NAME = GCS_JAR_PARTS.path[1:]

Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from os import path
from tempfile import NamedTemporaryFile
from typing import IO, Callable, Generator, Sequence, TypeVar, cast, overload
from urllib.parse import urlparse
from urllib.parse import urlsplit

from google.api_core.exceptions import NotFound

Expand Down Expand Up @@ -1161,7 +1161,7 @@ def _parse_gcs_url(gsurl: str) -> tuple[str, str]:
Given a Google Cloud Storage URL (gs://<bucket>/<blob>), returns a
tuple containing the corresponding bucket and blob.
"""
parsed_url = urlparse(gsurl)
parsed_url = urlsplit(gsurl)
if not parsed_url.netloc:
raise AirflowException("Please provide a bucket name")
if parsed_url.scheme.lower() != "gs":
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/google/cloud/operators/cloud_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import re
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Sequence
from urllib.parse import unquote, urlparse
from urllib.parse import unquote, urlsplit

from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.api_core.retry import Retry
Expand Down Expand Up @@ -972,7 +972,7 @@ def _convert_repo_url_to_dict(source: str) -> dict[str, Any]:
https://source.cloud.google.com/airflow-project/airflow-repo/+/branch-name:
"""
url_parts = urlparse(source)
url_parts = urlsplit(source)

match = REGEX_REPO_PATH.search(url_parts.path)

Expand Down Expand Up @@ -1006,7 +1006,7 @@ def _convert_storage_url_to_dict(storage_url: str) -> dict[str, Any]:
gs://bucket-name/object-name.tar.gz
"""
url_parts = urlparse(storage_url)
url_parts = urlsplit(storage_url)

if url_parts.scheme != "gs" or not url_parts.hostname or not url_parts.path or url_parts.path == "/":
raise AirflowException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import tempfile
import urllib.request
from typing import TYPE_CHECKING, Any, Sequence
from urllib.parse import urlparse
from urllib.parse import urlsplit

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
Expand Down Expand Up @@ -282,7 +282,7 @@ def execute(self, context: Context):

# If no custom report_name provided, use DV360 name
file_url = resource["metadata"]["googleCloudStoragePathForLatestReport"]
report_name = self.report_name or urlparse(file_url).path.split("/")[-1]
report_name = self.report_name or urlsplit(file_url).path.split("/")[-1]
report_name = self._resolve_file_name(report_name)

# Download the report
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/slack/hooks/slack_webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import warnings
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import urlparse
from urllib.parse import urlsplit

from slack_sdk import WebhookClient

Expand Down Expand Up @@ -285,7 +285,7 @@ def _get_conn_params(self) -> dict[str, Any]:

base_url = base_url.rstrip("/")
if not webhook_token:
parsed_token = (urlparse(base_url).path or "").strip("/")
parsed_token = (urlsplit(base_url).path or "").strip("/")
if base_url == DEFAULT_SLACK_WEBHOOK_ENDPOINT or not parsed_token:
# Raise an error in case of password not specified and
# 1. Result of constructing base_url equal https://hooks.slack.com/services
Expand Down
Loading

0 comments on commit 00af5c0

Please sign in to comment.