Skip to content

Instantly share code, notes, and snippets.

@thistleknot
Last active January 22, 2024 05:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thistleknot/b36a3daf6a31e2c4c2ac21803e575afd to your computer and use it in GitHub Desktop.
Save thistleknot/b36a3daf6a31e2c4c2ac21803e575afd to your computer and use it in GitHub Desktop.
Train Mamba
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
import wandb
from datasets import load_dataset
import torch
import os
import argparse
import numpy as np
import pandas as pd
from transformers import EvalPrediction
from torch.utils.data import DataLoader
from transformers import DataCollatorForLanguageModeling
os.environ["WANDB_MODE"] = "offline"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained('Q-bert/Mamba-130M', trust_remote_code=True)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained('Q-bert/Mamba-130M')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Move model to appropriate device
# Load dataset
dataset = load_dataset("Abirate/english_quotes", split='train')
max_size = 25
min_size = 10
# Preprocessing function to tokenize the 'quotes' field
def tokenize_function(examples):
return tokenizer(examples["quote"], padding="max_length", truncation=True, max_length=max_size)
# Apply the tokenizer to the dataset
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# Filter function to keep quotes between 5 and 25 tokens
def filter_quotes(batch):
# Calculate actual lengths for each example in the batch
actual_lengths = [sum(mask) for mask in batch["attention_mask"]]
# Determine which examples to keep based on their actual length
keep = [min_size <= length <= max_size for length in actual_lengths]
return keep
# Apply the filter to the dataset
filtered_dataset = tokenized_dataset.filter(filter_quotes, batched=True)
# Splitting the dataset into training and evaluation sets
split_dataset = filtered_dataset.train_test_split(test_size=0.1) # 10% for evaluation
parser = argparse.ArgumentParser()
parser.add_argument("--block_size", type=int, default=np.max([len(t) for t in filtered_dataset['input_ids']]))
args = parser.parse_args()
#1200 = 16GB & 130M
parser.add_argument("--target_tokens", type=int, default=1200)
args = parser.parse_args()
parser.add_argument("--batch_size", type=int, default=int(np.round(args.target_tokens/args.block_size)))
args = parser.parse_args()
parser.add_argument("--epochs", type=int, default=3)
args = parser.parse_args()
parser.add_argument("--gradient_steps", type=int, default=4)
args = parser.parse_args()
parser.add_argument("--epoch_iters", type=int, default=int(np.round((len(split_dataset['train'])*args.block_size)/(args.block_size*args.batch_size)/args.gradient_steps)))
parser.add_argument("--learning_rate", type=int, default=1e-4)
parser.add_argument("--weight_decay", type=int, default=0.1)
args = parser.parse_args()
print(len(filtered_dataset))
print(args.block_size)
print(args.batch_size)
print(args.epoch_iters)
# Define custom trainer
class MambaTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids)[0]
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
return lm_loss
# Training arguments with logging and evaluation strategies
training_args = TrainingArguments(
output_dir="./mamba_trainer_output",
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.epochs,
logging_strategy="steps", # Log training metrics at every step
logging_steps=1, # Log every step
evaluation_strategy="epoch", # Evaluate at the end of each e
save_steps=args.epoch_iters,
save_total_limit=2,
weight_decay=args.weight_decay,
learning_rate=args.learning_rate,
gradient_accumulation_steps=args.gradient_steps
)
# Initialize trainer
trainer = MambaTrainer(
model=model,
args=training_args,
train_dataset=split_dataset["train"],
eval_dataset=split_dataset["test"]
)
# Manually compute evaluation loss
# Create a dictionary containing the input data
input_data = {"input_ids": torch.tensor(split_dataset["test"]["input_ids"][0:4], dtype=torch.long).to(device)}
# Manually compute evaluation loss using compute_loss
eval_loss = trainer.compute_loss(model=trainer.model, inputs=input_data)
# Print the evaluation loss
print("Evaluation Loss:", eval_loss)
# Start training
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment