Skip to content

Instantly share code, notes, and snippets.

@JoshuaPurtell
Created May 13, 2024 19:42
Show Gist options
  • Save JoshuaPurtell/75861bfc513725382f3149c591433e56 to your computer and use it in GitHub Desktop.
Save JoshuaPurtell/75861bfc513725382f3149c591433e56 to your computer and use it in GitHub Desktop.
How does GPT-4O's internal state tracking stack up?
import asyncio
import os
import random
import hashlib
from datetime import datetime
from typing import Dict, List, Type
from dotenv import load_dotenv
from loguru import logger
from pydantic import BaseModel
from diskcache import Cache
from openai import AsyncOpenAI
from anthropic import AsyncAnthropic
from together import AsyncTogether
from groq import AsyncGroq
import instructor
import loguru
# Initialize logger and load environment variables
logger = loguru.logger
load_dotenv()
# Create a cache object
cache = Cache(directory=".cache")
def generate_cache_key(messages: List[Dict], model: str) -> str:
key = "".join(msg["content"] for msg in messages) + model
return hashlib.sha256(key.encode()).hexdigest()
def generate_cache_key_with_response_model(messages: List[Dict], model: str, response_model: Type[BaseModel]) -> str:
key = "".join(msg["content"] for msg in messages) + model + str(response_model.schema())
return hashlib.sha256(key.encode()).hexdigest()
# Clients initialization
openai_client = instructor.patch(AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")))
anthropic_client = AsyncAnthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
tgi_client = AsyncTogether(api_key=os.getenv("TOGETHER_AI_API_KEY"))
groq_client = instructor.patch(AsyncGroq(api_key=os.getenv("GROQ_API_KEY")), mode=instructor.Mode.MD_JSON)
async def chat_completion(client, messages: List[Dict], model: str, temperature: float, max_tokens: int, response_model: Type[BaseModel] = None):
key = generate_cache_key_with_response_model(messages, model, response_model) if response_model else generate_cache_key(messages, model)
if key in cache:
return response_model.parse_raw(cache[key]) if response_model else cache[key]
if response_model and 'claude' not in model:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
response_model=response_model
)
result = response.json()
output = response_model.parse_raw(result)
elif 'claude' in model:
response = await client.messages.create(
model=model,
system=messages[0]["content"],
messages=messages[1:],
temperature=temperature,
max_tokens=max_tokens,
)
result = output = response.content[0].text
else:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
result = output = response.choices[0].message.content
cache[key] = result
return output
def sync_chat_completion(client, messages: List[Dict], model: str, temperature: float = 0.0, max_tokens: int = 150, response_model: Type[BaseModel] = None):
return asyncio.run(chat_completion(client, messages, model, temperature, max_tokens, response_model))
def build_messages(sys_msg: str, user_msg: str) -> List[Dict]:
return [{"role": "system", "content": sys_msg}, {"role": "user", "content": user_msg}]
class LLM:
def __init__(self, model_name: str, temperature: float = 0.0, max_tokens: int = 150, response_model: Type[BaseModel] = None):
self.model_name = model_name
self.temperature = temperature
self.max_tokens = max_tokens
self.response_model = response_model
self.client = self.determine_client()
def determine_client(self):
if "gpt" in self.model_name:
return openai_client
elif "claude" in self.model_name:
return anthropic_client
elif "llama" in self.model_name:
return groq_client
else:
return tgi_client
async def respond(self, system_prompt: str, user_prompt: str):
messages = build_messages(system_prompt, user_prompt)
return await chat_completion(self.client, messages, self.model_name, self.temperature, self.max_tokens, self.response_model)
def create_synthetic_data(k=100,trial=0):
random.seed(420+trial)
counterparty_names = ["Google", "Apple", "Microsoft", ...] # truncated for brevity
synthetic_data = []
for _ in range(k):
cnp = random.choice(counterparty_names)
date = datetime(2023, random.randint(1, 12), random.randint(1, 28))
amount = random.randint(1000, 10000)
synthetic_data.append({"counterparty_name": cnp, "amount": amount, "date": date})
synthetic_data.append({"counterparty_name": cnp, "amount": -amount, "date": date})
needle = {"counterparty_name": random.choice(counterparty_names), "amount": random.randint(1000, 10000), "date": datetime(2023, random.randint(1, 12), random.randint(1, 28))}
synthetic_data.append(needle)
random.shuffle(synthetic_data)
return "\n".join(f"{data['counterparty_name']} {data['amount']} {data['date']}" for data in synthetic_data), needle
async def check_correctness(stringified_haystack, needle, llm: LLM):
completion = await llm.respond("""
# Premise
You will be provided with records of accounting entries. Some represent real-world transactions, and others represent offsetting entries.
Each real-world transaction ought to have an offsetting entry to balance the books.
Matching entries share the following characteristics:
- Same counterparty name
- Same date
- Amount with the same absolute value but opposite sign
## Examples of Matching Entries
### Matching Pair 1
Google 1000 2023-01-01
Google -1000 2023-01-01
### Matching Pair 2
Apple 2000 2023-02-01
Apple -2000 2023-02-01
### Matching Pair 3
Microsoft 3000 2023-03-01
Microsoft -3000 2023-03-01
# Objective
Identify the entry that does not have an offsetting entry. Respond only with its information, in the same format as it is presented.
""", "The entries you have to pick from:" + stringified_haystack)
correctness = (
(str(needle["counterparty_name"]) in completion)
and (str(needle["amount"]) in completion)
and (str(needle["date"]) in completion)
)
return correctness, completion
async def full_eval_for_model(model="gpt-4o", dataset_sizes=[10, 25, 50, 100]):
llm = LLM(model_name=model)
n_trials = 3
results = {}
last_viable_k = None
for k in dataset_sizes:
results[k] = {}
results[k]["prcntg_trials_passed"] = 0
for trial in range(n_trials):
stringified, needle = create_synthetic_data(k=k, trial=trial)
correctness_for_k, full_completion = await check_correctness(stringified, needle, llm)
results[k]["prcntg_trials_passed"] += correctness_for_k
results[k]["prcntg_trials_passed"] /= n_trials
if results[k]["prcntg_trials_passed"] == 0:
break
else:
last_viable_k = k
print("Passed for k = ", k)
return results, last_viable_k
if __name__ == "__main__":
dataset_sizes = [10, 15, 20, 25, 30, 50, 75, 85, 95, 100, 125, 150, 200, 500, 1000, 2000]
model = "gpt-4o"
results, last_viable_k = asyncio.run(full_eval_for_model(model=model, dataset_sizes=dataset_sizes))
print(last_viable_k)
print(results)
# K pairs at which the LLM passes / fails (pass means it got 1/3 tries correct or better)
# OpenAI models
# gpt-4-32k: 95/100 {10: {'prcntg_trials_passed': 1.0}, 15: {'prcntg_trials_passed': 1.0}, 20: {'prcntg_trials_passed': 1.0}, 25: {'prcntg_trials_passed': 1.0}, 30: {'prcntg_trials_passed': 1.0}, 50: {'prcntg_trials_passed': 0.6666666666666666}, 75: {'prcntg_trials_passed': 0.3333333333333333}, 85: {'prcntg_trials_passed': 0.6666666666666666}, 100: {'prcntg_trials_passed': 0.0}}
# gpt-4-turbo: 75/85 {10: {'prcntg_trials_passed': 1.0}, 15: {'prcntg_trials_passed': 1.0}, 20: {'prcntg_trials_passed': 1.0}, 25: {'prcntg_trials_passed': 0.6666666666666666}, 30: {'prcntg_trials_passed': 0.6666666666666666}, 50: {'prcntg_trials_passed': 0.6666666666666666}, 75: {'prcntg_trials_passed': 0.3333333333333333}, 85: {'prcntg_trials_passed': 0.0}}
# gpt-4o: 20/25 {10: {'prcntg_trials_passed': 0.6666666666666666}, 15: {'prcntg_trials_passed': 0.6666666666666666}, 20: {'prcntg_trials_passed': 1.0}, 25: {'prcntg_trials_passed': 0.0}}
# gpt-3.5-turbo: 20/25 {10: {'prcntg_trials_passed': 0.6666666666666666}, 15: {'prcntg_trials_passed': 0.6666666666666666}, 20: {'prcntg_trials_passed': 0.3333333333333333}, 25: {'prcntg_trials_passed': 0.0}}
# Meta models
# llama-3-70b 20/25 {10: {'prcntg_trials_passed': 0.6666666666666666}, 15: {'prcntg_trials_passed': 0.6666666666666666}, 20: {'prcntg_trials_passed': 0.6666666666666666}, 25: {'prcntg_trials_passed': 0.0}}
# llama-3-8b 10/15
# Anthropic models
# claude-3-opus: 95/100 {10: {'prcntg_trials_passed': 1.0}, 15: {'prcntg_trials_passed': 1.0}, 20: {'prcntg_trials_passed': 1.0}, 25: {'prcntg_trials_passed': 0.6666666666666666}, 30: {'prcntg_trials_passed': 1.0}, 50: {'prcntg_trials_passed': 0.6666666666666666}, 75: {'prcntg_trials_passed': 0.3333333333333333}, 85: {'prcntg_trials_passed': 0.6666666666666666}, 95: {'prcntg_trials_passed': 0.6666666666666666}, 100: {'prcntg_trials_passed': 0.0}
# claude-3-sonnet: 100/125 {10: {'prcntg_trials_passed': 0.3333333333333333}, 15: {'prcntg_trials_passed': 0.6666666666666666}, 20: {'prcntg_trials_passed': 0.6666666666666666}, 25: {'prcntg_trials_passed': 0.6666666666666666}, 30: {'prcntg_trials_passed': 0.3333333333333333}, 50: {'prcntg_trials_passed': 0.6666666666666666}, 75: {'prcntg_trials_passed': 0.6666666666666666}, 85: {'prcntg_trials_passed': 0.6666666666666666}, 95: {'prcntg_trials_passed': 0.6666666666666666}, 100: {'prcntg_trials_passed': 0.3333333333333333}, 125: {'prcntg_trials_passed': 0.0}}
# claude-3-haiku: 20/25 {10: {'prcntg_trials_passed': 0.3333333333333333}, 15: {'prcntg_trials_passed': 0.3333333333333333}, 20: {'prcntg_trials_passed': 0.6666666666666666}, 25: {'prcntg_trials_passed': 0.0}}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment