Skip to content

Commit

Permalink
apply_default keeps the function signature for mypy (#9784)
Browse files Browse the repository at this point in the history
  • Loading branch information
mik-laj authored Jul 22, 2020
1 parent 39a0288 commit 33f0cd2
Show file tree
Hide file tree
Showing 65 changed files with 216 additions and 241 deletions.
2 changes: 1 addition & 1 deletion airflow/contrib/operators/qubole_check_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# pylint: disable=unused-import
from airflow.providers.qubole.operators.qubole_check import ( # noqa
QuboleCheckOperator, QuboleValueCheckOperator, ValueCheckOperator,
QuboleCheckOperator, QuboleValueCheckOperator,
)

warnings.warn(
Expand Down
33 changes: 17 additions & 16 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
import warnings
from abc import ABCMeta, abstractmethod
from datetime import datetime, timedelta
from typing import Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import (
Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, Sequence, Set, Tuple, Type, Union,
)

import attr
import jinja2
Expand Down Expand Up @@ -264,9 +266,9 @@ class derived from this one results in the creation of a task object,
:type do_xcom_push: bool
"""
# For derived classes to define which fields will get jinjaified
template_fields: Iterable[str] = []
template_fields: Iterable[str] = ()
# Defines which files extensions to look for in the templated fields
template_ext: Iterable[str] = []
template_ext: Iterable[str] = ()
# Defines the color in the UI
ui_color = '#fff' # type: str
ui_fgcolor = '#000' # type: str
Expand Down Expand Up @@ -357,24 +359,23 @@ def __init__(
do_xcom_push: bool = True,
inlets: Optional[Any] = None,
outlets: Optional[Any] = None,
*args,
**kwargs
):
from airflow.models.dag import DagContext
super().__init__()
if args or kwargs:
if kwargs:
if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
raise AirflowException(
"Invalid arguments were passed to {c} (task_id: {t}). Invalid "
"arguments were:\n*args: {a}\n**kwargs: {k}".format(
c=self.__class__.__name__, a=args, k=kwargs, t=task_id),
"arguments were:\n**kwargs: {k}".format(
c=self.__class__.__name__, k=kwargs, t=task_id),
)
warnings.warn(
'Invalid arguments were passed to {c} (task_id: {t}). '
'Support for passing such arguments will be dropped in '
'future. Invalid arguments were:'
'\n*args: {a}\n**kwargs: {k}'.format(
c=self.__class__.__name__, a=args, k=kwargs, t=task_id),
'\n**kwargs: {k}'.format(
c=self.__class__.__name__, k=kwargs, t=task_id),
category=PendingDeprecationWarning,
stacklevel=3
)
Expand Down Expand Up @@ -1149,7 +1150,7 @@ def add_only_new(self, item_set: Set[str], item: str) -> None:
item_set.add(item)

def _set_relatives(self,
task_or_task_list: Union['BaseOperator', List['BaseOperator']],
task_or_task_list: Union['BaseOperator', Sequence['BaseOperator']],
upstream: bool = False) -> None:
"""Sets relatives for the task or task list."""
from airflow.models.xcom_arg import XComArg
Expand Down Expand Up @@ -1208,14 +1209,14 @@ def _set_relatives(self,
self.add_only_new(self._downstream_task_ids, task.task_id)
task.add_only_new(task.get_direct_relative_ids(upstream=True), self.task_id)

def set_downstream(self, task_or_task_list: Union['BaseOperator', List['BaseOperator']]) -> None:
def set_downstream(self, task_or_task_list: Union['BaseOperator', Sequence['BaseOperator']]) -> None:
"""
Set a task or a task list to be directly downstream from the current
task.
"""
self._set_relatives(task_or_task_list, upstream=False)

def set_upstream(self, task_or_task_list: Union['BaseOperator', List['BaseOperator']]) -> None:
def set_upstream(self, task_or_task_list: Union['BaseOperator', Sequence['BaseOperator']]) -> None:
"""
Set a task or a task list to be directly upstream from the current
task.
Expand Down Expand Up @@ -1338,7 +1339,7 @@ def get_serialized_fields(cls):
return cls.__serialized_fields


def chain(*tasks: Union[BaseOperator, List[BaseOperator]]):
def chain(*tasks: Union[BaseOperator, Sequence[BaseOperator]]):
r"""
Given a number of tasks, builds a dependency chain.
Support mix airflow.models.BaseOperator and List[airflow.models.BaseOperator].
Expand Down Expand Up @@ -1375,7 +1376,7 @@ def chain(*tasks: Union[BaseOperator, List[BaseOperator]]):
if isinstance(down_task, BaseOperator):
down_task.set_upstream(up_task)
continue
if not isinstance(up_task, List) or not isinstance(down_task, List):
if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
raise TypeError(
'Chain not supported between instances of {up_type} and {down_type}'.format(
up_type=type(up_task), down_type=type(down_task)))
Expand All @@ -1389,8 +1390,8 @@ def chain(*tasks: Union[BaseOperator, List[BaseOperator]]):
up_t.set_downstream(down_t)


def cross_downstream(from_tasks: List[BaseOperator],
to_tasks: Union[BaseOperator, List[BaseOperator]]):
def cross_downstream(from_tasks: Sequence[BaseOperator],
to_tasks: Union[BaseOperator, Sequence[BaseOperator]]):
r"""
Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks.
Expand Down
9 changes: 3 additions & 6 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,9 @@ def __init__(
op_kwargs: Optional[Dict] = None,
templates_dict: Optional[Dict] = None,
templates_exts: Optional[List[str]] = None,
*args,
**kwargs
) -> None:
super().__init__(*args, **kwargs)
super().__init__(**kwargs)
if not callable(python_callable):
raise AirflowException('`python_callable` param must be callable')
self.python_callable = python_callable
Expand Down Expand Up @@ -404,12 +403,11 @@ def __init__( # pylint: disable=too-many-arguments
python_version: Optional[str] = None,
use_dill: bool = False,
system_site_packages: bool = True,
op_args: Optional[Iterable] = None,
op_args: Optional[List] = None,
op_kwargs: Optional[Dict] = None,
string_args: Optional[Iterable[str]] = None,
templates_dict: Optional[Dict] = None,
templates_exts: Optional[Iterable[str]] = None,
*args,
templates_exts: Optional[List[str]] = None,
**kwargs
):
super().__init__(
Expand All @@ -418,7 +416,6 @@ def __init__( # pylint: disable=too-many-arguments
op_kwargs=op_kwargs,
templates_dict=templates_dict,
templates_exts=templates_exts,
*args,
**kwargs)
self.requirements = requirements or []
self.string_args = string_args or []
Expand Down
16 changes: 5 additions & 11 deletions airflow/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,8 @@ class SQLCheckOperator(BaseOperator):
:type sql: str
"""

template_fields = ("sql",) # type: Iterable[str]
template_ext = (
".hql",
".sql",
) # type: Iterable[str]
template_fields: Iterable[str] = ("sql",)
template_ext: Iterable[str] = (".hql", ".sql",)
ui_color = "#fff7e6"

@apply_defaults
Expand Down Expand Up @@ -264,11 +261,8 @@ class SQLIntervalCheckOperator(BaseOperator):
"""

__mapper_args__ = {"polymorphic_identity": "SQLIntervalCheckOperator"}
template_fields = ("sql1", "sql2") # type: Iterable[str]
template_ext = (
".hql",
".sql",
) # type: Iterable[str]
template_fields: Iterable[str] = ("sql1", "sql2")
template_ext: Iterable[str] = (".hql", ".sql",)
ui_color = "#fff7e6"

ratio_formulas = {
Expand Down Expand Up @@ -415,7 +409,7 @@ class SQLThresholdCheckOperator(BaseOperator):
:type max_threshold: numeric or str
"""

template_fields = ("sql", "min_threshold", "max_threshold") # type: Iterable[str]
template_fields = ("sql", "min_threshold", "max_threshold")
template_ext = (
".hql",
".sql",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from airflow.utils.dates import days_ago

# [START howto_operator_s3_to_redshift_env_variables]
S3_BUCKET = getenv("S3_BUCKET")
S3_BUCKET = getenv("S3_BUCKET", "test-bucket")
S3_KEY = getenv("S3_KEY", "key")
REDSHIFT_TABLE = getenv("REDSHIFT_TABLE", "test_table")
# [END howto_operator_s3_to_redshift_env_variables]
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/s3_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class S3ListOperator(BaseOperator):
aws_conn_id='aws_customers_conn'
)
"""
template_fields = ('bucket', 'prefix', 'delimiter') # type: Iterable[str]
template_fields: Iterable[str] = ('bucket', 'prefix', 'delimiter')
ui_color = '#ffd700'

@apply_defaults
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/sagemaker_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def __init__(self,
check_interval=30,
max_ingestion_time=None,
action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8
*args, **kwargs):
super().__init__(config=config, *args, **kwargs)
**kwargs):
super().__init__(config=config, **kwargs)

self.wait_for_completion = wait_for_completion
self.print_log = print_log
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/amazon/aws/transfers/gcs_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
This module contains Google Cloud Storage to S3 operator.
"""
import warnings
from typing import Iterable

from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.google.cloud.hooks.gcs import GCSHook
Expand Down Expand Up @@ -73,7 +74,7 @@ class GCSToS3Operator(GCSListObjectsOperator):
in the destination bucket.
:type replace: bool
"""
template_fields = ('bucket', 'prefix', 'delimiter', 'dest_s3_key')
template_fields: Iterable[str] = ('bucket', 'prefix', 'delimiter', 'dest_s3_key')
ui_color = '#f0eee4'

@apply_defaults
Expand Down
7 changes: 3 additions & 4 deletions airflow/providers/apache/cassandra/sensors/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
of a record in a Cassandra cluster.
"""

from typing import Any, Dict, Tuple
from typing import Any, Dict

from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
Expand Down Expand Up @@ -56,9 +56,8 @@ class CassandraRecordSensor(BaseSensorOperator):
template_fields = ('table', 'keys')

@apply_defaults
def __init__(self, table: str, keys: Dict[str, str], cassandra_conn_id: str,
*args: Tuple[Any, ...], **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def __init__(self, table: str, keys: Dict[str, str], cassandra_conn_id: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.cassandra_conn_id = cassandra_conn_id
self.table = table
self.keys = keys
Expand Down
7 changes: 3 additions & 4 deletions airflow/providers/apache/cassandra/sensors/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
of a table in a Cassandra cluster.
"""

from typing import Any, Dict, Tuple
from typing import Any, Dict

from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
Expand Down Expand Up @@ -54,9 +54,8 @@ class CassandraTableSensor(BaseSensorOperator):
template_fields = ('table',)

@apply_defaults
def __init__(self, table: str, cassandra_conn_id: str, *args: Tuple[Any, ...],
**kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def __init__(self, table: str, cassandra_conn_id: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.cassandra_conn_id = cassandra_conn_id
self.table = table

Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/apache/druid/operators/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class DruidOperator(BaseOperator):
def __init__(self, json_index_file: str,
druid_ingest_conn_id: str = 'druid_ingest_default',
max_ingestion_time: Optional[int] = None,
*args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
**kwargs: Any) -> None:
super().__init__(**kwargs)
self.json_index_file = json_index_file
self.conn_id = druid_ingest_conn_id
self.max_ingestion_time = max_ingestion_time
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/apache/druid/operators/druid_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def __init__(
self,
sql: str,
druid_broker_conn_id: str = 'druid_broker_default',
*args: Any, **kwargs: Any
**kwargs: Any
) -> None:
super().__init__(sql=sql, *args, **kwargs)
super().__init__(sql=sql, **kwargs)
self.druid_broker_conn_id = druid_broker_conn_id
self.sql = sql

Expand Down
3 changes: 1 addition & 2 deletions airflow/providers/apache/druid/transfers/hive_to_druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,9 @@ def __init__( # pylint: disable=too-many-arguments
segment_granularity: str = "DAY",
hive_tblproperties: Optional[Dict[Any, Any]] = None,
job_properties: Optional[Dict[Any, Any]] = None,
*args: Any,
**kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
super().__init__(**kwargs)
self.sql = sql
self.druid_datasource = druid_datasource
self.ts_dim = ts_dim
Expand Down
3 changes: 1 addition & 2 deletions airflow/providers/apache/hdfs/sensors/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ def __init__(self,
ignore_copying: bool = True,
file_size: Optional[int] = None,
hook: Type[HDFSHook] = HDFSHook,
*args: Any,
**kwargs: Any) -> None:
super().__init__(*args, **kwargs)
super().__init__(**kwargs)
if ignored_ext is None:
ignored_ext = ['_COPYING_']
self.filepath = filepath
Expand Down
3 changes: 1 addition & 2 deletions airflow/providers/apache/hdfs/sensors/web_hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ class WebHdfsSensor(BaseSensorOperator):
def __init__(self,
filepath: str,
webhdfs_conn_id: str = 'webhdfs_default',
*args: Any,
**kwargs: Any) -> None:
super().__init__(*args, **kwargs)
super().__init__(**kwargs)
self.filepath = filepath
self.webhdfs_conn_id = webhdfs_conn_id

Expand Down
6 changes: 2 additions & 4 deletions airflow/providers/apache/hive/operators/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
import os
import re
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional

from airflow.configuration import conf
from airflow.models import BaseOperator
Expand Down Expand Up @@ -81,11 +81,9 @@ def __init__(
mapred_queue: Optional[str] = None,
mapred_queue_priority: Optional[str] = None,
mapred_job_name: Optional[str] = None,
*args: Tuple[Any, ...],
**kwargs: Any
) -> None:

super().__init__(*args, **kwargs)
super().__init__(**kwargs)
self.hql = hql
self.hive_cli_conn_id = hive_cli_conn_id
self.schema = schema
Expand Down
5 changes: 2 additions & 3 deletions airflow/providers/apache/hive/operators/hive_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import json
import warnings
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
Expand Down Expand Up @@ -72,7 +72,6 @@ def __init__(self,
metastore_conn_id: str = 'metastore_default',
presto_conn_id: str = 'presto_default',
mysql_conn_id: str = 'airflow_db',
*args: Tuple[Any, ...],
**kwargs: Any
) -> None:
if 'col_blacklist' in kwargs:
Expand All @@ -84,7 +83,7 @@ def __init__(self,
stacklevel=2
)
excluded_columns = kwargs.pop('col_blacklist')
super().__init__(*args, **kwargs)
super().__init__(**kwargs)
self.table = table
self.partition = partition
self.extra_exprs = extra_exprs or {}
Expand Down
Loading

0 comments on commit 33f0cd2

Please sign in to comment.