Skip to content

Instantly share code, notes, and snippets.

@krishnakalyan3
Created December 5, 2023 23:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krishnakalyan3/cb1b72628c4cdb7d3197f6c9a2468e3e to your computer and use it in GitHub Desktop.
Save krishnakalyan3/cb1b72628c4cdb7d3197f6c9a2468e3e to your computer and use it in GitHub Desktop.
Fine tune for regression
#!/usr/bin/env python3
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import Dataset, Audio
from transformers import AutoFeatureExtractor
from transformers import AutoModelForAudioClassification
from transformers import TrainingArguments
from transformers import Trainer
import torch
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score
import numpy as np
import evaluate
from audiomentations import Compose, RoomSimulator, TimeMask
import wandb
from datasets import load_from_disk
def compute_metrics_for_regression(eval_pred):
mse_metric = evaluate.load("mse")
print(eval_pred)
predictions, target = eval_pred
preds = np.squeeze(predictions)
result = mse_metric.compute(predictions=preds, references=target)
return {'mse': result}
if __name__ == "__main__":
os.environ["WANDB_PROJECT"] = "audio-classifier"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
# Initialize a new run
run = wandb.init()
model_id = "ntu-spml/distilhubert"
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_id, do_normalize=True, return_attention_mask=True
)
train_data_encoded = load_from_disk('/mnt/raid/krishna/data/oe/regression/train')
val_data_encoded = load_from_disk('/mnt/raid/krishna/data/oe/regression/val')
model = AutoModelForAudioClassification.from_pretrained(
model_id,
num_labels=1,
problem_type="regression"
)
model_name = model_id.split("/")[-1]
batch_size = 128
num_train_epochs = 40
training_args = TrainingArguments(
f"checkpoints/{model_name}-{run.name}",
evaluation_strategy="steps",
save_strategy="steps",
remove_unused_columns=True,
learning_rate=5e-5,
per_device_train_batch_size=batch_size,
num_train_epochs=num_train_epochs,
warmup_ratio=0.1,
logging_steps=10,
save_steps=10,
lr_scheduler_type="cosine_with_restarts",
do_predict=True,
report_to="wandb",
bf16=True,
)
trainer = Trainer(
model,
training_args,
train_dataset=train_data_encoded,
eval_dataset=val_data_encoded,
tokenizer=feature_extractor,
compute_metrics=compute_metrics_for_regression,
)
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment