Skip to content

Instantly share code, notes, and snippets.

@mmsamiei
Created August 22, 2020 10:41
Show Gist options
  • Save mmsamiei/6516163fff93db5645ac6857111e6054 to your computer and use it in GitHub Desktop.
Save mmsamiei/6516163fff93db5645ac6857111e6054 to your computer and use it in GitHub Desktop.
refrences = []
predictions = []
kwargs = {'num_beams':1,
'num_return_sequences':1,'temperature':1, 'max_length':50,'early_stopping':True,
'no_repeat_ngram_size':3,
'decoder_start_token_id':0,
'eos_token_id':2
#'do_sample':True
}
for batch_idx, batch in enumerate(valid_loader):
pair_batch, segment_batch, response_batch = batch
pair_batch = pair_batch.to(dev)
segment_batch = segment_batch.to(dev)
response_batch = response_batch.to(dev)
generateds = model.seq2seq.generate(pair_batch, **kwargs)
new_prediction = (dec_tokenizer.batch_decode(generateds, skip_special_tokens=True))
new_refrence = dec_tokenizer.batch_decode(response_batch, skip_special_tokens=True)
refrences.extend(new_refrence)
predictions.extend(new_prediction)
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment