Skip to content

Instantly share code, notes, and snippets.

@seanie12
Created August 9, 2019 00:22
Show Gist options
  • Save seanie12/e840e7cb85cb334ee559de8eabc9723b to your computer and use it in GitHub Desktop.
Save seanie12/e840e7cb85cb334ee559de8eabc9723b to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_pretrained_bert import BertModel
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from torch_scatter import scatter_max
import numpy as np
from torch.distributions.categorical import Categorical
INF = 1e12
EOS_ID = 102
class CatEncoder(nn.Module):
def __init__(self, embedding_size, hidden_size,
num_vars, num_classes):
super(CatEncoder, self).__init__()
self.embedding = BertModel.from_pretrained("bert-base-uncased").embeddings
self.embedding.requires_grad = False
self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True,
bidirectional=True, num_layers=1)
self.recog_layer = nn.Linear(2 * hidden_size, num_vars * num_classes)
def forward(self, q_ids, q_len):
if q_ids.dim() == 2:
embedded = self.embedding(q_ids)
else:
embedded = self.get_embedding(q_ids)
packed = pack_padded_sequence(embedded, q_len,
batch_first=True,
enforce_sorted=False)
output, states = self.lstm(packed)
output, _ = pad_packed_sequence(output, batch_first=True)
hiddens = states[0] # [2, b, d]
_, b, d = hiddens.size()
concat_hidden = torch.cat([h for h in hiddens], dim=-1) # [b,2*d]
# logits for K categorical variables
qz_logits = self.recog_layer(concat_hidden)
return qz_logits
def get_embedding(self, vocab_dist):
# vocab_dist : [b,t,|V|]
batch_size, nsteps, _ = vocab_dist.size()
token_type_ids = torch.zeros((batch_size, nsteps), dtype=torch.long).to(vocab_dist.device)
position_ids = torch.arange(nsteps, dtype=torch.long, device=vocab_dist.device)
position_ids = position_ids.unsqueeze(0).repeat([batch_size, 1])
embedding_matrix = self.embedding.word_embeddings.weight
word_embeddings = torch.matmul(vocab_dist, embedding_matrix)
position_embeddings = self.embedding.position_embeddings(position_ids)
token_type_embeddings = self.embedding.token_type_embeddings(token_type_ids)
embeddings = word_embeddings + position_embeddings + token_type_embeddings
embeddings = self.embedding.LayerNorm(embeddings)
embeddings = self.embedding.dropout(embeddings)
return embeddings
class CatDecoder(nn.Module):
def __init__(self, vocab_size, embedding_size, hidden_size):
super(CatDecoder, self).__init__()
self.embedding = BertModel.from_pretrained("bert-base-uncased").embeddings
self.embedding.requires_grad = False
self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True,
num_layers=1)
self.logit_layer = nn.Linear(hidden_size, vocab_size)
def forward(self, q_ids, init_states):
batch_size, max_len = q_ids.size()
logits = []
states = init_states
for i in range(max_len):
q_i = q_ids[:, i]
embedded = self.embedding(q_i.unsqueeze(1))
hidden, states = self.lstm(embedded, states)
logit = self.logit_layer(hidden) # [b,1,|V|]
logits.append(logit)
logits = torch.cat(logits, dim=1)
return logits
def decode(self, sos_tokens, init_states, max_step):
inputs = sos_tokens
prev_states = init_states
decoded_ids = []
for i in range(max_step):
embedded = self.embedding(inputs.unsqueeze(1))
output, prev_states = self.lstm(embedded, prev_states)
logit = self.logit_layer(output).squeeze(1) # [b,|V|]
inputs = torch.argmax(logit, 1)
decoded_ids.append(inputs)
decoded_ids = torch.stack(decoded_ids, dim=1)
return decoded_ids
class CatVAE(nn.Module):
def __init__(self, vocab_size, embedding_size,
hidden_size, num_vars, num_classes):
super(CatVAE, self).__init__()
# Encoder-Decoder for question
self.encoder = CatEncoder(embedding_size, hidden_size,
num_vars, num_classes)
self.decoder = CatDecoder(vocab_size, embedding_size,
hidden_size)
self.linear_h = nn.Linear(num_vars * num_classes, hidden_size, bias=False)
self.linear_c = nn.Linear(num_vars * num_classes, hidden_size, bias=False)
self.num_vars = num_vars
self.num_classes = num_classes
def forward(self, q_ids):
sos_q_ids = q_ids[:, :-1]
eos_q_ids = q_ids[:, 1:]
q_len = torch.sum(torch.sign(eos_q_ids), 1)
qz_logits = self.encoder(eos_q_ids, q_len) # exclude [CLS]
flatten_logits = qz_logits.view(-1, self.num_classes)
# sample categorical variable by gumbel-softmax
z_samples = self.gumbel_softmax(flatten_logits, tau=1.0).view(-1, self.num_vars * self.num_classes)
init_h = self.linear_h(z_samples).unsqueeze(0) # [1,b,d]
init_c = self.linear_c(z_samples).unsqueeze(0) # [1,b,d]
init_states = (init_h, init_c)
criterion = nn.CrossEntropyLoss(ignore_index=0)
logits = self.q_decoder(sos_q_ids, init_states)
batch_size, nsteps, _ = logits.size()
preds = logits.view(batch_size * nsteps, -1)
targets = eos_q_ids.contiguous().view(-1)
nll = criterion(preds, targets)
# KL(q(z) || p(z)) p(z) ~ Uniform dist
log_qz = F.log_softmax(flatten_logits, dim=-1) # [b*num_vars, num_classes]
avg_log_qz = torch.exp(log_qz.view(-1, self.num_vars, self.num_classes))
avg_log_qz = torch.log(torch.mean(avg_log_qz, dim=0) + 1e-15) # [num_vars, num_classes]
log_uniform_z = torch.log(torch.ones(1, device=log_qz.device) / self.num_classes)
qz = torch.exp(avg_log_qz)
kl = torch.sum(qz * (avg_log_qz - log_uniform_z), dim=1)
avg_kl = kl.mean()
# mutual information H(Z) - H(Z|X)
mi = self.entropy(avg_log_qz) - self.entropy(log_qz)
return nll, avg_kl, mi
@staticmethod
def gumbel_softmax(logits, tau=1.0, eps=1e-20):
u = torch.rand_like(logits)
sample = -torch.log(-torch.log(u + eps) + eps)
y = logits + sample.to(logits.device)
return F.softmax(y / tau, dim=-1)
@staticmethod
def entropy(log_prob):
# log_prob: [b, K]
prob = torch.exp(log_prob)
h = torch.sum(-log_prob * prob, dim=1)
return h.mean()
class AnsEncoder(nn.Module):
def __init__(self, embedding_size, hidden_size, num_layers, dropout):
super(AnsEncoder, self).__init__()
self.embedding = BertModel.from_pretrained("bert-base-uncased").embeddings
self.embedding.requires_grad = False
self.num_layers = num_layers
if self.num_layers == 1:
dropout = 0.0
self.lstm = nn.LSTM(embedding_size, hidden_size, dropout=dropout,
num_layers=num_layers, bidirectional=True, batch_first=True)
def forward(self, ans_ids, ans_len):
embedded = self.embedding(ans_ids)
packed = pack_padded_sequence(embedded, ans_len, batch_first=True,
enforce_sorted=False)
_, states = self.lstm(packed)
h, c = states
_, b, d = h.size()
h = h.view(self.num_layers, 2, b, d) # [n_layers, bi, b, d]
h = torch.cat((h[:, 0, :, :], h[:, 1, :, :]), dim=-1)
c = c.view(self.num_layers, 2, b, d)
c = torch.cat((c[:, 0, :, :], c[:, 1, :, :]), dim=-1)
concat_states = (h, c)
return concat_states
class Encoder(nn.Module):
def __init__(self, embedding_size,
hidden_size, num_layers, dropout, use_tag):
super(Encoder, self).__init__()
self.use_tag = use_tag
self.num_layers = num_layers
# tag embedding
if use_tag:
self.tag_embedding = nn.Embedding(3, 3)
lstm_input_size = embedding_size + 3
else:
lstm_input_size = embedding_size
self.embedding = BertModel.from_pretrained("bert-base-uncased").embeddings
self.embedding.requires_grad = False
self.num_layers = num_layers
if self.num_layers == 1:
dropout = 0.0
self.lstm = nn.LSTM(lstm_input_size, hidden_size, dropout=dropout,
num_layers=num_layers, bidirectional=True, batch_first=True)
self.linear_trans = nn.Linear(2 * hidden_size, 2 * hidden_size, bias=False)
self.update_layer = nn.Linear(4 * hidden_size, 2 * hidden_size, bias=False)
self.gate = nn.Linear(4 * hidden_size, 2 * hidden_size, bias=False)
def gated_self_attn(self, queries, memories, mask):
# queries: [b,t,d]
# memories: [b,t,d]
# mask: [b,t]
energies = torch.matmul(queries, memories.transpose(1, 2)) # [b, t, t]
energies = energies.masked_fill(mask.unsqueeze(1), value=-1e12)
scores = F.softmax(energies, dim=2)
context = torch.matmul(scores, queries)
inputs = torch.cat((queries, context), dim=2)
f_t = torch.tanh(self.update_layer(inputs))
g_t = torch.sigmoid(self.gate(inputs))
updated_output = g_t * f_t + (1 - g_t) * queries
return updated_output
def forward(self, src_seq, src_len, tag_seq):
total_length = src_seq.size(1)
embedded = self.embedding(src_seq)
if self.use_tag and tag_seq is not None:
tag_embedded = self.tag_embedding(tag_seq)
embedded = torch.cat((embedded, tag_embedded), dim=2)
packed = pack_padded_sequence(embedded, src_len, batch_first=True, enforce_sorted=False)
self.lstm.flatten_parameters()
outputs, states = self.lstm(packed) # states : tuple of [4, b, d]
outputs, _ = pad_packed_sequence(outputs, batch_first=True,
total_length=total_length) # [b, t, d]
h, c = states
# self attention
zeros = outputs.sum(dim=-1)
mask = (zeros == 0).byte()
memories = self.linear_trans(outputs)
outputs = self.gated_self_attn(outputs, memories, mask)
_, b, d = h.size()
h = h.view(self.num_layers, 2, b, d) # [n_layers, bi, b, d]
h = torch.cat((h[:, 0, :, :], h[:, 1, :, :]), dim=-1)
c = c.view(self.num_layers, 2, b, d)
c = torch.cat((c[:, 0, :, :], c[:, 1, :, :]), dim=-1)
concat_states = (h, c)
return outputs, concat_states
class Decoder(nn.Module):
def __init__(self, embedding_size, vocab_size, enc_size,
hidden_size, num_layers, dropout,
pointer=False):
super(Decoder, self).__init__()
self.vocab_size = vocab_size
self.embedding = BertModel.from_pretrained("bert-base-uncased").embeddings
self.embedding.requires_grad = False
if num_layers == 1:
dropout = 0.0
self.encoder_trans = nn.Linear(enc_size, hidden_size)
self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True,
num_layers=num_layers, bidirectional=False, dropout=dropout)
self.concat_layer = nn.Linear(2 * hidden_size, hidden_size, bias=False)
self.logit_layer = nn.Linear(hidden_size, vocab_size)
self.pointer = pointer
@staticmethod
def attention(query, memories, mask):
# query : [b, 1, d]
energy = torch.matmul(query, memories.transpose(1, 2)) # [b, 1, t]
energy = energy.squeeze(1).masked_fill(mask, value=-1e12)
attn_dist = F.softmax(energy, dim=1).unsqueeze(dim=1) # [b, 1, t]
context_vector = torch.matmul(attn_dist, memories) # [b, 1, d]
return context_vector, energy
def get_encoder_features(self, encoder_outputs):
return self.encoder_trans(encoder_outputs)
def forward(self, q_ids, c_ids, init_states, encoder_outputs, enc_mask):
# q_ids : [b,t]
# z_samples: [b, M*K]
# init_states : [2,b,d]
# encoder_outputs : [b,t,d]
# init_states : a tuple of [2, b, d]
batch_size, max_len = q_ids.size()
memories = self.get_encoder_features(encoder_outputs)
logits = []
prev_states = init_states
self.lstm.flatten_parameters()
for i in range(max_len):
y_i = q_ids[:, i].unsqueeze(dim=1)
embedded = self.embedding(y_i)
hidden, states = self.lstm(embedded, prev_states)
# encoder-decoder attention
context, energy = self.attention(hidden, memories, enc_mask)
concat_input = torch.cat((hidden, context), dim=2).squeeze(dim=1)
logit_input = torch.tanh(self.concat_layer(concat_input))
logit = self.logit_layer(logit_input) # [b, |V|]
# maxout pointer network
if self.pointer:
num_oov = max(torch.max(c_ids - self.vocab_size + 1), 0)
zeros = torch.zeros((batch_size, num_oov), device=logit.device)
extended_logit = torch.cat((logit, zeros), dim=1)
out = torch.zeros_like(extended_logit) - INF
out, _ = scatter_max(energy, c_ids, out=out)
out = out.masked_fill(out == -INF, 0)
logit = extended_logit + out
logit = logit.masked_fill(logit == 0, -INF)
logits.append(logit)
# update prev state and context
prev_states = states
logits = torch.stack(logits, dim=1) # [b, t, |V|]
return logits
def decode(self, q_id, c_ids, prev_states, memories, enc_mask):
self.lstm.flatten_parameters()
embedded = self.embedding(q_id.unsqueeze(1))
batch_size = c_ids.size(0)
hidden, states = self.lstm(embedded, prev_states)
# attention
context, energy = self.attention(hidden, memories, enc_mask)
concat_input = torch.cat((hidden, context), dim=2).squeeze(dim=1)
logit_input = torch.tanh(self.concat_layer(concat_input))
logit = self.logit_layer(logit_input) # [b, |V|]
if self.pointer:
num_oov = max(torch.max(c_ids - self.vocab_size + 1), 0)
zeros = torch.zeros((batch_size, num_oov), device=logit.device)
extended_logit = torch.cat((logit, zeros), dim=1)
out = torch.zeros_like(extended_logit) - INF
out, _ = scatter_max(energy, c_ids, out=out)
out = out.masked_fill(out == -INF, 0)
logit = extended_logit + out
logit = logit.masked_fill(logit == 0, -INF)
return logit, states
class AttnQG(nn.Module):
def __init__(self, vocab_size, embedding_size, hidden_size, vae_hidden_size, num_layers,
dropout, num_vars, num_classes, save_path, pointer=False):
super(AttnQG, self).__init__()
self.num_vars = num_vars
self.num_classes = num_classes
self.vae_net = CatVAE(vocab_size,
embedding_size,
vae_hidden_size,
num_vars,
num_classes)
if save_path is not None:
state_dict = torch.load(save_path, map_location="cpu")
self.vae_net.load_state_dict(state_dict)
self.encoder = Encoder(embedding_size, hidden_size, num_layers, dropout, use_tag=True)
self.ans_encoder = AnsEncoder(embedding_size, hidden_size, num_layers, dropout)
self.decoder = Decoder(embedding_size, vocab_size, 2 * hidden_size, 2 * hidden_size,
num_layers, dropout, pointer)
self.linear_trans = nn.Linear(4 * hidden_size, 4 * hidden_size)
self.prior_net = nn.Linear(4 * hidden_size, num_vars * num_classes)
self.linear_h = nn.Linear(num_vars * num_classes, 2 * hidden_size)
self.linear_c = nn.Linear(num_vars * num_classes, 2 * hidden_size)
def forward(self, c_ids, tag_ids, q_ids, ans_ids, use_prior=False):
sos_q_ids = q_ids[:, :-1]
eos_q_ids = q_ids[:, 1:]
# sample z
with torch.no_grad():
q_len = torch.sum(torch.sign(eos_q_ids), 1)
qz_logits = self.vae_net.encoder(eos_q_ids, q_len)
flatten_logits = qz_logits.view(-1, self.num_classes)
# encode passage
c_len = torch.sum(torch.sign(c_ids), 1)
enc_outputs, states = self.encoder(c_ids, c_len, tag_ids)
last_c_hidden = states[0][-1]
# encode answer
ans_len = torch.sum(torch.sign(ans_ids), 1)
ans_states = self.ans_encoder(ans_ids, ans_len)
last_ans_states = ans_states[0][-1]
last_hidden = torch.cat([last_c_hidden, last_ans_states], -1)
# compute prior
prior_logits = self.prior_net(torch.relu(self.linear_trans(last_hidden)))
log_pz = F.log_softmax(prior_logits.view(-1, self.num_classes), dim=-1)
log_qz = F.log_softmax(flatten_logits, dim=-1).detach()
if use_prior:
probs = torch.exp(log_pz)
else:
# z_samples = self.vae_net.gumbel_softmax(flatten_logits, tau)
# z_samples = z_samples.view(-1, self.num_vars * self.num_classes)
probs = F.softmax(flatten_logits, dim=1)
m = Categorical(probs)
z_samples = m.sample()
z_samples = F.one_hot(z_samples, num_classes=self.num_classes)
z_samples = z_samples.view(-1, self.num_vars * self.num_classes).float()
init_h = self.linear_h(z_samples.detach())
init_c = self.linear_c(z_samples.detach())
init_h = init_h + states[0]
init_c = init_c + states[1]
new_states = (init_h, init_c)
# KL(q(z|x) || p(z|c))
qz = torch.exp(log_qz)
prior_kl = torch.sum(qz * (log_qz - log_pz), dim=1)
prior_kl = prior_kl.mean()
c_mask = (c_ids == 0).byte()
logits = self.decoder(sos_q_ids, c_ids, new_states, enc_outputs, c_mask)
# \hat{x} ~ p(x|c,z)
decoded_ids = torch.argmax(logits, dim=-1)
seq_len = self.get_seq_len(decoded_ids, EOS_ID)
vocab_dist = F.softmax(logits, dim=-1)
# z ~ p(z|\hat{x})
z_logits = self.vae_net.encoder(vocab_dist, seq_len)
flatten_z_logits = z_logits.view(-1, self.num_classes)
true_z = z_samples.view(-1, self.num_classes)
true_z = torch.argmax(true_z, dim=-1)
aux_criterion = nn.CrossEntropyLoss()
aux_loss = aux_criterion(flatten_z_logits, true_z)
batch_size, nsteps, _ = logits.size()
criterion = nn.CrossEntropyLoss(ignore_index=0)
preds = logits.view(batch_size * nsteps, -1)
targets = eos_q_ids.contiguous().view(-1)
nll = criterion(preds, targets)
return nll, prior_kl, aux_loss
def generate(self, c_ids, tag_ids, ans_ids, sos_ids, max_step):
# encode context
c_len = torch.sum(torch.sign(c_ids), 1)
enc_outputs, states = self.encoder(c_ids, c_len, tag_ids)
c_mask = (c_ids == 0).byte()
last_c_hidden = states[0][-1]
# encode answer
ans_len = torch.sum(torch.sign(ans_ids), 1)
ans_states = self.ans_encoder(ans_ids, ans_len)
last_ans_hidden = ans_states[0][-1]
last_hidden = torch.cat([last_c_hidden, last_ans_hidden], -1)
# sample z from prior distribution
prior_logits = self.prior_net(torch.relu(self.linear_trans(last_hidden)))
prior_logits = prior_logits.view(-1, self.num_classes)
prior_prob = F.softmax(prior_logits, dim=-1) # [b*num_vars, num_classes]
m = Categorical(prior_prob)
z_samples = m.sample() # [b*num_vars]
# one-hot vector
z_samples = F.one_hot(z_samples, num_classes=self.num_classes)
z_samples = z_samples.view(-1, self.num_vars * self.num_classes).float()
init_h = self.linear_h(z_samples)
init_c = self.linear_c(z_samples)
new_h = init_h + states[0]
new_c = init_c + states[1]
new_states = (new_h, new_c)
inputs = sos_ids
prev_states = new_states
decoded_ids = []
memories = self.decoder.get_encoder_features(enc_outputs)
for _ in range(max_step):
logit, prev_states = self.decoder.decode(inputs, c_ids,
prev_states, memories,
c_mask)
inputs = torch.argmax(logit, 1) # [b]
decoded_ids.append(inputs)
decoded_ids = torch.stack(decoded_ids, 1)
return decoded_ids
@staticmethod
def get_seq_len(input_ids, eos_id):
# input_ids: [b, t]
# eos_id : scalar
mask = input_ids == eos_id
num_eos = torch.sum(mask, 1)
# change Tensor to cpu because torch.argmax works differently in cuda and cpu
# but np.argmax is consistent it returns the first index of the maximum element
mask = mask.cpu().numpy()
indices = np.argmax(mask, 1)
# convert numpy array to Tensor
seq_len = torch.LongTensor(indices).to(input_ids.device)
# in case there is no eos in the sequence
max_len = input_ids.size(1)
seq_len = seq_len.masked_fill(num_eos == 0, max_len - 1)
# +1 for eos
seq_len = seq_len + 1
return seq_len
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment