Skip to content

Instantly share code, notes, and snippets.

@Sandeep42
Created August 9, 2017 08:57
Show Gist options
  • Save Sandeep42/b4af6aaf1fa92a69f26558fda246dd6f to your computer and use it in GitHub Desktop.
Save Sandeep42/b4af6aaf1fa92a69f26558fda246dd6f to your computer and use it in GitHub Desktop.
prediction_single.py
def get_predictions_single(val_tokens, word_attn_model, sent_attn_model):
state_word = word_attn_model.init_hidden().cuda()
state_sent = sent_attn_model.init_hidden().cuda()
s = None
word_attns = []
for sent in val_tokens:
sent = Variable(torch.from_numpy(np.array(sent)).unsqueeze(0).transpose(0,1), requires_grad= False, volatile = True).cuda()
print sent.size()
print state_word.size()
_s, state_word, _ = word_attn_model(sent, state_word)
if(s is None):
s = _s
else:
s = torch.cat((s,_s),0)
word_attns.append(word_attns_)
y_pred, state_sent, sent_attns = sent_attn_model(s, state_sent)
return y_pred, word_attns, sent_attns
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment