Skip to content

Instantly share code, notes, and snippets.

@rohan-paul
Created August 31, 2023 11:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rohan-paul/e936889f7ebb73625fdac54cfb1a1c0b to your computer and use it in GitHub Desktop.
Save rohan-paul/e936889f7ebb73625fdac54cfb1a1c0b to your computer and use it in GitHub Desktop.
import argparse
import bitsandbytes as bnb
from datasets import load_dataset
from functools import partial
import os
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, AutoPeftModelForCausalLM
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, Trainer, TrainingArguments, BitsAndBytesConfig, \
DataCollatorForLanguageModeling, Trainer, TrainingArguments
from datasets import load_dataset
import random
import pandas as pd
from datasets import load_dataset
seed = 42
set_seed(seed)
def load_model(model_name, bnb_config):
n_gpus = torch.cuda.device_count()
max_memory = f'{40960}MB'
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
max_memory = {i: max_memory for i in range(n_gpus)},
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
print(f'Number of prompts: {len(dataset)}')
print(f'Column names are: {dataset.column_names}')
nb_samples = 3
random_indices = random.sample(range(len(dataset)), nb_samples)
samples = []
for idx in random_indices:
sample = dataset[idx]
sample_data = {
'instruction': sample['instruction'],
'context': sample['context'],
'response': sample['response'],
'category': sample['category']
}
samples.append(sample_data)
df = pd.DataFrame(samples)
def create_prompt_formats(sample):
"""
Format various fields of the sample ('instruction', 'context', 'response')
Then concatenate them using two newline characters
:param sample: Sample dictionnary
"""
INTRO_BLURB = "Below is an instruction for task. Write a response that appropriately completes the request."
INSTRUCTION_KEY = "
INPUT_KEY = "
RESPONSE_KEY = "
END_KEY = "
blurb = f"{INTRO_BLURB}"
instruction = f"{INSTRUCTION_KEY}\n{sample['instruction']}"
input_context = f"{INPUT_KEY}\n{sample['context']}" if sample["context"] else None
response = f"{RESPONSE_KEY}\n{sample['response']}"
end = f"{END_KEY}"
parts = [part for part in [blurb, instruction, input_context, response, end] if part]
formatted_prompt = "\n\n".join(parts)
sample["text"] = formatted_prompt
return sample
print(create_prompt_formats(sample)["text"])
def get_max_length(model):
conf = model.config
max_length = None
for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
max_length = getattr(model.config, length_setting, None)
if max_length:
print(f"Found max lenth: {max_length}")
break
if not max_length:
max_length = 1024
print(f"Using default max length: {max_length}")
return max_length
def preprocess_batch(batch, tokenizer, max_length):
"""
Tokenizing a batch
"""
return tokenizer(
batch["text"],
max_length=max_length,
truncation=True,
)
def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, seed, dataset: str):
"""Format & tokenize it so it is ready for training
:param tokenizer (AutoTokenizer): Model Tokenizer
:param max_length (int): Maximum number of tokens to emit from tokenizer
"""
print("Preprocessing dataset...")
dataset = dataset.map(create_prompt_formats)
_preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer)
dataset = dataset.map(
_preprocessing_function,
batched=True,
remove_columns=["instruction", "context", "response", "text", "category"],
)
dataset = dataset.filter(lambda sample: len(sample["input_ids"]) < max_length)
dataset = dataset.shuffle(seed=seed)
return dataset
def create_bnb_config():
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
return bnb_config
def create_peft_config(modules):
"""
Create Parameter-Efficient Fine-Tuning config for your model
:param modules: Names of the modules to apply Lora to
"""
config = LoraConfig(
r=16,
lora_alpha=64,
target_modules=modules,
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
)
return config
def find_all_linear_names(model):
cls = bnb.nn.Linear4bit
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names:
lora_module_names.remove('lm_head')
return list(lora_module_names)
def print_trainable_parameters(model, use_4bit=False):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
num_params = param.numel()
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
all_param += num_params
if param.requires_grad:
trainable_params += num_params
if use_4bit:
trainable_params /= 2
print(
f"all params: {all_param:,d} || trainable params: {trainable_params:,d} || trainable%: {100 * trainable_params / all_param}"
)
model_name = "meta-llama/Llama-2-7b-hf"
bnb_config = create_bnb_config()
model, tokenizer = load_model(model_name, bnb_config)
max_length = get_max_length(model)
dataset = preprocess_dataset(tokenizer, max_length, seed, dataset)
def train(model, tokenizer, dataset, output_dir):
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
modules = find_all_linear_names(model)
peft_config = create_peft_config(modules)
model = get_peft_model(model, peft_config)
print_trainable_parameters(model)
trainer = Trainer(
model=model,
train_dataset=dataset,
args=TrainingArguments(
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=15,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir="outputs",
optim="paged_adamw_8bit",
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
model.config.use_cache = False
dtypes = {}
for _, p in model.named_parameters():
dtype = p.dtype
if dtype not in dtypes: dtypes[dtype] = 0
dtypes[dtype] += p.numel()
total = 0
for k, v in dtypes.items(): total+= v
for k, v in dtypes.items():
print(k, v, v/total)
do_train = True
print("Training...")
if do_train:
train_result = trainer.train()
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
print(metrics)
print("Saving last checkpoint of the model...")
os.makedirs(output_dir, exist_ok=True)
trainer.model.save_pretrained(output_dir)
del model
del trainer
torch.cuda.empty_cache()
output_dir = "results/llama2/final_checkpoint"
train(model, tokenizer, dataset, output_dir)
model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16)
model = model.merge_and_unload()
output_merged_dir = "results/llama2/final_merged_checkpoint"
os.makedirs(output_merged_dir, exist_ok=True)
model.save_pretrained(output_merged_dir, safe_serialization=True)
##########################################################
# NOW INFERENCING
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.save_pretrained(output_merged_dir)
text = "Llama-2 is Great?"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(input_ids=inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"], max_new_tokens=50, pad_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
model.push_to_hub("llama2-fine-tuned-dolly-15k")
tokenizer.push_to_hub("llama2-fine-tuned-dolly-15k")
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("<username>/llama2-fine-tuned-dolly-15k")
model = AutoModelForCausalLM.from_pretrained("<username>/llama2-fine-tuned-dolly-15k")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment