Skip to content

Instantly share code, notes, and snippets.

@juliensimon
Last active April 16, 2024 21:27
Show Gist options
  • Save juliensimon/4c77ec7ed44587aaf5666fe95e0dbec2 to your computer and use it in GitHub Desktop.
Save juliensimon/4c77ec7ed44587aaf5666fe95e0dbec2 to your computer and use it in GitHub Desktop.
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer
dataset = load_dataset("yelp_polarity")
print(dataset)
# Select N examples per class (8 in this case)
train_ds = dataset["train"].shuffle(seed=42).select(range(8 * 2))
test_ds = dataset["test"]#.shuffle(seed=42).select(range(10000))
print(train_ds)
print(test_ds)
# Load SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
# Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=train_ds,
eval_dataset=test_ds,
metric="accuracy",
loss_class=CosineSimilarityLoss,
batch_size=16,
num_iterations=16, # Number of text pairs to generate for contrastive learning
num_epochs=1, # Number of epochs to use for contrastive learning,
)
# Train and evaluate!
trainer.train()
metrics = trainer.evaluate()
print(metrics)
# model._save_pretrained(save_directory)
# saved_model = SetFitModel._from_pretrained(save_directory)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment