Created
February 28, 2024 08:58
-
-
Save gustavz/cc6fa1b4ff8fba19296fe9c3638d2f85 to your computer and use it in GitHub Desktop.
Async Compare GPT Responses to Dataframe
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import asyncio | |
import openai | |
from sentence_transformers import SentenceTransformer | |
openai.api_key = "YOUR-API-KEY" | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
async def fetch_gpt_responses(prompts: list[str]) -> list[str]: | |
tasks = [] | |
for prompt in prompts: | |
task = openai.ChatCompletion.acreate( | |
engine="gpt-35-turbo", | |
messages = [ | |
{"role":"user","content":f"{prompt}"}, | |
], | |
temperature=0.7, | |
max_tokens=800, | |
top_p=0.95, | |
frequency_penalty=0, | |
presence_penalty=0, | |
stop=None | |
) | |
tasks.append(task) | |
responses = await asyncio.gather(*tasks) | |
return [response.choices[0].message['content'].strip() for response in responses] | |
def generate_embeddings(texts: list[str]) -> np.ndarray: | |
return model.encode(texts, show_progress_bar=False) | |
def cosine_similarity(vec_a: np.ndarray, vec_b: np.ndarray) -> float: | |
return np.dot(vec_a, vec_b) / (np.linalg.norm(vec_a) * np.linalg.norm(vec_b)) | |
async def compare_responses_async(dataframe: pd.DataFrame) -> pd.DataFrame: | |
prompts = dataframe['prompts'].tolist() | |
original_responses = dataframe['responses'].tolist() | |
gpt_responses = await fetch_gpt_responses(prompts) | |
embeddings = generate_embeddings(original_responses + gpt_responses) | |
original_embeddings, gpt_embeddings = np.split(embeddings, 2) | |
similarity_scores = [cosine_similarity(orig_emb, gpt_emb) for orig_emb, gpt_emb in zip(original_embeddings, gpt_embeddings)] | |
dataframe['gpt_responses'] = gpt_responses | |
dataframe['similarity_scores'] = similarity_scores | |
return dataframe | |
async def main(): | |
# Sample DataFrame for demonstration | |
df = pd.DataFrame({ | |
'prompts': ['What is the capital of France?', 'What is the leading AutoML platform?'], | |
'responses': ['Paris', 'DataRobot'] | |
}) | |
updated_df = await compare_responses_async(df) | |
print(updated_df) | |
await main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment