Skip to content

Instantly share code, notes, and snippets.

@gustavofuhr
Created August 20, 2024 13:52
Show Gist options
  • Save gustavofuhr/09aa0889b56abc612048daf88435e103 to your computer and use it in GitHub Desktop.
Save gustavofuhr/09aa0889b56abc612048daf88435e103 to your computer and use it in GitHub Desktop.
A wrapper for GPT4, Claude Sonnet, LLama and Google Gemini.
import os
from enum import Enum
import argparse
from openai import OpenAI
import anthropic
import ollama
import google.generativeai as genai
class LLM_Models(Enum):
OPENAI_GPT4_MINI = 0
ANTHROPIC_SONNET = 1
LLAMA_LOCAL = 2
GOOGLE_GEMINI = 3
class LLM_wrapper:
"""
This is a wrapper for some popular LLMs currently available. It is a simple
code that adapts the input and output for four different APIs:
- OpenAI GPT-4 Mini (OPENAI_GPT4_MINI): the famous OpenAI GPT-4o-mini model.
- Anthropic Claude Sonnet (ANTHROPIC_SONNET): Claude Sonnet model.
- Llama (LLAMA_LOCAL): Meta's model running locally, interfaced by ollama.
- Google Gemini (GOOGLE_GEMINI): Google's Gemini model.
You're expected to have set environment variables for the API keys:
OPENAI_API_KEY, ANTHROPIC_API_KEY, GOOGLE_GEMINI_API_KEY
1. Create the model like this:
llm = LLM_wrapper(LLM_Models.OPENAI_GPT4_MINI, max_tokens=1024)
Notice that llama implementation does not provide a max_tokens parameter,
so it will be ignored.
2. Start sending messages:
assistant_response = llm.send_message("Hello, I'm a user", "user")
For each message it would be send the whole conversation history to the API.
3. Optionally, you can print the trace of the conversation:
llm.print_trace()
Have fun and remember that LLMs lie! ;)
"""
def __init__(self, llm_model: LLM_Models, max_tokens: int = 1024):
self.llm_model = llm_model
self.max_tokens = max_tokens
if llm_model == LLM_Models.OPENAI_GPT4_MINI:
self.init_openai_gpt4_mini()
elif llm_model == LLM_Models.ANTHROPIC_SONNET:
self.init_anthropic()
elif llm_model == LLM_Models.LLAMA_LOCAL:
self.init_ollama()
elif llm_model == LLM_Models.GOOGLE_GEMINI:
self.init_google_gemini()
self.system_message = ""
self.user_messages = []
self.trace = ""
self.message_history = []
def init_openai_gpt4_mini(self):
self.client = OpenAI(
project='proj_XCapd6U3grhxrToccuW6UdNO',
api_key=os.environ.get('OPENAI_API_KEY')
)
self.model_name = "gpt-4o-mini"
self.chat_function = self.client.chat.completions.create
self.response_content_fun = lambda response: response.choices[0].message.content
def init_anthropic(self):
self.client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
self.model_name = "claude-3-5-sonnet-20240620"
self.chat_function = self.client.messages.create
self.response_content_fun = lambda response: response.content[0].text
def init_ollama(self):
self.model_name = "llama3.1"
self.chat_function = ollama.chat
self.max_tokens = None
self.response_content_fun = lambda response: response["message"]["content"]
def init_google_gemini(self):
self.model_name = "gemini-1.5-flash"
genai.configure(api_key=os.environ["GOOGLE_GEMINI_API_KEY"])
self.client = genai.GenerativeModel()
self.chat = self.client.start_chat()
self.chat_function = self.chat.send_message
self.response_content_fun = lambda response: response.candidates[0].content.parts[0].text
def print_trace(self):
print("\nMESSAGES TRACE")
print(self.trace)
def add_to_trace(self, message, role):
self.trace += f"{role}: {message}\n\n"
def call_api_chat_w_history(self):
print(f"Calling api with the following messages\n{self.message_history}")
if self.llm_model != LLM_Models.GOOGLE_GEMINI:
chat_kwargs = {
"model": self.model_name,
"messages": self.message_history,
}
if self.max_tokens is not None:
chat_kwargs["max_tokens"] = self.max_tokens
res_api = self.chat_function(**chat_kwargs)
else:
last_user_message = self.message_history[-1]["parts"]
res_api = self.chat_function(last_user_message)
self.add_to_trace(">> API called.", str(self.llm_model))
return res_api
def add_message_history(self, message: str, role: str):
content_field = "content" if self.llm_model != LLM_Models.GOOGLE_GEMINI else "parts"
self.message_history.append({"role": role, content_field: message})
self.add_to_trace(message, role)
def clear_history(self):
self.message_history = []
def send_message(self, message: str, role: str = "user"):
if role == "system" \
and (self.llm_model == LLM_Models.ANTHROPIC_SONNET or self.llm_model == LLM_Models.GOOGLE_GEMINI):
print(f"WARNING: {self.llm_model} does not support system messages, changing to user message.")
self.add_message_history(message, "user")
else:
self.add_message_history(message, role)
llm_res = self.call_api_chat_w_history()
assistant_response = self.response_content_fun(llm_res)
assistant_role = "assistant" if self.llm_model != LLM_Models.GOOGLE_GEMINI else "model"
self.add_message_history(assistant_response, assistant_role)
return assistant_response
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment