Skip to content

Instantly share code, notes, and snippets.

@Akash-Rawat
Created July 2, 2021 10:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Akash-Rawat/93c29ae8818d88e54d3e3dae7fea42af to your computer and use it in GitHub Desktop.
Save Akash-Rawat/93c29ae8818d88e54d3e3dae7fea42af to your computer and use it in GitHub Desktop.
Defining VAECaptioner Model
class VAECaptioner(nn.Module):
def __init__(self, in_channel, code_channels, image_dim, vocab):
super().__init__()
LATENT_DIM = 300
EMBEDDING_SIZE = 600
HIDDEN_SIZE = 512
CODE_FLAT = code_channels*((image_dim[0]*image_dim[1])//(POOLING_FACTOR**2))
self.vocab = vocab
self.encoder = Encoder(in_channel, code_channels, image_dim, LATENT_DIM)
self.decoder = Decoder(code_channels, in_channel, image_dim)
self.captionr = CaptionRNN(CODE_FLAT, len(vocab), EMBEDDING_SIZE, HIDDEN_SIZE, vocab["."])
def forward(self, x, y):
c, indices_1, indices_2, indices_3, mean, log_std = self.encoder(x)
reconstructed = self.decoder(c, indices_1, indices_2, indices_3)
caption_prob = self.captionr.caption_prob(c, y)
return reconstructed, caption_prob, mean, log_std
def generate_caption(self, x):
c, indices_1, indices_2, indices_3, mean, log_std = self.encoder(x)
return self.captionr.generate_caption(c[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment