Skip to content

Instantly share code, notes, and snippets.

@ebsmothers
Created March 15, 2024 19:15
Show Gist options
  • Save ebsmothers/125c822a9cb0323a2ebfe59dae54abc3 to your computer and use it in GitHub Desktop.
Save ebsmothers/125c822a9cb0323a2ebfe59dae54abc3 to your computer and use it in GitHub Desktop.
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
HfArgumentParser,
TrainingArguments,
pipeline,
logging,
GenerationConfig,
)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import bitsandbytes as bnb
def find_all_linear_names(model):
cls = bnb.nn.Linear4bit
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
return list(lora_module_names)
max_memory = '80000MB'
max_memory = {i: max_memory for i in range(torch.cuda.device_count())}
compute_dtype=torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
'huggyllama/llama-7b',
cache_dir=None,
load_in_4bit=True,
load_in_8bit=False,
device_map="auto",
max_memory=max_memory,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4',
),
torch_dtype=torch.bfloat16,
trust_remote_code=False,
use_auth_token=False
)
modules = find_all_linear_names(model)
config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=modules,
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
for k, v in model.state_dict().items():
if isinstance(v, torch.Tensor):
print(k, v.dtype)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment