Skip to content

Instantly share code, notes, and snippets.

@louis030195
Created May 28, 2023 09:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save louis030195/f273458e82b346f09ee4f1dfb83388d6 to your computer and use it in GitHub Desktop.
Save louis030195/f273458e82b346f09ee4f1dfb83388d6 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from models import imagebind_model
from models.imagebind_model import ModalityType
import data
class ImageBindGPTJ(nn.Module):
def __init__(self, imagebind):
super().__init__()
self.imagebind = imagebind
self.gptj = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
input_embedding_size = self.gptj.config.hidden_size
# HACK hardcoded output embedding dim
self.embedding_proj = nn.Linear(1024, input_embedding_size)
def forward(self, inputs, labels=None):
# Get the embeddings using ImageBind
with torch.no_grad():
embeddings = self.imagebind(inputs)
embeddings = torch.stack(list(embeddings.values())).mean(dim=0)
embeddings_proj = self.embedding_proj(embeddings)
input_token_tensors = torch.zeros(
(embeddings_proj.shape[0], 1, embeddings_proj.shape[1])
).to(embeddings_proj.device)
input_token_tensors[:, 0, :] = embeddings_proj
gptj_out = self.gptj(inputs_embeds=input_token_tensors, labels=labels)
return gptj_out.loss, gptj_out.logits
text_list = ["A dog.", "A car", "A bird"]
image_paths = ["client/dog_image.jpg", "client/car_image.jpg", "client/bird_image.jpg"]
audio_paths = ["client/dog_audio.wav", "client/car_audio.wav", "client/bird_audio.wav"]
device = "cpu"
pretrained_imagebind = imagebind_model.imagebind_huge(pretrained=True)
model = ImageBindGPTJ(pretrained_imagebind).to(device)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
inputs = {
ModalityType.TEXT: data.load_and_transform_text(text_list, device),
ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
}
generated_text = None
try:
while True:
loss, logits = model(inputs)
generated_ids = torch.argmax(logits, dim=-1)
generated_text = tokenizer.decode(generated_ids[0])
print(generated_text)
inputs[ModalityType.TEXT] = torch.cat(
(inputs[ModalityType.TEXT][:, 1:], generated_ids), dim=1
)
except KeyboardInterrupt:
print("Stopped generating text.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment