Skip to content

Instantly share code, notes, and snippets.

@GohioAC
Last active June 28, 2024 09:49
Show Gist options
  • Save GohioAC/44cc9ac2ee54380ff80aa6ab89fea698 to your computer and use it in GitHub Desktop.
Save GohioAC/44cc9ac2ee54380ff80aa6ab89fea698 to your computer and use it in GitHub Desktop.
Issue with LLaVA-Next SFT
"""
python vsft.py \
--dataset_name="HuggingFaceH4/llava-instruct-mix-vsft" \
--model_name_or_path="llava-hf/llava-v1.6-mistral-7b-hf" \
--report_to="tensorboard" \
--learning_rate=2e-5 \
--lr_scheduler_type="cosine" \
--per_device_train_batch_size=8 \
--gradient_accumulation_steps=1 \
--output_dir="data/vsft-llava-1.5-7b-hf" \
--logging_steps=1 \
--num_train_epochs=1 \
--gradient_checkpointing \
--remove_unused_columns=False \
--torch_dtype=float16 \
--fp16=True \
--max_seq_length=4096 \
--attn_implementation="flash_attention_2"
"""
from contextlib import nullcontext
from trl.commands.cli_utils import SFTScriptArguments, TrlParser
import torch
from datasets import load_dataset
from tqdm.rich import tqdm
from transformers import AutoTokenizer, AutoProcessor, LlavaNextForConditionalGeneration
from trl import (
ModelConfig,
SFTConfig,
SFTTrainer,
get_peft_config,
get_quantization_config,
get_kbit_device_map,
)
tqdm.pandas()
if __name__ == "__main__":
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
sft_script_args, training_args, model_config = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
LLAVA_CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, use_fast=True, padding_side="right"
)
tokenizer.chat_template = LLAVA_CHAT_TEMPLATE
processor = AutoProcessor.from_pretrained(model_config.model_name_or_path)
processor.tokenizer = tokenizer
model = LlavaNextForConditionalGeneration.from_pretrained(
model_config.model_name_or_path, **model_kwargs
)
class LLavaDataCollator:
def __init__(self, processor):
self.processor = processor
def __call__(self, examples):
texts = []
images = []
for example in examples:
if len(example["images"]) > 1:
raise ValueError(
"This collator only supports one image per example"
)
messages = example["messages"]
text = self.processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
texts.append(text)
images.append(example["images"][0])
batch = self.processor(texts, images, return_tensors="pt", padding=True)
labels = batch["input_ids"].clone()
batch["labels"] = labels
return batch
data_collator = LLavaDataCollator(processor)
raw_datasets = load_dataset(sft_script_args.dataset_name)
train_dataset = raw_datasets[sft_script_args.dataset_train_split]
eval_dataset = raw_datasets[sft_script_args.dataset_test_split]
init_context = nullcontext()
save_context = nullcontext()
with init_context:
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field="text", # need a dummy field
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
callbacks=None,
data_collator=data_collator,
dataset_kwargs={"skip_prepare_dataset": True},
)
trainer.train()
with save_context:
trainer.save_model(training_args.output_dir)
@GohioAC
Copy link
Author

GohioAC commented Jun 28, 2024

Error stack trace

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour                                                                                               Loading checkpoint shards: 100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  3.24it/s]
Resolving data files: 100%|████████████████████████████████████████████| 20/20 [00:00<00:00, 301748.49it/s]
Resolving data files: 100%|████████████████████████████████████████████| 20/20 [00:00<00:00, 348075.02it/s]                                                                                                           Loading dataset shards: 100%|████████████████████████████████████████████| 23/23 [00:00<00:00, 1762.92it/s]
/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': dataset_text_field, dataset_kwargs. Will not be suppo
rted from version '1.0.0'.

Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
  warnings.warn(message, FutureWarning)
/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:307: UserWarning: You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.
  warnings.warn(
/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:355: UserWarning: You passed a `dataset_kwargs` argument to the SFTTrainer, the value you passed will override the one in the `
SFTConfig`.
  warnings.warn(
  0%|                                                                                                                                                                                        | 0/4050 [00:00<?, ?it/s]
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and ret
urn a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '
Traceback (most recent call last):
  File "/opt/aritra.c/worktree/llava-finetune-v1/LLaVA/scripts/vsft.py", line 139, in <module>
    trainer.train()
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 440, in train
    output = super().train(*args, **kwargs)
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/transformers/trainer.py", line 1885, in train
    return inner_training_loop(
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/transformers/trainer.py", line 3238, in training_step
    loss = self.compute_loss(model, inputs)
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/transformers/trainer.py", line 3264, in compute_loss
  outputs = model(**inputs)
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 186, in forward
    return self.gather(outputs, self.output_device)
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 203, in gather
    return gather(outputs, output_device, dim=self.dim)
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py", line 104, in gather
    res = gather_map(outputs)
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py", line 95, in gather_map
    return type(out)((k, gather_map([d[k] for d in outputs]))
  File "<string>", line 9, in __init__
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/transformers/utils/generic.py", line 389, in __post_init__
    for idx, element in enumerate(iterator):
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py", line 95, in <genexpr>
    return type(out)((k, gather_map([d[k] for d in outputs]))
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py", line 89, in gather_map
    return Gather.apply(target_device, dim, *outputs)
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/torch/nn/parallel/_functions.py", line 75, in forward
    return comm.gather(inputs, ctx.dim, ctx.target_device)
  File "/opt/aritra.c/.venvs/llava-trl/lib/python3.10/site-packages/torch/nn/parallel/comm.py", line 231, in gather
    return torch._C._gather(tensors, dim, destination)
RuntimeError: Input tensor at index 1 has invalid shape [8, 2785, 32064], but expected [8, 2889, 32064]
  0%|          | 0/4050 [01:35<?, ?it/s]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment