Created
February 6, 2025 20:18
-
-
Save hnyu/f3fa732ada9c1fdda4f007c4cc06da94 to your computer and use it in GitHub Desktop.
Finetuning PaliGemma2 on Bridge robot data for task success check
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Tuple, Optional, Dict, Any | |
import numpy as np | |
import os | |
import cv2 | |
import random | |
from pynvml import * | |
import logging | |
# Set the logging level to INFO | |
logging.basicConfig(level=logging.INFO) | |
import tensorflow_datasets as tfds | |
from transformers import ( | |
PaliGemmaProcessor, | |
PaliGemmaForConditionalGeneration, | |
) | |
from transformers.image_utils import load_image | |
from transformers import TrainingArguments | |
from accelerate import Accelerator | |
from datasets import Dataset | |
from peft import LoraConfig, get_peft_model | |
import tensorflow as tf | |
import torch | |
def print_gpu_utilization(): | |
nvmlInit() | |
handle = nvmlDeviceGetHandleByIndex(0) | |
info = nvmlDeviceGetMemoryInfo(handle) | |
print(f"GPU memory occupied: {info.used//1024**2} MB.") | |
# Filter options: single arm, language | |
DATASETS = [ | |
"bridge", | |
"taco_play", | |
"jaco_play", | |
"roboturk", | |
"viola", | |
"berkeley_autolab_ur5", | |
"language_table", | |
"nyu_rot_dataset_converted_externally_to_rlds", | |
"stanford_hydra_dataset_converted_externally_to_rlds", | |
"maniskill_dataset_converted_externally_to_rlds", | |
"furniture_bench_dataset_converted_externally_to_rlds", | |
"cmu_franka_exploration_dataset_converted_externally_to_rlds", | |
"ucsd_kitchen_dataset_converted_externally_to_rlds", | |
"ucsd_pick_and_place_dataset_converted_externally_to_rlds", | |
"berkeley_mvp_converted_externally_to_rlds", | |
"berkeley_rpt_converted_externally_to_rlds", | |
"kaist_nonprehensile_converted_externally_to_rlds", | |
"tokyo_u_lsmo_converted_externally_to_rlds", | |
"dlr_edan_shared_control_converted_externally_to_rlds", | |
"asu_table_top_converted_externally_to_rlds", | |
"stanford_robocook_converted_externally_to_rlds", | |
"imperialcollege_sawyer_wrist_cam", | |
"iamlab_cmu_pickup_insert_converted_externally_to_rlds", | |
"utaustin_mutex", | |
"berkeley_fanuc_manipulation", | |
"cmu_playing_with_food", | |
"cmu_play_fusion", | |
"droid", | |
"fmb", | |
"robo_set", | |
"vima_converted_externally_to_rlds", | |
"spoc" | |
] | |
def dataset2path(dataset_name): | |
if dataset_name in ['robo_net', 'bridge']: | |
version = '1.0.0' | |
elif dataset_name == 'language_table': | |
version = '0.0.1' | |
else: | |
version = '0.1.0' | |
local_path = os.path.expanduser(f'~/tensorflow_datasets/{dataset_name}/{version}') | |
if os.path.isdir(local_path): | |
# If local path exists, we directly use it. | |
return local_path | |
return f'gs://gresearch/robotics/{dataset_name}/{version}' | |
def get_dataset(dataset_name, split): | |
# We will convert TF tensors to numpy arrays anyway, so no need to use GPU | |
with tf.device('/CPU:0'): | |
b = tfds.builder_from_directory(builder_dir=dataset2path(dataset_name)) | |
ds = b.as_dataset(split=split).shuffle(buffer_size=10) | |
ds = ds.prefetch(tf.data.AUTOTUNE) | |
return ds | |
def get_frames_and_instr(episode) -> Tuple[np.ndarray, np.ndarray, str]: | |
"""Given an episode, extract its first frame, a random middle frame, last | |
frame, and instruction. | |
Args: | |
episode: a dictionary representing an episode. | |
Returns: | |
tuple: | |
- first_frame: the first frame of the episode. | |
- middle_frame: a random middle frame of the episode. This frame will | |
be sampled from [0%, 90%] of the episode. | |
- last_frame: the last frame of the episode. | |
- instr: the instruction of the episode. | |
""" | |
rgb_key = 'image_0' | |
# episode['steps'] is a _VariantDataset and can only be consumed once | |
steps = [step for step in episode['steps']] | |
frames = [step['observation'] for step in steps] | |
length = len(frames) | |
lower, upper = 0, int(length * 0.9) | |
assert lower <= upper | |
idx = random.randint(lower, upper) | |
first_frame = frames[0][rgb_key].numpy() | |
middle_frame = frames[idx][rgb_key].numpy() | |
last_frame = frames[-1][rgb_key].numpy() | |
if 'natural_language_instruction' in frames[0]: | |
instr = frames[0]['natural_language_instruction'].numpy().decode('utf-8') | |
else: | |
instr = steps[0]['language_instruction'].numpy().decode('utf-8') | |
return first_frame, middle_frame, last_frame, instr | |
def prepare_paligemma_model_inputs(texts: List[str], | |
images: List[List[np.ndarray]], | |
processor, | |
device, | |
reformat_text: bool = False): | |
"""Given an instruction and one or more images, prepare the model inputs for | |
PaliGemma2. | |
Args: | |
texts: a list of sentences | |
images: a list of lists of images. Each inner list represents a set | |
of images corresponding to one instruction. The length of the | |
outer list should be equal to the length of the ``texts``. | |
Returns: | |
dict: a dictionary of model inputs, with keys 'input_ids', 'pixel_values', | |
'attention_mask'. | |
""" | |
assert len(texts) == len(images) | |
# We need to insert special '<image>' tokens in the beginning of each | |
# instruction. | |
image_token = "<image>" | |
updated_instructions = [] | |
for text, imgs in zip(texts, images): | |
n_imgs = len(imgs) | |
if reformat_text: | |
question = f"Is the task '{text}' finished given the initial image [IMAGE1] and final image [IMAGE2]? Please only answer 'yes' or 'no'." | |
else: | |
question = text | |
updated_instr = f"{image_token * n_imgs}{question}" | |
updated_instructions.append(updated_instr) | |
# Flatten the image list | |
images = [i for imgs in images for i in imgs] | |
model_inputs = processor( | |
text=updated_instructions, | |
images=images, | |
return_tensors="pt", | |
padding=True).to( | |
torch.bfloat16).to(device) | |
return model_inputs | |
class PaliGemma3b224Generation(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
model_id = "google/paligemma2-3b-pt-224" | |
self._model = PaliGemmaForConditionalGeneration.from_pretrained( | |
model_id, torch_dtype=torch.bfloat16, device_map="auto") | |
self._processor = PaliGemmaProcessor.from_pretrained(model_id) | |
self._device = self._model.device | |
def generate(self, | |
texts: List[str], | |
images: List[List[np.ndarray]], | |
max_response_tokens: int = 100, | |
sample: bool = False): | |
""" | |
Args: | |
instructions: a list of instructions | |
images: a list of lists of images. Each inner list represents a set | |
of images corresponding to one instruction. The length of the | |
outer list should be equal to the length of the ``instructions``. | |
Each image can have a shape of either (H, W, 3) or (3, H, W). | |
max_response_tokens: maximum number of tokens to generate in response | |
(default 100). | |
sample: whether to sample from the model. If False (default), argmax | |
will be performed for each output token distribution. | |
Returns: | |
List[str]: a list of generated responses, one per user prompt. | |
.. code-block:: python | |
# Example: | |
texts = ['Look at the image what do you see?', | |
'what is the difference between the two images?'] | |
images = [[image1], [image2, image3]] | |
# The processed inputs to the language model will be two sequences. | |
# The first sequence will contain N + 256 (image1) tokens, where N | |
# is the number of text tokens. | |
# The second sequence will contain M + 512 (image2 + image3) tokens, | |
# where M is the number of text tokens. | |
""" | |
self._model.eval() | |
model_inputs = prepare_paligemma_model_inputs( | |
texts, images, self._processor, self._device) | |
input_len = model_inputs["input_ids"].shape[-1] | |
responses = [] | |
with torch.inference_mode(): | |
generation = self._model.generate( | |
**model_inputs, | |
max_new_tokens=max_response_tokens, | |
num_beams=3, | |
do_sample=sample) | |
# If not truncating the generation, prompt will also be included | |
generation = generation[:, input_len:] | |
for i in range(generation.shape[0]): | |
decoded = self._processor.decode(generation[i], skip_special_tokens=True) | |
responses.append(decoded) | |
return responses | |
def pali_gemma3b224_generation(): | |
texts = ['captioning in details', | |
'What is the difference between the two images?'] | |
img = cv2.imread('/tmp/car.jpg') | |
img2 = cv2.imread('/tmp/pexels-pixabay-45201.jpg') | |
images = [[img], [img, img2]] | |
m = PaliGemma3b224Generation() | |
response = m.generate(texts, images) | |
return response | |
class PaliGemma3b224Reward(torch.nn.Module): | |
"""This class adapts ``PaliGemmaForConditionalGeneration`` so that the model | |
outputs a reward in [0,1], denoting the probability of the robot completing | |
the instruction given the first and last frames of an episode. | |
It will load a pretrained ``PaliGemmaForConditionalGeneration`` and adapt | |
its output layers. | |
""" | |
def __init__(self, lora_config = None): | |
super().__init__() | |
model_id = "google/paligemma2-3b-pt-224" | |
self._model = PaliGemmaForConditionalGeneration.from_pretrained( | |
model_id, torch_dtype=torch.bfloat16, device_map="auto") | |
if lora_config is not None: | |
self._model = get_peft_model(self._model, lora_config) | |
self._processor = PaliGemmaProcessor.from_pretrained(model_id) | |
self._device = self._model.device | |
self._bos_token_id = self._model.config.bos_token_id | |
# Get Gemma2 language model hidden size | |
hidden_size = self._model.language_model.config.hidden_size | |
self._reward_head = torch.nn.Sequential( | |
torch.nn.Linear(hidden_size, hidden_size), | |
torch.nn.ReLU(), | |
torch.nn.Linear(hidden_size, 1)) | |
self._reward_head = self._reward_head.to(self._device) | |
@property | |
def device(self): | |
"""Return the device of the model.""" | |
return self._device | |
def set_device(self, device): | |
self._device = device | |
self.to(device) | |
def forward(self, | |
input_ids: torch.Tensor, | |
pixel_values: torch.Tensor, | |
attention_mask: torch.Tensor): | |
"""Given a sentence represented as an input sequence of tokens, | |
and one or more images, compute a reward denoting how well the sentence | |
describes the scene represented by the given images. | |
Args: | |
input_ids: a sequence of token ids. Each token is either a normal text | |
token or a special image token. For each complete 224x224 image, | |
there are 256 (16x16) special image tokens already inserted | |
in this sequence in a contiguous manner, and each image token | |
will later be replaced by an image patch feature embedding. The | |
normal text tokens will be encoded by a text embedding table. | |
Shape: ``[B, sequence_length]``. Sequences are padded | |
to have the same length. | |
pixel_values: a batch of images. Shape: ``[N, 3, 224, 224]``. ``N`` | |
does not have to be equal to ``B``. | |
attention_mask: Mask to avoid performing attention on padding token | |
indices. Mask values selected in `[0, 1]`: | |
- 1 for tokens that are **not masked**, | |
- 0 for tokens that are **masked**. | |
It is generated by ``PaliGemmaProcessor.__call__``. | |
Returns: | |
torch.Tensor: unnormalized (logit) reward | |
""" | |
with torch.set_grad_enabled(True): | |
# First append self._bos_token to the input sequence | |
bos_tokens = torch.full( | |
(input_ids.shape[0], 1), self._bos_token_id, device=input_ids.device) | |
# [B,L1] | |
input_ids = torch.cat([input_ids, bos_tokens], dim=1) | |
# [B,L1,D] | |
input_embeddings = self._model.get_input_embeddings()(input_ids) | |
device = input_embeddings.device | |
dtype = input_embeddings.dtype | |
# [L1] | |
position_ids = torch.arange(0, input_embeddings.shape[1], device=device) | |
# Paligemma positions are 1-indexed | |
position_ids = position_ids + 1 | |
# [B,L1] | |
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |
image_token_index = self._model.config.image_token_index | |
## Merge image and text embeddings | |
# [N,L2,D] | |
image_features = self._model.get_image_features(pixel_values) | |
# [B,L1,1] | |
special_image_mask = (input_ids == image_token_index).unsqueeze(-1) | |
# [B,L1,D] | |
special_image_mask = special_image_mask.expand_as(input_embeddings).to( | |
device) | |
if input_embeddings[special_image_mask].numel() != image_features.numel(): | |
image_tokens_in_text = torch.sum(input_ids == image_token_index) | |
raise ValueError( | |
f"Number of images does not match number of special image tokens in the input text. " | |
f"Got {image_tokens_in_text} image tokens in the text but " | |
f"{image_features.shape[0] * image_features.shape[1]} " | |
"tokens from image embeddings." | |
) | |
image_features = image_features.to(device, dtype) | |
# Fill image features in the image token placeholders | |
input_embeddings = input_embeddings.masked_scatter( | |
special_image_mask, image_features) | |
# Update ``attention_mask`` to include the added last bos_token | |
mask = torch.ones_like(bos_tokens).to( | |
attention_mask.device, attention_mask.dtype) | |
# [B,L1] | |
attention_mask = torch.cat([attention_mask, mask], dim=1) | |
outputs = self._model.language_model( | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
use_cache=False, | |
output_hidden_states=True, | |
inputs_embeds=input_embeddings, | |
# We don't need the model's original logits, but 1 is the minimum to compute | |
num_logits_to_keep=1) | |
# [B,L1,Q] | |
last_hidden_state = outputs.hidden_states[-1] | |
# [B,Q] taking the state of the last token | |
last_hidden_state = last_hidden_state[:, -1, :].to( | |
self._device).to(torch.float32) | |
# [B,1] | |
logits = self._reward_head(last_hidden_state) | |
return logits | |
def compute_reward(self, | |
instructions: List[str], | |
images: List[List[np.ndarray]]): | |
"""Given an instruction and one or more images, compute a reward denoting | |
how well the instruction has been achieved given the images. | |
Args: | |
instructions: a list of instructions | |
images: a list of lists of images. Each inner list represents a set | |
of images corresponding to one instruction. The length of the | |
outer list should be equal to the length of the ``instructions``. | |
Each image should be a numpy array of shape ``(H,W,3)`` or | |
``(3,H,W)``. | |
Example: | |
.. code-block:: python | |
texts = ['captioning in details', | |
'What is the difference between the two images?'] | |
img = cv2.imread('/tmp/car.jpg') | |
img2 = cv2.imread('/tmp/pexels-pixabay-45201.jpg') | |
images = [[img], [img, img2]] | |
m = PaliGemma3b224Reward() | |
reward = m.compute_reward(texts, images) | |
""" | |
model_inputs = prepare_paligemma_model_inputs( | |
instructions, images, self._processor, self._device, | |
reformat_text=True) | |
reward = self.forward(**model_inputs) | |
return reward | |
def gradient_checkpointing_enable(self): | |
self._model.gradient_checkpointing_enable() | |
class PaliGemmaInputBatchAccumulator(object): | |
def __init__(self, batch_size: int = 1): | |
self._reset() | |
self._batch_size = batch_size | |
labels = (torch.ones([batch_size, 1], dtype=torch.int32), | |
torch.zeros([batch_size, 1], dtype=torch.int32)) | |
# [2B,1] | |
self._labels = torch.cat(labels, dim=0) | |
def _reset(self): | |
self._instrs = [] | |
self._pos_images = [] | |
self._neg_images = [] | |
def add_episode(self, episode) -> Optional[Dict[str, Any]]: | |
(first_frame, | |
middle_frame, | |
last_frame, | |
instr) = get_frames_and_instr(episode) | |
self._instrs.append(instr) | |
self._pos_images.append([first_frame, last_frame]) | |
self._neg_images.append([first_frame, middle_frame]) | |
batch = None | |
if len(self._instrs) == self._batch_size: | |
batch = { | |
"instructions": self._instrs + self._instrs, | |
"images": self._pos_images + self._neg_images, | |
"labels": self._labels} | |
self._reset() | |
return batch | |
class PaliGemmaEvalResult(object): | |
def __init__(self, | |
instructions: List[str] = [], | |
images: List[List[np.ndarray]] = [], | |
pred_labels: torch.Tensor = torch.empty(0, device='cpu'), | |
pred_probs: torch.Tensor = torch.empty(0, device='cpu'), | |
labels: torch.Tensor = torch.empty(0, device='cpu')): | |
""" | |
Args: | |
instructions: a list of instructions. Length N | |
images: a list of lists of images. Length N | |
pred_labels: 1D int32 tensor | |
pred_probs: 1D float32 tensor | |
labels: 1D int32 tensor | |
""" | |
self._results = { | |
"instructions": instructions, | |
"images": images, | |
"pred_labels": pred_labels.to('cpu'), | |
"pred_probs": pred_probs.to('cpu'), | |
"labels": labels.to('cpu'), | |
} | |
def add(self, res: 'PaliGemmaEvalResult'): | |
"""Add another PaliGemmaEvalResult to the current one.""" | |
for k, v in self._results.items(): | |
if isinstance(v, torch.Tensor): | |
# TODO: this cat might be inefficient! | |
self._results[k] = torch.cat([v, res._results[k]]) | |
else: | |
self._results[k].extend(res._results[k]) | |
def compute_accuracy(self) -> float: | |
"""Given the current results, compute the accuracy of prediction. | |
""" | |
return float(( | |
self._results["pred_labels"] == self._results["labels"]).to( | |
torch.float32).mean()) | |
def eval_paligemma_reward(model, eval_datasets, batch_size: int = 2, | |
total_episodes: int = None): | |
model.eval() | |
def _eval_one_batch(instructions, images, labels): | |
# [2B] | |
with torch.no_grad(): | |
logits = model.compute_reward(instructions, images).squeeze(-1) | |
# [2B] | |
pred_labels = (logits > 0).to(torch.int32) | |
pred_probs = torch.sigmoid(logits) | |
labels = labels.squeeze(-1) | |
return PaliGemmaEvalResult( | |
instructions, images, pred_labels, pred_probs, labels) | |
batch_acc = PaliGemmaInputBatchAccumulator(batch_size) | |
result = PaliGemmaEvalResult() | |
episodes = 0 | |
terminate = False | |
for ds in eval_datasets: | |
ds_iter = iter(ds) | |
for i, episode in enumerate(ds_iter): | |
batch = batch_acc.add_episode(episode) | |
# Evaluate one batch | |
if batch is not None: | |
res = _eval_one_batch(**batch) | |
result.add(res) | |
episodes += batch_size | |
if total_episodes is not None and episodes >= total_episodes: | |
terminate = True | |
break | |
if terminate: | |
break | |
model.train() | |
return result | |
def train_paligemma_reward(model = None, | |
epochs: int = 10, | |
batch_size: int = 16, | |
grad_accumulation_steps: int = 4, | |
summary_interval: int = 100): | |
""" | |
Args: | |
grad_accumulation_steps: to save GPU memory for backprop, we will divide | |
a batch into ``grad_accumulation_steps`` splits. We compute the grad | |
for each split separately and accumulate them together, followed by | |
an optimizer step. | |
""" | |
#datasets = ['berkeley_autolab_ur5', | |
# 'stanford_hydra_dataset_converted_externally_to_rlds' | |
# ] | |
datasets = ['bridge'] | |
# Only create the datasets once from disk | |
train_datasets = [get_dataset(d, split='train') for d in datasets] | |
eval_datasets = [get_dataset(d, split='val') for d in datasets] | |
if model is None: | |
# The entire model will be finetuned. | |
model = PaliGemma3b224Reward() | |
model.gradient_checkpointing_enable() | |
model.train() | |
loss_fn = torch.nn.BCEWithLogitsLoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) | |
def _compute_loss(instructions, images, labels): | |
# [2B,1] | |
rewards = model.compute_reward(instructions, images) | |
loss = loss_fn(rewards, labels.to(dtype=torch.float32, device=model.device)) | |
return loss | |
assert batch_size % grad_accumulation_steps == 0 | |
micro_batch_size = batch_size // grad_accumulation_steps | |
batch_acc = PaliGemmaInputBatchAccumulator(micro_batch_size) | |
# For debugging purpose | |
res = eval_paligemma_reward(model, eval_datasets, total_episodes=100) | |
logging.info("Initial eval accuracy: %f" % res.compute_accuracy()) | |
# | |
for e in range(epochs): | |
logging.info("Starting epoch %d" % e) | |
random.shuffle(train_datasets) | |
losses = [] | |
n_batches = 0 | |
for ds in train_datasets: | |
ds_iter = iter(ds) | |
for i, episode in enumerate(ds_iter): | |
batch = batch_acc.add_episode(episode) | |
if batch is not None: | |
# A new batch is ready | |
n_batches += 1 | |
l = _compute_loss(**batch) | |
losses.append(l) | |
l = l / grad_accumulation_steps | |
l.backward() | |
# Summary loss every ``summary_interval`` batches | |
if n_batches % summary_interval == 0: | |
logging.info(f"Loss: {sum(losses) / len(losses)}") | |
losses = [] | |
# optimizer step every ``grad_accumulation_steps`` | |
if n_batches % grad_accumulation_steps == 0: | |
optimizer.step() | |
optimizer.zero_grad() | |
# eval after every training epoch | |
res = eval_paligemma_reward(model, eval_datasets, total_episodes=100) | |
logging.info(f"Eval accuracy at epoch {e}: {res.compute_accuracy()}") | |
def train_paligemma_reward_lora(): | |
# Define LoRA configuration | |
lora_config = LoraConfig( | |
r=8, # Rank of the low-rank matrices | |
lora_alpha=32, # Scaling factor | |
target_modules=["self_attn.k_proj", # vision&language | |
"self_attn.q_proj", # vision&language | |
"self_attn.v_proj", # vision&language | |
"self_attn.out_proj", # vision | |
"mlp.fc1", # vision | |
"mlp.fc2", # vision | |
"self_attn.o_proj", # language | |
"mlp.down_proj", # language | |
"mlp.up_proj", # language | |
"mlp.gate_proj", # language | |
"multi_modal_projector.linear" | |
], # Layers to apply LoRA | |
lora_dropout=0.1, # Dropout for LoRA layers | |
bias="none", # Whether to add bias terms | |
task_type="CAUSAL_LM", | |
) | |
model = PaliGemma3b224Reward(lora_config=lora_config) | |
model._model.print_trainable_parameters() | |
train_paligemma_reward(model) | |
if __name__ == '__main__': | |
if False: | |
# Test sentence generation | |
print(pali_gemma3b224_generation()) | |
elif False: | |
# Finetune the full model | |
train_paligemma_reward() | |
elif True: | |
# Finetune the LoRA model | |
train_paligemma_reward_lora() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment