Skip to content

Instantly share code, notes, and snippets.

@zilto
Created September 6, 2023 21:09
Show Gist options
  • Save zilto/319772fa37f63ee38b893a3e39ed6592 to your computer and use it in GitHub Desktop.
Save zilto/319772fa37f63ee38b893a3e39ed6592 to your computer and use it in GitHub Desktop.
from contextlib import asynccontextmanager
from dataclasses import dataclass
import fastapi
import pydantic
from fastapi.responses import JSONResponse
from hamilton import driver
# define a global dataclass that is shared across endpoints
@dataclass
class GlobalContext:
vector_db_url: str
hamilton_driver: driver.Driver
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI) -> None:
"""Startup and shutdown logic of the FastAPI app
Above yield statement is at startup and below at shutdown
Import the Hamilton modules and instantiate the Hamilton driver
"""
# import the Python modules containing your dataflows
import ingestion
import retrieval
import vector_db
driver_config = dict()
dr = (
driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True) # to allow Parallelizable/Collect
.with_config(driver_config)
.with_modules(ingestion, retrieval, vector_db) # pass our dataflows
.build()
)
# make the variable global to reuse it within endpoints
global global_context
global_context = GlobalContext(vector_db_url="http://weaviate_storage:8083", hamilton_driver=dr)
# execute Hamilton code to make sure the Weaviate class schemas is instantiated
global_context.hamilton_driver.execute(
["initialize_weaviate_instance"], inputs=dict(vector_db_url=global_context.vector_db_url)
)
# anything above yield is executed at startup
yield
# anything below yield is executed at teardown
# instantiate the FastAPI app
app = fastapi.FastAPI(
title="Retrieval Augmented Generation with Hamilton",
lifespan=lifespan, # pass the lifespan context
)
# define a POST endpoint
@app.post("/store_arxiv", tags=["Ingestion"])
async def store_arxiv(arxiv_ids: list[str] = fastapi.Form(...)) -> JSONResponse:
"""Retrieve PDF files of arxiv articles for arxiv_ids\n
Read the PDF as text, create chunks, and embed them using OpenAI API\n
Store chunks with embeddings in Weaviate.
"""
global_context.hamilton_driver.execute(
["store_documents"],
inputs=dict(
arxiv_ids=arxiv_ids,
embedding_model_name="text-embedding-ada-002",
data_dir="./data",
vector_db_url=global_context.vector_db_url,
),
)
return JSONResponse(content=dict(stored_arxiv_ids=arxiv_ids))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment