Created
October 3, 2024 23:20
-
-
Save pgarbacki/870d5fa69e13b0f45266157a0be11038 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import necessary libraries | |
import json | |
import random | |
from sklearn.model_selection import train_test_split | |
import dspy | |
from dspy.evaluate import Evaluate | |
from dspy.teleprompt import BootstrapFewShotWithRandomSearch | |
import os | |
import dspy | |
diagnosis_grammar = """ | |
root ::= diagnosis | |
diagnosis ::= "easy" | "medium" | "hard" | |
""" | |
# Example configuration | |
turbo = dspy.OpenAI( | |
model="fireworks/qwen2p5-72b-instruct", | |
api_key=os.environ["FIREWORKS_API_KEY"], | |
base_url="https://api.fireworks.ai/inference/v1", | |
max_tokens=500, | |
stop="---", | |
response_format={"type": "grammar", "grammar": diagnosis_grammar}, | |
model_type="chat", | |
) | |
dspy.settings.configure(lm=turbo) | |
# 1. Loading and Preparing the Data | |
class ClassificationDataset: | |
def __init__( | |
self, json_path, train_size=0.7, dev_size=0.15, test_size=0.15, random_seed=42 | |
): | |
with open(json_path, "r") as f: | |
data = json.load(f) | |
# Shuffle the data | |
random.seed(random_seed) | |
random.shuffle(data) | |
# Split the data | |
total = len(data) | |
train_end = int(total * train_size) | |
dev_end = train_end + int(total * dev_size) | |
for item in data: | |
item["question"] = item.pop("question_content", item.get("question")) | |
train_data = data[:train_end] | |
dev_data = data[train_end:dev_end] | |
test_data = data[dev_end:] | |
# Convert to DSPy Examples | |
self.train = [ | |
dspy.Example(**item).with_inputs("question") for item in train_data | |
] | |
self.dev = [dspy.Example(**item).with_inputs("question") for item in dev_data] | |
self.test = [dspy.Example(**item).with_inputs("question") for item in test_data] | |
print(f"Train set size: {len(self.train)}") | |
print(f"Dev set size: {len(self.dev)}") | |
print(f"Test set size: {len(self.test)}") | |
# 2. Defining the Classification Dataset | |
dataset = ClassificationDataset( | |
json_path="/home/bchen/LiveCodeBench/output/Qwen2p5_72bInstruct/Scenario.codegeneration_1_0.2_eval_all.json" | |
) | |
# 3. Defining the Classification Signature and Predictor | |
class DifficultyClassificationSignature(dspy.Signature): | |
"""Classify the difficulty of a problem as easy, medium, or difficult.""" | |
question = dspy.InputField(desc="The content of the problem to classify.") | |
difficulty = dspy.OutputField( | |
desc="The difficulty level: easy, medium, or difficult." | |
) | |
# Define the predictor using the signature | |
difficulty_classifier = dspy.Predict(DifficultyClassificationSignature) | |
# Example usage | |
example_question = dataset.dev[0].question | |
prediction = difficulty_classifier(question=example_question) | |
print(f"Predicted Difficulty: {prediction.difficulty}") | |
# 4. Creating the Evaluation Metric | |
def classification_metric(gold, pred, trace=None): | |
""" | |
Computes the accuracy of the predictions. | |
""" | |
metric = int( | |
str(gold.difficulty).lower() | |
== str(pred.difficulty.strip().split("\n")[0]).lower() | |
) | |
return metric | |
# 5. Evaluating the Classifier | |
evaluator = Evaluate( | |
devset=dataset.dev, | |
metric=classification_metric, | |
num_threads=8, # Adjust based on your machine | |
display_progress=True, | |
display_table=10, | |
) | |
# Run evaluation | |
evaluator(difficulty_classifier) | |
# 6. Optimizing the DSPy Program (Optional) | |
optimizer = BootstrapFewShotWithRandomSearch( | |
metric=classification_metric, | |
max_bootstrapped_demos=3, | |
max_labeled_demos=6, | |
num_candidate_programs=6, | |
) | |
compiled_classifier = optimizer.compile( | |
student=difficulty_classifier, | |
trainset=dataset.train, | |
valset=dataset.dev[:100], # Use a subset for faster optimization | |
) | |
# Evaluate the optimized classifier | |
evaluator(compiled_classifier) | |
# Inspect the optimized prompts | |
print(compiled_classifier) | |
_ = turbo.inspect_history(n=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment