Skip to content

Instantly share code, notes, and snippets.

@seanie12
Created September 3, 2019 11:57
Show Gist options
  • Save seanie12/a41d209153c90f63350344f225a19d21 to your computer and use it in GitHub Desktop.
Save seanie12/a41d209153c90f63350344f225a19d21 to your computer and use it in GitHub Desktop.
class RLTrainer(CatTrainer):
def __init__(self, args):
super(RLTrainer, self).__init__(args)
self.qa_model = BiDAF(embedding_size=100,
vocab_size=self.vocab_size,
hidden_size=args.qa_hidden_size,
drop_prob=0.2)
state_dict = torch.load(args.qa_file, map_location="cpu")
self.qa_model.load_state_dict(state_dict)
self.qa_model = self.qa_model.to(self.device)
self.scheduler = lr_scheduler.LambdaLR(self.opt, lambda s: 1.)
self.ema = EMA(self.qa_model, 0.999)
params = self.qa_model.parameters()
self.qa_opt = optim.Adam(params, self.args.qa_lr)
pi_params = self.model.prior_encoder.parameters()
self.pi_opt = optim.Adam(pi_params, self.args.lr)
def init_model(self, args):
# QG model
sos_id = self.tokenizer.vocab["[CLS]"]
eos_id = self.tokenizer.vocab["[SEP]"]
model = DiscreteVAE(padding_idx=0,
sos_id=sos_id,
eos_id=eos_id,
bert_model="bert-base-uncased",
ntokens=len(self.tokenizer.vocab),
nhidden=512,
nlayers=1,
dropout=0.2,
nz=20,
nzdim=10,
freeze=self.args.freeze,
copy=True)
model = model.to(self.device)
state_dict = torch.load(self.args.qg_file, map_location="cpu")
model.load_state_dict(state_dict)
return model
def get_opt(self):
params = self.model.parameters()
opt = optim.Adam(params, self.args.lr)
return opt
def process_batch(self, batch):
batch = tuple(t.to(self.device) for t in batch)
q_ids, c_ids, tag_ids, ans_ids, start_positions, end_positions = batch
q_len = torch.sum(torch.sign(q_ids), 1)
max_len = torch.max(q_len)
q_ids = q_ids[:, :max_len]
c_len = torch.sum(torch.sign(c_ids), 1)
max_len = torch.max(c_len)
c_ids = c_ids[:, :max_len]
tag_ids = tag_ids[:, :max_len]
a_len = torch.sum(torch.sign(ans_ids), 1)
max_len = torch.max(a_len)
ans_ids = ans_ids[:, :max_len]
return q_ids, c_ids, tag_ids, ans_ids, start_positions, end_positions
def train(self):
batch_num = len(self.train_loader)
global_step = 1
avg_qa_loss = 0
avg_kl = 0
avg_adv_loss = 0
kl_div = nn.KLDivLoss(reduction="batchmean")
best_f1 = 0
for epoch in range(1, self.args.num_epochs + 1):
start = time.time()
self.model.train()
self.qa_model.train()
for step, batch in enumerate(self.train_loader, start=1):
# allocate tensors to device
q_ids, c_ids, tag_ids, _, \
start_positions, end_positions = self.process_batch(batch)
ans_ids = (tag_ids != 0).long()
real_loss = self.qa_model(c_ids, q_ids,
start_positions=start_positions,
end_positions=end_positions)
# generate question from prior
with torch.no_grad():
# sample z from prior distribution
prior_z_logits, _ = self.model.prior_encoder(c_ids)
# sample (q, a)
gen_q_ids, gen_start_positions, gen_end_positions, \
_, _ = self.model.generate(prior_z_logits, c_ids)
# forward qa model for generated question
fake_loss = self.qa_model(c_ids, gen_q_ids,
start_positions=gen_start_positions,
end_positions=gen_end_positions)
# loss for generated question
qa_loss = real_loss + self.args.adv_lambda * fake_loss
self.qa_opt.zero_grad()
qa_loss.backward()
nn.utils.clip_grad_norm_(self.qa_model.parameters(), 5.0)
self.qa_opt.step()
self.scheduler.step()
self.ema(self.qa_model, global_step)
global_step += 1
# sample z from prior
prior_z_logits, prior_z_probs = self.model.prior_encoder(c_ids)
flatten_prior_logits = prior_z_logits.view(-1, self.args.num_classes)
log_prob_prior = F.log_softmax(flatten_prior_logits, dim=1)
# sample z from posterior
with torch.no_grad():
posterior_z_logits, posterior_z_prob = self.model.posterior_encoder(c_ids, q_ids, ans_ids)
flatten_posterior = posterior_z_logits.view(-1, self.args.num_classes)
# regularization with kl-divergence
prob_posterior = F.softmax(flatten_posterior, dim=1)
kl = kl_div(log_prob_prior, prob_posterior.detach())
with torch.no_grad():
gen_q_ids, gen_start_positions, gen_end_positions, \
latent_z = self.model.sample(prior_z_logits, c_ids)
# reward is qa loss, so pi maximizes qa loss
start_logits, end_logits = self.qa_model(c_ids, gen_q_ids)
reward = self.get_reward(start_logits, end_logits,
gen_start_positions, gen_end_positions)
action_probs = torch.sum(prior_z_probs * latent_z, dim=-1)
log_prob = torch.log(action_probs + 1e-12) # [b,num_vars]
adv_loss = -(reward.unsqueeze(1).detach() * log_prob).sum(1).mean()
# backward pass
pi_loss = adv_loss + kl
pi_loss.backward()
self.pi_opt.step()
self.pi_opt.zero_grad()
avg_qa_loss = cal_running_avg_loss(qa_loss.item(), avg_qa_loss)
avg_kl = cal_running_avg_loss(kl.item(), avg_kl)
avg_adv_loss = cal_running_avg_loss(adv_loss.item(), avg_adv_loss)
msg = "{}/{} {} - ETA : {} - QA loss: {:.4f}, KL: {:.4f}, adv loss: {:.4f}" \
.format(step, batch_num, progress_bar(step, batch_num),
eta(start, step, batch_num), avg_qa_loss, avg_kl, avg_adv_loss)
print(msg, end="\r")
if not self.args.debug:
result_dict = self.eval(msg)
f1 = result_dict["f1"]
em = result_dict["exact_match"]
print("Epoch {} took {} - F1: {:.4f}, EM: {:.4f},"
.format(epoch, user_friendly_time(time_since(start)), f1, em))
if f1 > best_f1:
best_f1 = f1
self.save_qa_model(epoch, f1, em)
self.save_model(epoch, f1)
@staticmethod
def compute_fake_loss(start_logits, end_logits, context_len, context_mask):
batch_size, time_step = start_logits.size()
uniform_dist = torch.ones(batch_size, device=context_len.device) / context_len.float()
uniform_dist = uniform_dist.unsqueeze(1)
uniform_dist = uniform_dist.repeat([1, time_step]).masked_fill(context_mask, 0)
kl_div = nn.KLDivLoss(reduction="batchmean")
start_log_prob = F.log_softmax(start_logits, dim=1)
end_log_prob = F.log_softmax(end_logits, dim=1)
start_loss = kl_div(start_log_prob, uniform_dist)
end_loss = kl_div(end_log_prob, uniform_dist)
loss = 0.5 * (start_loss + end_loss)
return loss
@staticmethod
def get_seq_len(input_ids, eos_id):
# input_ids: [b, t]
# eos_id : scalar
mask = (input_ids == eos_id).byte()
num_eos = torch.sum(mask, 1)
# change Tensor to cpu because torch.argmax works differently in cuda and cpu
# but np.argmax is consistent it returns the first index of the maximum element
mask = mask.cpu().numpy()
indices = np.argmax(mask, 1)
# convert numpy array to Tensor
seq_len = torch.LongTensor(indices).to(input_ids.device)
# in case there is no eos in the sequence
max_len = input_ids.size(1)
seq_len = seq_len.masked_fill(num_eos == 0, max_len - 1)
# +1 for eos
seq_len = seq_len + 1
return seq_len
@staticmethod
def get_reward(start_logits, end_logits, start_positions, end_positions):
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index, reduction="none")
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
return total_loss
@staticmethod
def post_process(q_ids, q_len, c_ids, cls_id):
batch_size = q_ids.size(0)
max_q_len = torch.max(q_len)
cls_ids = cls_id * torch.ones((batch_size, 1), device=q_ids.device, dtype=torch.long)
all_input_ids = []
all_seg_ids = []
for i in range(batch_size):
q_length = q_len[i]
q = q_ids[i, :q_length] # exclude pad tokens
c = c_ids[i, 1:] # exclude [CLS]
# input ids
pads = torch.zeros((max_q_len - q_length), device=q_ids.device, dtype=torch.long)
input_ids = torch.cat([q, c, pads], dim=0)
all_input_ids.append(input_ids)
# segment ids
zeros = torch.zeros_like(q)
ones = torch.ones_like(c)
seg_ids = torch.cat([zeros, ones, pads], dim=0)
all_seg_ids.append(seg_ids)
all_input_ids = torch.stack(all_input_ids, dim=0)
all_input_ids = torch.cat([cls_ids, all_input_ids], dim=1)
# segment id for cls
zeros = torch.zeros_like(cls_ids)
all_seg_ids = torch.stack(all_seg_ids, dim=0)
all_seg_ids = torch.cat([zeros, all_seg_ids], dim=1)
# attention mask
mask = (all_input_ids == 0).byte()
all_seg_ids = all_seg_ids.masked_fill(mask, 0)
attention_mask = 1 - mask
return all_input_ids, all_seg_ids, attention_mask
def save_qa_model(self, epoch, f1, em):
f1 = round(f1, 2)
em = round(em, 2)
save_file = os.path.join(self.save_dir, "{}_{:.2f}_{:.2f}".format(epoch, f1, em))
state_dict = self.qa_model.state_dict()
torch.save(state_dict, save_file)
def eval(self, msg):
self.ema.assign(self.qa_model)
self.qa_model.eval()
all_results = []
example_index = -1
num_val_batches = len(self.dev_loader)
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])
for i, batch in enumerate(self.dev_loader, start=1):
q_ids, c_ids, tag_ids, ans_ids, _, _ = self.process_batch(batch)
with torch.no_grad():
batch_start_logits, batch_end_logits = self.qa_model(c_ids, q_ids)
batch_size = batch_end_logits.size(0)
for j in range(batch_size):
example_index += 1
start_logits = batch_start_logits[j].detach().cpu().tolist()
end_logits = batch_end_logits[j].detach().cpu().tolist()
eval_feature = self.eval_features[example_index]
unique_id = int(eval_feature.unique_id)
all_results.append(RawResult(unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits))
msg2 = "{} => Evaluating :{}/{}".format(msg, i, num_val_batches)
print(msg2, end="\r")
output_prediction_file = os.path.join(self.save_dir, "adv_prediction.json")
write_predictions(self.eval_examples, self.eval_features, all_results,
n_best_size=20, max_answer_length=30, do_lower_case=True,
output_prediction_file=output_prediction_file,
verbose_logging=False,
version_2_with_negative=False,
null_score_diff_threshold=0,
noq_position=True)
with open(self.args.dev_file) as f:
data_json = json.load(f)
dataset = data_json["data"]
with open(output_prediction_file) as prediction_file:
predictions = json.load(prediction_file)
results = evaluate(dataset, predictions)
self.qa_model.train()
self.ema.resume(self.qa_model)
return results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment