Created
May 29, 2024 21:36
-
-
Save halflearned/95e52981a16742dd94416db334072093 to your computer and use it in GitHub Desktop.
Partial code to interleave image embeddings into text embeddings
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def interleave_image_embeddings(text_embeddings, image_embeddings, tensor_indices, image_token_positions): | |
# Transform tensor into list of tensors, as they'll temporarily have different lengths | |
text_embeddings = [tensor for tensor in text_embeddings] | |
for i in range(len(text_embeddings)): | |
# Find the indices of the image tokens in the current tensor | |
idx = tensor_indices == i | |
if not idx.any(): | |
continue | |
# Interleave the image embeddings between text embeddings | |
text_embeddings[i] = interleave_tensors_2d(text_embeddings[i], image_token_positions[idx], image_embeddings[idx]) | |
# Pad and concatenate the tensors | |
text_embeddings = torch.nn.utils.rnn.pad_sequence(text_embeddings, batch_first=True, padding_value=IGNORE_INDEX) | |
return text_embeddings | |
def interleave_tensors_2d(base_tensor, insertion_indices, insertion_tensors): | |
if base_tensor is None: | |
return | |
offset = 0 | |
for idx, tensor in zip(insertion_indices, insertion_tensors): | |
# offset the original indices by the size of the tensors that have already been inserted | |
idx = idx + offset | |
# insert the tensor from the list into the result tensor at the specified position | |
base_tensor = torch.cat([base_tensor[:idx], tensor, base_tensor[idx+1:]], dim=0) | |
# update the amount to be offset | |
offset += tensor.size(0) - 1 | |
return base_tensor | |
# Find the positions of the image tokens id (e.g., '<image>') in the input | |
tensor_indices, image_token_positions = torch.nonzero(input_ids == image_token_id, as_tuple=True) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment