Last active
September 8, 2023 12:40
-
-
Save benoit-cty/01a4dd1e81c6ded86395c760490f3b73 to your computer and use it in GitHub Desktop.
Demonstrates how to cancel a task when the client disconnects in FastAPI.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Demonstrates how to cancel a task when the client disconnects in FastAPI. | |
This reduce cpu consumption and memory usage for long tasks if clients stop waiting for answers. | |
Code from Scott Driggers posted at https://gist.github.com/msdrigg/02c7716d6e2a0cb4e5ef08d14f180119 | |
Initial discussion at here https://github.com/tiangolo/fastapi/discussions/8805 | |
Note: using these methods will likely discard request body (because they call receive and discard the result) | |
So it is not suitable for websocket or streaming requests. | |
Usage tested with Python 3.10 and 3.11 : | |
- python -m venv .venv | |
- source .venv/bin/activate | |
- pip install uvicorn fastapi | |
- uvicorn fastapi_cancel_test:app --reload | |
Then go to http://127.0.0.1:8000/docs, try the routes and close the navigator tab to see in terminal that the task is terminated. | |
""" | |
import asyncio | |
import uuid | |
from contextlib import asynccontextmanager | |
import functools | |
from typing import Annotated, Any, AsyncContextManager, Awaitable, Callable | |
from anyio import create_task_group | |
from anyio.abc import TaskGroup | |
from fastapi import Request, FastAPI | |
from fastapi.params import Depends | |
# Create the FastAPI app | |
app = FastAPI() | |
""" | |
As we run task asynchronously, we need to store the result somewhere. | |
A production solution may be to use something like Redis. | |
""" | |
temporary_returns_storage = {} | |
async def huge_work(task_id: str, task_param: str): | |
""" | |
A fake long running task | |
""" | |
print("Starting work " + task_id) | |
for i in range(5): | |
print("Working...", i) | |
await asyncio.sleep(1) | |
# We can't return a value from a task, so we store it in a temporary storage | |
# return "Work completed for " + task_id | |
temporary_returns_storage[task_id] = "Work completed for " + task_param | |
async def wait_for_disconnect(request: Request) -> True: | |
""" | |
Read the request stream until the client disconnects | |
Returns True if the client disconnects | |
""" | |
_receive = request._receive | |
while (await _receive())["type"] != "http.disconnect": | |
pass | |
print("wait_for_disconnect - Client disconnected, cancelling task.") | |
return True | |
@asynccontextmanager | |
async def create_request_task_group(request: Request): | |
""" | |
Create a task group that can be cancelled | |
1. We create a task group that will be cancelled if the client disconnects | |
2. We create a task that will wait for a disconnect event | |
3. We await the task group and the task concurrently | |
4. When the client disconnects, the task group is cancelled and the task that was waiting for the event will be cancelled too | |
5. The CancelledError will be caught in the outer task group, which will then cancel the inner task group | |
6. The inner task group will cancel all its tasks | |
""" | |
async def cancel_on_disconnect(): | |
await wait_for_disconnect(request) | |
print("create_request_task_group - Client disconnected, raising CancelledError") | |
raise asyncio.CancelledError() | |
async with create_task_group() as outer_tg: | |
outer_tg.start_soon(cancel_on_disconnect) | |
async with create_task_group() as tg: | |
yield tg | |
outer_tg.cancel_scope.cancel() | |
@app.get("/cancel_tg") | |
async def cancel_tg(request: Request, task_param: str = "default_for_cancel_tg"): | |
""" | |
Simplest way to define a task that could be canceled. | |
""" | |
# Generate a unique id for the task | |
task_id = str(uuid.uuid4()) | |
async with create_request_task_group(request) as tg: | |
tg.start_soon(huge_work, task_id, task_param) | |
result = temporary_returns_storage.get(task_id) | |
if result: | |
del temporary_returns_storage[task_id] | |
print("temporary_returns_storage must keep small:", temporary_returns_storage) | |
return {"result": result} | |
######################################################### | |
# Belows methodes uses RequestTaskGroup | |
async def request_task_group(request: Request): | |
""" | |
Create a task group with the request | |
""" | |
return functools.partial(create_request_task_group, request) | |
RequestTaskGroup = Annotated[ | |
Callable[[], AsyncContextManager[TaskGroup]], Depends(request_task_group) | |
] | |
@app.get("/cancel_tg_dependency") | |
async def cancel_tg_dependency( | |
get_task_group: RequestTaskGroup, | |
task_param: str = "default_for_cancel_tg_dependency", | |
): | |
task_id = str(uuid.uuid4()) | |
# This task group will be cancelled if the client disconnects | |
# before it exits | |
async with get_task_group() as tg: | |
tg.start_soon(huge_work, task_id, task_param) | |
result = temporary_returns_storage.get(task_id) | |
if result: | |
del temporary_returns_storage[task_id] | |
print("temporary_returns_storage must keep small:", temporary_returns_storage) | |
return {"result": result} | |
######################################################### | |
# Below methods use asyncio.Event to cancel a task | |
# Only works with Python >= 3.11 because of the asyncio.TaskGroup class, added in Python 3.11 | |
async def cancellation( | |
request: Request, | |
): | |
""" | |
Returns an asyncio.Event that will be set if the client disconnects | |
""" | |
event = asyncio.Event() | |
async def set_event_on_disconnect(): | |
# This will block until the client disconnects | |
await wait_for_disconnect(request) | |
print("set_event_on_disconnect - Client disconnected, cancelling task.") | |
event.set() | |
async with asyncio.TaskGroup() as tg: | |
# Create a task that will set the event when the client disconnects | |
disconnect_task = tg.create_task(set_event_on_disconnect()) | |
yield event | |
disconnect_task.cancel() | |
# Define the Event | |
CancellationEvent = Annotated[asyncio.Event, Depends(cancellation)] | |
@app.get("/cancel_event") | |
async def cancel_event( | |
event: CancellationEvent, task_param: str = "default_for_cancel_tg_dependency" | |
): | |
task_id = str(uuid.uuid4()) | |
""" | |
This route will be cancelled if the client disconnects before it exits | |
It use a cancellation event to cancel the task. | |
""" | |
async with create_task_group() as tg: | |
async def cancel_after_completion(func: Awaitable[Any]): | |
await func() | |
tg.cancel_scope.cancel() | |
tg.start_soon(cancel_after_completion, huge_work, task_id, task_param) | |
tg.start_soon(cancel_after_completion, event.wait) | |
result = temporary_returns_storage.get(task_id) | |
if result: | |
del temporary_returns_storage[task_id] | |
print("temporary_returns_storage must keep small:", temporary_returns_storage) | |
return {"result": result} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment