Skip to content

Instantly share code, notes, and snippets.

@seanie12
Last active September 3, 2019 08:36
Show Gist options
  • Save seanie12/2141ce3a7d91c34811948e95dafaecf0 to your computer and use it in GitHub Desktop.
Save seanie12/2141ce3a7d91c34811948e95dafaecf0 to your computer and use it in GitHub Desktop.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_pretrained_bert import BertModel
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch_scatter import scatter_max
def return_mask_lengths(ids):
if ids.dim() == 3: # it means it is one hot
mask = torch.sum(ids, dim=2)
else:
mask = torch.sign(ids).long()
lengths = mask.sum(dim=1)
return mask, lengths
def cal_attn(left, right, mask):
mask = (1.0 - mask.float()) * -10000.0
attn_logits = torch.matmul(left, right.transpose(-1, -2).contiguous())
attn_logits = attn_logits + mask
attn_weights = F.softmax(input=attn_logits, dim=-1)
attn_outputs = torch.matmul(attn_weights, right)
return attn_outputs, attn_logits
class BertEmbedding(nn.Module):
def __init__(self, bert_model):
super(BertEmbedding, self).__init__()
bert_embedding = BertModel.from_pretrained(bert_model).embeddings
self.word_embeddings = bert_embedding.word_embeddings
self.position_embeddings = bert_embedding.position_embeddings
self.token_type_embeddings = bert_embedding.token_type_embeddings
self.LayerNorm = bert_embedding.LayerNorm
self.dropout = bert_embedding.dropout
def forward(self, input_ids, token_type_ids=None, position_ids=None):
if input_ids.dim() == 3:
word_embeddings = F.linear(input_ids, self.word_embeddings.weight.transpose(-1, -2).contiguous())
input_size = input_ids[:, :, 0].size()
else:
word_embeddings = self.word_embeddings(input_ids)
input_size = input_ids.size()
if position_ids is None:
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(input_size)
if token_type_ids is None:
token_type_ids = torch.zeros(input_size).to(input_ids.device).long()
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = word_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
def gumbel_softmax(logits, tau=1, hard=False, eps=1e-20, dim=-1):
# type: (Tensor, float, bool, float, int) -> Tensor
# gumbels = -(torch.empty_like(logits).exponential_() + eps).log() # ~Gumbel(0,1)
gumbels = -(-(torch.rand_like(logits) + eps).log() + eps).log()
gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
y_soft = gumbels.softmax(dim)
if hard:
# Straight through.
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
# Reparametrization trick.
ret = y_soft
return ret
class CatKLLoss(nn.Module):
def __init__(self):
super(CatKLLoss, self).__init__()
def forward(self, log_qy, log_py):
qy = torch.exp(log_qy)
kl = torch.sum(qy * (log_qy - log_py), dim=-1)
return torch.sum(kl, dim=-1)
class CustomLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, dropout, bidirectional=False):
super(CustomLSTM, self).__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
self.bidirectional = bidirectional
self.dropout = nn.Dropout(dropout)
if dropout > 0.0 and num_layers == 1:
dropout = 0.0
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
num_layers=num_layers, dropout=dropout,
bidirectional=bidirectional, batch_first=True)
def forward(self, input, input_lengths, state=None):
batch_size, total_length, _ = input.size()
input_packed = pack_padded_sequence(input, input_lengths,
batch_first=True, enforce_sorted=False)
self.lstm.flatten_parameters()
output_packed, state = self.lstm(input_packed, state)
output, _ = pad_packed_sequence(output_packed, batch_first=True, total_length=total_length)
output = self.dropout(output)
return output, state
class PosteriorEncoder(nn.Module):
def __init__(self, embedding, bert_model, emsize,
nhidden, ntokens, nlayers,
nz, nzdim,
dropout=0, freeze=False):
super(PosteriorEncoder, self).__init__()
self.nhidden = nhidden
self.ntokens = ntokens
self.nlayers = nlayers
self.nz = nz
self.nzdim = nzdim
if embedding is not None:
self.embedding = embedding
else:
self.embedding = BertEmbedding(bert_model)
self.question_encoder = CustomLSTM(input_size=emsize,
hidden_size=nhidden,
num_layers=nlayers,
dropout=dropout,
bidirectional=True)
self.question_linear = nn.Linear(2 * nhidden, 2 * nhidden)
self.context_answer_encoder = CustomLSTM(input_size=emsize,
hidden_size=nhidden,
num_layers=nlayers,
dropout=dropout,
bidirectional=True)
self.context_answer_linear = nn.Linear(2 * nhidden, 2 * nhidden)
self.posterior_linear = nn.Linear(2 * 4 * nhidden, nz * nzdim)
def forward(self, c_ids, q_ids, a_ids):
c_mask, c_lengths = return_mask_lengths(c_ids)
q_mask, q_lengths = return_mask_lengths(q_ids)
# question enc
q_embeddings = self.embedding(q_ids)
q_hs, q_state = self.question_encoder(q_embeddings, q_lengths)
q_h = q_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
q_h = q_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)
# answer enc
c_a_embeddings = self.embedding(c_ids, a_ids, None)
c_a_hs, c_a_state = self.context_answer_encoder(c_a_embeddings, c_lengths)
c_a_h = c_a_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
c_a_h = c_a_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)
mask = q_mask.unsqueeze(1)
q_attned_by_ca, _ = cal_attn(self.question_linear(c_a_h).unsqueeze(1), q_hs, mask)
q_attned_by_ca = q_attned_by_ca.squeeze(1)
mask = c_mask.unsqueeze(1)
ca_attned_by_q, _ = cal_attn(self.context_answer_linear(q_h).unsqueeze(1), c_a_hs, mask)
ca_attned_by_q = ca_attned_by_q.squeeze(1)
h = torch.cat([q_h, q_attned_by_ca, c_a_h, ca_attned_by_q], dim=-1)
posterior_z_logits = self.posterior_linear(h).view(-1, self.nz, self.nzdim).contiguous()
posterior_z_prob = F.softmax(posterior_z_logits, dim=-1)
return posterior_z_logits, posterior_z_prob
class PriorEncoder(nn.Module):
def __init__(self, embedding, bert_model, emsize,
nhidden, ntokens, nlayers,
nz, nzdim,
dropout=0):
super(PriorEncoder, self).__init__()
self.nhidden = nhidden
self.ntokens = ntokens
self.nlayers = nlayers
self.nz = nz
self.nzdim = nzdim
if embedding is not None:
self.embedding = embedding
else:
self.embedding = BertEmbedding(bert_model)
self.context_encoder = CustomLSTM(input_size=emsize,
hidden_size=nhidden,
num_layers=nlayers,
dropout=dropout,
bidirectional=True)
self.prior_linear = nn.Linear(2 * nhidden, nz * nzdim)
def forward(self, c_ids):
c_mask, c_lengths = return_mask_lengths(c_ids)
# answer enc
c_embeddings = self.embedding(c_ids)
_, c_state = self.context_encoder(c_embeddings, c_lengths)
c_h = c_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
h = c_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)
prior_z_logits = self.prior_linear(h).view(-1, self.nz, self.nzdim)
prior_z_prob = F.softmax(prior_z_logits, dim=-1)
return prior_z_logits, prior_z_prob
class AnswerDecoder(nn.Module):
def __init__(self, embedding, bert_model, emsize,
nhidden, nlayers,
dropout=0):
super(AnswerDecoder, self).__init__()
self.nhidden = nhidden = int(0.5 * nhidden)
self.nlayers = nlayers
if embedding is not None:
self.embedding = embedding
else:
self.embedding = BertEmbedding(bert_model)
self.context_lstm = CustomLSTM(input_size=emsize,
hidden_size=nhidden,
num_layers=nlayers,
dropout=dropout,
bidirectional=True)
self.start_lstm = CustomLSTM(input_size=4 * 2 * nhidden,
hidden_size=nhidden,
num_layers=nlayers,
dropout=dropout,
bidirectional=True)
self.end_lstm = CustomLSTM(input_size=2 * nhidden,
hidden_size=nhidden,
num_layers=nlayers,
dropout=dropout,
bidirectional=True)
self.start_linear = nn.Linear(4 * 2 * nhidden + 2 * nhidden, 1)
self.end_linear = nn.Linear(4 * 2 * nhidden + 2 * nhidden, 1)
def forward(self, init_state, c_ids):
batch_size, max_c_len = c_ids.size()
c_mask, c_lengths = return_mask_lengths(c_ids)
c_embeddings = self.embedding(c_ids)
H, _ = self.context_lstm(c_embeddings, c_lengths)
U = init_state.unsqueeze(1).repeat(1, max_c_len, 1)
G = torch.cat([H, U, H * U, torch.abs(H - U)], dim=-1)
M1, _ = self.start_lstm(G, c_lengths)
M2, _ = self.end_lstm(M1, c_lengths)
start_logits = self.start_linear(torch.cat([G, M1], dim=-1)).squeeze(-1)
end_logits = self.end_linear(torch.cat([G, M2], dim=-1)).squeeze(-1)
start_end_mask = c_mask == 0
masked_start_logits = start_logits.masked_fill(start_end_mask, -10000.0)
masked_end_logits = end_logits.masked_fill(start_end_mask, -10000.0)
return masked_start_logits, masked_end_logits
class ContextEncoderforQG(nn.Module):
def __init__(self, embedding, bert_model, emsize,
nhidden, nlayers, dropout=0):
super(ContextEncoderforQG, self).__init__()
if embedding is not None:
self.embedding = embedding
else:
self.embedding = BertEmbedding(bert_model)
self.context_lstm = CustomLSTM(input_size=emsize,
hidden_size=nhidden,
num_layers=nlayers,
dropout=dropout,
bidirectional=True)
self.context_linear = nn.Linear(2 * nhidden, 2 * nhidden)
self.fusion = nn.Linear(4 * nhidden, 2 * nhidden, bias=False)
self.gate = nn.Linear(4 * nhidden, 2 * nhidden, bias=False)
def forward(self, c_ids, a_ids):
c_mask, c_lengths = return_mask_lengths(c_ids)
c_embeddings = self.embedding(c_ids, a_ids, None)
c_outputs, _ = self.context_lstm(c_embeddings, c_lengths)
# attention
mask = torch.matmul(c_mask.unsqueeze(2).float(), c_mask.unsqueeze(1).float())
c_attned_by_c, _ = cal_attn(self.context_linear(c_outputs), c_outputs, mask)
c_concat = torch.cat([c_outputs, c_attned_by_c], dim=2)
c_fused = self.fusion(c_concat).tanh()
c_gate = self.gate(c_concat).sigmoid()
c_outputs = c_gate * c_fused + (1 - c_gate) * c_outputs
return c_outputs
class QuestionDecoder(nn.Module):
def __init__(self, sos_id, eos_id,
embedding, bert_model, emsize,
nhidden, ntokens, nlayers,
dropout=0, copy=True, max_q_len=64):
super(QuestionDecoder, self).__init__()
self.sos_id = sos_id
self.eos_id = eos_id
# this max_len include sos eos
self.max_q_len = max_q_len
self.nhidden = nhidden
self.ntokens = ntokens
self.nlayers = nlayers
self.copy = copy
if embedding is not None:
self.embedding = embedding
else:
self.embedding = BertEmbedding(bert_model)
self.context_lstm = ContextEncoderforQG(embedding, bert_model, emsize,
nhidden, nlayers, dropout)
self.question_lstm = CustomLSTM(input_size=emsize,
hidden_size=2 * nhidden,
num_layers=nlayers,
dropout=dropout,
bidirectional=False)
self.question_linear = nn.Linear(2 * nhidden, 2 * nhidden)
self.concat_linear = nn.Linear(4 * nhidden, 2 * nhidden)
self.logit_linear = nn.Linear(2 * nhidden, ntokens)
def forward(self, init_state, c_ids, q_ids, a_ids):
c_mask, c_lengths = return_mask_lengths(c_ids)
q_mask, q_lengths = return_mask_lengths(q_ids)
c_outputs = self.context_lstm(c_ids, a_ids)
batch_size, max_q_len = q_ids.size()
# question dec
q_embeddings = self.embedding(q_ids, None, torch.zeros_like(q_ids))
q_outputs, _ = self.question_lstm(q_embeddings, q_lengths, init_state)
# attention
mask = torch.matmul(q_mask.unsqueeze(2).float(), c_mask.unsqueeze(1).float())
c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs), c_outputs, mask)
# gen logits
q_concat = self.concat_linear(torch.cat([q_outputs, c_attned_by_q], dim=2)).tanh()
logits = self.logit_linear(q_concat)
if self.copy:
# copy logits
bq = batch_size * max_q_len
c_ids = c_ids.unsqueeze(1).repeat(1, max_q_len, 1).view(bq, -1).contiguous()
attn_logits = attn_logits.view(bq, -1).contiguous()
out = torch.zeros(bq, self.ntokens).to(c_ids.device)
out = out - 10000.0
out, _ = scatter_max(attn_logits, c_ids, out=out)
out = out.masked_fill(out == -10000.0, 0)
out = out.view(batch_size, max_q_len, -1).contiguous()
logits = logits + out
return logits
def generate(self, init_state, c_ids, a_ids):
c_mask, c_lengths = return_mask_lengths(c_ids)
c_outputs = self.context_lstm(c_ids, a_ids)
batch_size = c_ids.size(0)
start_symbols = torch.LongTensor([self.sos_id] * batch_size).unsqueeze(1)
start_symbols = start_symbols.to(c_ids.device)
position_ids = torch.zeros_like(start_symbols)
q_embeddings = self.embedding(start_symbols, None, position_ids)
state = init_state
# unroll
all_indices = []
all_indices.append(start_symbols)
for _ in range(self.max_q_len - 1):
q_outputs, state = self.question_lstm.lstm(q_embeddings, state)
# attention
mask = c_mask.unsqueeze(1).float()
c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs), c_outputs, mask)
# gen logits
q_concat = self.concat_linear(torch.cat([q_outputs, c_attned_by_q], dim=2)).tanh()
logits = self.logit_linear(q_concat)
if self.copy:
# copy logits
attn_logits = attn_logits.squeeze(1)
out = torch.zeros(batch_size, self.ntokens).to(c_ids.device)
out = out - 10000.0
out, _ = scatter_max(attn_logits, c_ids, out=out)
out = out.masked_fill(out == -10000.0, 0)
logits = logits + out.unsqueeze(1)
indices = torch.argmax(logits, 2)
all_indices.append(indices)
q_embeddings = self.embedding(indices, None, position_ids)
q_ids = torch.cat(all_indices, 1)
eos_mask = q_ids == self.eos_id
no_eos_idx_sum = (eos_mask.sum(dim=1) == 0).long() * 63
eos_mask = eos_mask.cpu().numpy()
q_lengths = np.argmax(eos_mask, axis=1) + 1
q_lengths = torch.tensor(q_lengths).to(q_ids.device).long() + no_eos_idx_sum
batch_size, max_len = q_ids.size()
idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len))
idxes = idxes.unsqueeze(0).to(q_ids.device).repeat(batch_size, 1)
q_mask = (idxes < q_lengths.unsqueeze(1))
q_ids = q_ids.long() * q_mask.long()
return q_ids
class DiscreteVAE(nn.Module):
def __init__(self, padding_idx, sos_id, eos_id,
bert_model,
nhidden, ntokens, nlayers,
nz, nzdim, freeze=False,
dropout=0, copy=True, max_q_len=64):
super(DiscreteVAE, self).__init__()
self.nhidden = nhidden
if "large" in bert_model:
emsize = 1024
else:
emsize = 768
self.emsize = emsize
self.ntokens = ntokens
self.nlayers = nlayers
self.nz = nz
self.nzdim = nzdim
embedding = BertEmbedding(bert_model)
if freeze:
print("freeze bert embedding")
for param in embedding.parameters():
param.requires_grad = False
self.posterior_encoder = PosteriorEncoder(embedding, bert_model, emsize,
nhidden, ntokens, nlayers, nz, nzdim, dropout)
self.prior_encoder = PriorEncoder(embedding, bert_model, emsize,
nhidden, ntokens, nlayers, nz, nzdim, dropout)
self.answer_decoder = AnswerDecoder(embedding, bert_model, emsize,
nhidden, nlayers, dropout)
self.question_decoder = QuestionDecoder(sos_id, eos_id,
embedding, bert_model, emsize,
nhidden, ntokens, nlayers, dropout,
copy, max_q_len)
self.q_h_linear = nn.Linear(nz * nzdim, 2 * nlayers * nhidden, False)
self.q_c_linear = nn.Linear(nz * nzdim, 2 * nlayers * nhidden, False)
self.a_h_linear = nn.Linear(nz * nzdim, nhidden, False)
self.q_rec_criterion = nn.CrossEntropyLoss(ignore_index=padding_idx)
self.kl_criterion = CatKLLoss()
def return_init_state(self, z_flatten):
batch_size = z_flatten.size(0)
q_init_h = self.q_h_linear(z_flatten)
q_init_c = self.q_c_linear(z_flatten)
q_init_h = q_init_h.view(batch_size, self.nlayers, 2 * self.nhidden).transpose(0, 1).contiguous()
q_init_c = q_init_c.view(batch_size, self.nlayers, 2 * self.nhidden).transpose(0, 1).contiguous()
q_init_state = (q_init_h, q_init_c)
a_init_h = self.a_h_linear(z_flatten)
a_init_state = a_init_h
return q_init_state, a_init_state
def forward(self, c_ids, q_ids, a_ids, start_positions, end_positions, tau=1.0):
max_c_len = c_ids.size(1)
posterior_z_logits, posterior_z_prob = self.posterior_encoder(c_ids, q_ids, a_ids)
posterior_z = gumbel_softmax(posterior_z_logits, hard=True)
posterior_z_flatten = posterior_z.view(-1, self.nz * self.nzdim)
prior_z_logits, prior_z_prob = self.prior_encoder(c_ids)
q_init_state, a_init_state = self.return_init_state(posterior_z_flatten)
start_logits, end_logits = self.answer_decoder(a_init_state, c_ids)
q_logits = self.question_decoder(q_init_state, c_ids, q_ids, a_ids)
# q rec loss
loss_q_rec = self.q_rec_criterion(q_logits[:, :-1, :].transpose(1, 2).contiguous(),
q_ids[:, 1:])
# a rec loss
a_rec_criterion = nn.CrossEntropyLoss(ignore_index=max_c_len)
start_positions.clamp_(0, max_c_len)
end_positions.clamp_(0, max_c_len)
loss_start_a_rec = a_rec_criterion(start_logits, start_positions)
loss_end_a_rec = a_rec_criterion(end_logits, end_positions)
loss_a_rec = (loss_start_a_rec + loss_end_a_rec) / 2
# kl loss
posterior_avg_z_prob = posterior_z_prob.mean(dim=0)
loss_kl = self.kl_criterion(posterior_avg_z_prob.log(), prior_z_prob.log()).mean(dim=0)
loss = loss_q_rec + loss_a_rec + loss_kl
return loss, loss_q_rec, loss_a_rec, loss_kl
def recon_ans(self, c_ids, q_ids, a_ids):
posterior_z_logits, posterior_z_prob = self.posterior_encoder(c_ids, q_ids, a_ids)
posterior_z = gumbel_softmax(posterior_z_logits, hard=True)
posterior_z_flatten = posterior_z.view(-1, self.nz * self.nzdim)
q_init_state, a_init_state = self.return_init_state(posterior_z_flatten)
start_logits, end_logits = self.answer_decoder(a_init_state, c_ids)
return start_logits, end_logits
def generate(self, z_logits, c_ids):
batch_size, max_c_len = c_ids.size()
c_mask, _ = return_mask_lengths(c_ids)
z = gumbel_softmax(z_logits, hard=True)
z_flatten = z.view(-1, self.nz * self.nzdim)
q_init_state, a_init_state = self.return_init_state(z_flatten)
start_logits, end_logits = self.answer_decoder(a_init_state, c_ids)
mask = torch.matmul(c_mask.unsqueeze(2).float(), c_mask.unsqueeze(1).float())
mask = torch.triu(mask) == 0
score = (F.log_softmax(start_logits, dim=1).unsqueeze(2)
+ F.log_softmax(end_logits, dim=1).unsqueeze(1))
score = score.masked_fill(mask, -10000.0)
score, start_positions = score.max(dim=1)
score, end_positions = score.max(dim=1)
start_positions = torch.gather(start_positions, 1, end_positions.view(-1, 1)).squeeze(1)
idxes = torch.arange(0, max_c_len, out=torch.LongTensor(max_c_len))
idxes = idxes.unsqueeze(0).to(start_logits.device).repeat(batch_size, 1)
start_positions = start_positions.unsqueeze(1)
start_mask = (idxes >= start_positions).long()
end_positions = end_positions.unsqueeze(1)
end_mask = (idxes <= end_positions).long()
generated_a_ids = start_mask + end_mask - 1
q_ids = self.question_decoder.generate(q_init_state, c_ids, generated_a_ids)
return q_ids, start_positions.squeeze(1), end_positions.squeeze(1), start_logits, end_logits
from new_model import DiscreteVAE
class AttnTrainer(CatTrainer):
def __init__(self, args):
super(AttnTrainer, self).__init__(args)
def init_model(self, args):
sos_id = self.tokenizer.vocab["[CLS]"]
eos_id = self.tokenizer.vocab["[SEP]"]
model = DiscreteVAE(padding_idx=0,
sos_id=sos_id,
eos_id=eos_id,
bert_model="bert-base-uncased",
ntokens=len(self.tokenizer.vocab),
nhidden=512,
nlayers=1,
dropout=0.2,
nz=20,
nzdim=10,
freeze=self.args.freeze,
copy=True)
model = model.to(self.device)
return model
def get_opt(self):
# in case of using pre-trained vae
if self.args.save_file is not None:
params = [param for name, param in self.model.named_parameters() if "vae_net" not in name]
else:
params = self.model.parameters()
opt = optim.Adam(params, self.args.lr)
return opt
def process_batch(self, batch):
batch = tuple(t.to(self.device) for t in batch)
q_ids, c_ids, tag_ids, ans_ids, start_positions, end_positions = batch
q_len = torch.sum(torch.sign(q_ids), 1)
max_len = torch.max(q_len)
q_ids = q_ids[:, :max_len]
c_len = torch.sum(torch.sign(c_ids), 1)
max_len = torch.max(c_len)
c_ids = c_ids[:, :max_len]
tag_ids = tag_ids[:, :max_len]
a_len = torch.sum(torch.sign(ans_ids), 1)
max_len = torch.max(a_len)
ans_ids = ans_ids[:, :max_len]
return q_ids, c_ids, tag_ids, ans_ids, start_positions, end_positions
def train(self):
batch_num = len(self.train_loader)
avg_q_rec = 0
avg_a_rec = 0
avg_kl = 0
global_step = 1
best_f1 = 0
for epoch in range(1, self.args.num_epochs + 1):
start = time.time()
self.model.train()
for step, batch in enumerate(self.train_loader, start=1):
# allocate tensors to device
q_ids, c_ids, tag_ids, _, start_positions, end_positions = self.process_batch(batch)
# forward pass
ans_ids = (tag_ids != 0).long()
loss, q_rec, a_rec, kl = self.model(c_ids, q_ids, ans_ids,
start_positions, end_positions)
self.opt.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
self.opt.step()
global_step += 1
avg_q_rec = cal_running_avg_loss(q_rec.item(), avg_q_rec)
avg_a_rec = cal_running_avg_loss(a_rec.item(), avg_a_rec)
avg_kl = cal_running_avg_loss(kl.item(), avg_kl)
msg = "{}/{} {} - ETA : {} - Q recon: {:.4f}, A recon: {:.4f}, kl: {:.4f}" \
.format(step, batch_num, progress_bar(step, batch_num),
eta(start, step, batch_num), avg_q_rec, avg_a_rec, avg_kl)
print(msg, end="\r")
if not self.args.debug:
eval_dict = self.eval(msg)
f1 = eval_dict["f1"]
em = eval_dict["exact_match"]
print("Epoch {} took {} - final Q-rec : {:.3f}, final A-rec: {:.3f}, "
"F1 : {:.2f}, EM: {:.2f} "
.format(epoch, user_friendly_time(time_since(start)),
avg_q_rec, avg_a_rec, f1, em))
if f1 > best_f1:
best_f1 = f1
self.save_model_kl(epoch, f1, em)
@staticmethod
def get_seq_len(input_ids, eos_id):
# input_ids: [b, t]
# eos_id : scalar
mask = (input_ids == eos_id).byte()
num_eos = torch.sum(mask, 1)
# change Tensor to cpu because torch.argmax works differently in cuda and cpu
# but np.argmax is consistent it returns the first index of the maximum element
mask = mask.cpu().numpy()
indices = np.argmax(mask, 1)
# convert numpy array to Tensor
seq_len = torch.LongTensor(indices).to(input_ids.device)
# in case there is no eos in the sequence
max_len = input_ids.size(1)
seq_len = seq_len.masked_fill(num_eos == 0, max_len - 1)
# +1 for eos
seq_len = seq_len + 1
return seq_len
def save_model_kl(self, epoch, nll, kl):
nll = round(nll, 2)
kl = round(kl, 2)
save_file = os.path.join(self.save_dir, "{}_{:.2f}_{:.2f}".format(epoch, nll, kl))
state_dict = self.model.state_dict()
torch.save(state_dict, save_file)
def eval(self, msg):
num_val_batches = len(self.dev_loader)
all_results = []
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])
example_index = -1
self.model.eval()
for i, batch in enumerate(self.dev_loader, start=1):
q_ids, c_ids, tag_ids, _, _, _ = self.process_batch(batch)
ans_ids = (tag_ids != 0).long()
with torch.no_grad():
batch_start_logits, batch_end_logits = self.model.recon_ans(c_ids, q_ids, ans_ids)
batch_size = batch_start_logits.size(0)
for j in range(batch_size):
example_index += 1
start_logits = batch_start_logits[j].detach().cpu().tolist()
end_logits = batch_end_logits[j].detach().cpu().tolist()
eval_feature = self.eval_features[example_index]
unique_id = int(eval_feature.unique_id)
all_results.append(RawResult(unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits))
msg2 = "{} => Evaluating :{}/{}".format(msg, i, num_val_batches)
print(msg2, end="\r")
output_prediction_file = os.path.join(self.save_dir, "recon_pred.json")
write_predictions(self.eval_examples, self.eval_features, all_results,
n_best_size=20, max_answer_length=30, do_lower_case=True,
output_prediction_file=output_prediction_file,
verbose_logging=False,
version_2_with_negative=False,
null_score_diff_threshold=0,
noq_position=True)
with open(self.args.dev_file) as f:
data_json = json.load(f)
dataset = data_json["data"]
with open(output_prediction_file) as prediction_file:
predictions = json.load(prediction_file)
results = evaluate(dataset, predictions)
return results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment