Skip to content

Instantly share code, notes, and snippets.

@harisec
Created July 14, 2024 09:58
Show Gist options
  • Save harisec/2f816de4acf52a227766cbf9ba7402bc to your computer and use it in GitHub Desktop.
Save harisec/2f816de4acf52a227766cbf9ba7402bc to your computer and use it in GitHub Desktop.
# main.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import litellm
import asyncio
from typing import List
import os
from dotenv import load_dotenv
import json
from datetime import datetime
load_dotenv() # Load environment variables from .env file
app = FastAPI()
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"], # Allow the frontend origin
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Configure LiteLLM
litellm.set_verbose = True
# LLM configurations
llm_configs = {
"gpt-4": {"model": "gpt-4o"},
"claude": {"model": "anthropic/claude-3-5-sonnet-20240620"},
"deepseek": {"model": "deepseek/deepseek-chat"}
}
class Question(BaseModel):
text: str
class LLMResponse(BaseModel):
model: str
response: str
class AggregatedResponse(BaseModel):
individual_responses: List[LLMResponse]
aggregated_response: str
# Ensure logs directory exists
logs_dir = "logs"
os.makedirs(logs_dir, exist_ok=True)
def get_next_log_number():
existing_logs = [f for f in os.listdir(logs_dir) if f.endswith('.txt')]
if not existing_logs:
return 1
return max([int(f.split('.')[0]) for f in existing_logs]) + 1
def write_log(question, individual_responses, aggregated_response):
log_number = get_next_log_number()
filename = f"{log_number}.txt"
filepath = os.path.join(logs_dir, filename)
with open(filepath, 'w', encoding='utf-8') as f:
f.write("-" * 100 + "\n")
f.write(f"Question:\n{question}\n\n")
f.write("-" * 100 + "\n")
f.write("Individual Responses:\n")
f.write("-" * 100 + "\n")
for resp in individual_responses:
f.write(f"[{resp.model}]\n{resp.response}\n\n")
f.write("-" * 100 + "\n")
f.write("Aggregated Response:\n")
f.write("-" * 100 + "\n")
f.write(f"{aggregated_response}\n")
async def query_llm(question: str, llm_name: str):
try:
response = await litellm.acompletion(
model=llm_configs[llm_name]["model"],
messages=[{"role": "user", "content": question}]
)
return LLMResponse(model=llm_name, response=response.choices[0].message.content)
except Exception as e:
print(f"Error querying {llm_name}: {str(e)}")
return LLMResponse(model=llm_name, response=f"Error: {str(e)}")
async def aggregate_responses(question: str, responses: List[LLMResponse]):
aggregator_prompt = """
You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
User question: {question}
Responses from models:
{responses}
Please provide an aggregated response:
"""
responses_text = "\n\n".join([f"Model {i+1}:\n{resp.response}" for i, resp in enumerate(responses)])
full_prompt = aggregator_prompt.format(question=question, responses=responses_text)
try:
aggregated = await litellm.acompletion(
model="anthropic/claude-3-5-sonnet-20240620",
messages=[{"role": "user", "content": full_prompt}]
)
return aggregated.choices[0].message.content
except Exception as e:
print(f"Error in aggregation: {str(e)}")
return "Error occurred during aggregation."
@app.post("/ask", response_model=AggregatedResponse)
async def ask_question(question: Question):
try:
llm_responses = await asyncio.gather(*[query_llm(question.text, llm) for llm in llm_configs.keys()])
aggregated_response = await aggregate_responses(question.text, llm_responses)
# Write log
write_log(question.text, llm_responses, aggregated_response)
return AggregatedResponse(individual_responses=llm_responses, aggregated_response=aggregated_response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8010)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment