Skip to content

Instantly share code, notes, and snippets.

@maldevide
Last active April 10, 2024 22:12
Show Gist options
  • Save maldevide/fc5d1d6edc03da4845c4485bfd52d61e to your computer and use it in GitHub Desktop.
Save maldevide/fc5d1d6edc03da4845c4485bfd52d61e to your computer and use it in GitHub Desktop.
ztrainer.py
import contextlib
import datasets
from datasets.combine import concatenate_datasets
import json
import os
import pandas as pd
from peft import LoftQConfig, PeftModel, PeftConfig
import random
import torch
from transformers import TrainingArguments
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import PreTrainedTokenizerBase, DataCollatorForLanguageModeling
from transformers.utils import logging
from trl import SFTTrainer
from trl.trainer.utils import ConstantLengthDataset
from typing import Literal, List, Union, Any, Dict, Optional
from unsloth import FastLanguageModel
import wandb
import yaml
@contextlib.contextmanager
def cuda_memory_profiler(display : str = True):
"""
A context manager for profiling CUDA memory usage in PyTorch.
"""
if display is False:
yield
return
if not torch.cuda.is_available():
print("CUDA is not available, skipping memory profiling")
yield
return
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
start_memory = torch.cuda.memory_allocated()
try:
yield
finally:
torch.cuda.synchronize()
end_memory = torch.cuda.memory_allocated()
print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / (1024 ** 2):.2f} MB")
print(f"Memory allocated at start: {start_memory / (1024 ** 2):.2f} MB")
print(f"Memory allocated at end: {end_memory / (1024 ** 2):.2f} MB")
print(f"Net memory change: {(end_memory - start_memory) / (1024 ** 2):.2f} MB")
REMOVE_LABEL = 0
class DataCollatorForPromptCompletion(DataCollatorForLanguageModeling):
"""
Data collator used for training models on prompt-completion tasks. It processes a batch of examples,
separating the prompt and completion parts based on the specified number of prompt turns. The loss
calculation is focused on the completion part, while the prompt part is ignored during training.
Args:
tokenizer (PreTrainedTokenizerBase): The tokenizer used for encoding the input examples.
prompt_turns (List[int]): A list specifying the number of prompt turns for each example in the batch.
turn_sep (str, optional): The separator used to split the input into turns. Defaults to "\n\n".
mlm (bool, optional): Whether to use masked language modeling. Defaults to False.
ignore_index (int, optional): The index to use for ignoring loss calculation. Defaults to -100.
**kwargs: Additional keyword arguments passed to the superclass constructor.
"""
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
prompt_turns: List[int],
max_sequence_len: int = 8192,
turn_sep: str = "\n\n",
mlm: bool = False,
ignore_index: int = -100,
**kwargs,
):
super().__init__(tokenizer=tokenizer, mlm=mlm, **kwargs)
self.prompt_turns = prompt_turns
self.max_sequence_len = max_sequence_len
self.turn_sep_token_ids = [13, 13] # Placeholder for the turn separator token IDs
self.ignore_index = ignore_index
self.pad_token_id = tokenizer.pad_token_id
def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
with cuda_memory_profiler():
modified_examples = []
for i, example in enumerate(examples):
input_ids = example['input_ids']
attention_mask = example['attention_mask']
# Truncate the input if it exceeds the maximum sequence length
if len(input_ids) > self.max_sequence_len:
input_ids = input_ids[:self.max_sequence_len]
attention_mask = attention_mask[:self.max_sequence_len]
nth_index = self.find_nth_occurrence(input_ids, self.turn_sep_token_ids, self.prompt_turns[i])
# This is really dumb, but I am passing my information down through attention. There 1000% has to be a better
# way to do this. I mean, at least don't use zero.
# If I try to do this another way, it gets mangled when it passes through accelerate...
if nth_index > -1:
attention_mask[:nth_index] = [REMOVE_LABEL] * nth_index
modified_example = {
'input_ids': input_ids,
'attention_mask': attention_mask,
}
modified_examples.append(modified_example)
return super().__call__(modified_examples)
def find_nth_occurrence(self, input_ids: List[int], pattern: List[int], n: int) -> int:
"""
Find the start index of the nth occurrence of a pattern in a sequence of input_ids.
Args:
input_ids (List[int]): The list of input IDs.
pattern (List[int]): The sequence of token IDs for the pattern to find.
n (int): The 1-based index of the occurrence to find.
Returns:
int: The start index of the nth occurrence, or -1 if not found.
"""
matches = 0
pattern_len = len(pattern)
for i in range(len(input_ids) - pattern_len + 1):
if input_ids[i:i+pattern_len] == pattern:
matches += 1
if matches == n:
return i
return -1
def torch_call(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Processes a batch of input examples to prepare them for training, focusing loss calculation on completion parts.
Args:
examples (List[Dict[str, Any]]): A batch of examples. Each example is expected to be a dictionary
with keys including "input_ids" and "attention_mask".
Returns:
Dict[str, Any]: A dictionary containing processed "input_ids", "attention_mask", and "labels" for the batch.
"labels" are set to match "input_ids" for the completion part and `ignore_index` for the
prompt part or padding.
"""
# We have to invoke our superclass to generate our labels
batch = super().torch_call(examples)
for i in range(len(batch)):
if i < len(batch['labels']):
break
labels = batch['labels'][i].detach()
attention = batch['attention_mask'][i].detach()
# Now we swipe our labels where we had passed our mask earlier
labels[attention == REMOVE_LABEL] = self.ignore_index
# And return attention to normal
attention[attention == REMOVE_LABEL] = 1
# Just to be thourough
labels[batch['input_ids'][i] == self.pad_token_id] = self.ignore_index
attention[batch['input_ids'][i] == self.pad_token_id] = 0
batch['attention_mask'][i] = attention
batch['labels'][i] = labels
# Ugh, memory leak?
if torch.cuda.is_available():
torch.cuda.empty_cache()
return {
'input_ids': batch['input_ids'],
'attention_mask': batch['attention_mask'],
'labels': batch['labels'],
}
def calculate_perplexity(model, tokenizer, text, output):
inputs = tokenizer(text, return_tensors="pt")
target = inputs.input_ids.clone()
target[:, :-len(tokenizer.encode(output))] = -100
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs, labels=target)
log_likelihood = outputs.loss * inputs["input_ids"].shape[1]
perplexity = torch.exp(log_likelihood / inputs["input_ids"].shape[1])
return perplexity.item()
def get_model(model_path : str, context_length : int, dtype : str, load_in_4bit : bool, device_map : str = "cuda", **kwargs) -> tuple[FastLanguageModel, torch.nn.Module]:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = model_path,
max_seq_length = context_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
device_map=device_map,
# token = "hf_...",
)
return model, tokenizer
def patch_model(model : FastLanguageModel, lora_rank : int, load_in_4bit : bool, target_layers : list[str], **kwargs) -> FastLanguageModel:
if load_in_4bit:
config = {
"loftq_config": None
}
else:
#loftq_config = LoftQConfig(loftq_bits=4)
config = {
#"init_lora_weights": "loftq",
#"loftq_config": loftq_config
"loftq_config": None
}
return FastLanguageModel.get_peft_model(
model,
r = lora_rank,
target_modules = target_layers,
lora_alpha = lora_rank,
lora_dropout = 0, # Unsloth only supports dropout = 0
bias = "none", # Unsloth only supports bias = "none"
#use_gradient_checkpointing = True,
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = True,
**config,
)
def get_trainer(model : FastLanguageModel, tokenizer : torch.nn.Module, context_length : int,
dataset_train : datasets.Dataset,
epochs : int, batch_size : int,
learning_rate : float, gradient_accumulation_steps : int,
dataset_eval : Optional[datasets.Dataset] = None,
optim : str = "adamw_8bit",
scheduler : str = "constant_with_warmup",
neftune_noise_alpha : int = 5,
save_steps : int = 30,
collator : Optional[DataCollatorForLanguageModeling] = None,
**kwargs) -> SFTTrainer:
if collator is None:
options = {
'dataset_text_field': 'text'
}
else:
options = {
'dataset_text_field': 'text',
'data_collator': collator
}
if dataset_eval is not None:
print('Using eval dataset')
return SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset_train,
eval_dataset = dataset_eval,
max_seq_length = context_length,
dataset_num_proc = 2,
packing = False, # Packs short sequences together to save time!
args = TrainingArguments(
per_device_train_batch_size = batch_size,
warmup_ratio = 0.05,
num_train_epochs = epochs,
learning_rate = learning_rate,
fp16 = not torch.cuda.is_bf16_supported(),
bf16 = torch.cuda.is_bf16_supported(),
logging_steps = 1,
do_eval=True if dataset_eval is not None else False,
eval_steps = 10,
optim = optim,
weight_decay = 0.1,
lr_scheduler_type = scheduler,
seed = 3407,
gradient_accumulation_steps = gradient_accumulation_steps,
gradient_checkpointing=True,
output_dir = "outputs",
save_strategy = "steps",
save_steps = save_steps,
neftune_noise_alpha=neftune_noise_alpha,
report_to="wandb",
),
**options
)
def upload_model(model : FastLanguageModel, model_name : str, token : str, **kwargs):
model.push_to_hub(name=model_name, token=token)
def run(config_file : str, resume_from_checkpoint : bool = False):
#os.environ["WANDB_DISABLED"] = "true"
os.environ["LD_LIBRARY_PATH"] = "/usr/local/cuda-12.1/lib64"
os.environ["PATH"] = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/root/.local/bin:/usr/local/cuda-12.1/bin"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
#tf.enable_eager_execution()
major_version, minor_version = torch.cuda.get_device_capability()
if major_version >= 8:
pass
else:
pass
gen_config = yaml.safe_load(open(config_file))
stage_config = gen_config['stages']['a']
model_path = gen_config['base_model']
dsc = gen_config['dataset']
dataset = datasets.load_from_disk(dsc['filepath'])
if 'testpath' in dsc:
testset = datasets.load_from_disk(dsc['testpath'])
else:
testset = None
model, tokenizer = get_model(model_path=model_path, **gen_config)
if 'type' in dsc and dsc['type'] == 'completion':
collator = DataCollatorForPromptCompletion(tokenizer=tokenizer, mlm=False, prompt_turns=dataset['prompt_turns'])
else:
collator = None
if "tokenizer_model" in gen_config:
print(f"Loading {gen_config['tokenizer_model']}...")
tokenizer = AutoTokenizer.from_pretrained(gen_config["tokenizer_model"])
#model.resize_token_embeddings(len(tokenizer))
with cuda_memory_profiler():
model = patch_model(model, **stage_config, **gen_config)
wandb.login(key=gen_config['wandb_key'])
wandb.init(
project=gen_config["wandb_project"],
config={
"learning_rate": stage_config["learning_rate"],
"architecture": gen_config["architecture"],
"dataset": dsc["name"],
"epochs": stage_config["epochs"],
}
)
with cuda_memory_profiler():
sft_trainer = get_trainer(model, tokenizer, dataset_train=dataset, dataset_eval=testset, low_memory=True, collator=collator, **gen_config, **stage_config)
with cuda_memory_profiler():
trainer_stats = sft_trainer.train(resume_from_checkpoint = resume_from_checkpoint)
wandb.finish()
model.save_pretrained(f"{gen_config['model_name']}-lora")
model.save_pretrained_merged(f"{gen_config['model_name']}", tokenizer=tokenizer, save_method = "merged_16bit")
if __name__ == "__main__":
logging.set_verbosity_info()
options = {
'config_file': "./configs/mist2-sword.yaml",
'resume_from_checkpoint': False
}
run(**options)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment