Skip to content

Instantly share code, notes, and snippets.

@bob-ross27
Last active February 10, 2023 08:39
Show Gist options
  • Save bob-ross27/9089b23fe4e1d5bdede2a3c5386d6bea to your computer and use it in GitHub Desktop.
Save bob-ross27/9089b23fe4e1d5bdede2a3c5386d6bea to your computer and use it in GitHub Desktop.
FastAPI Background downloading with SSE progress streaming
import asyncio
import uuid
from dataclasses import dataclass
from dataclasses import field
from time import sleep
from fastapi import Request
@dataclass
class Download:
episode: str
id: str = field(init=False, default_factory=lambda: str(uuid.uuid4()))
files_to_download: list = field(init=False, default_factory=list)
completed_files: list = field(init=False, default_factory=list)
def __post_init__(self):
self.files_to_download = self.get_files_For_episode()
def get_files_For_episode(self) -> list[str]:
"""Get and return a list of files to be downloaded."""
return [f"scene-{x}" for x in range(1, 30)]
def download_files_for_episode(self) -> None:
"""Download all files for the episode."""
print(self.files_to_download)
for file in self.files_to_download:
self.completed_files.append(file)
print(f"Downloaded file {file}")
sleep(1)
print("All downloads complete.")
@property
def progress(self) -> float:
"""Output the download progress as a percentage"""
progress_percentage = round(
(len(self.completed_files) / len(self.files_to_download)) * 100, 1
)
return progress_percentage
async def download_progress_gen(request: Request, download: Download):
"""Yield progress updates when available."""
STREAM_DELAY = 1.0 # second
RETRY_TIMEOUT = 5000 # milisecond
last_progress = 0.0
while True:
# Handle client disconnects
if await request.is_disconnected():
break
# End when download completes.
if download.progress == 100.0:
yield {"event": "end", "data": ""}
break
# Send progress update.
if last_progress != download.progress:
yield {
"event": "progress_update",
"id": str(uuid.uuid4())[:8],
"retry": RETRY_TIMEOUT,
"data": {
"progress": download.progress,
"last_file": download.completed_files[-1],
},
}
last_progress = download.progress
await asyncio.sleep(STREAM_DELAY)
from fastapi import BackgroundTasks
from fastapi import FastAPI
from fastapi import Request
from fastapi.responses import RedirectResponse
from sse_starlette.sse import EventSourceResponse
from .download import Download
from .download import download_progress_gen
app = FastAPI()
download_queue = {}
@app.get("/download", status_code=201)
async def route_download(episode: str, background_tasks: BackgroundTasks):
"""Accept a download request as a background task"""
downloader = Download(episode)
download_queue[downloader.id] = downloader
background_tasks.add_task(downloader.download_files_for_episode)
# Redirect client to SSE response
return RedirectResponse(url=f"/download/{downloader.id}/status", status_code=303)
@app.get("/download/{id}/status")
async def route_status_stream(id: str, request: Request):
"""SSE stream of download progress"""
download = download_queue.get(id)
if download:
return EventSourceResponse(download_progress_gen(request, download))
else:
return {"message": "Invalid id provided."}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment