Skip to content

Instantly share code, notes, and snippets.

@halflearned
Created May 29, 2024 21:36
Show Gist options
  • Save halflearned/95e52981a16742dd94416db334072093 to your computer and use it in GitHub Desktop.
Save halflearned/95e52981a16742dd94416db334072093 to your computer and use it in GitHub Desktop.
Partial code to interleave image embeddings into text embeddings
def interleave_image_embeddings(text_embeddings, image_embeddings, tensor_indices, image_token_positions):
# Transform tensor into list of tensors, as they'll temporarily have different lengths
text_embeddings = [tensor for tensor in text_embeddings]
for i in range(len(text_embeddings)):
# Find the indices of the image tokens in the current tensor
idx = tensor_indices == i
if not idx.any():
continue
# Interleave the image embeddings between text embeddings
text_embeddings[i] = interleave_tensors_2d(text_embeddings[i], image_token_positions[idx], image_embeddings[idx])
# Pad and concatenate the tensors
text_embeddings = torch.nn.utils.rnn.pad_sequence(text_embeddings, batch_first=True, padding_value=IGNORE_INDEX)
return text_embeddings
def interleave_tensors_2d(base_tensor, insertion_indices, insertion_tensors):
if base_tensor is None:
return
offset = 0
for idx, tensor in zip(insertion_indices, insertion_tensors):
# offset the original indices by the size of the tensors that have already been inserted
idx = idx + offset
# insert the tensor from the list into the result tensor at the specified position
base_tensor = torch.cat([base_tensor[:idx], tensor, base_tensor[idx+1:]], dim=0)
# update the amount to be offset
offset += tensor.size(0) - 1
return base_tensor
# Find the positions of the image tokens id (e.g., '<image>') in the input
tensor_indices, image_token_positions = torch.nonzero(input_ids == image_token_id, as_tuple=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment