Skip to content

Instantly share code, notes, and snippets.

@Akash-Rawat
Last active July 2, 2021 10:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Akash-Rawat/ddfa566a301dbca1c4dc105fbbaffba2 to your computer and use it in GitHub Desktop.
Save Akash-Rawat/ddfa566a301dbca1c4dc105fbbaffba2 to your computer and use it in GitHub Desktop.
Defining Captioner
class CaptionRNN(nn.Module):
CAPTION_LIMIT = MAX_CAPTION_LEN
def __init__(self, input_size, vocab_size, embedding_size, hidden_size, stop_index):
super().__init__()
self.mlp_l1 = nn.Sequential(
nn.Linear(in_features=input_size, out_features=input_size),
nn.LeakyReLU(),
nn.Linear(in_features=input_size, out_features=hidden_size),
nn.Tanh()
)
self.mlp_l2 = nn.Sequential(
nn.Linear(in_features=hidden_size, out_features=hidden_size),
nn.LeakyReLU(),
nn.Linear(in_features=hidden_size, out_features=vocab_size),
)
self.gru = nn.GRU(input_size=embedding_size, hidden_size=hidden_size, batch_first=True)
self.embedding = nn.Embedding(vocab_size, embedding_size)
self.stop_index = stop_index
def generate_caption(self, code):
h_1 = self.mlp_l1(code)
prob_1 = F.softmax(self.mlp_l2(h_1), dim=-1)
y_1 = torch.multinomial(prob_1, 1)
words = [y_1.item()]
w_t = self.embedding(y_1)
y_t = y_1
h_t = h_1
while len(words) < CaptionRNN.CAPTION_LIMIT and y_t.item() != self.stop_index:
h_t = self.gru(w_t.unsqueeze(0), h_t.unsqueeze(0).unsqueeze(0))[0]
h_t = h_t.squeeze(0).squeeze(0)
prob_t = F.softmax(self.mlp_l2(h_t), dim=-1)
y_t = torch.multinomial(prob_t, 1)
words.append(y_t.item())
w_t = self.embedding(y_t)
return words
def caption_prob(self, code, caption):
hidden_1 = self.mlp_l1(code)
probs_1 = F.softmax(self.mlp_l2(hidden_1), dim=1)
weights = self.embedding(caption)
output, hidden = self.gru(weights, hidden_1.unsqueeze(0))
probs_2_above = F.softmax(self.mlp_l2(output[:, :-1]), dim=-1)
return torch.cat([probs_1.unsqueeze(1), probs_2_above], dim=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment