Skip to content

Commit

Permalink
Improve taskflow type hints with ParamSpec (#25173)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr committed Jul 26, 2022
1 parent 5758454 commit c8af059
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 63 deletions.
12 changes: 6 additions & 6 deletions airflow/decorators/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
# necessarily exist at run time. See "Creating Custom @task Decorators"
# documentation for more details.

from typing import Any, Dict, Iterable, List, Mapping, Optional, Union, overload
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union, overload

from airflow.decorators.base import Function, Task, TaskDecorator
from airflow.decorators.base import FParams, FReturn, Task, TaskDecorator
from airflow.decorators.branch_python import branch_task
from airflow.decorators.python import python_task
from airflow.decorators.python_virtualenv import virtualenv_task
Expand Down Expand Up @@ -68,7 +68,7 @@ class TaskDecoratorCollection:
"""
# [START mixin_for_typing]
@overload
def python(self, python_callable: Function) -> Task[Function]: ...
def python(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
# [END mixin_for_typing]
@overload
def __call__(
Expand All @@ -81,7 +81,7 @@ class TaskDecoratorCollection:
) -> TaskDecorator:
"""Aliasing ``python``; signature should match exactly."""
@overload
def __call__(self, python_callable: Function) -> Task[Function]:
def __call__(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]:
"""Aliasing ``python``; signature should match exactly."""
@overload
def virtualenv(
Expand Down Expand Up @@ -122,7 +122,7 @@ class TaskDecoratorCollection:
such as transmission a large amount of XCom to TaskAPI.
"""
@overload
def virtualenv(self, python_callable: Function) -> Task[Function]: ...
def virtualenv(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
@overload
def branch(self, *, multiple_outputs: Optional[bool] = None, **kwargs) -> TaskDecorator:
"""Create a decorator to wrap the decorated callable into a BranchPythonOperator.
Expand All @@ -134,7 +134,7 @@ class TaskDecoratorCollection:
Dict will unroll to XCom values with keys as XCom keys. Defaults to False.
"""
@overload
def branch(self, python_callable: Function) -> Task[Function]: ...
def branch(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
# [START decorator_signature]
def docker(
self,
Expand Down
78 changes: 45 additions & 33 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.

import functools
import inspect
import re
from typing import (
Expand Down Expand Up @@ -68,7 +67,7 @@
)
from airflow.models.pool import Pool
from airflow.models.xcom_arg import XComArg
from airflow.typing_compat import Protocol
from airflow.typing_compat import ParamSpec, Protocol
from airflow.utils import timezone
from airflow.utils.context import KNOWN_CONTEXT_KEYS, Context
from airflow.utils.task_group import TaskGroup, TaskGroupContext
Expand Down Expand Up @@ -236,13 +235,15 @@ def _hook_apply_defaults(self, *args, **kwargs):
return args, kwargs


Function = TypeVar("Function", bound=Callable)
FParams = ParamSpec("FParams")

FReturn = TypeVar("FReturn")

OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator")


@attr.define(slots=False)
class _TaskDecorator(Generic[Function, OperatorSubclass]):
class _TaskDecorator(Generic[FParams, FReturn, OperatorSubclass]):
"""
Helper class for providing dynamic task mapping to decorated functions.
Expand All @@ -251,7 +252,7 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
:meta private:
"""

function: Function = attr.ib()
function: Callable[FParams, FReturn] = attr.ib()
operator_class: Type[OperatorSubclass]
multiple_outputs: bool = attr.ib()
kwargs: Dict[str, Any] = attr.ib(factory=dict)
Expand All @@ -272,7 +273,7 @@ def __attrs_post_init__(self):
raise TypeError(f"@{self.decorator_name} does not support methods")
self.kwargs.setdefault('task_id', self.function.__name__)

def __call__(self, *args, **kwargs) -> XComArg:
def __call__(self, *args: "FParams.args", **kwargs: "FParams.kwargs") -> XComArg:
op = self.operator_class(
python_callable=self.function,
op_args=args,
Expand All @@ -285,7 +286,7 @@ def __call__(self, *args, **kwargs) -> XComArg:
return XComArg(op)

@property
def __wrapped__(self) -> Function:
def __wrapped__(self) -> Callable[FParams, FReturn]:
return self.function

@cached_property
Expand Down Expand Up @@ -337,9 +338,7 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg:
# to False to skip the checks on execution.
return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)

def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) -> XComArg:
from airflow.models.xcom_arg import XComArg

def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) -> XComArg:
if not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}")
return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
Expand Down Expand Up @@ -420,14 +419,14 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
)
return XComArg(operator=operator)

def partial(self, **kwargs: Any) -> "_TaskDecorator[Function, OperatorSubclass]":
def partial(self, **kwargs: Any) -> "_TaskDecorator[FParams, FReturn, OperatorSubclass]":
self._validate_arg_names("partial", kwargs)
old_kwargs = self.kwargs.get("op_kwargs", {})
prevent_duplicates(old_kwargs, kwargs, fail_reason="duplicate partial")
kwargs.update(old_kwargs)
return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs})

def override(self, **kwargs: Any) -> "_TaskDecorator[Function, OperatorSubclass]":
def override(self, **kwargs: Any) -> "_TaskDecorator[FParams, FReturn, OperatorSubclass]":
return attr.evolve(self, kwargs={**self.kwargs, **kwargs})


Expand Down Expand Up @@ -506,7 +505,7 @@ def _render_if_not_already_resolved(key: str, value: Any):
return {k: _render_if_not_already_resolved(k, v) for k, v in value.items()}


class Task(Generic[Function]):
class Task(Generic[FParams, FReturn]):
"""Declaration of a @task-decorated callable for type-checking.
An instance of this type inherits the call signature of the decorated
Expand All @@ -517,26 +516,32 @@ class Task(Generic[Function]):
This type is implemented by ``_TaskDecorator`` at runtime.
"""

__call__: Function
__call__: Callable[FParams, XComArg]

function: Function
function: Callable[FParams, FReturn]

@property
def __wrapped__(self) -> Function:
def __wrapped__(self) -> Callable[FParams, FReturn]:
...

def partial(self, **kwargs: Any) -> "Task[FParams, FReturn]":
...

def expand(self, **kwargs: "Mappable") -> XComArg:
...

def partial(self, **kwargs: Any) -> "Task[Function]":
def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) -> XComArg:
...


class TaskDecorator(Protocol):
"""Type declaration for ``task_decorator_factory`` return type."""

@overload
def __call__(self, python_callable: Function) -> Task[Function]:
def __call__( # type: ignore[misc]
self,
python_callable: Callable[FParams, FReturn],
) -> Task[FParams, FReturn]:
"""For the "bare decorator" ``@task`` case."""

@overload
Expand All @@ -545,7 +550,7 @@ def __call__(
*,
multiple_outputs: Optional[bool] = None,
**kwargs: Any,
) -> Callable[[Function], Task[Function]]:
) -> Callable[[Callable[FParams, FReturn]], Task[FParams, FReturn]]:
"""For the decorator factory ``@task()`` case."""


Expand All @@ -556,16 +561,20 @@ def task_decorator_factory(
decorated_operator_class: Type[BaseOperator],
**kwargs,
) -> TaskDecorator:
"""
A factory that generates a wrapper that wraps a function into an Airflow operator.
Accepts kwargs for operator kwarg. Can be reused in a single DAG.
"""Generate a wrapper that wraps a function into an Airflow operator.
:param python_callable: Function to decorate
:param multiple_outputs: If set to True, the decorated function's return value will be unrolled to
multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False.
:param decorated_operator_class: The operator that executes the logic needed to run the python function in
the correct environment
Can be reused in a single DAG.
:param python_callable: Function to decorate.
:param multiple_outputs: If set to True, the decorated function's return
value will be unrolled to multiple XCom values. Dict will unroll to XCom
values with its keys as XCom keys. If set to False (default), only at
most one XCom value is pushed.
:param decorated_operator_class: The operator that executes the logic needed
to run the python function in the correct environment.
Other kwargs are directly forwarded to the underlying operator class when
it's instantiated.
"""
if multiple_outputs is None:
multiple_outputs = cast(bool, attr.NOTHING)
Expand All @@ -579,10 +588,13 @@ def task_decorator_factory(
return cast(TaskDecorator, decorator)
elif python_callable is not None:
raise TypeError('No args allowed while using @task, use kwargs instead')
decorator_factory = functools.partial(
_TaskDecorator,
multiple_outputs=multiple_outputs,
operator_class=decorated_operator_class,
kwargs=kwargs,
)

def decorator_factory(python_callable):
return _TaskDecorator(
function=python_callable,
multiple_outputs=multiple_outputs,
operator_class=decorated_operator_class,
kwargs=kwargs,
)

return cast(TaskDecorator, decorator_factory)
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def _get_ti_pod_labels(
if include_try_number:
labels.update(try_number=ti.try_number)
# In the case of sub dags this is just useful
if context['dag'].is_subdag:
if context['dag'].parent_dag:
labels['parent_dag_id'] = context['dag'].parent_dag.dag_id
# Ensure that label is valid for Kube,
# and if not truncate/remove invalid chars and replace with short hash.
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/dbt/cloud/hooks/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest:
class JobRunInfo(TypedDict):
"""Type class for the ``job_run_info`` dictionary."""

account_id: int
account_id: Optional[int]
run_id: int


Expand Down
28 changes: 15 additions & 13 deletions airflow/providers/google/cloud/operators/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

from google.api_core.exceptions import Conflict
from google.cloud.exceptions import GoogleCloudError
from pendulum.datetime import DateTime

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
Expand Down Expand Up @@ -723,22 +722,25 @@ def __init__(
def execute(self, context: "Context") -> List[str]:
# Define intervals and prefixes.
try:
timespan_start = context["data_interval_start"]
timespan_end = context["data_interval_end"]
orig_start = context["data_interval_start"]
orig_end = context["data_interval_end"]
except KeyError:
timespan_start = pendulum.instance(context["execution_date"])
orig_start = pendulum.instance(context["execution_date"])
following_execution_date = context["dag"].following_schedule(context["execution_date"])
if following_execution_date is None:
timespan_end = None
orig_end = None
else:
timespan_end = pendulum.instance(following_execution_date)

if timespan_end is None: # Only possible in Airflow before 2.2.
self.log.warning("No following schedule found, setting timespan end to max %s", timespan_end)
timespan_end = DateTime.max
elif timespan_start >= timespan_end: # Airflow 2.2 sets start == end for non-perodic schedules.
self.log.warning("DAG schedule not periodic, setting timespan end to max %s", timespan_end)
timespan_end = DateTime.max
orig_end = pendulum.instance(following_execution_date)

timespan_start = orig_start
if orig_end is None: # Only possible in Airflow before 2.2.
self.log.warning("No following schedule found, setting timespan end to max %s", orig_end)
timespan_end = pendulum.instance(datetime.datetime.max)
elif orig_start >= orig_end: # Airflow 2.2 sets start == end for non-perodic schedules.
self.log.warning("DAG schedule not periodic, setting timespan end to max %s", orig_end)
timespan_end = pendulum.instance(datetime.datetime.max)
else:
timespan_end = orig_end

timespan_start = timespan_start.in_timezone(timezone.utc)
timespan_end = timespan_end.in_timezone(timezone.utc)
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/qubole/hooks/qubole.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from airflow.utils.state import State

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance
from airflow.utils.context import Context


Expand Down Expand Up @@ -139,7 +140,7 @@ def __init__(self, *args, **kwargs) -> None:
self.kwargs = kwargs
self.cls = COMMAND_CLASSES[self.kwargs['command_type']]
self.cmd: Optional[Command] = None
self.task_instance = None
self.task_instance: Optional["TaskInstance"] = None

@staticmethod
def handle_failure_retry(context) -> None:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/salesforce/operators/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class SalesforceBulkOperator(BaseOperator):
def __init__(
self,
*,
operation: Literal[available_operations],
operation: Literal['insert', 'update', 'upsert', 'delete', 'hard_delete'],
object_name: str,
payload: list,
external_id_field: str = 'Id',
Expand Down
32 changes: 25 additions & 7 deletions airflow/typing_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,28 @@
codebase easier.
"""

try:
# Literal, Protocol and TypedDict are only added to typing module starting from
# python 3.8 we can safely remove this shim import after Airflow drops
# support for <3.8
from typing import Literal, Protocol, TypedDict, runtime_checkable # type: ignore
except ImportError:
from typing_extensions import Literal, Protocol, TypedDict, runtime_checkable # type: ignore # noqa
__all__ = [
"Literal",
"ParamSpec",
"Protocol",
"TypedDict",
"runtime_checkable",
]

import sys

if sys.version_info >= (3, 8):
from typing import Protocol, TypedDict, runtime_checkable
else:
from typing_extensions import Protocol, TypedDict, runtime_checkable

# Literal in 3.8 is limited to one single argument, not e.g. "Literal[1, 2]".
if sys.version_info >= (3, 9):
from typing import Literal
else:
from typing_extensions import Literal

if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

0 comments on commit c8af059

Please sign in to comment.