Skip to content

Instantly share code, notes, and snippets.

@haoch
Last active August 5, 2020 21:11
Show Gist options
  • Save haoch/a0a2ac5053ed57de366043a77cf67903 to your computer and use it in GitHub Desktop.
Save haoch/a0a2ac5053ed57de366043a77cf67903 to your computer and use it in GitHub Desktop.
--- /usr/local/lib/python3.7/site-packages/airflow/models/dagrun.py 2020-08-05 21:09:34.304289571 +0000
+++ /usr/local/lib/python3.7/site-packages/airflow/models/dagrun.py.backup 2020-08-05 07:16:06.909641726 +0000
@@ -21,7 +21,7 @@
import six
from sqlalchemy import (
Column, Integer, String, Boolean, PickleType, Index, UniqueConstraint, func, DateTime, or_,
- and_, desc
+ and_
)
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import synonym
@@ -161,33 +161,33 @@
return dr
-# @provide_session
-# def get_task_instances(self, state=None, session=None):
-# """
-# Returns the task instances for this dag run
-# """
-# from airflow.models.taskinstance import TaskInstance # Avoid circular import
-# tis = session.query(TaskInstance).filter(
-# TaskInstance.dag_id == self.dag_id,
-# TaskInstance.execution_date == self.execution_date,
-# )
-# if state:
-# if isinstance(state, six.string_types):
-# tis = tis.filter(TaskInstance.state == state)
-# else:
-# # this is required to deal with NULL values
-# if None in state:
-# tis = tis.filter(
-# or_(TaskInstance.state.in_(state),
-# TaskInstance.state.is_(None))
-# )
-# else:
-# tis = tis.filter(TaskInstance.state.in_(state))
-#
-# if self.dag and self.dag.partial:
-# tis = tis.filter(TaskInstance.task_id.in_(self.dag.task_ids))
-#
-# return tis.all()
+ @provide_session
+ def get_task_instances(self, state=None, session=None):
+ """
+ Returns the task instances for this dag run
+ """
+ from airflow.models.taskinstance import TaskInstance # Avoid circular import
+ tis = session.query(TaskInstance).filter(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.execution_date == self.execution_date,
+ )
+ if state:
+ if isinstance(state, six.string_types):
+ tis = tis.filter(TaskInstance.state == state)
+ else:
+ # this is required to deal with NULL values
+ if None in state:
+ tis = tis.filter(
+ or_(TaskInstance.state.in_(state),
+ TaskInstance.state.is_(None))
+ )
+ else:
+ tis = tis.filter(TaskInstance.state.in_(state))
+
+ if self.dag and self.dag.partial:
+ tis = tis.filter(TaskInstance.task_id.in_(self.dag.task_ids))
+
+ return tis.all()
@provide_session
def get_task_instance(self, task_id, session=None):
@@ -300,10 +300,7 @@
duration = (timezone.utcnow() - start_dttm).total_seconds() * 1000
Stats.timing("dagrun.dependency-check.{}".format(self.dag_id), duration)
-# leaf_tis = [ti for ti in tis if ti.task_id in {t.task_id for t in dag.leaves}]
- leaf_task_ids = {t.task_id for t in dag.leaves}
- leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids]
-
+ leaf_tis = [ti for ti in tis if ti.task_id in {t.task_id for t in dag.leaves}]
# if all roots finished and at least one failed, the run failed
if not unfinished_tasks and any(
@@ -449,56 +446,3 @@
.all()
)
return dagruns
-
- @provide_session
- def get_task_instances(self, state=None, session=None, batch=500):
- """
- Returns the task instances for this dag run
- """
- from airflow.models.taskinstance import TaskInstance # Avoid circular import
-
- def query():
- tis = session.query(TaskInstance).filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.execution_date == self.execution_date,
- )
- if state:
- if isinstance(state, six.string_types):
- tis = tis.filter(TaskInstance.state == state)
- else:
- # this is required to deal with NULL values
- if None in state:
- tis = tis.filter(
- or_(TaskInstance.state.in_(state),
- TaskInstance.state.is_(None))
- )
- else:
- tis = tis.filter(TaskInstance.state.in_(state))
-
- if self.dag and self.dag.partial:
- tis = tis.filter(TaskInstance.task_id.in_(self.dag.task_ids))
-
- return tis
-
- def select(pos=-1, limit=-1):
- tis = query().order_by(desc(TaskInstance.priority_weight))
- if pos >= 0:
- tis = tis.offset(pos)
- if limit > 0:
- tis = tis.limit(limit)
- return tis.all()
-
- if batch <= 0:
- # Return all at once
- return select()
- else:
- result = []
- offset = 0
- while True:
- query_result = select(offset, batch)
- result.extend(query_result)
- if len(query_result) < batch:
- break
- else:
- offset += batch
- return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment