Skip to content

Instantly share code, notes, and snippets.

@VictorSanh
Last active October 2, 2023 22:07
Show Gist options
  • Save VictorSanh/710930406c8403a3bd3831e7ffba9afe to your computer and use it in GitHub Desktop.
Save VictorSanh/710930406c8403a3bd3831e7ffba9afe to your computer and use it in GitHub Desktop.
IDEFICS fine tuning with zero ds 3
#!/bin/bash
#SBATCH --job-name=idefics_zero3_finetuning_multinode # name
#SBATCH --nodes=3 # nodes
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=96 # number of cores per tasks
#SBATCH --gres=gpu:8 # number of gpus
#SBATCH --output=%x-%j.out # output file name
export GPUS_PER_NODE=8
export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
export MASTER_PORT=9901
conda activate victor-fresh
srun --jobid $SLURM_JOBID bash -c 'python -m torch.distributed.run \
--nproc_per_node $GPUS_PER_NODE --nnodes $SLURM_NNODES --node_rank $SLURM_PROCID \
--master_addr $MASTER_ADDR --master_port $MASTER_PORT \
idefics_zero3_finetuning.py'
"""
Start an interactive session with `srun --pty --cpus-per-task=96 --mem-per-cpu=11G --gpus=8 --mpi=pmix bash -i`
Launch with `deepspeed --num_gpus 8 idefics_zero3_finetuning.py`
Also see the slurm script for multi node launching (necessary for the 80b model)
"""
import torch
import torchvision.transforms as transforms
from datasets import load_dataset
from PIL import Image
from transformers import AutoProcessor, IdeficsForVisionText2Text, Trainer, TrainingArguments
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = "HuggingFaceM4/idefics-80b"
processor = AutoProcessor.from_pretrained(checkpoint, use_auth_token=True)
def convert_to_rgb(image):
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
# for transparent images. The call to `alpha_composite` handles this case
if image.mode == "RGB":
return image
image_rgba = image.convert("RGBA")
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
alpha_composite = Image.alpha_composite(background, image_rgba)
alpha_composite = alpha_composite.convert("RGB")
return alpha_composite
def ds_transforms(example_batch):
image_size = processor.image_processor.image_size
image_mean = processor.image_processor.image_mean
image_std = processor.image_processor.image_std
image_transform = transforms.Compose(
[
convert_to_rgb,
transforms.RandomResizedCrop(
(image_size, image_size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize(mean=image_mean, std=image_std),
]
)
prompts = []
for i in range(len(example_batch["caption"])):
# We split the captions to avoid having very long examples, which would require more GPU ram during training
caption = example_batch["caption"][i].split(".")[0]
try:
# There are a handful of images that are not hosted anymore... this is quite annoying. This is a small (dummy) hack to skip these
processor.image_processor.fetch_images(example_batch["image_url"][i])
except:
continue
prompts.append(
[
example_batch["image_url"][i],
f"Question: What's on the picture? Answer: This is {example_batch['name'][i]}. {caption}</s>",
],
)
inputs = processor(prompts, transform=image_transform, return_tensors="pt").to(device)
inputs["labels"] = inputs["input_ids"]
return inputs
# load and prepare dataset
ds = load_dataset("TheFusion21/PokemonCards")
ds = ds["train"].train_test_split(test_size=0.002)
train_ds = ds["train"]
eval_ds = ds["test"]
train_ds.set_transform(ds_transforms)
eval_ds.set_transform(ds_transforms)
# Important, define the training_args before the model
ds_config = {
"communication_data_type": "fp32",
"bf16": {"enabled": True},
"zero_optimization": {
"stage": 3,
"overlap_comm": False,
"reduce_bucket_size": "auto",
"contiguous_gradients": True,
"stage3_gather_16bit_weights_on_model_save": False,
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 2e9,
"stage3_max_reuse_distance": 2e9,
"offload_optimizer": {"device": "none"},
"offload_param": {"device": "none"},
},
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"steps_per_print": 2000000,
}
training_args = TrainingArguments(
output_dir=f"idefics-pokemon",
learning_rate=2e-4,
bf16=True,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=1,
gradient_checkpointing=True, # Gradient checkpointing helps reducing the memory requirements at a small speed cost
dataloader_pin_memory=False,
save_total_limit=3,
evaluation_strategy="steps",
save_strategy="steps",
save_steps=10,
eval_steps=10,
logging_steps=1,
max_steps=20,
remove_unused_columns=False,
push_to_hub=False,
label_names=["labels"],
load_best_model_at_end=True,
report_to="none",
optim="adamw_torch",
deepspeed=ds_config,
)
model = IdeficsForVisionText2Text.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
)
result = trainer.train()
print(result)
torch
deepspeed
transformers
pillow
sentencepiece
protobuf
datasets
accelerate
mpi4py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment