Created
August 20, 2024 13:52
-
-
Save gustavofuhr/09aa0889b56abc612048daf88435e103 to your computer and use it in GitHub Desktop.
A wrapper for GPT4, Claude Sonnet, LLama and Google Gemini.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from enum import Enum | |
import argparse | |
from openai import OpenAI | |
import anthropic | |
import ollama | |
import google.generativeai as genai | |
class LLM_Models(Enum): | |
OPENAI_GPT4_MINI = 0 | |
ANTHROPIC_SONNET = 1 | |
LLAMA_LOCAL = 2 | |
GOOGLE_GEMINI = 3 | |
class LLM_wrapper: | |
""" | |
This is a wrapper for some popular LLMs currently available. It is a simple | |
code that adapts the input and output for four different APIs: | |
- OpenAI GPT-4 Mini (OPENAI_GPT4_MINI): the famous OpenAI GPT-4o-mini model. | |
- Anthropic Claude Sonnet (ANTHROPIC_SONNET): Claude Sonnet model. | |
- Llama (LLAMA_LOCAL): Meta's model running locally, interfaced by ollama. | |
- Google Gemini (GOOGLE_GEMINI): Google's Gemini model. | |
You're expected to have set environment variables for the API keys: | |
OPENAI_API_KEY, ANTHROPIC_API_KEY, GOOGLE_GEMINI_API_KEY | |
1. Create the model like this: | |
llm = LLM_wrapper(LLM_Models.OPENAI_GPT4_MINI, max_tokens=1024) | |
Notice that llama implementation does not provide a max_tokens parameter, | |
so it will be ignored. | |
2. Start sending messages: | |
assistant_response = llm.send_message("Hello, I'm a user", "user") | |
For each message it would be send the whole conversation history to the API. | |
3. Optionally, you can print the trace of the conversation: | |
llm.print_trace() | |
Have fun and remember that LLMs lie! ;) | |
""" | |
def __init__(self, llm_model: LLM_Models, max_tokens: int = 1024): | |
self.llm_model = llm_model | |
self.max_tokens = max_tokens | |
if llm_model == LLM_Models.OPENAI_GPT4_MINI: | |
self.init_openai_gpt4_mini() | |
elif llm_model == LLM_Models.ANTHROPIC_SONNET: | |
self.init_anthropic() | |
elif llm_model == LLM_Models.LLAMA_LOCAL: | |
self.init_ollama() | |
elif llm_model == LLM_Models.GOOGLE_GEMINI: | |
self.init_google_gemini() | |
self.system_message = "" | |
self.user_messages = [] | |
self.trace = "" | |
self.message_history = [] | |
def init_openai_gpt4_mini(self): | |
self.client = OpenAI( | |
project='proj_XCapd6U3grhxrToccuW6UdNO', | |
api_key=os.environ.get('OPENAI_API_KEY') | |
) | |
self.model_name = "gpt-4o-mini" | |
self.chat_function = self.client.chat.completions.create | |
self.response_content_fun = lambda response: response.choices[0].message.content | |
def init_anthropic(self): | |
self.client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) | |
self.model_name = "claude-3-5-sonnet-20240620" | |
self.chat_function = self.client.messages.create | |
self.response_content_fun = lambda response: response.content[0].text | |
def init_ollama(self): | |
self.model_name = "llama3.1" | |
self.chat_function = ollama.chat | |
self.max_tokens = None | |
self.response_content_fun = lambda response: response["message"]["content"] | |
def init_google_gemini(self): | |
self.model_name = "gemini-1.5-flash" | |
genai.configure(api_key=os.environ["GOOGLE_GEMINI_API_KEY"]) | |
self.client = genai.GenerativeModel() | |
self.chat = self.client.start_chat() | |
self.chat_function = self.chat.send_message | |
self.response_content_fun = lambda response: response.candidates[0].content.parts[0].text | |
def print_trace(self): | |
print("\nMESSAGES TRACE") | |
print(self.trace) | |
def add_to_trace(self, message, role): | |
self.trace += f"{role}: {message}\n\n" | |
def call_api_chat_w_history(self): | |
print(f"Calling api with the following messages\n{self.message_history}") | |
if self.llm_model != LLM_Models.GOOGLE_GEMINI: | |
chat_kwargs = { | |
"model": self.model_name, | |
"messages": self.message_history, | |
} | |
if self.max_tokens is not None: | |
chat_kwargs["max_tokens"] = self.max_tokens | |
res_api = self.chat_function(**chat_kwargs) | |
else: | |
last_user_message = self.message_history[-1]["parts"] | |
res_api = self.chat_function(last_user_message) | |
self.add_to_trace(">> API called.", str(self.llm_model)) | |
return res_api | |
def add_message_history(self, message: str, role: str): | |
content_field = "content" if self.llm_model != LLM_Models.GOOGLE_GEMINI else "parts" | |
self.message_history.append({"role": role, content_field: message}) | |
self.add_to_trace(message, role) | |
def clear_history(self): | |
self.message_history = [] | |
def send_message(self, message: str, role: str = "user"): | |
if role == "system" \ | |
and (self.llm_model == LLM_Models.ANTHROPIC_SONNET or self.llm_model == LLM_Models.GOOGLE_GEMINI): | |
print(f"WARNING: {self.llm_model} does not support system messages, changing to user message.") | |
self.add_message_history(message, "user") | |
else: | |
self.add_message_history(message, role) | |
llm_res = self.call_api_chat_w_history() | |
assistant_response = self.response_content_fun(llm_res) | |
assistant_role = "assistant" if self.llm_model != LLM_Models.GOOGLE_GEMINI else "model" | |
self.add_message_history(assistant_response, assistant_role) | |
return assistant_response | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment