Skip to content

Instantly share code, notes, and snippets.

@ColeMurray
Created July 19, 2024 03:34
Show Gist options
  • Save ColeMurray/bdfb1741f53563df640f945306f3e059 to your computer and use it in GitHub Desktop.
Save ColeMurray/bdfb1741f53563df640f945306f3e059 to your computer and use it in GitHub Desktop.
import json
import logging
import csv
from typing import List, Dict, Any, Optional
from pydantic import BaseModel, create_model, Field, validator, field_validator
from openai import OpenAI
import time
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
client = OpenAI()
def get_task_description() -> str:
logger.info("Requesting task description from user")
description = input("Please describe the task you'd like to evaluate: ")
logger.info(f"Received task description: {description}")
return description
def get_pair_generation_model() -> str:
logger.info("Requesting model for pair generation")
model = input("Enter the name of the model to use for pair generation (e.g., gpt-3.5-turbo-0125): ")
logger.info(f"Selected model for pair generation: {model}")
return model
def get_data_samples() -> Optional[List[Dict[str, Any]]]:
logger.info("Requesting data samples from user")
use_samples = input("Do you want to provide data samples to influence pair generation? (yes/no): ").lower()
if use_samples == 'yes':
file_path = input("Enter the path to your JSON file containing data samples: ")
try:
with open(file_path, 'r') as file:
samples = json.load(file)
if not isinstance(samples, list):
logger.error("The JSON file should contain a list of samples")
return None
logger.info(f"Successfully loaded {len(samples)} data samples from {file_path}")
return samples
except FileNotFoundError:
logger.error(f"File not found: {file_path}")
except json.JSONDecodeError:
logger.error(f"Invalid JSON in file: {file_path}")
except Exception as e:
logger.error(f"Error reading file: {str(e)}")
return None
def generate_schema(task_description: str, pair_generation_model: str) -> Dict[str, Any]:
logger.info("Generating schema based on task description")
prompt = f"""
Task: Create a Pydantic schema for input and output based on the following task description:
"{task_description}"
Instructions:
1. Analyze the task description carefully.
2. Determine appropriate input and output fields based on the task.
3. Create a JSON object with two keys: "input_schema" and "output_schema".
4. For each schema, specify field names and their corresponding Python type hints.
5. Use only the following Python type hints: str, int, float, bool, List[str], List[int], List[float].
6. Ensure the output_schema always includes a "category" field (str) and a "confidence" field (float).
7. Provide your response as a valid JSON object, nothing else.
Example of a valid response:
{{
"input_schema": {{
"text": "str",
"max_length": "int"
}},
"output_schema": {{
"category": "str",
"confidence": "float",
"is_relevant": "bool"
}}
}}
Now, generate the schema for the given task. Respond with only the JSON:
"""
start_time = time.time()
response = client.chat.completions.create(
model=pair_generation_model,
messages=[{"role": "user", "content": prompt}]
)
end_time = time.time()
logger.info(f"Schema generation took {end_time - start_time:.2f} seconds")
try:
schema = json.loads(response.choices[0].message.content)
if not isinstance(schema, dict) or "input_schema" not in schema or "output_schema" not in schema:
raise ValueError("Invalid schema structure")
if "category" not in schema["output_schema"] or "confidence" not in schema["output_schema"]:
raise ValueError("Output schema must include 'category' and 'confidence' fields")
logger.info(f"Generated schema: {json.dumps(schema, indent=2)}")
return schema
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON from model output: {str(e)}")
logger.error(f"Raw model output: {response.choices[0].message.content}")
raise
except ValueError as e:
logger.error(f"Invalid schema structure: {str(e)}")
logger.error(f"Generated schema: {response.choices[0].message.content}")
raise
except Exception as e:
logger.error(f"Unexpected error in schema generation: {str(e)} {response.choices[0].message.content}")
raise e
def create_pydantic_models(schema: Dict[str, Any]):
logger.info("Creating Pydantic models from schema")
def create_field(field_type: str):
if field_type == "str":
return (str, ...)
elif field_type == "int":
return (int, ...)
elif field_type == "float":
return (float, Field(..., ge=0, le=1))
elif field_type == "bool":
return (bool, ...)
elif field_type.startswith("List["):
inner_type = field_type[5:-1]
return (List[create_field(inner_type)[0]], ...)
else:
logger.warning(f"Unknown field type: {field_type}. Defaulting to str.")
return (str, ...)
input_fields = {k: create_field(v) for k, v in schema['input_schema'].items()}
output_fields = {k: create_field(v) for k, v in schema['output_schema'].items()}
class OutputModel(BaseModel):
category: str
confidence: float
@field_validator('confidence')
@classmethod
def check_confidence(cls, v):
return round(v, 2)
InputModel = create_model('InputModel', **input_fields)
logger.info(f"Created InputModel: {InputModel.schema_json()}")
logger.info(f"Created OutputModel: {OutputModel.schema_json()}")
return InputModel, OutputModel
def generate_input_output_pairs(task_description: str, pair_generation_model: str, InputModel: BaseModel, OutputModel: BaseModel, num_pairs: int = 5, data_samples: Optional[List[Dict[str, Any]]] = None) -> List[Dict[str, Any]]:
logger.info(f"Generating {num_pairs} input-output pairs")
sample_prompt = ""
if data_samples:
sample_prompt = f"Use these data samples as inspiration: {json.dumps(data_samples)}\n"
prompt = f"""
Task description: {task_description}
{sample_prompt}
Generate {num_pairs} input-output pairs for the above task.
Input schema: {InputModel.schema_json()}
Output schema: {OutputModel.schema_json()}
Respond with a JSON array of objects, each containing 'input' and 'output' keys.
Ensure that the types match the schema exactly.
Example response format:
[
{{
"input": {{ ... input fields matching InputModel ... }},
"output": {{ ... output fields matching OutputModel ... }}
}},
...
]
IMPORTANT: Your response must be a single, valid JSON array with no additional text or formatting.
"""
start_time = time.time()
response = client.chat.completions.create(
model=pair_generation_model,
messages=[{"role": "user", "content": prompt}]
)
end_time = time.time()
logger.info(f"Input-output pair generation took {end_time - start_time:.2f} seconds")
try:
pairs = json.loads(response.choices[0].message.content)
if not isinstance(pairs, list):
raise ValueError("Generated pairs are not in the expected list format")
# Validate each pair against the models
validated_pairs = []
for pair in pairs:
try:
input_data = InputModel(**pair['input'])
output_data = OutputModel(**pair['output'])
validated_pairs.append({
"input": input_data.dict(),
"output": output_data.dict()
})
except Exception as e:
logger.warning(f"Skipping invalid pair: {pair}. Error: {str(e)}")
logger.info(f"Generated {len(validated_pairs)} valid input-output pairs")
return validated_pairs
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON from model output: {str(e)}")
logger.error(f"Raw model output: {response.choices[0].message.content}")
raise
except Exception as e:
logger.error(f"Unexpected error in pair generation: {str(e)}")
raise
def evaluate_model(model_name: str, task_description: str, pairs: List[Dict[str, Any]], InputModel: BaseModel, OutputModel: BaseModel) -> Dict[str, float]:
logger.info(f"Evaluating model: {model_name}")
correct = 0
total = len(pairs)
results = []
for i, pair in enumerate(pairs):
logger.info(f"Evaluating pair {i+1}/{total}")
input_data = InputModel(**pair['input'])
expected_output = OutputModel(**pair['output'])
logger.info(f"Input: {input_data.json()}")
logger.info(f"Expected output: {expected_output.json()}")
prompt = f"""Task description: {task_description}
Given the input: {input_data.json()}, perform the task described above and provide the output.
The output should be a valid JSON object matching this schema: {OutputModel.schema_json()}
IMPORTANT: Your response must be a single, valid JSON object with no additional text, code blocks, or formatting. Example of a valid response:
{{"category": "Electronics", "confidence": 0.95}}"""
start_time = time.time()
response = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}]
)
end_time = time.time()
logger.info(f"Model response time: {end_time - start_time:.2f} seconds")
logger.info(f"Raw model output: {response.choices[0].message.content}")
try:
json_response = json.loads(response.choices[0].message.content)
actual_output = OutputModel(**json_response)
logger.info(f"Parsed output: {actual_output.json()}")
category_match = expected_output.category.lower() in actual_output.category.lower() or \
actual_output.category.lower() in expected_output.category.lower()
confidence_match = abs(expected_output.confidence - actual_output.confidence) <= 0.2
if category_match and confidence_match:
correct += 1
logger.info("Output matched expected output (with tolerance)")
result = "Correct"
else:
logger.info("Output did not match expected output")
if not category_match:
logger.info(f"Category mismatch: expected '{expected_output.category}', got '{actual_output.category}'")
if not confidence_match:
logger.info(f"Confidence mismatch: expected {expected_output.confidence}, got {actual_output.confidence}")
result = "Incorrect"
results.append({
"input": input_data.json(),
"expected_output": expected_output.json(),
"actual_output": actual_output.json(),
"result": result
})
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON from model output: {str(e)}")
results.append({
"input": input_data.json(),
"expected_output": expected_output.json(),
"actual_output": "Error: Invalid JSON",
"result": "Error"
})
except Exception as e:
logger.error(f"Error parsing model output: {str(e)}")
results.append({
"input": input_data.json(),
"expected_output": expected_output.json(),
"actual_output": f"Error: {str(e)}",
"result": "Error"
})
logger.info("---")
accuracy = correct / total
logger.info(f"Evaluation complete. Accuracy: {accuracy:.2f}")
return {"accuracy": accuracy}, results
def save_to_csv(data: List[Dict[str, Any]], filename: str):
if not data:
logger.warning(f"No data to save to {filename}")
return
keys = data[0].keys()
with open(filename, 'w', newline='') as output_file:
dict_writer = csv.DictWriter(output_file, keys)
dict_writer.writeheader()
dict_writer.writerows(data)
logger.info(f"Data saved to {filename}")
def main():
logger.info("Starting LLM evaluation tool")
task_description = get_task_description()
pair_generation_model = get_pair_generation_model()
data_samples = get_data_samples()
schema = generate_schema(task_description, pair_generation_model)
print("Generated schema:")
print(json.dumps(schema, indent=2))
confirm = input("Do you confirm this schema? (yes/no): ")
if confirm.lower() != 'yes':
logger.info("User did not confirm schema. Exiting.")
print("Please try again with a different task description.")
return
InputModel, OutputModel = create_pydantic_models(schema)
pairs = generate_input_output_pairs(task_description, pair_generation_model, InputModel, OutputModel, data_samples=data_samples, num_pairs=25)
print("Generated input/output pairs:")
print(json.dumps(pairs, indent=2))
confirm = input("Do you confirm these pairs? (yes/no): ")
if confirm.lower() != 'yes':
logger.info("User did not confirm input/output pairs. Exiting.")
print("Please try again or modify the pairs manually.")
return
save_to_csv(pairs, "input_output_pairs.csv")
model_name = input("Enter the name of the model you want to evaluate (e.g., gpt-3.5-turbo-0125): ")
logger.info(f"User selected model: {model_name}")
results, detailed_results = evaluate_model(model_name, task_description, pairs, InputModel, OutputModel)
print("Evaluation results:")
print(json.dumps(results, indent=2))
save_to_csv(detailed_results, "evaluation_results.csv")
logger.info("LLM evaluation tool execution completed")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment