Skip to content

Commit

Permalink
fix: current_state method on TaskInstance doesn't filter by map_index (
Browse files Browse the repository at this point in the history
  • Loading branch information
xlanor authored Dec 3, 2022
1 parent c931d88 commit 51c70a5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
14 changes: 5 additions & 9 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,17 +736,13 @@ def current_state(self, session: Session = NEW_SESSION) -> str:
we use and looking up the state becomes part of the session, otherwise
a new session is used.
sqlalchemy.inspect is used here to get the primary keys ensuring that if they change
it will not regress
:param session: SQLAlchemy ORM Session
"""
return (
session.query(TaskInstance.state)
.filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.run_id == self.run_id,
)
.scalar()
)
filters = (col == getattr(self, col.name) for col in inspect(TaskInstance).primary_key)
return session.query(TaskInstance.state).filter(*filters).scalar()

@provide_session
def error(self, session: Session = NEW_SESSION) -> None:
Expand Down
23 changes: 23 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,29 @@ def test_outlet_datasets_failed(self, create_task_instance):
# check that no dataset events were generated
assert session.query(DatasetEvent).count() == 0

def test_mapped_current_state(self, dag_maker):
with dag_maker(dag_id="test_mapped_current_state") as _:
from airflow.decorators import task

@task()
def raise_an_exception(placeholder: int):
if placeholder == 0:
raise AirflowFailException("failing task")
else:
pass

_ = raise_an_exception.expand(placeholder=[0, 1])

tis = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances
for task_instance in tis:
if task_instance.map_index == 0:
with pytest.raises(AirflowFailException):
task_instance.run()
assert task_instance.current_state() == TaskInstanceState.FAILED
else:
task_instance.run()
assert task_instance.current_state() == TaskInstanceState.SUCCESS

def test_outlet_datasets_skipped(self, create_task_instance):
"""
Verify that when we have an outlet dataset on a task, and the task
Expand Down

0 comments on commit 51c70a5

Please sign in to comment.