Skip to content

Instantly share code, notes, and snippets.

@benoit-cty
Last active September 8, 2023 12:40
Show Gist options
  • Save benoit-cty/01a4dd1e81c6ded86395c760490f3b73 to your computer and use it in GitHub Desktop.
Save benoit-cty/01a4dd1e81c6ded86395c760490f3b73 to your computer and use it in GitHub Desktop.
Demonstrates how to cancel a task when the client disconnects in FastAPI.
"""
Demonstrates how to cancel a task when the client disconnects in FastAPI.
This reduce cpu consumption and memory usage for long tasks if clients stop waiting for answers.
Code from Scott Driggers posted at https://gist.github.com/msdrigg/02c7716d6e2a0cb4e5ef08d14f180119
Initial discussion at here https://github.com/tiangolo/fastapi/discussions/8805
Note: using these methods will likely discard request body (because they call receive and discard the result)
So it is not suitable for websocket or streaming requests.
Usage tested with Python 3.10 and 3.11 :
- python -m venv .venv
- source .venv/bin/activate
- pip install uvicorn fastapi
- uvicorn fastapi_cancel_test:app --reload
Then go to http://127.0.0.1:8000/docs, try the routes and close the navigator tab to see in terminal that the task is terminated.
"""
import asyncio
import uuid
from contextlib import asynccontextmanager
import functools
from typing import Annotated, Any, AsyncContextManager, Awaitable, Callable
from anyio import create_task_group
from anyio.abc import TaskGroup
from fastapi import Request, FastAPI
from fastapi.params import Depends
# Create the FastAPI app
app = FastAPI()
"""
As we run task asynchronously, we need to store the result somewhere.
A production solution may be to use something like Redis.
"""
temporary_returns_storage = {}
async def huge_work(task_id: str, task_param: str):
"""
A fake long running task
"""
print("Starting work " + task_id)
for i in range(5):
print("Working...", i)
await asyncio.sleep(1)
# We can't return a value from a task, so we store it in a temporary storage
# return "Work completed for " + task_id
temporary_returns_storage[task_id] = "Work completed for " + task_param
async def wait_for_disconnect(request: Request) -> True:
"""
Read the request stream until the client disconnects
Returns True if the client disconnects
"""
_receive = request._receive
while (await _receive())["type"] != "http.disconnect":
pass
print("wait_for_disconnect - Client disconnected, cancelling task.")
return True
@asynccontextmanager
async def create_request_task_group(request: Request):
"""
Create a task group that can be cancelled
1. We create a task group that will be cancelled if the client disconnects
2. We create a task that will wait for a disconnect event
3. We await the task group and the task concurrently
4. When the client disconnects, the task group is cancelled and the task that was waiting for the event will be cancelled too
5. The CancelledError will be caught in the outer task group, which will then cancel the inner task group
6. The inner task group will cancel all its tasks
"""
async def cancel_on_disconnect():
await wait_for_disconnect(request)
print("create_request_task_group - Client disconnected, raising CancelledError")
raise asyncio.CancelledError()
async with create_task_group() as outer_tg:
outer_tg.start_soon(cancel_on_disconnect)
async with create_task_group() as tg:
yield tg
outer_tg.cancel_scope.cancel()
@app.get("/cancel_tg")
async def cancel_tg(request: Request, task_param: str = "default_for_cancel_tg"):
"""
Simplest way to define a task that could be canceled.
"""
# Generate a unique id for the task
task_id = str(uuid.uuid4())
async with create_request_task_group(request) as tg:
tg.start_soon(huge_work, task_id, task_param)
result = temporary_returns_storage.get(task_id)
if result:
del temporary_returns_storage[task_id]
print("temporary_returns_storage must keep small:", temporary_returns_storage)
return {"result": result}
#########################################################
# Belows methodes uses RequestTaskGroup
async def request_task_group(request: Request):
"""
Create a task group with the request
"""
return functools.partial(create_request_task_group, request)
RequestTaskGroup = Annotated[
Callable[[], AsyncContextManager[TaskGroup]], Depends(request_task_group)
]
@app.get("/cancel_tg_dependency")
async def cancel_tg_dependency(
get_task_group: RequestTaskGroup,
task_param: str = "default_for_cancel_tg_dependency",
):
task_id = str(uuid.uuid4())
# This task group will be cancelled if the client disconnects
# before it exits
async with get_task_group() as tg:
tg.start_soon(huge_work, task_id, task_param)
result = temporary_returns_storage.get(task_id)
if result:
del temporary_returns_storage[task_id]
print("temporary_returns_storage must keep small:", temporary_returns_storage)
return {"result": result}
#########################################################
# Below methods use asyncio.Event to cancel a task
# Only works with Python >= 3.11 because of the asyncio.TaskGroup class, added in Python 3.11
async def cancellation(
request: Request,
):
"""
Returns an asyncio.Event that will be set if the client disconnects
"""
event = asyncio.Event()
async def set_event_on_disconnect():
# This will block until the client disconnects
await wait_for_disconnect(request)
print("set_event_on_disconnect - Client disconnected, cancelling task.")
event.set()
async with asyncio.TaskGroup() as tg:
# Create a task that will set the event when the client disconnects
disconnect_task = tg.create_task(set_event_on_disconnect())
yield event
disconnect_task.cancel()
# Define the Event
CancellationEvent = Annotated[asyncio.Event, Depends(cancellation)]
@app.get("/cancel_event")
async def cancel_event(
event: CancellationEvent, task_param: str = "default_for_cancel_tg_dependency"
):
task_id = str(uuid.uuid4())
"""
This route will be cancelled if the client disconnects before it exits
It use a cancellation event to cancel the task.
"""
async with create_task_group() as tg:
async def cancel_after_completion(func: Awaitable[Any]):
await func()
tg.cancel_scope.cancel()
tg.start_soon(cancel_after_completion, huge_work, task_id, task_param)
tg.start_soon(cancel_after_completion, event.wait)
result = temporary_returns_storage.get(task_id)
if result:
del temporary_returns_storage[task_id]
print("temporary_returns_storage must keep small:", temporary_returns_storage)
return {"result": result}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment