Skip to content

Instantly share code, notes, and snippets.

@VictorSanh
Created October 19, 2023 02:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save VictorSanh/99b313903302525fca49d5a53cc3bcf3 to your computer and use it in GitHub Desktop.
Save VictorSanh/99b313903302525fca49d5a53cc3bcf3 to your computer and use it in GitHub Desktop.
Packing and splitting OBELICS style documents for IDEFICS training
import torch
import numpy as np
import logging
IMAGE_TOKEN = "<image>"
FAKE_TOKEN_AROUND_IMAGE_V2 = "<fake_token_around_image>"
_MIN_LENGTH_DOCUMENTS_TO_PACK = (
5 # Minimum lengths of documents to pack together (lenghts is measures in number of tokens)
)
_IMAGE_BONUS_VALUE = 2 # The bonus value for tokens preceding the image token
logger = logging.getLogger(__name__)
def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1):
# This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]]
# If any of images index are more than num_classes, set them to -1.
# Words after the max number of images allowed have been seen don't attend on anything
if num_classes != -1:
incremental_mask[incremental_mask >= num_classes] = -1
negatives = incremental_mask == -1
incremental_mask[negatives] = 0
attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes)
attn_mask[negatives, :] = 0
return attn_mask
def image_attention_mask_for_packed_input_ids(input_ids, tokenizer):
image_attention_mask = torch.full_like(input_ids, fill_value=-1)
next_image_attention_mask = torch.full_like(input_ids, fill_value=-1)
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
eod_token_id = tokenizer.eos_token_id
for batch_idx in range(input_ids.size(0)):
count = -1
seen_eod = False
for idx, token_id in enumerate(input_ids[batch_idx]):
if token_id == image_token_id:
count += 1
image_attention_mask[batch_idx][idx] = count
seen_eod = False
else:
image_attention_mask[batch_idx][idx] = count
if seen_eod:
image_attention_mask[batch_idx][idx] = -1
if token_id == eod_token_id:
seen_eod = True
for batch_idx in range(input_ids.size(0)):
count = -1
seen_eod = False
for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1):
token_id = input_ids[batch_idx][idx]
if token_id == image_token_id:
count += 1
next_image_attention_mask[batch_idx][idx] = count
seen_eod = False
else:
next_image_attention_mask[batch_idx][idx] = count
if token_id == eod_token_id:
seen_eod = True
if seen_eod:
next_image_attention_mask[batch_idx][idx] = -1
non_negative_indices = next_image_attention_mask[batch_idx] != -1
next_image_attention_mask[batch_idx][non_negative_indices] -= count
next_image_attention_mask[batch_idx][non_negative_indices] *= -1
return image_attention_mask, next_image_attention_mask
def split_pack_and_pad(
sample,
tokenizer,
max_seq_len,
max_num_images,
max_num_samples_per_document=10,
prefix_seed=(0, 0),
add_begin_of_doc_token=False,
add_end_of_doc_token=True,
max_num_images_per_document=None,
):
"""
Return a batch of samples in the format expected by the model which
includes `input_ids`, `pixel_values`, `attention_mask`, `image_attention_mask`,
and `next_image_attention_mask`. The `input_ids` are sampled from the document to
ensure it has `max_seq_len` tokens otherwise, the shorter documents are packed together.
For each document, we sample a maximum of `max_num_samples_per_document` or `max_num_samples_for_curr_document`
(where the latter is proportional to the length of the document and inversely proportional to the length of subsequences)
`input_ids` with sequence length `max_seq_len` from the document. This means that
each sample sampled can have different start index. Based on the start index of sample that
has been sampled, we also sample a maximum of `max_num_images` images from the document.
If there are less than `max_num_images` images in the document, we pad the images with zeros.
The start indexes are skewed towards subsequences that contain images.
Args:
sample (Dict): A sample object containing the document with images and texts.
Each of the key contains a list of interleaved elements.
For instance, for a given document is represented by two list `images` and `texts` of the same length, where for each position, only one element in the two lists can be NOT None: `images=[image1, None, None, image2, None]`, `texts=[None, text1, text2, None, text3]`
tokenizer (PretrainedTokenizer): Text tokenizer to be used.
max_seq_len (int): Maximum sequence length of the returned text tokens.
max_num_images (int): Maximum number of images to be sampled per sample. If less, they are padded with zeros.
max_num_samples_per_document (int, optional): Maximum number of samples per document to be sampled. Defaults to 10.
prefix_seed: Prefix seed sequence for "reproducible randomness" in calls to `np.random.choice`
Returns:
_type_: _description_
"""
text_batch = sample["texts"]
image_batch = sample.get("images", None)
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
last_was_image = False
all_images = []
all_texts = []
for raw_images, raw_texts in zip(image_batch, text_batch):
# Filter ones that don't have either one image and one text word
if not any(raw_images) or not any(raw_texts):
continue
if max_num_images_per_document:
num_images = sum([1 if image is not None else 0 for image in raw_images])
if num_images > max_num_images_per_document:
continue
splitted_raw_images, splitted_raw_texts = [raw_images], [raw_texts]
for s_r_ims, s_r_txts in zip(splitted_raw_images, splitted_raw_texts):
images, web_text = [], ""
for image, text in zip(s_r_ims, s_r_txts):
if text is None and image is None:
continue
if image is not None:
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}"
images.append(torch.tensor(image))
last_was_image = True
elif text is not None:
if last_was_image:
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{text}"
last_was_image = False
else:
web_text += f" {text}" if web_text != "" else text
if last_was_image:
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}"
web_text = web_text.strip(" ")
# This is mostly a sanity check. Cases like that should not happen at that point.
if web_text == "" or len(images) == 0:
continue
images = torch.stack(images)
all_images.append(images)
web_text_ids = tokenizer.encode(web_text, add_special_tokens=False)
if add_end_of_doc_token:
web_text_ids += [tokenizer.eos_token_id]
if add_begin_of_doc_token:
web_text_ids = [tokenizer.bos_token_id] + web_text_ids
all_texts.append(web_text_ids)
output_input_ids = []
output_images = []
output_attention_masks = []
output_num_images = []
output_num_text_tokens = []
input_ids_to_pack = []
images_to_pack = []
for images, text in zip(all_images, all_texts):
# We save all the documents which are shorter than the max_seq_len to pack them together.
if len(text) <= max_seq_len:
if len(text) < _MIN_LENGTH_DOCUMENTS_TO_PACK: # Filter out extremely short sequences
continue
input_ids_to_pack.extend(text)
images_to_pack.extend(images)
else:
# Computing the bonus scores for tokens near images to skew the sampling towards them
# The main idea is to give a bonus to tokens that are closely before an image token, so that these tokens have more chance to be sampled.
# Bonuses are computed for each image, which means a given token can receive bonuses from multiple images if this token is closely preceding multiple images.
# We sum all the bonuses and L1 normalized along the seq_len axis to get a probability distribution.
# Each token start with a regular bonus of 1, which corresponds to the uniform distribution over the sequence when there are no bonuses added.
# Now the remaining question is which precedding tokens do we distribue bonuses to.
# We first observe that for the sampled sub-sequence to be considered valid (i.e. sub-sequence contains an image), the start index can only be among [image_idx - max_seq_len + 1, image_idx].
# For the sake of the explanation, let's split the [image_idx - max_seq_len + 1, image_idx] interval in 3 parts: left, middle and right (in increasing order).
# If we give bonuses to the tokens just before the image (right part), then we are favoring p_next=0 because only the tokens after the image have an image to attend to.
# In practice, images will tend to be at the beginning of the sampled sub-sequence.
# If we give bonuses very far before the image (left part), then we are favoring p_next=1 because only the tokens before the image gave an image to attend to.
# In practice, images will tend to be at the end of the sampled sub-sequence.
# To avoid choosing favoring p_next=0 or p_next=1, we can give bonuses to the tokens in the middle part.
# In practise, images will tend to be in the middle of the sampled sequence.
# Ultimately, we don't want to skew the distribution fed to model in that way (i.e. whether images are in the beginning, middle or end of the sampled sub-sequence),
# and have all these cases represented equally in the data. So the easiest is to distribute a bonus to all of the max_seq_len tokens preceding the image.
all_scores = np.array([1] * len(text))
for img_token_idx in np.where(np.array(text) == image_token_id)[0]:
all_scores[max(0, img_token_idx - max_seq_len) : img_token_idx + 1] += _IMAGE_BONUS_VALUE
# all_scores = np.clip(all_scores, a_min=1, a_max=3 * _IMAGE_BONUS_VALUE * max_num_images + 1) # We can optionally clip the bonuses to avoid having too high values (i.e. outliers documents)
all_scores = all_scores[:-_MIN_LENGTH_DOCUMENTS_TO_PACK]
# The number of samples is proportional to the length of the text and inversely proportional to the maximum sequence length
max_num_samples_for_curr_document = len(text) // max_seq_len
# Set "reproducible randomness" by creating an np.default_rng seeded by (main seed, epoch, rank_idx, worker_idx, mapped_batch_index, text len)
choices = np.random.default_rng(seed=list(prefix_seed) + [len(text)]).choice(
range(len(text) - _MIN_LENGTH_DOCUMENTS_TO_PACK), # shorter sub-sequences are reserved for packing
min(
len(text) - max_seq_len, 2 * max_num_samples_per_document
), # Sampling more than necessary and then breaking out of the for loop once we have enough samples
p=all_scores / np.linalg.norm(all_scores, ord=1),
replace=False,
)
nb_effective_sequences_out_of_sampling = 0
for start_index in choices:
image_start_index = text[:start_index].count(image_token_id)
text_sub_sequence = text[start_index : start_index + max_seq_len]
image_count = text_sub_sequence.count(image_token_id)
if image_count == 0:
# Skip if there are no images in the sequence
continue
if len(text_sub_sequence) < max_seq_len:
# If the sub-sequence is shorter than max_seq_len, we reserve it for packing
# It necessarily mean that the sub-sequence was sampled towards the end of the document,
# which implies that we only need the `image_start_index` and not the `image_end_index`
if text_sub_sequence.count(image_token_id) != len(images[image_start_index:]):
# A safeguard for this
logger.warning(
"Skipping this sample because of mismatch in actual number of images and "
"the '<image>' tokens in the text"
)
continue
input_ids_to_pack.extend(text_sub_sequence)
images_to_pack.extend(images[image_start_index:])
continue
current_images = images[image_start_index : image_start_index + min(max_num_images, image_count)]
if len(current_images) != min(max_num_images, image_count):
# A safeguard for something off about this document, maybe `<image>` tag that
# by there from before or some issue in parsing the image?
logger.warning(
"Skipping this sample because of mismatch in actual number of images and "
"the '<image>' tokens in the text"
)
break
padded_image_tensor = torch.zeros(max_num_images, *images.size()[1:])
padded_image_tensor[: min(max_num_images, image_count)] = current_images
output_images.append(padded_image_tensor)
output_num_images.append(min(max_num_images, image_count))
output_input_ids.append(torch.tensor(text_sub_sequence))
output_num_text_tokens.append(len(text_sub_sequence))
attention_mask = torch.ones((max_seq_len,), dtype=torch.long)
output_attention_masks.append(attention_mask)
nb_effective_sequences_out_of_sampling += 1
if nb_effective_sequences_out_of_sampling >= min(
max_num_samples_for_curr_document, max_num_samples_per_document
):
# We got all the samples we need for this document, so breaking out
break
# Pack the remaining sequences from `input_ids_to_pack` x `images_to_pack`
if input_ids_to_pack:
image_counter = 0
for i in range(0, len(input_ids_to_pack), max_seq_len):
current_input_ids = input_ids_to_pack[i : i + max_seq_len]
unpadded_seq_len = len(current_input_ids)
num_images = current_input_ids.count(image_token_id)
if num_images == 0:
continue
current_images = images_to_pack[image_counter : image_counter + num_images]
image_counter += num_images
if unpadded_seq_len < max_seq_len:
padded_input_ids = [tokenizer.pad_token_id] * max_seq_len
padded_input_ids[:unpadded_seq_len] = current_input_ids
current_input_ids = padded_input_ids
elif unpadded_seq_len > max_seq_len:
# This case has no purpose other than safeguard
continue
try:
current_images = torch.stack(current_images)[:max_num_images]
except Exception:
continue
padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
padded_image_tensor[: current_images.size(0)] = current_images
attention_mask = torch.zeros((max_seq_len,), dtype=torch.long)
attention_mask[:unpadded_seq_len] = 1
output_images.append(padded_image_tensor)
output_input_ids.append(torch.tensor(current_input_ids))
output_num_text_tokens.append(unpadded_seq_len)
output_num_images.append(min(max_num_images, num_images))
output_attention_masks.append(attention_mask)
if len(output_images) == 0 or len(output_input_ids) == 0:
result = {
"input_ids": torch.tensor([], dtype=torch.long),
"attention_mask": torch.tensor([], dtype=torch.bool),
"image_attention_mask": torch.tensor([], dtype=torch.bool),
"next_image_attention_mask": torch.tensor([], dtype=torch.bool),
"num_images": torch.tensor([], dtype=torch.long),
"num_text_tokens": torch.tensor([], dtype=torch.long),
"pixel_values": torch.tensor([], dtype=torch.float32),
}
return result
output_input_ids = torch.stack(output_input_ids)
output_images = torch.stack(output_images)
output_attention_masks = torch.stack(output_attention_masks)
# We create two image attention masks: normal and next.
# In the normal one, a given text token can only attend to an image that precedes it
# In the next-attention_mask, a given text token can only attend to an image that follows it
# During training, only one of this image_attention_mask is fed to the model (as `image_attention_mask`),
# we flip a coin to decide which one and ditch the other.
image_attention_mask, next_image_attention_mask = image_attention_mask_for_packed_input_ids(
output_input_ids, tokenizer
)
image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images)
next_image_attention_mask = incremental_to_binary_attention_mask(
next_image_attention_mask, num_classes=max_num_images
)
result = {
"input_ids": output_input_ids,
"attention_mask": output_attention_masks,
"image_attention_mask": image_attention_mask,
"next_image_attention_mask": next_image_attention_mask,
"num_images": torch.tensor(output_num_images),
"num_text_tokens": torch.tensor(output_num_text_tokens),
"pixel_values": output_images
}
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment