Created
May 15, 2023 15:38
-
-
Save morganmcg1/99f6ce7ab9b03c86c185561229875190 to your computer and use it in GitHub Desktop.
Using prompts and fastapi
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Callable, Dict, Any, List | |
from fastapi import Depends, FastAPI | |
from fastapi.encoders import jsonable_encoder | |
from fastapi.responses import JSONResponse | |
from lanarky.responses import StreamingResponse | |
from langchain.callbacks.manager import AsyncCallbackManager | |
from langchain.callbacks.base import AsyncCallbackHandler | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains import LLMChain, SequentialChain, ConversationChain | |
from pydantic import BaseModel | |
import logging, toml, json | |
from datetime import datetime | |
import wandb | |
# Set up wandb | |
from wandb.integration.langchain import WandbTracer | |
# os.environ["WANDB_API_KEY"] = toml.load("api_keys.toml")["wandb_api_key"] | |
# os.environ["WANDB_HOST"] = "https://wandb.myhost.com" | |
wandb_config = { | |
"project": "my-project", | |
# "entity" : "my-entity", | |
"tags" : ["hi wandb team"], | |
"name" : "my_log_" + datetime.now().strftime("%Y-%m-%d-%H-%M-%S"), | |
} | |
app = FastAPI() | |
@app.on_event("shutdown") | |
async def shutdown_event(): | |
WandbTracer.finish() | |
# Set up server endpoint | |
class Request(BaseModel): | |
query : str | |
import os | |
os.environ["OPENAI_API_KEY"] = "xxx" | |
import langchain | |
print(wandb.__version__) | |
print(langchain.__version__) | |
@app.post("/chat") | |
async def chat( | |
request: Request | |
): #-> StreamingResponse: | |
# set up LLM with wandb tracer hooked up | |
llm = ChatOpenAI( | |
streaming=True, # must stream | |
openai_api_key="XXX" | |
) | |
# set up chain taking in user query | |
from langchain import PromptTemplate | |
llm_chain = LLMChain( | |
llm=llm, | |
# prompt=PromptTemplate.from_template(request.query) | |
prompt=PromptTemplate.from_template("You are a super happy robot") | |
) | |
response = llm_chain({"input":"hey hows it going reall?"}, | |
callbacks = [WandbTracer(wandb_config)] | |
# callbacks = AsyncCallbackManager([WandbTracer(wandb_config)]) | |
) | |
return response | |
import uvicorn | |
if __name__ == "__main__": | |
uvicorn.run(app, host="127.0.0.1", port=8010) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment