Skip to content

Instantly share code, notes, and snippets.

@davidmezzetti
Created July 11, 2024 10:56
Show Gist options
  • Save davidmezzetti/a3c423b827db93103aeab5045b3d1711 to your computer and use it in GitHub Desktop.
Save davidmezzetti/a3c423b827db93103aeab5045b3d1711 to your computer and use it in GitHub Desktop.
import time
from datetime import timedelta
from datasets import load_dataset
from txtai import LLM
from txtai.pipeline import Labels, HFTrainer
def prompt(text):
text = f"""
Given the following text:
{text}
Analyze the sentiment and apply one of the following labels:
0-sadness
1-joy
2-love
3-anger
4-fear
5-surprise
Only return the label.
""".strip()
return [{"role": "user", "content": text}]
def train(path, epochs, lora):
# Train model
trainer = HFTrainer()
model, tokenizer = trainer(
path,
ds["train"],
lora=lora,
num_train_epochs=epochs,
per_device_train_batch_size=4 if lora else 8
)
# Print parameters
print(f"PARAMETERS: {model.num_parameters():,}")
# Merge adapter
if lora:
model = model.merge_and_unload()
return Labels((model, tokenizer), dynamic=False)
def test(model):
total, correct = 0, 0
for row in ds["test"]:
if isinstance(model, LLM):
result = model(prompt(row["text"]))[0]
else:
result = model(row["text"])[0][0]
total += 1
correct += 1 if row["label"] == int(result) else 0
print(f"ACCURACY: {correct} / {total} ({correct / total:.2f})")
def evaluate(path, epochs=3, lora=False):
start = time.time()
# Train model, if necessary
model = train(path, epochs, lora) if isinstance(path, str) else path
# Run model against test dataset
test(model)
print(f"ELAPSED: {timedelta(seconds=time.time() - start)}")
# Load dataset
ds = load_dataset("dair-ai/emotion")
# Zero-shot prompts with LLMs
evaluate(LLM("microsoft/Phi-3-mini-4k-instruct"))
# Evaluate encoder-only models
evaluate("google/bert_uncased_L-2_H-128_A-2", epochs=10)
evaluate("google-bert/bert-base-uncased", epochs=5)
# Fine-tune LLM for classification
evaluate("microsoft/Phi-3-mini-4k-instruct", epochs=3, lora=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment