Last active
December 25, 2018 06:56
-
-
Save parksunwoo/2da126f7223a6367ea0b79790e973c1b to your computer and use it in GitHub Desktop.
show_attend_tell_model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Source code for an attention based image caption generation system described | |
in: | |
Show, Attend and Tell: Neural Image Caption Generation with Visual Attention | |
International Conference for Machine Learning (2015) | |
http://arxiv.org/abs/1502.03044 | |
''' | |
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
MAX_LENGTH = 10 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class EncoderCNN(nn.Module): | |
def __init__(self, encoded_image_size=14): | |
super(EncoderCNN, self).__init__() | |
resnet = models.resnet101(pretrained=True) | |
# 마지막 linear와 pool layer 삭제, 훈련된 모델을 사용하는게 목적이라 분류하는 부분은 필요없음 | |
modules = list(resnet.children())[:-2] | |
self.resnet = nn.Sequential(*modules) | |
# 이미지를 고정된 사이즈로 리사이징하는 부분, | |
self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size)) | |
self.fine_tune() | |
def forward(self, images): | |
out = self.resnet(images) | |
out = self.adaptive_pool(out) | |
out = out.permute(0, 2, 3, 1) | |
return out | |
def fine_tune(self, fine_tune=True): | |
for p in self.resnet.parameters(): | |
p.requires_grad = False | |
# fine-tuning 한다면 conv 블록중 2-4번째만, 첫번째 conv 블록은 이미지 가공중에서 가장 기초적인 단계라 | |
# 선과 각도, 곡선을 찾음. 기초 단계에선 fine-tuning을 하지 않기위함 | |
for c in list(self.resnet.children())[5:]: | |
for p in c.parameters(): | |
p.requires_grad = fine_tune | |
class Attention(nn.Module): | |
def __init__(self, encoder_dim, decoder_dim, attention_dim): | |
super(Attention, self).__init__() | |
self.encoder_att = nn.Linear(encoder_dim, attention_dim) # encoded 이미지를 변환하기 위한 linear layer | |
self.decoder_att = nn.Linear(decoder_dim, attention_dim) # decodoer 출력값을 변환하기 위한 linear layer | |
self.full_att = nn.Linear(attention_dim, 1) # 소프트맥스 함수에 들어갈 값을 계산하기 위한 linear layer | |
self.relu = nn.ReLU() | |
self.softmax = nn.Softmax(dim=1) # 가중치를 계산하기 위한 소프트맥스 layer | |
def forward(self, encoder_out, decoder_hidden): | |
""" | |
Forward propagation. | |
:param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) | |
:param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim) | |
:return: attention weighted encoding, weights | |
""" | |
att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim) | |
att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim) | |
att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels) | |
alpha = self.softmax(att) # (batch_size, num_pixels) | |
attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_dim) | |
return attention_weighted_encoding, alpha | |
class AttnDecoderRNN(nn.Module): | |
def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5): | |
super(AttnDecoderRNN, self).__init__() | |
self.encoder_dim = encoder_dim | |
self.attention_dim = attention_dim | |
self.embed_dim = embed_dim | |
self.decoder_dim = decoder_dim | |
self.vocab_size = vocab_size | |
self.dropout = dropout | |
self.attention = Attention(encoder_dim, decoder_dim, attention_dim) | |
self.embedding = nn.Embedding(vocab_size, embed_dim) | |
self.dropout = nn.Dropout(p=self.dropout) | |
self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True) # LSTMCell 디코딩 | |
self.init_h = nn.Linear(encoder_dim, decoder_dim) # LSTM 의 초기 hidden state를 찾기위한 linear layer | |
self.init_c = nn.Linear(encoder_dim, decoder_dim) # LSTM 의 초기 cell state를 찾기위한 linear layer | |
self.f_beta = nn.Linear(decoder_dim, encoder_dim) # 시그모이드 활성화 게이트를 만들기 위한 linear layer | |
self.sigmoid = nn.Sigmoid() | |
self.fc = nn.Linear(decoder_dim, vocab_size) | |
self.init_weights() | |
def init_weights(self): | |
self.embedding.weight.data.uniform_(-0.1, 0.1) | |
self.fc.bias.data.fill_(0) | |
self.fc.weight.data.uniform_(-0.1, 0.1) | |
def load_pretrained_embeddings(self, embeddings): | |
self.embedding.weight = nn.Parameter(embeddings) | |
def fine_tune_embeddings(self, fine_tune=True): | |
for p in self.embedding.parameters(): | |
p.requires_grad = fine_tune | |
def init_hidden_state(self, encoder_out): | |
mean_encoder_out = encoder_out.mean(dim=1) | |
h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim) | |
c = self.init_c(mean_encoder_out) | |
return h, c | |
def forward(self, encoder_out, encoded_captions, caption_lengths): | |
""" | |
Forward propagation. | |
:param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim) | |
:param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length) | |
:param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1) | |
:return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices | |
""" | |
batch_size = encoder_out.size(0) | |
encoder_dim = encoder_out.size(-1) | |
vocab_size = self.vocab_size | |
encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) | |
num_pixels = encoder_out.size(1) | |
embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim) | |
h, c = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim) | |
# <end> 위치에선 decode를 하지않으므로 decoding 길이는 실제길이 -1 이 | |
decode_lengths = [c-1 for c in caption_lengths] | |
predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device) | |
alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device) | |
# 각 단계마다 디코더의 이전 hidden state 출력을 기반으로 인코더의 출력에 어텐션 가중치가 부여되고 | |
# 그 다음 이전 단어와 어텐션 가중치 인코딩을 사용해서 디코더에 새 단어를 생성합니 | |
for t in range(max(decode_lengths)): | |
batch_size_t = sum([l > t for l in decode_lengths ]) | |
attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t]) | |
gate = self.sigmoid(self.f_beta(h[:batch_size_t])) | |
attention_weighted_encoding = gate * attention_weighted_encoding | |
h, c = self.decode_step( | |
torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1), | |
(h[:batch_size_t], c[:batch_size_t])) | |
preds = self.fc(self.dropout(h)) | |
predictions[:batch_size_t, t, :] = preds | |
alphas[:batch_size_t, t, :] = alpha | |
return predictions, encoded_captions, decode_lengths, alphas |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment