Skip to content

Instantly share code, notes, and snippets.

@hnyu
Created February 6, 2025 20:18
Show Gist options
  • Save hnyu/f3fa732ada9c1fdda4f007c4cc06da94 to your computer and use it in GitHub Desktop.
Save hnyu/f3fa732ada9c1fdda4f007c4cc06da94 to your computer and use it in GitHub Desktop.
Finetuning PaliGemma2 on Bridge robot data for task success check
from typing import List, Tuple, Optional, Dict, Any
import numpy as np
import os
import cv2
import random
from pynvml import *
import logging
# Set the logging level to INFO
logging.basicConfig(level=logging.INFO)
import tensorflow_datasets as tfds
from transformers import (
PaliGemmaProcessor,
PaliGemmaForConditionalGeneration,
)
from transformers.image_utils import load_image
from transformers import TrainingArguments
from accelerate import Accelerator
from datasets import Dataset
from peft import LoraConfig, get_peft_model
import tensorflow as tf
import torch
def print_gpu_utilization():
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
info = nvmlDeviceGetMemoryInfo(handle)
print(f"GPU memory occupied: {info.used//1024**2} MB.")
# Filter options: single arm, language
DATASETS = [
"bridge",
"taco_play",
"jaco_play",
"roboturk",
"viola",
"berkeley_autolab_ur5",
"language_table",
"nyu_rot_dataset_converted_externally_to_rlds",
"stanford_hydra_dataset_converted_externally_to_rlds",
"maniskill_dataset_converted_externally_to_rlds",
"furniture_bench_dataset_converted_externally_to_rlds",
"cmu_franka_exploration_dataset_converted_externally_to_rlds",
"ucsd_kitchen_dataset_converted_externally_to_rlds",
"ucsd_pick_and_place_dataset_converted_externally_to_rlds",
"berkeley_mvp_converted_externally_to_rlds",
"berkeley_rpt_converted_externally_to_rlds",
"kaist_nonprehensile_converted_externally_to_rlds",
"tokyo_u_lsmo_converted_externally_to_rlds",
"dlr_edan_shared_control_converted_externally_to_rlds",
"asu_table_top_converted_externally_to_rlds",
"stanford_robocook_converted_externally_to_rlds",
"imperialcollege_sawyer_wrist_cam",
"iamlab_cmu_pickup_insert_converted_externally_to_rlds",
"utaustin_mutex",
"berkeley_fanuc_manipulation",
"cmu_playing_with_food",
"cmu_play_fusion",
"droid",
"fmb",
"robo_set",
"vima_converted_externally_to_rlds",
"spoc"
]
def dataset2path(dataset_name):
if dataset_name in ['robo_net', 'bridge']:
version = '1.0.0'
elif dataset_name == 'language_table':
version = '0.0.1'
else:
version = '0.1.0'
local_path = os.path.expanduser(f'~/tensorflow_datasets/{dataset_name}/{version}')
if os.path.isdir(local_path):
# If local path exists, we directly use it.
return local_path
return f'gs://gresearch/robotics/{dataset_name}/{version}'
def get_dataset(dataset_name, split):
# We will convert TF tensors to numpy arrays anyway, so no need to use GPU
with tf.device('/CPU:0'):
b = tfds.builder_from_directory(builder_dir=dataset2path(dataset_name))
ds = b.as_dataset(split=split).shuffle(buffer_size=10)
ds = ds.prefetch(tf.data.AUTOTUNE)
return ds
def get_frames_and_instr(episode) -> Tuple[np.ndarray, np.ndarray, str]:
"""Given an episode, extract its first frame, a random middle frame, last
frame, and instruction.
Args:
episode: a dictionary representing an episode.
Returns:
tuple:
- first_frame: the first frame of the episode.
- middle_frame: a random middle frame of the episode. This frame will
be sampled from [0%, 90%] of the episode.
- last_frame: the last frame of the episode.
- instr: the instruction of the episode.
"""
rgb_key = 'image_0'
# episode['steps'] is a _VariantDataset and can only be consumed once
steps = [step for step in episode['steps']]
frames = [step['observation'] for step in steps]
length = len(frames)
lower, upper = 0, int(length * 0.9)
assert lower <= upper
idx = random.randint(lower, upper)
first_frame = frames[0][rgb_key].numpy()
middle_frame = frames[idx][rgb_key].numpy()
last_frame = frames[-1][rgb_key].numpy()
if 'natural_language_instruction' in frames[0]:
instr = frames[0]['natural_language_instruction'].numpy().decode('utf-8')
else:
instr = steps[0]['language_instruction'].numpy().decode('utf-8')
return first_frame, middle_frame, last_frame, instr
def prepare_paligemma_model_inputs(texts: List[str],
images: List[List[np.ndarray]],
processor,
device,
reformat_text: bool = False):
"""Given an instruction and one or more images, prepare the model inputs for
PaliGemma2.
Args:
texts: a list of sentences
images: a list of lists of images. Each inner list represents a set
of images corresponding to one instruction. The length of the
outer list should be equal to the length of the ``texts``.
Returns:
dict: a dictionary of model inputs, with keys 'input_ids', 'pixel_values',
'attention_mask'.
"""
assert len(texts) == len(images)
# We need to insert special '<image>' tokens in the beginning of each
# instruction.
image_token = "<image>"
updated_instructions = []
for text, imgs in zip(texts, images):
n_imgs = len(imgs)
if reformat_text:
question = f"Is the task '{text}' finished given the initial image [IMAGE1] and final image [IMAGE2]? Please only answer 'yes' or 'no'."
else:
question = text
updated_instr = f"{image_token * n_imgs}{question}"
updated_instructions.append(updated_instr)
# Flatten the image list
images = [i for imgs in images for i in imgs]
model_inputs = processor(
text=updated_instructions,
images=images,
return_tensors="pt",
padding=True).to(
torch.bfloat16).to(device)
return model_inputs
class PaliGemma3b224Generation(torch.nn.Module):
def __init__(self):
super().__init__()
model_id = "google/paligemma2-3b-pt-224"
self._model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto")
self._processor = PaliGemmaProcessor.from_pretrained(model_id)
self._device = self._model.device
def generate(self,
texts: List[str],
images: List[List[np.ndarray]],
max_response_tokens: int = 100,
sample: bool = False):
"""
Args:
instructions: a list of instructions
images: a list of lists of images. Each inner list represents a set
of images corresponding to one instruction. The length of the
outer list should be equal to the length of the ``instructions``.
Each image can have a shape of either (H, W, 3) or (3, H, W).
max_response_tokens: maximum number of tokens to generate in response
(default 100).
sample: whether to sample from the model. If False (default), argmax
will be performed for each output token distribution.
Returns:
List[str]: a list of generated responses, one per user prompt.
.. code-block:: python
# Example:
texts = ['Look at the image what do you see?',
'what is the difference between the two images?']
images = [[image1], [image2, image3]]
# The processed inputs to the language model will be two sequences.
# The first sequence will contain N + 256 (image1) tokens, where N
# is the number of text tokens.
# The second sequence will contain M + 512 (image2 + image3) tokens,
# where M is the number of text tokens.
"""
self._model.eval()
model_inputs = prepare_paligemma_model_inputs(
texts, images, self._processor, self._device)
input_len = model_inputs["input_ids"].shape[-1]
responses = []
with torch.inference_mode():
generation = self._model.generate(
**model_inputs,
max_new_tokens=max_response_tokens,
num_beams=3,
do_sample=sample)
# If not truncating the generation, prompt will also be included
generation = generation[:, input_len:]
for i in range(generation.shape[0]):
decoded = self._processor.decode(generation[i], skip_special_tokens=True)
responses.append(decoded)
return responses
def pali_gemma3b224_generation():
texts = ['captioning in details',
'What is the difference between the two images?']
img = cv2.imread('/tmp/car.jpg')
img2 = cv2.imread('/tmp/pexels-pixabay-45201.jpg')
images = [[img], [img, img2]]
m = PaliGemma3b224Generation()
response = m.generate(texts, images)
return response
class PaliGemma3b224Reward(torch.nn.Module):
"""This class adapts ``PaliGemmaForConditionalGeneration`` so that the model
outputs a reward in [0,1], denoting the probability of the robot completing
the instruction given the first and last frames of an episode.
It will load a pretrained ``PaliGemmaForConditionalGeneration`` and adapt
its output layers.
"""
def __init__(self, lora_config = None):
super().__init__()
model_id = "google/paligemma2-3b-pt-224"
self._model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto")
if lora_config is not None:
self._model = get_peft_model(self._model, lora_config)
self._processor = PaliGemmaProcessor.from_pretrained(model_id)
self._device = self._model.device
self._bos_token_id = self._model.config.bos_token_id
# Get Gemma2 language model hidden size
hidden_size = self._model.language_model.config.hidden_size
self._reward_head = torch.nn.Sequential(
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(hidden_size, 1))
self._reward_head = self._reward_head.to(self._device)
@property
def device(self):
"""Return the device of the model."""
return self._device
def set_device(self, device):
self._device = device
self.to(device)
def forward(self,
input_ids: torch.Tensor,
pixel_values: torch.Tensor,
attention_mask: torch.Tensor):
"""Given a sentence represented as an input sequence of tokens,
and one or more images, compute a reward denoting how well the sentence
describes the scene represented by the given images.
Args:
input_ids: a sequence of token ids. Each token is either a normal text
token or a special image token. For each complete 224x224 image,
there are 256 (16x16) special image tokens already inserted
in this sequence in a contiguous manner, and each image token
will later be replaced by an image patch feature embedding. The
normal text tokens will be encoded by a text embedding table.
Shape: ``[B, sequence_length]``. Sequences are padded
to have the same length.
pixel_values: a batch of images. Shape: ``[N, 3, 224, 224]``. ``N``
does not have to be equal to ``B``.
attention_mask: Mask to avoid performing attention on padding token
indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
It is generated by ``PaliGemmaProcessor.__call__``.
Returns:
torch.Tensor: unnormalized (logit) reward
"""
with torch.set_grad_enabled(True):
# First append self._bos_token to the input sequence
bos_tokens = torch.full(
(input_ids.shape[0], 1), self._bos_token_id, device=input_ids.device)
# [B,L1]
input_ids = torch.cat([input_ids, bos_tokens], dim=1)
# [B,L1,D]
input_embeddings = self._model.get_input_embeddings()(input_ids)
device = input_embeddings.device
dtype = input_embeddings.dtype
# [L1]
position_ids = torch.arange(0, input_embeddings.shape[1], device=device)
# Paligemma positions are 1-indexed
position_ids = position_ids + 1
# [B,L1]
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
image_token_index = self._model.config.image_token_index
## Merge image and text embeddings
# [N,L2,D]
image_features = self._model.get_image_features(pixel_values)
# [B,L1,1]
special_image_mask = (input_ids == image_token_index).unsqueeze(-1)
# [B,L1,D]
special_image_mask = special_image_mask.expand_as(input_embeddings).to(
device)
if input_embeddings[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = torch.sum(input_ids == image_token_index)
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but "
f"{image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(device, dtype)
# Fill image features in the image token placeholders
input_embeddings = input_embeddings.masked_scatter(
special_image_mask, image_features)
# Update ``attention_mask`` to include the added last bos_token
mask = torch.ones_like(bos_tokens).to(
attention_mask.device, attention_mask.dtype)
# [B,L1]
attention_mask = torch.cat([attention_mask, mask], dim=1)
outputs = self._model.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False,
output_hidden_states=True,
inputs_embeds=input_embeddings,
# We don't need the model's original logits, but 1 is the minimum to compute
num_logits_to_keep=1)
# [B,L1,Q]
last_hidden_state = outputs.hidden_states[-1]
# [B,Q] taking the state of the last token
last_hidden_state = last_hidden_state[:, -1, :].to(
self._device).to(torch.float32)
# [B,1]
logits = self._reward_head(last_hidden_state)
return logits
def compute_reward(self,
instructions: List[str],
images: List[List[np.ndarray]]):
"""Given an instruction and one or more images, compute a reward denoting
how well the instruction has been achieved given the images.
Args:
instructions: a list of instructions
images: a list of lists of images. Each inner list represents a set
of images corresponding to one instruction. The length of the
outer list should be equal to the length of the ``instructions``.
Each image should be a numpy array of shape ``(H,W,3)`` or
``(3,H,W)``.
Example:
.. code-block:: python
texts = ['captioning in details',
'What is the difference between the two images?']
img = cv2.imread('/tmp/car.jpg')
img2 = cv2.imread('/tmp/pexels-pixabay-45201.jpg')
images = [[img], [img, img2]]
m = PaliGemma3b224Reward()
reward = m.compute_reward(texts, images)
"""
model_inputs = prepare_paligemma_model_inputs(
instructions, images, self._processor, self._device,
reformat_text=True)
reward = self.forward(**model_inputs)
return reward
def gradient_checkpointing_enable(self):
self._model.gradient_checkpointing_enable()
class PaliGemmaInputBatchAccumulator(object):
def __init__(self, batch_size: int = 1):
self._reset()
self._batch_size = batch_size
labels = (torch.ones([batch_size, 1], dtype=torch.int32),
torch.zeros([batch_size, 1], dtype=torch.int32))
# [2B,1]
self._labels = torch.cat(labels, dim=0)
def _reset(self):
self._instrs = []
self._pos_images = []
self._neg_images = []
def add_episode(self, episode) -> Optional[Dict[str, Any]]:
(first_frame,
middle_frame,
last_frame,
instr) = get_frames_and_instr(episode)
self._instrs.append(instr)
self._pos_images.append([first_frame, last_frame])
self._neg_images.append([first_frame, middle_frame])
batch = None
if len(self._instrs) == self._batch_size:
batch = {
"instructions": self._instrs + self._instrs,
"images": self._pos_images + self._neg_images,
"labels": self._labels}
self._reset()
return batch
class PaliGemmaEvalResult(object):
def __init__(self,
instructions: List[str] = [],
images: List[List[np.ndarray]] = [],
pred_labels: torch.Tensor = torch.empty(0, device='cpu'),
pred_probs: torch.Tensor = torch.empty(0, device='cpu'),
labels: torch.Tensor = torch.empty(0, device='cpu')):
"""
Args:
instructions: a list of instructions. Length N
images: a list of lists of images. Length N
pred_labels: 1D int32 tensor
pred_probs: 1D float32 tensor
labels: 1D int32 tensor
"""
self._results = {
"instructions": instructions,
"images": images,
"pred_labels": pred_labels.to('cpu'),
"pred_probs": pred_probs.to('cpu'),
"labels": labels.to('cpu'),
}
def add(self, res: 'PaliGemmaEvalResult'):
"""Add another PaliGemmaEvalResult to the current one."""
for k, v in self._results.items():
if isinstance(v, torch.Tensor):
# TODO: this cat might be inefficient!
self._results[k] = torch.cat([v, res._results[k]])
else:
self._results[k].extend(res._results[k])
def compute_accuracy(self) -> float:
"""Given the current results, compute the accuracy of prediction.
"""
return float((
self._results["pred_labels"] == self._results["labels"]).to(
torch.float32).mean())
def eval_paligemma_reward(model, eval_datasets, batch_size: int = 2,
total_episodes: int = None):
model.eval()
def _eval_one_batch(instructions, images, labels):
# [2B]
with torch.no_grad():
logits = model.compute_reward(instructions, images).squeeze(-1)
# [2B]
pred_labels = (logits > 0).to(torch.int32)
pred_probs = torch.sigmoid(logits)
labels = labels.squeeze(-1)
return PaliGemmaEvalResult(
instructions, images, pred_labels, pred_probs, labels)
batch_acc = PaliGemmaInputBatchAccumulator(batch_size)
result = PaliGemmaEvalResult()
episodes = 0
terminate = False
for ds in eval_datasets:
ds_iter = iter(ds)
for i, episode in enumerate(ds_iter):
batch = batch_acc.add_episode(episode)
# Evaluate one batch
if batch is not None:
res = _eval_one_batch(**batch)
result.add(res)
episodes += batch_size
if total_episodes is not None and episodes >= total_episodes:
terminate = True
break
if terminate:
break
model.train()
return result
def train_paligemma_reward(model = None,
epochs: int = 10,
batch_size: int = 16,
grad_accumulation_steps: int = 4,
summary_interval: int = 100):
"""
Args:
grad_accumulation_steps: to save GPU memory for backprop, we will divide
a batch into ``grad_accumulation_steps`` splits. We compute the grad
for each split separately and accumulate them together, followed by
an optimizer step.
"""
#datasets = ['berkeley_autolab_ur5',
# 'stanford_hydra_dataset_converted_externally_to_rlds'
# ]
datasets = ['bridge']
# Only create the datasets once from disk
train_datasets = [get_dataset(d, split='train') for d in datasets]
eval_datasets = [get_dataset(d, split='val') for d in datasets]
if model is None:
# The entire model will be finetuned.
model = PaliGemma3b224Reward()
model.gradient_checkpointing_enable()
model.train()
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
def _compute_loss(instructions, images, labels):
# [2B,1]
rewards = model.compute_reward(instructions, images)
loss = loss_fn(rewards, labels.to(dtype=torch.float32, device=model.device))
return loss
assert batch_size % grad_accumulation_steps == 0
micro_batch_size = batch_size // grad_accumulation_steps
batch_acc = PaliGemmaInputBatchAccumulator(micro_batch_size)
# For debugging purpose
res = eval_paligemma_reward(model, eval_datasets, total_episodes=100)
logging.info("Initial eval accuracy: %f" % res.compute_accuracy())
#
for e in range(epochs):
logging.info("Starting epoch %d" % e)
random.shuffle(train_datasets)
losses = []
n_batches = 0
for ds in train_datasets:
ds_iter = iter(ds)
for i, episode in enumerate(ds_iter):
batch = batch_acc.add_episode(episode)
if batch is not None:
# A new batch is ready
n_batches += 1
l = _compute_loss(**batch)
losses.append(l)
l = l / grad_accumulation_steps
l.backward()
# Summary loss every ``summary_interval`` batches
if n_batches % summary_interval == 0:
logging.info(f"Loss: {sum(losses) / len(losses)}")
losses = []
# optimizer step every ``grad_accumulation_steps``
if n_batches % grad_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# eval after every training epoch
res = eval_paligemma_reward(model, eval_datasets, total_episodes=100)
logging.info(f"Eval accuracy at epoch {e}: {res.compute_accuracy()}")
def train_paligemma_reward_lora():
# Define LoRA configuration
lora_config = LoraConfig(
r=8, # Rank of the low-rank matrices
lora_alpha=32, # Scaling factor
target_modules=["self_attn.k_proj", # vision&language
"self_attn.q_proj", # vision&language
"self_attn.v_proj", # vision&language
"self_attn.out_proj", # vision
"mlp.fc1", # vision
"mlp.fc2", # vision
"self_attn.o_proj", # language
"mlp.down_proj", # language
"mlp.up_proj", # language
"mlp.gate_proj", # language
"multi_modal_projector.linear"
], # Layers to apply LoRA
lora_dropout=0.1, # Dropout for LoRA layers
bias="none", # Whether to add bias terms
task_type="CAUSAL_LM",
)
model = PaliGemma3b224Reward(lora_config=lora_config)
model._model.print_trainable_parameters()
train_paligemma_reward(model)
if __name__ == '__main__':
if False:
# Test sentence generation
print(pali_gemma3b224_generation())
elif False:
# Finetune the full model
train_paligemma_reward()
elif True:
# Finetune the LoRA model
train_paligemma_reward_lora()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment