Skip to content

Instantly share code, notes, and snippets.

@acookin
Created September 8, 2022 22:58
Show Gist options
  • Save acookin/aa5f0d718997e7a645cbc160e2c94a16 to your computer and use it in GitHub Desktop.
Save acookin/aa5f0d718997e7a645cbc160e2c94a16 to your computer and use it in GitHub Desktop.
Agent + flow to cleanup local process flow runs
"""
Runs prefect agent
Keeps track of flow runs the agent kicked off,
and sets a DateTime block storage to NOW() for
every running flow, every couple of seconds.
Also, has a cleanup function that will set the state of
every running flow to AwaitingRetry for each running
flow.
This is meant to be used in a context where the agent
process is PID 1 (e.g. in a container), and local process
infrastructure is being used to run flows. This may be the
only reasonable way to run agents/flows in a restrictive
infrastructure, such as heroku, but I wouldn't recommend
running your agents or flows this way unless you had to.
case in a restricted managed infrastructure, such as heroku.
"""
import asyncio
import logging
from datetime import datetime
from typing import List, Set
from uuid import UUID
import anyio
import pendulum
from prefect.agent import OrionAgent
from prefect.blocks.system import DateTime
from prefect.client import OrionClient, get_client
from prefect.orion.schemas.states import AwaitingRetry, StateType
logger = logging.getLogger(__name__)
RUNNING_STATE_TYPES = [StateType.PENDING, StateType.RUNNING]
async def flow_is_running(client: OrionClient, run_id: UUID) -> bool:
flow_run = await client.read_flow_run(run_id)
return flow_run.state.type in RUNNING_STATE_TYPES
async def retry_flow(client: OrionClient, run_id: UUID):
if await flow_is_running(client, run_id):
failed_state = AwaitingRetry(datetime.now())
await client.set_flow_run_state(run_id, failed_state)
class PrefectAgent:
def __init__(self, queues: List[str]) -> None:
self.work_queues = set(queues)
self.running_flows: Set[UUID] = set()
async def _ping_for_flow_run(self, flow_run_id: UUID):
"""Ping for a running flow, so external processes can see this flow is actually
running."""
try:
storage_key = f"flowrun-{str(flow_run_id)}"
block = DateTime(name=storage_key, value=pendulum.now())
await block.save(storage_key, overwrite=True)
except Exception as e:
logger.error("Error pinging for flow run: %s", str(e))
async def _reset_running_flows(self):
still_running = set()
async with get_client() as client:
for run_id in self.running_flows:
if await flow_is_running(client, run_id):
still_running.add(run_id)
self.running_flows = still_running
async def start(self):
logger.info("Starting prefect agent...")
async with OrionAgent(work_queues=self.work_queues) as agent:
while True:
flow_runs = await agent.get_and_submit_flow_runs()
for r in flow_runs:
self.running_flows.add(r.id)
await anyio.sleep(2.0)
ping_tasks = []
for flow_run_id in self.running_flows:
ping_tasks.append(self._ping_for_flow_run(flow_run_id))
asyncio.gather(*ping_tasks, self._reset_running_flows())
async def cleanup_flow_runs(self):
async with get_client() as client:
await asyncio.gather(*[retry_flow(client, f) for f in self.running_flows])
async def main():
prefect_agent = PrefectAgent(queues=["default"])
try:
asyncio.run(prefect_agent.start())
except KeyboardInterrupt:
logger.info("Cleaning up flows")
asyncio.run(prefect_agent.cleanup_flow_runs())
if __name__ == "__main__":
main()
"""
Flow that can cancel flow runs that have not "pinged" their
DateTime storage block for some amount of time.
Helps cleanup any "zombies" left from the agent above that
may have gotten sigkilled before finishing its cleanup process.
"""
import asyncio
from datetime import timedelta
from typing import List
import pendulum
from prefect import flow, get_run_logger, task
from prefect.client import OrionClient, get_client
from prefect.orion.schemas.core import FlowRun
from prefect.orion.schemas.filters import (FlowRunFilter, FlowRunFilterState,
FlowRunFilterStateType)
from prefect.orion.schemas.states import Cancelled, StateType
MAX_TIME = timedelta(minutes=5)
AGENT_NAMES = ["default"]
def get_agent_tags():
tags = []
for agent in AGENT_NAMES:
tags.append(f"agent:{agent}")
return tags
async def get_matching_flow_runs(
client: OrionClient, state_types: List[StateType], tags: List[str]
) -> FlowRun:
flow_runs = await client.read_flow_runs(
flow_run_filter=FlowRunFilter(
state=FlowRunFilterState(type=FlowRunFilterStateType(any_=state_types)),
)
)
ret = []
for run in flow_runs:
matching_run_tags = [t for t in run.tags if t in tags]
if len(matching_run_tags) > 0:
ret.append(run)
return ret
@task
async def get_stale_runs():
stale_run_ids = []
async with get_client() as c:
scheduled_runs = await get_matching_flow_runs(
c, [StateType.SCHEDULED], get_agent_tags()
)
cutoff_time = pendulum.now() - timedelta(
minutes=MAX_TIME,
)
for run in scheduled_runs:
if (
"auto-scheduled" not in run.tags
and run.next_scheduled_start_time < cutoff_time
):
stale_run_ids.append(run.id)
return stale_run_ids
@task(retries=3)
async def cancel_run(flow_run_id):
logger = get_run_logger()
logger.info(f"Cancelling flow run {flow_run_id}")
async with get_client() as c:
cancelled_state = Cancelled(
message="Cancelled because run was in Scheduled state beyond time limit"
)
await c.set_flow_run_state(flow_run_id, cancelled_state)
@flow()
async def cancel_stale_runs():
stale_runs = await get_stale_runs()
cancel_tasks = [cancel_run(str(run_id)) for run_id in stale_runs]
asyncio.gather(*cancel_tasks)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment