Skip to content

Instantly share code, notes, and snippets.

@rajy4683
Last active February 8, 2021 16:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rajy4683/52ffdca66bc8bca06dd7570acc3d6a79 to your computer and use it in GitHub Desktop.
Save rajy4683/52ffdca66bc8bca06dd7570acc3d6a79 to your computer and use it in GitHub Desktop.
"""
Used for inference. Make sure model object is loaded
"""
def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 50):
model.eval() #### Trained model object
### Generates tokens for the source sentence
if isinstance(sentence, str):
nlp = spacy.load('de')
tokens = [token.text.lower() for token in nlp(sentence)]
else:
tokens = [token.lower() for token in sentence]
### Pad by <sos> token and <eos> tokens
tokens = [src_field.init_token] + tokens + [src_field.eos_token]
### Numericalize the input src token. Needs the src vocab object
src_indexes = [src_field.vocab.stoi[token] for token in tokens]
### Convert to tensors
src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
### Pass through encoder to get the two output vectors
with torch.no_grad():
encoder_conved, encoder_combined = model.encoder(src_tensor)
### Numericalize the input src token. Needs the trg vocab object
trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
### Loop and feed the trg tokens till max_len or <eos> prediction
for i in range(max_len):
### Convert trg tokens to respective tensors
trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
with torch.no_grad():
output, attention = model.decoder(trg_tensor, encoder_conved, encoder_combined)
pred_token = output.argmax(2)[:,-1].item()
trg_indexes.append(pred_token)
### Check if <eos> has been predicted yet
if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
break
trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
return trg_tokens[1:], attention
##### Generates attention visualization
def display_attention(sentence, translation, attention):
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111)
attention = attention.squeeze(0).cpu().detach().numpy()
cax = ax.matshow(attention, cmap='bone')
ax.tick_params(labelsize=15)
ax.set_xticklabels(['']+['<sos>']+[t.lower() for t in sentence]+['<eos>'],
rotation=45)
ax.set_yticklabels(['']+translation)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.show()
plt.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment