Skip to content

Instantly share code, notes, and snippets.

@iloveitaly
Created February 9, 2024 00:10
Show Gist options
  • Save iloveitaly/6f32d926ae2534dd4291dd2247e0214f to your computer and use it in GitHub Desktop.
Save iloveitaly/6f32d926ae2534dd4291dd2247e0214f to your computer and use it in GitHub Desktop.
Example of how to categorize emails using openai
import click
import sqlite_utils
import json
import logging
SERVICE = "openai"
PREFIX = ""
logger = logging.getLogger(__name__)
def call_litellm(prompt):
import litellm
from litellm import completion
litellm.set_verbose = True
# setting format=json wraps the prompt with additional markdown
response = completion(
# model="ollama/mistral:text",
model="ollama/mistral:instruct",
api_base="http://localhost:11434",
# Do not use a system prompt here, it will cause the response to hang.
messages=[
# this message wrapper adds addition markdown to the prompt
{
"role": "user",
"content": prompt,
}
],
timeout=60,
# this does NOT just pass this to ollama but adds some additional markdown as well
format="json",
)
# no way to get the original response, it gets sucked into this terrible choice structure
contact_type = json.loads(response.choices[0].message.content)["contact_type"]
return contact_type
def call_ollama(prompt):
from ollama import Client
# you add additional httpx params, which is nice
client = Client(timeout=60)
response = client.generate(
model="mistral:instruct", stream=False, format="json", prompt=prompt
)
contact_type = json.loads(response["response"])["contact_type"]
return contact_type
def call_openai(prompt):
from openai import OpenAI
import os
client = OpenAI()
OpenAI.api_key = os.getenv("OPENAI_API_KEY")
completion = client.chat.completions.create(
# only this specific version supports json_object
model="gpt-3.5-turbo-1106",
messages=[
{
"role": "system",
"content": "You are a helpful assistant. Your response should be in JSON format.",
},
{"role": "user", "content": prompt},
],
response_format={"type": "json_object"},
)
json_response = json.loads(completion.choices[0].message.content)
return json_response["contact_type"]
def categorize(email, conversations: str):
prompt = f"""
Past conversations:
{conversations}
IMPORTANT INSTRUCTIONS:
- Categorize email address {email} by analyzing the past included conversations.
- Only use 'work' if the contact is not personal and is related to my professional work (software engineering, startups, venture capital)
- If unsure, use 'vendor'
- Respond ONLY in JSON with either 'work', 'vendor', or 'friend' in a single 'contact_type' field. Example: `{{"contact_type": "vendor"}}`
"""
contact_type = None
# contact_type = call_ollama(prompt)
# contact_type = call_litellm(prompt)
if SERVICE == "openai":
contact_type = call_openai(prompt)
assert contact_type
if contact_type not in ["work", "vendor", "friend"]:
raise ValueError(f"Invalid contact type: {contact_type}")
return contact_type
def get_conversation(db, email, count=10):
# am i lazy? yes I am.
formatted_prefix = f"{PREFIX}_" if PREFIX else ""
query = f"""
SELECT subject, body
FROM {formatted_prefix}mbox_emails
WHERE gmail_thread_id IN (
SELECT gmail_thread_id
FROM (
SELECT gmail_thread_id, MAX(date) AS max_date
FROM {formatted_prefix}mbox_emails
WHERE EXISTS (
SELECT 1
FROM json_each({formatted_prefix}mbox_emails.all_contacts) AS contact
WHERE json_extract(contact.value, '$.email') = '{email}'
)
GROUP BY gmail_thread_id
ORDER BY MAX(date) DESC
LIMIT {count}
)
)
ORDER BY date DESC;
"""
return db.execute_returning_dicts(query)
from litellm import token_counter
def format_conversations(conversations: list[dict]) -> str:
filtered_conversations = []
if len(conversations) == 0:
raise ValueError("No conversations found.")
for email in conversations:
token_count = token_counter(
# mistral
model="openai/gpt-3.5-turbo",
text=email["body"],
)
if token_count < 1_750:
filtered_conversations.append(email)
else:
logger.info("Skipping email due to token count: %s", token_count)
pass
conversation = "\n\n".join(
[
# \r\n to \n
f"```Subject: {email['subject']}\n\n{email['body']}```".replace(
"\r\n", "\n"
)
# limit to a max of two conversations, mistral struggles with more
# also fits within gpt3's token limit
for email in filtered_conversations[0:2]
]
)
return conversation
def determine_contact_type(db, email):
conversation = get_conversation(db, email)
formatted_conversations = format_conversations(conversation)
contact_type = categorize(email, formatted_conversations)
return contact_type
def run_evals(db_path):
db = sqlite_utils.Database(db_path)
eval_list = [
("email1@example.com", "vendor"),
("email2@example.com", "friend"),
("email3@example.com", "friend"),
("email4@example.com", "vendor"),
("email5@example.com", "friend"),
("email6@example.com", "vendor"),
("email7@example.com", "vendor"),
("email8@example.com", "work"),
("email9@example.com", "work"),
]
for email, expected_contact_type in eval_list:
contact_type = determine_contact_type(db, email)
print(f"email: {email}, contact_type: {contact_type}")
assert contact_type == expected_contact_type
print("All tests passed!")
# TODO allow service to be passed in
@click.command()
@click.argument("db_path", type=click.Path(exists=True), required=True)
@click.option("--evals", is_flag=True, help="Run evals")
@click.option("--prefix", help="Table prefix, if there is one")
@click.option("--email", help="Email to categorize")
def cli(db_path, evals, prefix, email):
db = sqlite_utils.Database(db_path)
global PREFIX
PREFIX = prefix
if evals:
run_evals(db_path)
elif email:
click.echo(determine_contact_type(db, email))
else:
raise ValueError("Must specify --evals or --email")
if __name__ == "__main__":
cli()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment