-
-
Save rajy4683/52ffdca66bc8bca06dd7570acc3d6a79 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Used for inference. Make sure model object is loaded | |
""" | |
def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 50): | |
model.eval() #### Trained model object | |
### Generates tokens for the source sentence | |
if isinstance(sentence, str): | |
nlp = spacy.load('de') | |
tokens = [token.text.lower() for token in nlp(sentence)] | |
else: | |
tokens = [token.lower() for token in sentence] | |
### Pad by <sos> token and <eos> tokens | |
tokens = [src_field.init_token] + tokens + [src_field.eos_token] | |
### Numericalize the input src token. Needs the src vocab object | |
src_indexes = [src_field.vocab.stoi[token] for token in tokens] | |
### Convert to tensors | |
src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device) | |
### Pass through encoder to get the two output vectors | |
with torch.no_grad(): | |
encoder_conved, encoder_combined = model.encoder(src_tensor) | |
### Numericalize the input src token. Needs the trg vocab object | |
trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]] | |
### Loop and feed the trg tokens till max_len or <eos> prediction | |
for i in range(max_len): | |
### Convert trg tokens to respective tensors | |
trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
output, attention = model.decoder(trg_tensor, encoder_conved, encoder_combined) | |
pred_token = output.argmax(2)[:,-1].item() | |
trg_indexes.append(pred_token) | |
### Check if <eos> has been predicted yet | |
if pred_token == trg_field.vocab.stoi[trg_field.eos_token]: | |
break | |
trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes] | |
return trg_tokens[1:], attention | |
##### Generates attention visualization | |
def display_attention(sentence, translation, attention): | |
fig = plt.figure(figsize=(10,10)) | |
ax = fig.add_subplot(111) | |
attention = attention.squeeze(0).cpu().detach().numpy() | |
cax = ax.matshow(attention, cmap='bone') | |
ax.tick_params(labelsize=15) | |
ax.set_xticklabels(['']+['<sos>']+[t.lower() for t in sentence]+['<eos>'], | |
rotation=45) | |
ax.set_yticklabels(['']+translation) | |
ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) | |
ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) | |
plt.show() | |
plt.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment