Last active
February 10, 2023 08:39
-
-
Save bob-ross27/9089b23fe4e1d5bdede2a3c5386d6bea to your computer and use it in GitHub Desktop.
FastAPI Background downloading with SSE progress streaming
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import uuid | |
from dataclasses import dataclass | |
from dataclasses import field | |
from time import sleep | |
from fastapi import Request | |
@dataclass | |
class Download: | |
episode: str | |
id: str = field(init=False, default_factory=lambda: str(uuid.uuid4())) | |
files_to_download: list = field(init=False, default_factory=list) | |
completed_files: list = field(init=False, default_factory=list) | |
def __post_init__(self): | |
self.files_to_download = self.get_files_For_episode() | |
def get_files_For_episode(self) -> list[str]: | |
"""Get and return a list of files to be downloaded.""" | |
return [f"scene-{x}" for x in range(1, 30)] | |
def download_files_for_episode(self) -> None: | |
"""Download all files for the episode.""" | |
print(self.files_to_download) | |
for file in self.files_to_download: | |
self.completed_files.append(file) | |
print(f"Downloaded file {file}") | |
sleep(1) | |
print("All downloads complete.") | |
@property | |
def progress(self) -> float: | |
"""Output the download progress as a percentage""" | |
progress_percentage = round( | |
(len(self.completed_files) / len(self.files_to_download)) * 100, 1 | |
) | |
return progress_percentage | |
async def download_progress_gen(request: Request, download: Download): | |
"""Yield progress updates when available.""" | |
STREAM_DELAY = 1.0 # second | |
RETRY_TIMEOUT = 5000 # milisecond | |
last_progress = 0.0 | |
while True: | |
# Handle client disconnects | |
if await request.is_disconnected(): | |
break | |
# End when download completes. | |
if download.progress == 100.0: | |
yield {"event": "end", "data": ""} | |
break | |
# Send progress update. | |
if last_progress != download.progress: | |
yield { | |
"event": "progress_update", | |
"id": str(uuid.uuid4())[:8], | |
"retry": RETRY_TIMEOUT, | |
"data": { | |
"progress": download.progress, | |
"last_file": download.completed_files[-1], | |
}, | |
} | |
last_progress = download.progress | |
await asyncio.sleep(STREAM_DELAY) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import BackgroundTasks | |
from fastapi import FastAPI | |
from fastapi import Request | |
from fastapi.responses import RedirectResponse | |
from sse_starlette.sse import EventSourceResponse | |
from .download import Download | |
from .download import download_progress_gen | |
app = FastAPI() | |
download_queue = {} | |
@app.get("/download", status_code=201) | |
async def route_download(episode: str, background_tasks: BackgroundTasks): | |
"""Accept a download request as a background task""" | |
downloader = Download(episode) | |
download_queue[downloader.id] = downloader | |
background_tasks.add_task(downloader.download_files_for_episode) | |
# Redirect client to SSE response | |
return RedirectResponse(url=f"/download/{downloader.id}/status", status_code=303) | |
@app.get("/download/{id}/status") | |
async def route_status_stream(id: str, request: Request): | |
"""SSE stream of download progress""" | |
download = download_queue.get(id) | |
if download: | |
return EventSourceResponse(download_progress_gen(request, download)) | |
else: | |
return {"message": "Invalid id provided."} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment