Created
April 24, 2024 20:37
-
-
Save sumandas0/0483db8514ea43e45cc5e5f5525914ab to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import Enum | |
import gc | |
import os | |
import torch | |
from datasets import DatasetDict, load_dataset, load_from_disk | |
from datasets.builder import DatasetGenerationError | |
from peft import LoraConfig, replace_lora_weights_loftq | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
) | |
current_mse = float("inf") | |
DEFAULT_CHATML_CHAT_TEMPLATE = "{% for message in messages %}\n{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% if loop.last and add_generation_prompt %}{{'<|im_start|>assistant\n' }}{% endif %}{% endfor %}" | |
DEFAULT_ZEPHYR_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" | |
class ZephyrSpecialTokens(str, Enum): | |
user = "<|user|>" | |
assistant = "<|assistant|>" | |
system = "<|system|>" | |
eos_token = "</s>" | |
bos_token = "<s>" | |
pad_token = "<pad>" | |
@classmethod | |
def list(cls): | |
return [c.value for c in cls] | |
class ChatmlSpecialTokens(str, Enum): | |
user = "<|im_start|>user" | |
assistant = "<|im_start|>assistant" | |
system = "<|im_start|>system" | |
eos_token = "<|im_end|>" | |
bos_token = "<s>" | |
pad_token = "<pad>" | |
@classmethod | |
def list(cls): | |
return [c.value for c in cls] | |
def create_datasets(tokenizer, data_args, training_args, apply_chat_template=False): | |
def preprocess(samples): | |
batch = [] | |
for conversation in samples["messages"]: | |
batch.append(tokenizer.apply_chat_template(conversation, tokenize=False)) | |
return {"content": batch} | |
raw_datasets = DatasetDict() | |
for split in data_args.splits.split(","): | |
try: | |
# Try first if dataset on a Hub repo | |
dataset = load_dataset(data_args.dataset_name, split=split) | |
except DatasetGenerationError: | |
# If not, check local dataset | |
dataset = load_from_disk(os.path.join(data_args.dataset_name, split)) | |
if "train" in split: | |
raw_datasets["train"] = dataset | |
elif "test" in split: | |
raw_datasets["test"] = dataset | |
else: | |
raise ValueError( | |
f"Split type {split} not recognized as one of test or train." | |
) | |
if apply_chat_template: | |
raw_datasets = raw_datasets.map( | |
preprocess, | |
batched=True, | |
remove_columns=raw_datasets["train"].column_names, | |
) | |
train_data = raw_datasets["train"] | |
valid_data = raw_datasets["test"] | |
print( | |
f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}" | |
) | |
print(f"A sample of train dataset: {train_data[0]}") | |
return train_data, valid_data | |
def create_and_prepare_model(args, data_args, training_args): | |
if args.use_unsloth: | |
from unsloth import FastLanguageModel | |
bnb_config = None | |
quant_storage_stype = None | |
load_in_8bit = args.use_8bit_qunatization | |
load_in_4bit = args.use_4bit_quantization | |
if ( | |
torch.distributed.is_available() | |
and torch.distributed.is_initialized() | |
and torch.distributed.get_world_size() > 1 | |
and args.use_unsloth | |
): | |
raise NotImplementedError("Unsloth is not supported in distributed training") | |
if args.use_4bit_quantization: | |
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype) | |
quant_storage_stype = getattr(torch, args.bnb_4bit_quant_storage_dtype) | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=args.use_4bit_quantization, | |
bnb_4bit_quant_type=args.bnb_4bit_quant_type, | |
bnb_4bit_compute_dtype=compute_dtype, | |
bnb_4bit_use_double_quant=args.use_nested_quant, | |
bnb_4bit_quant_storage=quant_storage_stype, | |
) | |
if compute_dtype == torch.float16 and args.use_4bit_quantization: | |
major, _ = torch.cuda.get_device_capability() | |
if major >= 8: | |
print("=" * 80) | |
print( | |
"Your GPU supports bfloat16, you can accelerate training with the argument --bf16" | |
) | |
print("=" * 80) | |
if args.use_unsloth: | |
# Load model | |
model, _ = FastLanguageModel.from_pretrained( | |
model_name=args.model_name_or_path, | |
max_seq_length=data_args.max_seq_length, | |
dtype=None, | |
load_in_4bit=load_in_4bit, | |
) | |
else: | |
torch_dtype = quant_storage_stype if quant_storage_stype and quant_storage_stype.is_floating_point else torch.bfloat16 | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_name_or_path, | |
load_in_8bit=load_in_8bit, | |
quantization_config=bnb_config, | |
trust_remote_code=True, | |
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager", | |
torch_dtype=torch_dtype, | |
) | |
peft_config = None | |
chat_template = None | |
if args.use_peft_lora and not args.use_unsloth: | |
peft_config = LoraConfig( | |
lora_alpha=args.lora_alpha, | |
lora_dropout=args.lora_dropout, | |
r=args.lora_r, | |
bias="none", | |
task_type="CAUSAL_LM", | |
target_modules=args.lora_target_modules.split(",") | |
if args.lora_target_modules != "all-linear" | |
else args.lora_target_modules, | |
) | |
special_tokens = None | |
chat_template = None | |
if args.chat_template_format == "chatml": | |
special_tokens = ChatmlSpecialTokens | |
chat_template = DEFAULT_CHATML_CHAT_TEMPLATE | |
elif args.chat_template_format == "zephyr": | |
special_tokens = ZephyrSpecialTokens | |
chat_template = DEFAULT_ZEPHYR_CHAT_TEMPLATE | |
if special_tokens is not None: | |
tokenizer = AutoTokenizer.from_pretrained( | |
args.model_name_or_path, | |
pad_token=special_tokens.pad_token.value, | |
bos_token=special_tokens.bos_token.value, | |
eos_token=special_tokens.eos_token.value, | |
additional_special_tokens=special_tokens.list(), | |
trust_remote_code=True, | |
) | |
tokenizer.chat_template = chat_template | |
# make embedding resizing configurable? | |
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained( | |
args.model_name_or_path, trust_remote_code=True | |
) | |
tokenizer.pad_token = tokenizer.eos_token | |
if args.use_unsloth: | |
# Do model patching and add fast LoRA weights | |
model = FastLanguageModel.get_peft_model( | |
model, | |
lora_alpha=args.lora_alpha, | |
lora_dropout=args.lora_dropout, | |
r=args.lora_r, | |
target_modules=args.lora_target_modules.split(",") | |
if args.lora_target_modules != "all-linear" | |
else args.lora_target_modules, | |
use_gradient_checkpointing=training_args.gradient_checkpointing, | |
random_state=training_args.seed, | |
max_seq_length=data_args.max_seq_length, | |
) | |
return model, peft_config, tokenizer | |
def get_mae(x, y): | |
return (x - y).abs().mean() | |
def get_mse(x, y): | |
return torch.pow(x - y, 2).mean() | |
def error_report(x, y): | |
mae = get_mae(x, y) | |
mse = get_mse(x, y) | |
print( | |
f"Mean absolute error: {mae:>8.5f}\n" | |
f"Mean squared error: {mse:>8.5f}" | |
) | |
def loftq_init(model, tokenizer, train_dataset, max_seq_length, args): | |
if args.use_loftq_callback: | |
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype) | |
base_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=compute_dtype) | |
base_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) | |
random_input_ids = torch.randint(0, len(train_dataset), size=(1,)).numpy().tolist() | |
random_inputs = [train_dataset[i]['content'] for i in random_input_ids] | |
random_inputs = tokenizer(random_inputs, return_tensors="pt", padding=True, truncation="max_length", max_length=max_seq_length) | |
logits_base = base_model(**random_inputs).logits | |
del base_model | |
gc.collect() | |
def loftq_callback(model, module_name): | |
"""Callable to replace weights with LoFTQ if the mse is lower than the current best one.""" | |
global current_mse | |
logits = model(**random_inputs).logits | |
mse = get_mse(logits_base, logits) | |
if mse < current_mse: | |
current_mse = mse | |
print(f"MSE improved for module {module_name}") | |
return True | |
print(f"MSE did not improve for module {module_name}") | |
return False | |
replace_lora_weights_loftq(model, callback=loftq_callback) | |
logits_loftq_callback = model(**random_inputs).logits | |
error_report(logits_base, logits_loftq_callback) | |
else: | |
replace_lora_weights_loftq(model) | |
def get_module_class_from_name(module, name): | |
""" | |
Gets a class from a module by its name. | |
Args: | |
module (`torch.nn.Module`): The module to get the class from. | |
name (`str`): The name of the class. | |
""" | |
modules_children = list(module.children()) | |
if module.__class__.__name__ == name: | |
return module.__class__ | |
elif len(modules_children) == 0: | |
return | |
else: | |
for child_module in modules_children: | |
module_class = get_module_class_from_name(child_module, name) | |
if module_class is not None: | |
return module_class |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import Enum | |
import gc | |
import os | |
import torch | |
from datasets import DatasetDict, load_dataset, load_from_disk | |
from datasets.builder import DatasetGenerationError | |
from peft import LoraConfig, replace_lora_weights_loftq | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
) | |
current_mse = float("inf") | |
DEFAULT_CHATML_CHAT_TEMPLATE = "{% for message in messages %}\n{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% if loop.last and add_generation_prompt %}{{'<|im_start|>assistant\n' }}{% endif %}{% endfor %}" | |
DEFAULT_ZEPHYR_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" | |
class ZephyrSpecialTokens(str, Enum): | |
user = "<|user|>" | |
assistant = "<|assistant|>" | |
system = "<|system|>" | |
eos_token = "</s>" | |
bos_token = "<s>" | |
pad_token = "<pad>" | |
@classmethod | |
def list(cls): | |
return [c.value for c in cls] | |
class ChatmlSpecialTokens(str, Enum): | |
user = "<|im_start|>user" | |
assistant = "<|im_start|>assistant" | |
system = "<|im_start|>system" | |
eos_token = "<|im_end|>" | |
bos_token = "<s>" | |
pad_token = "<pad>" | |
@classmethod | |
def list(cls): | |
return [c.value for c in cls] | |
def create_datasets(tokenizer, data_args, training_args, apply_chat_template=False): | |
def preprocess(samples): | |
batch = [] | |
for conversation in samples["messages"]: | |
batch.append(tokenizer.apply_chat_template(conversation, tokenize=False)) | |
return {"content": batch} | |
raw_datasets = DatasetDict() | |
for split in data_args.splits.split(","): | |
try: | |
# Try first if dataset on a Hub repo | |
dataset = load_dataset(data_args.dataset_name, split=split) | |
except DatasetGenerationError: | |
# If not, check local dataset | |
dataset = load_from_disk(os.path.join(data_args.dataset_name, split)) | |
if "train" in split: | |
raw_datasets["train"] = dataset | |
elif "test" in split: | |
raw_datasets["test"] = dataset | |
else: | |
raise ValueError( | |
f"Split type {split} not recognized as one of test or train." | |
) | |
if apply_chat_template: | |
raw_datasets = raw_datasets.map( | |
preprocess, | |
batched=True, | |
remove_columns=raw_datasets["train"].column_names, | |
) | |
train_data = raw_datasets["train"] | |
valid_data = raw_datasets["test"] | |
print( | |
f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}" | |
) | |
print(f"A sample of train dataset: {train_data[0]}") | |
return train_data, valid_data | |
def create_and_prepare_model(args, data_args, training_args): | |
if args.use_unsloth: | |
from unsloth import FastLanguageModel | |
bnb_config = None | |
quant_storage_stype = None | |
load_in_8bit = args.use_8bit_qunatization | |
load_in_4bit = args.use_4bit_quantization | |
if ( | |
torch.distributed.is_available() | |
and torch.distributed.is_initialized() | |
and torch.distributed.get_world_size() > 1 | |
and args.use_unsloth | |
): | |
raise NotImplementedError("Unsloth is not supported in distributed training") | |
if args.use_4bit_quantization: | |
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype) | |
quant_storage_stype = getattr(torch, args.bnb_4bit_quant_storage_dtype) | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=args.use_4bit_quantization, | |
bnb_4bit_quant_type=args.bnb_4bit_quant_type, | |
bnb_4bit_compute_dtype=compute_dtype, | |
bnb_4bit_use_double_quant=args.use_nested_quant, | |
bnb_4bit_quant_storage=quant_storage_stype, | |
) | |
if compute_dtype == torch.float16 and args.use_4bit_quantization: | |
major, _ = torch.cuda.get_device_capability() | |
if major >= 8: | |
print("=" * 80) | |
print( | |
"Your GPU supports bfloat16, you can accelerate training with the argument --bf16" | |
) | |
print("=" * 80) | |
if args.use_unsloth: | |
# Load model | |
model, _ = FastLanguageModel.from_pretrained( | |
model_name=args.model_name_or_path, | |
max_seq_length=data_args.max_seq_length, | |
dtype=None, | |
load_in_4bit=load_in_4bit, | |
) | |
else: | |
torch_dtype = quant_storage_stype if quant_storage_stype and quant_storage_stype.is_floating_point else torch.bfloat16 | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_name_or_path, | |
load_in_8bit=load_in_8bit, | |
quantization_config=bnb_config, | |
trust_remote_code=True, | |
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager", | |
torch_dtype=torch_dtype, | |
) | |
peft_config = None | |
chat_template = None | |
if args.use_peft_lora and not args.use_unsloth: | |
peft_config = LoraConfig( | |
lora_alpha=args.lora_alpha, | |
lora_dropout=args.lora_dropout, | |
r=args.lora_r, | |
bias="none", | |
task_type="CAUSAL_LM", | |
target_modules=args.lora_target_modules.split(",") | |
if args.lora_target_modules != "all-linear" | |
else args.lora_target_modules, | |
) | |
special_tokens = None | |
chat_template = None | |
if args.chat_template_format == "chatml": | |
special_tokens = ChatmlSpecialTokens | |
chat_template = DEFAULT_CHATML_CHAT_TEMPLATE | |
elif args.chat_template_format == "zephyr": | |
special_tokens = ZephyrSpecialTokens | |
chat_template = DEFAULT_ZEPHYR_CHAT_TEMPLATE | |
if special_tokens is not None: | |
tokenizer = AutoTokenizer.from_pretrained( | |
args.model_name_or_path, | |
pad_token=special_tokens.pad_token.value, | |
bos_token=special_tokens.bos_token.value, | |
eos_token=special_tokens.eos_token.value, | |
additional_special_tokens=special_tokens.list(), | |
trust_remote_code=True, | |
) | |
tokenizer.chat_template = chat_template | |
# make embedding resizing configurable? | |
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained( | |
args.model_name_or_path, trust_remote_code=True | |
) | |
tokenizer.pad_token = tokenizer.eos_token | |
if args.use_unsloth: | |
# Do model patching and add fast LoRA weights | |
model = FastLanguageModel.get_peft_model( | |
model, | |
lora_alpha=args.lora_alpha, | |
lora_dropout=args.lora_dropout, | |
r=args.lora_r, | |
target_modules=args.lora_target_modules.split(",") | |
if args.lora_target_modules != "all-linear" | |
else args.lora_target_modules, | |
use_gradient_checkpointing=training_args.gradient_checkpointing, | |
random_state=training_args.seed, | |
max_seq_length=data_args.max_seq_length, | |
) | |
return model, peft_config, tokenizer | |
def get_mae(x, y): | |
return (x - y).abs().mean() | |
def get_mse(x, y): | |
return torch.pow(x - y, 2).mean() | |
def error_report(x, y): | |
mae = get_mae(x, y) | |
mse = get_mse(x, y) | |
print( | |
f"Mean absolute error: {mae:>8.5f}\n" | |
f"Mean squared error: {mse:>8.5f}" | |
) | |
def loftq_init(model, tokenizer, train_dataset, max_seq_length, args): | |
if args.use_loftq_callback: | |
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype) | |
base_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=compute_dtype) | |
base_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) | |
random_input_ids = torch.randint(0, len(train_dataset), size=(1,)).numpy().tolist() | |
random_inputs = [train_dataset[i]['content'] for i in random_input_ids] | |
random_inputs = tokenizer(random_inputs, return_tensors="pt", padding=True, truncation="max_length", max_length=max_seq_length) | |
logits_base = base_model(**random_inputs).logits | |
del base_model | |
gc.collect() | |
def loftq_callback(model, module_name): | |
"""Callable to replace weights with LoFTQ if the mse is lower than the current best one.""" | |
global current_mse | |
logits = model(**random_inputs).logits | |
mse = get_mse(logits_base, logits) | |
if mse < current_mse: | |
current_mse = mse | |
print(f"MSE improved for module {module_name}") | |
return True | |
print(f"MSE did not improve for module {module_name}") | |
return False | |
replace_lora_weights_loftq(model, callback=loftq_callback) | |
logits_loftq_callback = model(**random_inputs).logits | |
error_report(logits_base, logits_loftq_callback) | |
else: | |
replace_lora_weights_loftq(model) | |
def get_module_class_from_name(module, name): | |
""" | |
Gets a class from a module by its name. | |
Args: | |
module (`torch.nn.Module`): The module to get the class from. | |
name (`str`): The name of the class. | |
""" | |
modules_children = list(module.children()) | |
if module.__class__.__name__ == name: | |
return module.__class__ | |
elif len(modules_children) == 0: | |
return | |
else: | |
for child_module in modules_children: | |
module_class = get_module_class_from_name(child_module, name) | |
if module_class is not None: | |
return module_class |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment