Created
February 9, 2024 00:10
-
-
Save iloveitaly/6f32d926ae2534dd4291dd2247e0214f to your computer and use it in GitHub Desktop.
Example of how to categorize emails using openai
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import click | |
import sqlite_utils | |
import json | |
import logging | |
SERVICE = "openai" | |
PREFIX = "" | |
logger = logging.getLogger(__name__) | |
def call_litellm(prompt): | |
import litellm | |
from litellm import completion | |
litellm.set_verbose = True | |
# setting format=json wraps the prompt with additional markdown | |
response = completion( | |
# model="ollama/mistral:text", | |
model="ollama/mistral:instruct", | |
api_base="http://localhost:11434", | |
# Do not use a system prompt here, it will cause the response to hang. | |
messages=[ | |
# this message wrapper adds addition markdown to the prompt | |
{ | |
"role": "user", | |
"content": prompt, | |
} | |
], | |
timeout=60, | |
# this does NOT just pass this to ollama but adds some additional markdown as well | |
format="json", | |
) | |
# no way to get the original response, it gets sucked into this terrible choice structure | |
contact_type = json.loads(response.choices[0].message.content)["contact_type"] | |
return contact_type | |
def call_ollama(prompt): | |
from ollama import Client | |
# you add additional httpx params, which is nice | |
client = Client(timeout=60) | |
response = client.generate( | |
model="mistral:instruct", stream=False, format="json", prompt=prompt | |
) | |
contact_type = json.loads(response["response"])["contact_type"] | |
return contact_type | |
def call_openai(prompt): | |
from openai import OpenAI | |
import os | |
client = OpenAI() | |
OpenAI.api_key = os.getenv("OPENAI_API_KEY") | |
completion = client.chat.completions.create( | |
# only this specific version supports json_object | |
model="gpt-3.5-turbo-1106", | |
messages=[ | |
{ | |
"role": "system", | |
"content": "You are a helpful assistant. Your response should be in JSON format.", | |
}, | |
{"role": "user", "content": prompt}, | |
], | |
response_format={"type": "json_object"}, | |
) | |
json_response = json.loads(completion.choices[0].message.content) | |
return json_response["contact_type"] | |
def categorize(email, conversations: str): | |
prompt = f""" | |
Past conversations: | |
{conversations} | |
IMPORTANT INSTRUCTIONS: | |
- Categorize email address {email} by analyzing the past included conversations. | |
- Only use 'work' if the contact is not personal and is related to my professional work (software engineering, startups, venture capital) | |
- If unsure, use 'vendor' | |
- Respond ONLY in JSON with either 'work', 'vendor', or 'friend' in a single 'contact_type' field. Example: `{{"contact_type": "vendor"}}` | |
""" | |
contact_type = None | |
# contact_type = call_ollama(prompt) | |
# contact_type = call_litellm(prompt) | |
if SERVICE == "openai": | |
contact_type = call_openai(prompt) | |
assert contact_type | |
if contact_type not in ["work", "vendor", "friend"]: | |
raise ValueError(f"Invalid contact type: {contact_type}") | |
return contact_type | |
def get_conversation(db, email, count=10): | |
# am i lazy? yes I am. | |
formatted_prefix = f"{PREFIX}_" if PREFIX else "" | |
query = f""" | |
SELECT subject, body | |
FROM {formatted_prefix}mbox_emails | |
WHERE gmail_thread_id IN ( | |
SELECT gmail_thread_id | |
FROM ( | |
SELECT gmail_thread_id, MAX(date) AS max_date | |
FROM {formatted_prefix}mbox_emails | |
WHERE EXISTS ( | |
SELECT 1 | |
FROM json_each({formatted_prefix}mbox_emails.all_contacts) AS contact | |
WHERE json_extract(contact.value, '$.email') = '{email}' | |
) | |
GROUP BY gmail_thread_id | |
ORDER BY MAX(date) DESC | |
LIMIT {count} | |
) | |
) | |
ORDER BY date DESC; | |
""" | |
return db.execute_returning_dicts(query) | |
from litellm import token_counter | |
def format_conversations(conversations: list[dict]) -> str: | |
filtered_conversations = [] | |
if len(conversations) == 0: | |
raise ValueError("No conversations found.") | |
for email in conversations: | |
token_count = token_counter( | |
# mistral | |
model="openai/gpt-3.5-turbo", | |
text=email["body"], | |
) | |
if token_count < 1_750: | |
filtered_conversations.append(email) | |
else: | |
logger.info("Skipping email due to token count: %s", token_count) | |
pass | |
conversation = "\n\n".join( | |
[ | |
# \r\n to \n | |
f"```Subject: {email['subject']}\n\n{email['body']}```".replace( | |
"\r\n", "\n" | |
) | |
# limit to a max of two conversations, mistral struggles with more | |
# also fits within gpt3's token limit | |
for email in filtered_conversations[0:2] | |
] | |
) | |
return conversation | |
def determine_contact_type(db, email): | |
conversation = get_conversation(db, email) | |
formatted_conversations = format_conversations(conversation) | |
contact_type = categorize(email, formatted_conversations) | |
return contact_type | |
def run_evals(db_path): | |
db = sqlite_utils.Database(db_path) | |
eval_list = [ | |
("email1@example.com", "vendor"), | |
("email2@example.com", "friend"), | |
("email3@example.com", "friend"), | |
("email4@example.com", "vendor"), | |
("email5@example.com", "friend"), | |
("email6@example.com", "vendor"), | |
("email7@example.com", "vendor"), | |
("email8@example.com", "work"), | |
("email9@example.com", "work"), | |
] | |
for email, expected_contact_type in eval_list: | |
contact_type = determine_contact_type(db, email) | |
print(f"email: {email}, contact_type: {contact_type}") | |
assert contact_type == expected_contact_type | |
print("All tests passed!") | |
# TODO allow service to be passed in | |
@click.command() | |
@click.argument("db_path", type=click.Path(exists=True), required=True) | |
@click.option("--evals", is_flag=True, help="Run evals") | |
@click.option("--prefix", help="Table prefix, if there is one") | |
@click.option("--email", help="Email to categorize") | |
def cli(db_path, evals, prefix, email): | |
db = sqlite_utils.Database(db_path) | |
global PREFIX | |
PREFIX = prefix | |
if evals: | |
run_evals(db_path) | |
elif email: | |
click.echo(determine_contact_type(db, email)) | |
else: | |
raise ValueError("Must specify --evals or --email") | |
if __name__ == "__main__": | |
cli() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment