Skip to content

Instantly share code, notes, and snippets.

@seanie12
Created February 15, 2020 07:23
Show Gist options
  • Save seanie12/6835a3fef0bb1dd1b9590cb78e84b6b7 to your computer and use it in GitHub Desktop.
Save seanie12/6835a3fef0bb1dd1b9590cb78e84b6b7 to your computer and use it in GitHub Desktop.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch_scatter import scatter_max
from pytorch_transformers import BertModel, BertTokenizer
def return_mask_lengths(ids):
mask = torch.sign(ids).float()
lengths = mask.sum(dim=1).long()
return mask, lengths
def return_num(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return params
def cal_attn(left, right, mask):
mask = (1.0 - mask.float()) * -10000.0
attn_logits = torch.matmul(left, right.transpose(-1, -2).contiguous())
attn_logits = attn_logits + mask
attn_weights = F.softmax(input=attn_logits, dim=-1)
attn_outputs = torch.matmul(attn_weights, right)
return attn_outputs, attn_logits
def gumbel_softmax(logits, tau=1, hard=False, eps=1e-20, dim=-1):
# type: (Tensor, float, bool, float, int) -> Tensor
gumbels = -(torch.empty_like(logits).exponential_() + eps).log() # ~Gumbel(0,1)
gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
y_soft = gumbels.softmax(dim)
if hard:
# Straight through.
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
# Re-parametrization trick.
ret = y_soft
return ret
class CategoricalKLLoss(nn.Module):
def __init__(self):
super(CategoricalKLLoss, self).__init__()
def forward(self, P, Q):
log_P = P.log()
log_Q = Q.log()
kl = (P * (log_P - log_Q)).sum(dim=-1).sum(dim=-1)
return kl.mean(dim=0)
class GaussianKLLoss(nn.Module):
def __init__(self):
super(GaussianKLLoss, self).__init__()
def forward(self, mu1, logvar1, mu2, logvar2):
numerator = logvar1.exp() + torch.pow(mu1 - mu2, 2)
fraction = torch.div(numerator, (logvar2.exp()))
kl = 0.5 * torch.sum(logvar2 - logvar1 + fraction - 1, dim=1)
return kl.mean(dim=0)
class Embedding(nn.Module):
def __init__(self, bert_model):
super(Embedding, self).__init__()
bert_embeddings = BertModel.from_pretrained(bert_model).embeddings
self.word_embeddings = bert_embeddings.word_embeddings
self.token_type_embeddings = bert_embeddings.token_type_embeddings
self.position_embeddings = bert_embeddings.position_embeddings
self.LayerNorm = bert_embeddings.LayerNorm
self.dropout = bert_embeddings.dropout
def forward(self, input_ids, token_type_ids=None, position_ids=None):
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
if position_ids is None:
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
words_embeddings = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + token_type_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class ContextualizedEmbedding(nn.Module):
def __init__(self, bert_model):
super(ContextualizedEmbedding, self).__init__()
bert = BertModel.from_pretrained(bert_model)
self.embedding = bert.embeddings
self.encoder = bert.encoder
self.num_hidden_layers = bert.config.num_hidden_layers
def forward(self, input_ids, attention_mask, token_type_ids=None):
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2).float()
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
head_mask = [None] * self.num_hidden_layers
embedding_output = self.embedding(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
encoder_outputs = self.encoder(embedding_output,
extended_attention_mask,
head_mask=head_mask)
sequence_output = encoder_outputs[0]
return sequence_output
class CustomLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, dropout, bidirectional=False):
super(CustomLSTM, self).__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
self.bidirectional = bidirectional
self.dropout = nn.Dropout(dropout)
if dropout > 0.0 and num_layers == 1:
dropout = 0.0
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
num_layers=num_layers, dropout=dropout,
bidirectional=bidirectional, batch_first=True)
def forward(self, input, input_lengths, state=None):
batch_size, total_length, _ = input.size()
input_packed = pack_padded_sequence(input, input_lengths,
batch_first=True, enforce_sorted=False)
self.lstm.flatten_parameters()
output_packed, state = self.lstm(input_packed, state)
output = pad_packed_sequence(output_packed, batch_first=True, total_length=total_length)[0]
output = self.dropout(output)
return output, state
class PosteriorEncoder(nn.Module):
def __init__(self, embedding, emsize,
nhidden, nlayers,
nzqdim, nza, nzadim,
dropout=0.0):
super(PosteriorEncoder, self).__init__()
self.embedding = embedding
self.nhidden = nhidden
self.nlayers = nlayers
self.nzqdim = nzqdim
self.nza = nza
self.nzadim = nzadim
self.question_encoder = CustomLSTM(input_size=emsize,
hidden_size=nhidden,
num_layers=nlayers,
dropout=dropout,
bidirectional=True)
self.context_encoder = CustomLSTM(input_size=emsize,
hidden_size=nhidden,
num_layers=nlayers,
dropout=dropout,
bidirectional=True)
self.context_answer_encoder = CustomLSTM(input_size=emsize,
hidden_size=nhidden,
num_layers=nlayers,
dropout=dropout,
bidirectional=True)
self.question_attention = nn.Linear(2 * nhidden, 2 * nhidden)
self.context_attention = nn.Linear(2 * nhidden, 2 * nhidden)
self.zq_attention = nn.Linear(nzqdim, 2 * nhidden)
self.zq_linear = nn.Linear(4 * 2 * nhidden, 2 * nzqdim)
self.za_linear = nn.Linear(nzqdim + 2 * 2 * nhidden, nza * nzadim)
def forward(self, c_ids, q_ids, a_ids):
c_mask, c_lengths = return_mask_lengths(c_ids)
q_mask, q_lengths = return_mask_lengths(q_ids)
# question enc
q_embeddings = self.embedding(q_ids)
q_hs, q_state = self.question_encoder(q_embeddings, q_lengths)
q_h = q_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
q_h = q_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)
# context enc
c_embeddings = self.embedding(c_ids)
c_hs, c_state = self.question_encoder(c_embeddings, c_lengths)
c_h = c_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
c_h = c_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)
# context and answer enc
c_a_embeddings = self.embedding(c_ids, a_ids, None)
c_a_hs, c_a_state = self.question_encoder(c_a_embeddings, c_lengths)
c_a_h = c_a_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
c_a_h = c_a_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)
# attetion q, c
mask = c_mask.unsqueeze(1)
c_attned_by_q, _ = cal_attn(self.question_attention(q_h).unsqueeze(1), c_hs, mask)
c_attned_by_q = c_attned_by_q.squeeze(1)
# attetion c, q
mask = q_mask.unsqueeze(1)
q_attned_by_c, _ = cal_attn(self.context_attention(c_h).unsqueeze(1), q_hs, mask)
q_attned_by_c = q_attned_by_c.squeeze(1)
h = torch.cat([q_h, q_attned_by_c, c_h, c_attned_by_q], dim=-1)
zq_mu, zq_logvar = torch.split(self.zq_linear(h), self.nzqdim, dim=1)
zq = zq_mu + torch.randn_like(zq_mu) * torch.exp(0.5 * zq_logvar)
# attention zq, c_a
mask = c_mask.unsqueeze(1)
c_a_attned_by_zq, _ = cal_attn(self.zq_attention(zq).unsqueeze(1), c_a_hs, mask)
c_a_attned_by_zq = c_a_attned_by_zq.squeeze(1)
h = torch.cat([zq, c_a_attned_by_zq, c_a_h], dim=-1)
za_logits = self.za_linear(h).view(-1, self.nza, self.nzadim)
za_prob = F.softmax(za_logits, dim=-1)
za = gumbel_softmax(za_logits, hard=True)
return zq_mu, zq_logvar, zq, za_prob, za
class PriorEncoder(nn.Module):
def __init__(self, embedding, emsize,
nhidden, nlayers,
nzqdim, nza, nzadim,
dropout=0):
super(PriorEncoder, self).__init__()
self.embedding = embedding
self.nhidden = nhidden
self.nlayers = nlayers
self.nzqdim = nzqdim
self.nza = nza
self.nzadim = nzadim
self.context_encoder = CustomLSTM(input_size=emsize,
hidden_size=nhidden,
num_layers=nlayers,
dropout=dropout,
bidirectional=True)
self.zq_attention = nn.Linear(nzqdim, 2 * nhidden)
self.zq_linear = nn.Linear(2 * nhidden, 2 * nzqdim)
self.za_linear = nn.Linear(nzqdim + 2 * 2 * nhidden, nza * nzadim)
def forward(self, c_ids):
c_mask, c_lengths = return_mask_lengths(c_ids)
c_embeddings = self.embedding(c_ids)
c_hs, c_state = self.context_encoder(c_embeddings, c_lengths)
c_h = c_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
c_h = c_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)
zq_mu, zq_logvar = torch.split(self.zq_linear(c_h), self.nzqdim, dim=1)
zq = zq_mu + torch.randn_like(zq_mu) * torch.exp(0.5 * zq_logvar)
mask = c_mask.unsqueeze(1)
c_attned_by_zq, _ = cal_attn(self.zq_attention(zq).unsqueeze(1), c_hs, mask)
c_attned_by_zq = c_attned_by_zq.squeeze(1)
h = torch.cat([zq, c_attned_by_zq, c_h], dim=-1)
za_logits = self.za_linear(h).view(-1, self.nza, self.nzadim)
za_prob = F.softmax(za_logits, dim=-1)
za = gumbel_softmax(za_logits, hard=True)
return zq_mu, zq_logvar, zq, za_prob, za
def interpolation(self, c_ids, zq):
c_mask, c_lengths = return_mask_lengths(c_ids)
c_embeddings = self.embedding(c_ids)
c_hs, c_state = self.context_encoder(c_embeddings, c_lengths)
c_h = c_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
c_h = c_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)
mask = c_mask.unsqueeze(1)
c_attned_by_zq, _ = cal_attn(self.zq_attention(zq).unsqueeze(1), c_hs, mask)
c_attned_by_zq = c_attned_by_zq.squeeze(1)
h = torch.cat([zq, c_attned_by_zq, c_h], dim=-1)
za_logits = self.za_linear(h).view(-1, self.nza, self.nzadim)
za_prob = F.softmax(za_logits, dim=-1)
za = gumbel_softmax(za_logits, hard=True)
return za
class AnswerDecoder(nn.Module):
def __init__(self, embedding, emsize,
nhidden, nlayers,
dropout=0.0):
super(AnswerDecoder, self).__init__()
self.embedding = embedding
self.context_lstm = CustomLSTM(input_size=4 * emsize,
hidden_size=nhidden,
num_layers=nlayers,
dropout=dropout,
bidirectional=True)
self.start_linear = nn.Linear(2 * nhidden, 1)
self.end_linear = nn.Linear(2 * nhidden, 1)
self.ls = nn.LogSoftmax(dim=1)
def forward(self, init_state, c_ids):
batch_size, max_c_len = c_ids.size()
c_mask, c_lengths = return_mask_lengths(c_ids)
H = self.embedding(c_ids, c_mask)
U = init_state.unsqueeze(1).repeat(1, max_c_len, 1)
G = torch.cat([H, U, H * U, torch.abs(H - U)], dim=-1)
M, _ = self.context_lstm(G, c_lengths)
start_logits = self.start_linear(M).squeeze(-1)
end_logits = self.end_linear(M).squeeze(-1)
start_end_mask = (c_mask == 0)
masked_start_logits = start_logits.masked_fill(start_end_mask, -10000.0)
masked_end_logits = end_logits.masked_fill(start_end_mask, -10000.0)
return masked_start_logits, masked_end_logits
def generate(self, init_state, c_ids):
start_logits, end_logits = self.forward(init_state, c_ids)
c_mask, _ = return_mask_lengths(c_ids)
batch_size, max_c_len = c_ids.size()
mask = torch.matmul(c_mask.unsqueeze(2).float(), c_mask.unsqueeze(1).float())
mask = torch.triu(mask) == 0
score = (self.ls(start_logits).unsqueeze(2) + self.ls(end_logits).unsqueeze(1))
score = score.masked_fill(mask, -10000.0)
score, start_positions = score.max(dim=1)
score, end_positions = score.max(dim=1)
start_positions = torch.gather(start_positions, 1, end_positions.view(-1, 1)).squeeze(1)
idxes = torch.arange(0, max_c_len, out=torch.LongTensor(max_c_len))
idxes = idxes.unsqueeze(0).to(start_logits.device).repeat(batch_size, 1)
start_positions = start_positions.unsqueeze(1)
start_mask = (idxes >= start_positions).long()
end_positions = end_positions.unsqueeze(1)
end_mask = (idxes <= end_positions).long()
a_ids = start_mask + end_mask - 1
return a_ids, start_positions.squeeze(1), end_positions.squeeze(1)
class ContextEncoderforQG(nn.Module):
def __init__(self, embedding, emsize,
nhidden, nlayers,
dropout=0.0):
super(ContextEncoderforQG, self).__init__()
self.embedding = embedding
self.context_lstm = CustomLSTM(input_size=emsize,
hidden_size=nhidden,
num_layers=nlayers,
dropout=dropout,
bidirectional=True)
self.context_linear = nn.Linear(2 * nhidden, 2 * nhidden)
self.fusion = nn.Linear(4 * nhidden, 2 * nhidden, bias=False)
self.gate = nn.Linear(4 * nhidden, 2 * nhidden, bias=False)
def forward(self, c_ids, a_ids):
c_mask, c_lengths = return_mask_lengths(c_ids)
c_embeddings = self.embedding(c_ids, c_mask, a_ids)
c_outputs, _ = self.context_lstm(c_embeddings, c_lengths)
# attention
mask = torch.matmul(c_mask.unsqueeze(2), c_mask.unsqueeze(1))
c_attned_by_c, _ = cal_attn(self.context_linear(c_outputs),
c_outputs,
mask)
c_concat = torch.cat([c_outputs, c_attned_by_c], dim=2)
c_fused = self.fusion(c_concat).tanh()
c_gate = self.gate(c_concat).sigmoid()
c_outputs = c_gate * c_fused + (1 - c_gate) * c_outputs
return c_outputs
class QuestionDecoder(nn.Module):
def __init__(self, sos_id, eos_id,
embedding, contextualized_embedding, emsize,
nhidden, ntokens, nlayers,
dropout=0.0,
max_q_len=64):
super(QuestionDecoder, self).__init__()
self.sos_id = sos_id
self.eos_id = eos_id
self.emsize = emsize
self.embedding = embedding
self.nhidden = nhidden
self.ntokens = ntokens
self.nlayers = nlayers
# this max_len include sos eos
self.max_q_len = max_q_len
self.context_lstm = ContextEncoderforQG(contextualized_embedding, emsize,
nhidden // 2, nlayers, dropout)
self.question_lstm = CustomLSTM(input_size=emsize,
hidden_size=nhidden,
num_layers=nlayers,
dropout=dropout,
bidirectional=False)
self.question_linear = nn.Linear(nhidden, nhidden)
self.concat_linear = nn.Sequential(nn.Linear(2 * nhidden, 2 * nhidden),
nn.Dropout(dropout),
nn.Linear(2 * nhidden, 2 * emsize))
self.logit_linear = nn.Linear(emsize, ntokens, bias=False)
# fix output word matrix
self.logit_linear.weight = embedding.word_embeddings.weight
for param in self.logit_linear.parameters():
param.requires_grad = False
self.discriminator = nn.Bilinear(emsize, nhidden, 1)
def postprocess(self, q_ids):
eos_mask = q_ids == self.eos_id
no_eos_idx_sum = (eos_mask.sum(dim=1) == 0).long() * (self.max_q_len - 1)
eos_mask = eos_mask.cpu().numpy()
q_lengths = np.argmax(eos_mask, axis=1) + 1
q_lengths = torch.tensor(q_lengths).to(q_ids.device).long() + no_eos_idx_sum
batch_size, max_len = q_ids.size()
idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len))
idxes = idxes.unsqueeze(0).to(q_ids.device).repeat(batch_size, 1)
q_mask = (idxes < q_lengths.unsqueeze(1))
q_ids = q_ids.long() * q_mask.long()
return q_ids
def forward(self, init_state, c_ids, q_ids, a_ids):
batch_size, max_q_len = q_ids.size()
c_outputs = self.context_lstm(c_ids, a_ids)
c_mask, c_lengths = return_mask_lengths(c_ids)
q_mask, q_lengths = return_mask_lengths(q_ids)
# question dec
q_embeddings = self.embedding(q_ids)
q_outputs, _ = self.question_lstm(q_embeddings, q_lengths, init_state)
# attention
mask = torch.matmul(q_mask.unsqueeze(2), c_mask.unsqueeze(1))
c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs),
c_outputs,
mask)
# gen logits
q_concated = torch.cat([q_outputs, c_attned_by_q], dim=2)
q_concated = self.concat_linear(q_concated)
q_maxouted, _ = q_concated.view(batch_size, max_q_len, self.emsize, 2).max(dim=-1)
gen_logits = self.logit_linear(q_maxouted)
# copy logits
bq = batch_size * max_q_len
c_ids = c_ids.unsqueeze(1).repeat(1, max_q_len, 1).view(bq, -1).contiguous()
attn_logits = attn_logits.view(bq, -1).contiguous()
copy_logits = torch.zeros(bq, self.ntokens).to(c_ids.device)
copy_logits = copy_logits - 10000.0
copy_logits, _ = scatter_max(attn_logits, c_ids, out=copy_logits)
copy_logits = copy_logits.masked_fill(copy_logits == -10000.0, 0)
copy_logits = copy_logits.view(batch_size, max_q_len, -1).contiguous()
logits = gen_logits + copy_logits
# mutual information btw answer and question
a_emb = c_outputs * a_ids.float().unsqueeze(2)
a_mean_emb = torch.sum(a_emb, dim=1) / a_ids.sum(dim=1).unsqueeze(1).float()
fake_a_mean_emb = torch.cat([a_mean_emb[-1].unsqueeze(0), a_mean_emb[:-1]], dim=0)
q_emb = q_maxouted * q_mask.unsqueeze(2)
q_mean_emb = torch.sum(q_emb, dim=1) / q_lengths.unsqueeze(1).float()
fake_q_mean_emb = torch.cat([q_mean_emb[-1].unsqueeze(0), q_mean_emb[:-1]], dim=0)
bce_loss = nn.BCEWithLogitsLoss()
true_logits = self.discriminator(q_mean_emb, a_mean_emb)
true_labels = torch.ones_like(true_logits)
fake_a_logits = self.discriminator(q_mean_emb, fake_a_mean_emb)
fake_q_logits = self.discriminator(fake_q_mean_emb, a_mean_emb)
fake_logits = torch.cat([fake_a_logits, fake_q_logits], dim=0)
fake_labels = torch.zeros_like(fake_logits)
true_loss = bce_loss(true_logits, true_labels)
fake_loss = 0.5 * bce_loss(fake_logits, fake_labels)
loss_info = 0.5 * (true_loss + fake_loss)
print(logits.size())
return logits, loss_info
def get_mi(self, init_state, c_ids, q_ids, a_ids):
batch_size, max_q_len = q_ids.size()
c_outputs = self.context_lstm(c_ids, a_ids)
c_mask, c_lengths = return_mask_lengths(c_ids)
q_mask, q_lengths = return_mask_lengths(q_ids)
# question dec
q_embeddings = self.embedding(q_ids)
q_outputs, _ = self.question_lstm(q_embeddings, q_lengths, init_state)
# attention
mask = torch.matmul(q_mask.unsqueeze(2), c_mask.unsqueeze(1))
c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs),
c_outputs,
mask)
# gen logits
q_concated = torch.cat([q_outputs, c_attned_by_q], dim=2)
q_concated = self.concat_linear(q_concated)
q_maxouted, _ = q_concated.view(batch_size, max_q_len, self.emsize, 2).max(dim=-1)
# mutual information btw answer and question
a_emb = c_outputs * a_ids.float().unsqueeze(2)
a_mean_emb = torch.sum(a_emb, dim=1) / a_ids.sum(dim=1).unsqueeze(1).float()
q_emb = q_maxouted * q_mask.unsqueeze(2)
q_mean_emb = torch.sum(q_emb, dim=1) / q_lengths.unsqueeze(1).float()
logits = self.discriminator(q_mean_emb, a_mean_emb)
logits = logits.squeeze(1)
return logits
def generate(self, init_state, c_ids, a_ids):
c_mask, c_lengths = return_mask_lengths(c_ids)
c_outputs = self.context_lstm(c_ids, a_ids)
batch_size = c_ids.size(0)
q_ids = torch.LongTensor([self.sos_id] * batch_size).unsqueeze(1)
q_ids = q_ids.to(c_ids.device)
token_type_ids = torch.zeros_like(q_ids)
position_ids = torch.zeros_like(q_ids)
q_embeddings = self.embedding(q_ids, token_type_ids, position_ids)
state = init_state
# unroll
all_q_ids = list()
all_q_ids.append(q_ids)
for _ in range(self.max_q_len - 1):
position_ids = position_ids + 1
q_outputs, state = self.question_lstm.lstm(q_embeddings, state)
# attention
mask = c_mask.unsqueeze(1)
c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs),
c_outputs,
mask)
# gen logits
q_concated = torch.cat([q_outputs, c_attned_by_q], dim=2)
q_concated = self.concat_linear(q_concated)
q_maxouted, _ = q_concated.view(batch_size, 1, self.emsize, 2).max(dim=-1)
gen_logits = self.logit_linear(q_maxouted)
# copy logits
attn_logits = attn_logits.squeeze(1)
copy_logits = torch.zeros(batch_size, self.ntokens).to(c_ids.device)
copy_logits = copy_logits - 10000.0
copy_logits, _ = scatter_max(attn_logits, c_ids, out=copy_logits)
copy_logits = copy_logits.masked_fill(copy_logits == -10000.0, 0)
logits = gen_logits + copy_logits.unsqueeze(1)
q_ids = torch.argmax(logits, 2)
all_q_ids.append(q_ids)
q_embeddings = self.embedding(q_ids, token_type_ids, position_ids)
q_ids = torch.cat(all_q_ids, 1)
q_ids = self.postprocess(q_ids)
return q_ids
def sample(self, init_state, c_ids, a_ids):
c_mask, c_lengths = return_mask_lengths(c_ids)
c_outputs = self.context_lstm(c_ids, a_ids)
batch_size = c_ids.size(0)
q_ids = torch.LongTensor([self.sos_id] * batch_size).unsqueeze(1)
q_ids = q_ids.to(c_ids.device)
token_type_ids = torch.zeros_like(q_ids)
position_ids = torch.zeros_like(q_ids)
q_embeddings = self.embedding(q_ids, token_type_ids, position_ids)
state = init_state
# unroll
all_q_ids = list()
all_q_ids.append(q_ids)
for _ in range(self.max_q_len - 1):
position_ids = position_ids + 1
q_outputs, state = self.question_lstm.lstm(q_embeddings, state)
# attention
mask = c_mask.unsqueeze(1)
c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs),
c_outputs,
mask)
# gen logits
q_concated = torch.cat([q_outputs, c_attned_by_q], dim=2)
q_concated = self.concat_linear(q_concated)
q_maxouted, _ = q_concated.view(batch_size, 1, self.emsize, 2).max(dim=-1)
gen_logits = self.logit_linear(q_maxouted)
# copy logits
attn_logits = attn_logits.squeeze(1)
copy_logits = torch.zeros(batch_size, self.ntokens).to(c_ids.device)
copy_logits = copy_logits - 10000.0
copy_logits, _ = scatter_max(attn_logits, c_ids, out=copy_logits)
copy_logits = copy_logits.masked_fill(copy_logits == -10000.0, 0)
logits = gen_logits + copy_logits.unsqueeze(1)
logits = logits.squeeze(1)
logits =self.top_k_top_p_filtering(logits, 2, top_p=0.8)
probs = F.softmax(logits, dim=-1)
q_ids = torch.multinomial(probs, num_samples=1) # [b,1]
all_q_ids.append(q_ids)
q_embeddings = self.embedding(q_ids, token_type_ids, position_ids)
q_ids = torch.cat(all_q_ids, 1)
q_ids = self.postprocess(q_ids)
return q_ids
def top_k_top_p_filtering(self, logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size x vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[
0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[...,
1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
dim=1, index=sorted_indices, src=sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
class DiscreteVAE(nn.Module):
def __init__(self, args):
super(DiscreteVAE, self).__init__()
tokenizer = BertTokenizer.from_pretrained(args.bert_model)
padding_idx = tokenizer.vocab['[PAD]']
sos_id = tokenizer.vocab['[CLS]']
eos_id = tokenizer.vocab['[SEP]']
ntokens = len(tokenizer.vocab)
bert_model = args.bert_model
if "large" in bert_model:
emsize = 1024
else:
emsize = 768
enc_nhidden = args.enc_nhidden
enc_nlayers = args.enc_nlayers
enc_dropout = args.enc_dropout
dec_a_nhidden = args.dec_a_nhidden
dec_a_nlayers = args.dec_a_nlayers
dec_a_dropout = args.dec_a_dropout
self.dec_q_nhidden = dec_q_nhidden = args.dec_q_nhidden
self.dec_q_nlayers = dec_q_nlayers = args.dec_q_nlayers
dec_q_dropout = args.dec_q_dropout
self.nzqdim = nzqdim = args.nzqdim
self.nza = nza = args.nza
self.nzadim = nzadim = args.nzadim
self.lambda_kl = args.lambda_kl
self.lambda_info = args.lambda_info
max_q_len = args.max_q_len
embedding = Embedding(bert_model)
contextualized_embedding = ContextualizedEmbedding(bert_model)
for param in embedding.parameters():
param.requires_grad = False
for param in contextualized_embedding.parameters():
param.requires_grad = False
self.posterior_encoder = PosteriorEncoder(embedding, emsize,
enc_nhidden, enc_nlayers,
nzqdim, nza, nzadim,
enc_dropout)
self.prior_encoder = PriorEncoder(embedding, emsize,
enc_nhidden, enc_nlayers,
nzqdim, nza, nzadim, enc_dropout)
self.answer_decoder = AnswerDecoder(contextualized_embedding, emsize,
dec_a_nhidden, dec_a_nlayers,
dec_a_dropout)
self.question_decoder = QuestionDecoder(sos_id, eos_id,
embedding, contextualized_embedding, emsize,
dec_q_nhidden, ntokens, dec_q_nlayers,
dec_q_dropout,
max_q_len)
self.q_h_linear = nn.Linear(nzqdim, dec_q_nlayers * dec_q_nhidden)
self.q_c_linear = nn.Linear(nzqdim, dec_q_nlayers * dec_q_nhidden)
self.a_linear = nn.Linear(nza * nzadim, emsize, False)
self.q_rec_criterion = nn.CrossEntropyLoss(ignore_index=padding_idx)
self.gaussian_kl_criterion = GaussianKLLoss()
self.categorical_kl_criterion = CategoricalKLLoss()
def return_init_state(self, zq, za):
q_init_h = self.q_h_linear(zq)
q_init_c = self.q_c_linear(zq)
q_init_h = q_init_h.view(-1, self.dec_q_nlayers, self.dec_q_nhidden).transpose(0, 1).contiguous()
q_init_c = q_init_c.view(-1, self.dec_q_nlayers, self.dec_q_nhidden).transpose(0, 1).contiguous()
q_init_state = (q_init_h, q_init_c)
za_flatten = za.view(-1, self.nza * self.nzadim)
a_init_state = self.a_linear(za_flatten)
return q_init_state, a_init_state
def forward(self, c_ids, q_ids, a_ids, start_positions, end_positions):
posterior_zq_mu, posterior_zq_logvar, posterior_zq, \
posterior_za_prob, posterior_za \
= self.posterior_encoder(c_ids, q_ids, a_ids)
prior_zq_mu, prior_zq_logvar, prior_zq, \
prior_za_prob, prior_za \
= self.prior_encoder(c_ids)
q_init_state, a_init_state = self.return_init_state(posterior_zq, posterior_za)
# answer decoding
start_logits, end_logits = self.answer_decoder(a_init_state, c_ids)
# question decoding
q_logits, loss_info = self.question_decoder(q_init_state, c_ids, q_ids, a_ids)
# q rec loss
loss_q_rec = self.q_rec_criterion(q_logits[:, :-1, :].transpose(1, 2).contiguous(),
q_ids[:, 1:])
# a rec loss
max_c_len = c_ids.size(1)
a_rec_criterion = nn.CrossEntropyLoss(ignore_index=max_c_len)
start_positions.clamp_(0, max_c_len)
end_positions.clamp_(0, max_c_len)
loss_start_a_rec = a_rec_criterion(start_logits, start_positions)
loss_end_a_rec = a_rec_criterion(end_logits, end_positions)
loss_a_rec = 0.5 * (loss_start_a_rec + loss_end_a_rec)
# kl loss
loss_zq_kl = self.gaussian_kl_criterion(posterior_zq_mu,
posterior_zq_logvar,
prior_zq_mu,
prior_zq_logvar)
loss_za_kl = self.categorical_kl_criterion(posterior_za_prob,
prior_za_prob)
loss_kl = self.lambda_kl * (loss_zq_kl + loss_za_kl)
loss_info = self.lambda_info * loss_info
loss = loss_q_rec + loss_a_rec + loss_kl + loss_info
return loss, \
loss_q_rec, loss_a_rec, \
loss_zq_kl, loss_za_kl, \
loss_info
def generate(self, zq, za, c_ids):
q_init_state, a_init_state = self.return_init_state(zq, za)
a_ids, start_positions, end_positions = self.answer_decoder.generate(a_init_state, c_ids)
q_ids = self.question_decoder.generate(q_init_state, c_ids, a_ids)
return q_ids, start_positions, end_positions, a_ids
def nucleus_sample(self, zq, za, c_ids):
q_init_state, a_init_state = self.return_init_state(zq, za)
a_ids, start_positions, end_positions = self.answer_decoder.generate(a_init_state, c_ids)
q_ids = self.question_decoder.sample(q_init_state, c_ids, a_ids)
return q_ids, start_positions, end_positions, a_ids
def return_answer_logits(self, zq, za, c_ids):
q_init_state, a_init_state = self.return_init_state(zq, za)
start_logits, end_logits = self.answer_decoder(a_init_state, c_ids)
return start_logits, end_logits
def question_generate(self, c_ids, a_ids):
zq_mu, zq_logvar, _, _, _ = self.prior_encoder(c_ids)
zq = zq_mu + torch.randn_like(zq_mu) * torch.exp(0.5 * zq_logvar)
zq_rand = torch.rand_like(zq_mu)
zq = zq[0].unsqueeze(0)
zq = torch.cat([zq, zq_rand[1:]], dim=0)
q_init_h = self.q_h_linear(zq)
q_init_c = self.q_c_linear(zq)
q_init_h = q_init_h.view(-1, self.dec_q_nlayers, self.dec_q_nhidden).transpose(0, 1).contiguous()
q_init_c = q_init_c.view(-1, self.dec_q_nlayers, self.dec_q_nhidden).transpose(0, 1).contiguous()
q_init_state = (q_init_h, q_init_c)
q_ids = self.question_decoder.generate(q_init_state, c_ids, a_ids)
return q_ids
def estimate_mi(self, c_ids, q_ids, a_ids):
posterior_zq_mu, posterior_zq_logvar, posterior_zq, \
posterior_za_prob, posterior_za \
= self.posterior_encoder(c_ids, q_ids, a_ids)
q_init_state, a_init_state = self.return_init_state(posterior_zq, posterior_za)
mi = self.question_decoder.get_mi(q_init_state, c_ids, q_ids, a_ids)
return mi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment