-
-
Save zilto/319772fa37f63ee38b893a3e39ed6592 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextlib import asynccontextmanager | |
from dataclasses import dataclass | |
import fastapi | |
import pydantic | |
from fastapi.responses import JSONResponse | |
from hamilton import driver | |
# define a global dataclass that is shared across endpoints | |
@dataclass | |
class GlobalContext: | |
vector_db_url: str | |
hamilton_driver: driver.Driver | |
@asynccontextmanager | |
async def lifespan(app: fastapi.FastAPI) -> None: | |
"""Startup and shutdown logic of the FastAPI app | |
Above yield statement is at startup and below at shutdown | |
Import the Hamilton modules and instantiate the Hamilton driver | |
""" | |
# import the Python modules containing your dataflows | |
import ingestion | |
import retrieval | |
import vector_db | |
driver_config = dict() | |
dr = ( | |
driver.Builder() | |
.enable_dynamic_execution(allow_experimental_mode=True) # to allow Parallelizable/Collect | |
.with_config(driver_config) | |
.with_modules(ingestion, retrieval, vector_db) # pass our dataflows | |
.build() | |
) | |
# make the variable global to reuse it within endpoints | |
global global_context | |
global_context = GlobalContext(vector_db_url="http://weaviate_storage:8083", hamilton_driver=dr) | |
# execute Hamilton code to make sure the Weaviate class schemas is instantiated | |
global_context.hamilton_driver.execute( | |
["initialize_weaviate_instance"], inputs=dict(vector_db_url=global_context.vector_db_url) | |
) | |
# anything above yield is executed at startup | |
yield | |
# anything below yield is executed at teardown | |
# instantiate the FastAPI app | |
app = fastapi.FastAPI( | |
title="Retrieval Augmented Generation with Hamilton", | |
lifespan=lifespan, # pass the lifespan context | |
) | |
# define a POST endpoint | |
@app.post("/store_arxiv", tags=["Ingestion"]) | |
async def store_arxiv(arxiv_ids: list[str] = fastapi.Form(...)) -> JSONResponse: | |
"""Retrieve PDF files of arxiv articles for arxiv_ids\n | |
Read the PDF as text, create chunks, and embed them using OpenAI API\n | |
Store chunks with embeddings in Weaviate. | |
""" | |
global_context.hamilton_driver.execute( | |
["store_documents"], | |
inputs=dict( | |
arxiv_ids=arxiv_ids, | |
embedding_model_name="text-embedding-ada-002", | |
data_dir="./data", | |
vector_db_url=global_context.vector_db_url, | |
), | |
) | |
return JSONResponse(content=dict(stored_arxiv_ids=arxiv_ids)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment