Skip to content

Instantly share code, notes, and snippets.

@parksunwoo
Last active December 25, 2018 06:56
Show Gist options
  • Save parksunwoo/2da126f7223a6367ea0b79790e973c1b to your computer and use it in GitHub Desktop.
Save parksunwoo/2da126f7223a6367ea0b79790e973c1b to your computer and use it in GitHub Desktop.
show_attend_tell_model
'''
Source code for an attention based image caption generation system described
in:
Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
International Conference for Machine Learning (2015)
http://arxiv.org/abs/1502.03044
'''
import torch
import torch.nn as nn
import torchvision.models as models
MAX_LENGTH = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EncoderCNN(nn.Module):
def __init__(self, encoded_image_size=14):
super(EncoderCNN, self).__init__()
resnet = models.resnet101(pretrained=True)
# 마지막 linear와 pool layer 삭제, 훈련된 모델을 사용하는게 목적이라 분류하는 부분은 필요없음
modules = list(resnet.children())[:-2]
self.resnet = nn.Sequential(*modules)
# 이미지를 고정된 사이즈로 리사이징하는 부분,
self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
self.fine_tune()
def forward(self, images):
out = self.resnet(images)
out = self.adaptive_pool(out)
out = out.permute(0, 2, 3, 1)
return out
def fine_tune(self, fine_tune=True):
for p in self.resnet.parameters():
p.requires_grad = False
# fine-tuning 한다면 conv 블록중 2-4번째만, 첫번째 conv 블록은 이미지 가공중에서 가장 기초적인 단계라
# 선과 각도, 곡선을 찾음. 기초 단계에선 fine-tuning을 하지 않기위함
for c in list(self.resnet.children())[5:]:
for p in c.parameters():
p.requires_grad = fine_tune
class Attention(nn.Module):
def __init__(self, encoder_dim, decoder_dim, attention_dim):
super(Attention, self).__init__()
self.encoder_att = nn.Linear(encoder_dim, attention_dim) # encoded 이미지를 변환하기 위한 linear layer
self.decoder_att = nn.Linear(decoder_dim, attention_dim) # decodoer 출력값을 변환하기 위한 linear layer
self.full_att = nn.Linear(attention_dim, 1) # 소프트맥스 함수에 들어갈 값을 계산하기 위한 linear layer
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1) # 가중치를 계산하기 위한 소프트맥스 layer
def forward(self, encoder_out, decoder_hidden):
"""
Forward propagation.
:param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
:param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
:return: attention weighted encoding, weights
"""
att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim)
att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim)
att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels)
alpha = self.softmax(att) # (batch_size, num_pixels)
attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_dim)
return attention_weighted_encoding, alpha
class AttnDecoderRNN(nn.Module):
def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
super(AttnDecoderRNN, self).__init__()
self.encoder_dim = encoder_dim
self.attention_dim = attention_dim
self.embed_dim = embed_dim
self.decoder_dim = decoder_dim
self.vocab_size = vocab_size
self.dropout = dropout
self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.dropout = nn.Dropout(p=self.dropout)
self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True) # LSTMCell 디코딩
self.init_h = nn.Linear(encoder_dim, decoder_dim) # LSTM 의 초기 hidden state를 찾기위한 linear layer
self.init_c = nn.Linear(encoder_dim, decoder_dim) # LSTM 의 초기 cell state를 찾기위한 linear layer
self.f_beta = nn.Linear(decoder_dim, encoder_dim) # 시그모이드 활성화 게이트를 만들기 위한 linear layer
self.sigmoid = nn.Sigmoid()
self.fc = nn.Linear(decoder_dim, vocab_size)
self.init_weights()
def init_weights(self):
self.embedding.weight.data.uniform_(-0.1, 0.1)
self.fc.bias.data.fill_(0)
self.fc.weight.data.uniform_(-0.1, 0.1)
def load_pretrained_embeddings(self, embeddings):
self.embedding.weight = nn.Parameter(embeddings)
def fine_tune_embeddings(self, fine_tune=True):
for p in self.embedding.parameters():
p.requires_grad = fine_tune
def init_hidden_state(self, encoder_out):
mean_encoder_out = encoder_out.mean(dim=1)
h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim)
c = self.init_c(mean_encoder_out)
return h, c
def forward(self, encoder_out, encoded_captions, caption_lengths):
"""
Forward propagation.
:param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
:param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
:param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
:return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
"""
batch_size = encoder_out.size(0)
encoder_dim = encoder_out.size(-1)
vocab_size = self.vocab_size
encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim)
num_pixels = encoder_out.size(1)
embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim)
h, c = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim)
# <end> 위치에선 decode를 하지않으므로 decoding 길이는 실제길이 -1 이
decode_lengths = [c-1 for c in caption_lengths]
predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)
# 각 단계마다 디코더의 이전 hidden state 출력을 기반으로 인코더의 출력에 어텐션 가중치가 부여되고
# 그 다음 이전 단어와 어텐션 가중치 인코딩을 사용해서 디코더에 새 단어를 생성합니
for t in range(max(decode_lengths)):
batch_size_t = sum([l > t for l in decode_lengths ])
attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])
gate = self.sigmoid(self.f_beta(h[:batch_size_t]))
attention_weighted_encoding = gate * attention_weighted_encoding
h, c = self.decode_step(
torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
(h[:batch_size_t], c[:batch_size_t]))
preds = self.fc(self.dropout(h))
predictions[:batch_size_t, t, :] = preds
alphas[:batch_size_t, t, :] = alpha
return predictions, encoded_captions, decode_lengths, alphas
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment