Credit/source: here
how to use unsloth grad checkpointing
To integrate the provided monkey patch for offloading gradient checkpointing into the Hugging Face transformers
library, you need to follow these steps:
-
Understand the provided code: The code defines a custom gradient checkpointing function,
Unsloth_Offloaded_Gradient_Checkpointer
, that offloads tensors to CPU to save VRAM. This function is then used in a new methodnew_gradient_checkpointing_enable
to enable gradient checkpointing with this custom functionality. -
Apply the monkey patch: The provided function
apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
modifies thegradient_checkpointing_enable
method oftransformers
models to use the custom offloaded gradient checkpointing. -
Use the patched method in your model: After applying the monkey patch, you can enable gradient checkpointing in your model, and it will use the custom offloading method.
Here's a step-by-step guide:
Save the provided code in a Python script, for example, unsloth_offload_gc.py
.
In your training script, import the necessary components and apply the monkey patch:
import torch
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
# Import the custom gradient checkpointing patch
from unsloth_offload_gc import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
# Apply the monkey patch
apply_unsloth_offloaded_gradient_checkpoint_monkey_patch()
# Load your model
model_name = "bert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Rest of your training code
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
-
Import the Patch: Ensure that you import the
apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
function from the script where you saved the provided code. -
Apply the Patch: Call the
apply_unsloth_offloaded_gradient_checkpoint_monkey_patch()
function. This modifies thegradient_checkpointing_enable
method of thetransformers
models to use the custom gradient checkpointing function. -
Load and Configure the Model: Load your desired model using
AutoModelForSequenceClassification
or any other relevant class. Then, enable gradient checkpointing by callingmodel.gradient_checkpointing_enable()
. -
Training Script: Continue with your usual training script, setting up
TrainingArguments
andTrainer
as needed.
By following these steps, you integrate the custom gradient checkpointing functionality into your model training process, potentially saving VRAM by offloading tensors to RAM during the forward and backward passes.