Skip to content

Instantly share code, notes, and snippets.

@cyberfox
Created June 19, 2023 05:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cyberfox/566948b2c59a8f91b8c2bb888ff0e38c to your computer and use it in GitHub Desktop.
Save cyberfox/566948b2c59a8f91b8c2bb888ff0e38c to your computer and use it in GitHub Desktop.
Patch to support training LoRA for StarCoder-based models on Oobabooga's text-generation-webui
diff --git a/modules/training.py b/modules/training.py
index 75ba82c..c90d823 100644
--- a/modules/training.py
+++ b/modules/training.py
@@ -30,12 +30,14 @@ try:
MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}
except:
standard_modules = ["q_proj", "v_proj"]
- model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"]}
+ model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"],
+ "gpt_bigcode": ["c_attn", "c_proj", "c_fc"]}
MODEL_CLASSES = {
"LlamaForCausalLM": "llama",
"OPTForCausalLM": "opt",
"GPTJForCausalLM": "gptj",
- "GPTNeoXForCausalLM": "gpt_neox"
+ "GPTNeoXForCausalLM": "gpt_neox",
+ "GPTBigCodeForCausalLM": "gpt_bigcode"
}
@@ -423,7 +425,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
num_train_epochs=epochs,
learning_rate=actual_lr,
- fp16=False if shared.args.cpu else True,
+ fp16=False,# if shared.args.cpu else True,
optim=optimizer,
logging_steps=5,
evaluation_strategy="steps" if eval_data is not None else "no",
@@ -434,7 +436,8 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
load_best_model_at_end=eval_data is not None,
# TODO: Enable multi-device support
ddp_find_unused_parameters=None,
- no_cuda=shared.args.cpu
+ no_cuda=shared.args.cpu,
+ report_to="wandb"
),
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
callbacks=list([Callbacks()])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment