Skip to content

Instantly share code, notes, and snippets.

@sumandas0
Created April 24, 2024 20:37
Show Gist options
  • Save sumandas0/0483db8514ea43e45cc5e5f5525914ab to your computer and use it in GitHub Desktop.
Save sumandas0/0483db8514ea43e45cc5e5f5525914ab to your computer and use it in GitHub Desktop.
from enum import Enum
import gc
import os
import torch
from datasets import DatasetDict, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError
from peft import LoraConfig, replace_lora_weights_loftq
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
current_mse = float("inf")
DEFAULT_CHATML_CHAT_TEMPLATE = "{% for message in messages %}\n{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% if loop.last and add_generation_prompt %}{{'<|im_start|>assistant\n' }}{% endif %}{% endfor %}"
DEFAULT_ZEPHYR_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
class ZephyrSpecialTokens(str, Enum):
user = "<|user|>"
assistant = "<|assistant|>"
system = "<|system|>"
eos_token = "</s>"
bos_token = "<s>"
pad_token = "<pad>"
@classmethod
def list(cls):
return [c.value for c in cls]
class ChatmlSpecialTokens(str, Enum):
user = "<|im_start|>user"
assistant = "<|im_start|>assistant"
system = "<|im_start|>system"
eos_token = "<|im_end|>"
bos_token = "<s>"
pad_token = "<pad>"
@classmethod
def list(cls):
return [c.value for c in cls]
def create_datasets(tokenizer, data_args, training_args, apply_chat_template=False):
def preprocess(samples):
batch = []
for conversation in samples["messages"]:
batch.append(tokenizer.apply_chat_template(conversation, tokenize=False))
return {"content": batch}
raw_datasets = DatasetDict()
for split in data_args.splits.split(","):
try:
# Try first if dataset on a Hub repo
dataset = load_dataset(data_args.dataset_name, split=split)
except DatasetGenerationError:
# If not, check local dataset
dataset = load_from_disk(os.path.join(data_args.dataset_name, split))
if "train" in split:
raw_datasets["train"] = dataset
elif "test" in split:
raw_datasets["test"] = dataset
else:
raise ValueError(
f"Split type {split} not recognized as one of test or train."
)
if apply_chat_template:
raw_datasets = raw_datasets.map(
preprocess,
batched=True,
remove_columns=raw_datasets["train"].column_names,
)
train_data = raw_datasets["train"]
valid_data = raw_datasets["test"]
print(
f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}"
)
print(f"A sample of train dataset: {train_data[0]}")
return train_data, valid_data
def create_and_prepare_model(args, data_args, training_args):
if args.use_unsloth:
from unsloth import FastLanguageModel
bnb_config = None
quant_storage_stype = None
load_in_8bit = args.use_8bit_qunatization
load_in_4bit = args.use_4bit_quantization
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and torch.distributed.get_world_size() > 1
and args.use_unsloth
):
raise NotImplementedError("Unsloth is not supported in distributed training")
if args.use_4bit_quantization:
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
quant_storage_stype = getattr(torch, args.bnb_4bit_quant_storage_dtype)
bnb_config = BitsAndBytesConfig(
load_in_4bit=args.use_4bit_quantization,
bnb_4bit_quant_type=args.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=args.use_nested_quant,
bnb_4bit_quant_storage=quant_storage_stype,
)
if compute_dtype == torch.float16 and args.use_4bit_quantization:
major, _ = torch.cuda.get_device_capability()
if major >= 8:
print("=" * 80)
print(
"Your GPU supports bfloat16, you can accelerate training with the argument --bf16"
)
print("=" * 80)
if args.use_unsloth:
# Load model
model, _ = FastLanguageModel.from_pretrained(
model_name=args.model_name_or_path,
max_seq_length=data_args.max_seq_length,
dtype=None,
load_in_4bit=load_in_4bit,
)
else:
torch_dtype = quant_storage_stype if quant_storage_stype and quant_storage_stype.is_floating_point else torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
load_in_8bit=load_in_8bit,
quantization_config=bnb_config,
trust_remote_code=True,
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
torch_dtype=torch_dtype,
)
peft_config = None
chat_template = None
if args.use_peft_lora and not args.use_unsloth:
peft_config = LoraConfig(
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=args.lora_target_modules.split(",")
if args.lora_target_modules != "all-linear"
else args.lora_target_modules,
)
special_tokens = None
chat_template = None
if args.chat_template_format == "chatml":
special_tokens = ChatmlSpecialTokens
chat_template = DEFAULT_CHATML_CHAT_TEMPLATE
elif args.chat_template_format == "zephyr":
special_tokens = ZephyrSpecialTokens
chat_template = DEFAULT_ZEPHYR_CHAT_TEMPLATE
if special_tokens is not None:
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
pad_token=special_tokens.pad_token.value,
bos_token=special_tokens.bos_token.value,
eos_token=special_tokens.eos_token.value,
additional_special_tokens=special_tokens.list(),
trust_remote_code=True,
)
tokenizer.chat_template = chat_template
# make embedding resizing configurable?
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
else:
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path, trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
if args.use_unsloth:
# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
model,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
target_modules=args.lora_target_modules.split(",")
if args.lora_target_modules != "all-linear"
else args.lora_target_modules,
use_gradient_checkpointing=training_args.gradient_checkpointing,
random_state=training_args.seed,
max_seq_length=data_args.max_seq_length,
)
return model, peft_config, tokenizer
def get_mae(x, y):
return (x - y).abs().mean()
def get_mse(x, y):
return torch.pow(x - y, 2).mean()
def error_report(x, y):
mae = get_mae(x, y)
mse = get_mse(x, y)
print(
f"Mean absolute error: {mae:>8.5f}\n"
f"Mean squared error: {mse:>8.5f}"
)
def loftq_init(model, tokenizer, train_dataset, max_seq_length, args):
if args.use_loftq_callback:
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
base_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=compute_dtype)
base_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
random_input_ids = torch.randint(0, len(train_dataset), size=(1,)).numpy().tolist()
random_inputs = [train_dataset[i]['content'] for i in random_input_ids]
random_inputs = tokenizer(random_inputs, return_tensors="pt", padding=True, truncation="max_length", max_length=max_seq_length)
logits_base = base_model(**random_inputs).logits
del base_model
gc.collect()
def loftq_callback(model, module_name):
"""Callable to replace weights with LoFTQ if the mse is lower than the current best one."""
global current_mse
logits = model(**random_inputs).logits
mse = get_mse(logits_base, logits)
if mse < current_mse:
current_mse = mse
print(f"MSE improved for module {module_name}")
return True
print(f"MSE did not improve for module {module_name}")
return False
replace_lora_weights_loftq(model, callback=loftq_callback)
logits_loftq_callback = model(**random_inputs).logits
error_report(logits_base, logits_loftq_callback)
else:
replace_lora_weights_loftq(model)
def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name.
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
elif len(modules_children) == 0:
return
else:
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
from enum import Enum
import gc
import os
import torch
from datasets import DatasetDict, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError
from peft import LoraConfig, replace_lora_weights_loftq
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
current_mse = float("inf")
DEFAULT_CHATML_CHAT_TEMPLATE = "{% for message in messages %}\n{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% if loop.last and add_generation_prompt %}{{'<|im_start|>assistant\n' }}{% endif %}{% endfor %}"
DEFAULT_ZEPHYR_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
class ZephyrSpecialTokens(str, Enum):
user = "<|user|>"
assistant = "<|assistant|>"
system = "<|system|>"
eos_token = "</s>"
bos_token = "<s>"
pad_token = "<pad>"
@classmethod
def list(cls):
return [c.value for c in cls]
class ChatmlSpecialTokens(str, Enum):
user = "<|im_start|>user"
assistant = "<|im_start|>assistant"
system = "<|im_start|>system"
eos_token = "<|im_end|>"
bos_token = "<s>"
pad_token = "<pad>"
@classmethod
def list(cls):
return [c.value for c in cls]
def create_datasets(tokenizer, data_args, training_args, apply_chat_template=False):
def preprocess(samples):
batch = []
for conversation in samples["messages"]:
batch.append(tokenizer.apply_chat_template(conversation, tokenize=False))
return {"content": batch}
raw_datasets = DatasetDict()
for split in data_args.splits.split(","):
try:
# Try first if dataset on a Hub repo
dataset = load_dataset(data_args.dataset_name, split=split)
except DatasetGenerationError:
# If not, check local dataset
dataset = load_from_disk(os.path.join(data_args.dataset_name, split))
if "train" in split:
raw_datasets["train"] = dataset
elif "test" in split:
raw_datasets["test"] = dataset
else:
raise ValueError(
f"Split type {split} not recognized as one of test or train."
)
if apply_chat_template:
raw_datasets = raw_datasets.map(
preprocess,
batched=True,
remove_columns=raw_datasets["train"].column_names,
)
train_data = raw_datasets["train"]
valid_data = raw_datasets["test"]
print(
f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}"
)
print(f"A sample of train dataset: {train_data[0]}")
return train_data, valid_data
def create_and_prepare_model(args, data_args, training_args):
if args.use_unsloth:
from unsloth import FastLanguageModel
bnb_config = None
quant_storage_stype = None
load_in_8bit = args.use_8bit_qunatization
load_in_4bit = args.use_4bit_quantization
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and torch.distributed.get_world_size() > 1
and args.use_unsloth
):
raise NotImplementedError("Unsloth is not supported in distributed training")
if args.use_4bit_quantization:
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
quant_storage_stype = getattr(torch, args.bnb_4bit_quant_storage_dtype)
bnb_config = BitsAndBytesConfig(
load_in_4bit=args.use_4bit_quantization,
bnb_4bit_quant_type=args.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=args.use_nested_quant,
bnb_4bit_quant_storage=quant_storage_stype,
)
if compute_dtype == torch.float16 and args.use_4bit_quantization:
major, _ = torch.cuda.get_device_capability()
if major >= 8:
print("=" * 80)
print(
"Your GPU supports bfloat16, you can accelerate training with the argument --bf16"
)
print("=" * 80)
if args.use_unsloth:
# Load model
model, _ = FastLanguageModel.from_pretrained(
model_name=args.model_name_or_path,
max_seq_length=data_args.max_seq_length,
dtype=None,
load_in_4bit=load_in_4bit,
)
else:
torch_dtype = quant_storage_stype if quant_storage_stype and quant_storage_stype.is_floating_point else torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
load_in_8bit=load_in_8bit,
quantization_config=bnb_config,
trust_remote_code=True,
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
torch_dtype=torch_dtype,
)
peft_config = None
chat_template = None
if args.use_peft_lora and not args.use_unsloth:
peft_config = LoraConfig(
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=args.lora_target_modules.split(",")
if args.lora_target_modules != "all-linear"
else args.lora_target_modules,
)
special_tokens = None
chat_template = None
if args.chat_template_format == "chatml":
special_tokens = ChatmlSpecialTokens
chat_template = DEFAULT_CHATML_CHAT_TEMPLATE
elif args.chat_template_format == "zephyr":
special_tokens = ZephyrSpecialTokens
chat_template = DEFAULT_ZEPHYR_CHAT_TEMPLATE
if special_tokens is not None:
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
pad_token=special_tokens.pad_token.value,
bos_token=special_tokens.bos_token.value,
eos_token=special_tokens.eos_token.value,
additional_special_tokens=special_tokens.list(),
trust_remote_code=True,
)
tokenizer.chat_template = chat_template
# make embedding resizing configurable?
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
else:
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path, trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
if args.use_unsloth:
# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
model,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
target_modules=args.lora_target_modules.split(",")
if args.lora_target_modules != "all-linear"
else args.lora_target_modules,
use_gradient_checkpointing=training_args.gradient_checkpointing,
random_state=training_args.seed,
max_seq_length=data_args.max_seq_length,
)
return model, peft_config, tokenizer
def get_mae(x, y):
return (x - y).abs().mean()
def get_mse(x, y):
return torch.pow(x - y, 2).mean()
def error_report(x, y):
mae = get_mae(x, y)
mse = get_mse(x, y)
print(
f"Mean absolute error: {mae:>8.5f}\n"
f"Mean squared error: {mse:>8.5f}"
)
def loftq_init(model, tokenizer, train_dataset, max_seq_length, args):
if args.use_loftq_callback:
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
base_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=compute_dtype)
base_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
random_input_ids = torch.randint(0, len(train_dataset), size=(1,)).numpy().tolist()
random_inputs = [train_dataset[i]['content'] for i in random_input_ids]
random_inputs = tokenizer(random_inputs, return_tensors="pt", padding=True, truncation="max_length", max_length=max_seq_length)
logits_base = base_model(**random_inputs).logits
del base_model
gc.collect()
def loftq_callback(model, module_name):
"""Callable to replace weights with LoFTQ if the mse is lower than the current best one."""
global current_mse
logits = model(**random_inputs).logits
mse = get_mse(logits_base, logits)
if mse < current_mse:
current_mse = mse
print(f"MSE improved for module {module_name}")
return True
print(f"MSE did not improve for module {module_name}")
return False
replace_lora_weights_loftq(model, callback=loftq_callback)
logits_loftq_callback = model(**random_inputs).logits
error_report(logits_base, logits_loftq_callback)
else:
replace_lora_weights_loftq(model)
def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name.
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
elif len(modules_children) == 0:
return
else:
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment