Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active May 19, 2024 18:40
Show Gist options
  • Save pszemraj/e88ff24ab296b6d89057376b299b368a to your computer and use it in GitHub Desktop.
Save pszemraj/e88ff24ab296b6d89057376b299b368a to your computer and use it in GitHub Desktop.
how to use unsloth grad checkpointing
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import transformers
import inspect
class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
"""
Saves VRAM by smartly offloading to RAM.
Tiny hit to performance, since we mask the movement via non blocking calls.
"""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
pass
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dY):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, dY)
return (
None,
hidden_states.grad,
) + (
None,
) * len(ctx.args)
pass
pass
def new_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
assert gradient_checkpointing_kwargs == None
if not self.supports_gradient_checkpointing:
raise ValueError(
f"{self.__class__.__name__} does not support gradient checkpointing."
)
gradient_checkpointing_func = Unsloth_Offloaded_Gradient_Checkpointer.apply
# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` method
_is_using_old_format = (
"value" in inspect.signature(self._set_gradient_checkpointing).parameters
)
if not _is_using_old_format:
self._set_gradient_checkpointing(
enable=True, gradient_checkpointing_func=gradient_checkpointing_func
)
else:
raise NotImplementedError()
if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()
def apply_unsloth_offloaded_gradient_checkpoint_monkey_patch():
transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = (
new_gradient_checkpointing_enable
)

usage

Credit/source: here

how to use unsloth grad checkpointing

steps

To integrate the provided monkey patch for offloading gradient checkpointing into the Hugging Face transformers library, you need to follow these steps:

  1. Understand the provided code: The code defines a custom gradient checkpointing function, Unsloth_Offloaded_Gradient_Checkpointer, that offloads tensors to CPU to save VRAM. This function is then used in a new method new_gradient_checkpointing_enable to enable gradient checkpointing with this custom functionality.

  2. Apply the monkey patch: The provided function apply_unsloth_offloaded_gradient_checkpoint_monkey_patch modifies the gradient_checkpointing_enable method of transformers models to use the custom offloaded gradient checkpointing.

  3. Use the patched method in your model: After applying the monkey patch, you can enable gradient checkpointing in your model, and it will use the custom offloading method.

Here's a step-by-step guide:

Step 1: Save the provided code into a Python script

Save the provided code in a Python script, for example, unsloth_offload_gc.py.

Step 2: Import and apply the monkey patch

In your training script, import the necessary components and apply the monkey patch:

import torch
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments

# Import the custom gradient checkpointing patch
from unsloth_offload_gc import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch

# Apply the monkey patch
apply_unsloth_offloaded_gradient_checkpoint_monkey_patch()

# Load your model
model_name = "bert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Rest of your training code
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()

Detailed Explanation:

  1. Import the Patch: Ensure that you import the apply_unsloth_offloaded_gradient_checkpoint_monkey_patch function from the script where you saved the provided code.

  2. Apply the Patch: Call the apply_unsloth_offloaded_gradient_checkpoint_monkey_patch() function. This modifies the gradient_checkpointing_enable method of the transformers models to use the custom gradient checkpointing function.

  3. Load and Configure the Model: Load your desired model using AutoModelForSequenceClassification or any other relevant class. Then, enable gradient checkpointing by calling model.gradient_checkpointing_enable().

  4. Training Script: Continue with your usual training script, setting up TrainingArguments and Trainer as needed.

By following these steps, you integrate the custom gradient checkpointing functionality into your model training process, potentially saving VRAM by offloading tensors to RAM during the forward and backward passes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment