Skip to content

Instantly share code, notes, and snippets.

@baudm

baudm/test.py Secret

Created October 9, 2021 05:59
Show Gist options
  • Save baudm/6d2bb86290b0ef30685b3288503879f6 to your computer and use it in GitHub Desktop.
Save baudm/6d2bb86290b0ef30685b3288503879f6 to your computer and use it in GitHub Desktop.
ABINet LM test script
#!/usr/bin/env python3
import string
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from modules.model_language import BCNLanguage
from utils import Config
lm = BCNLanguage(Config('configs/pretrain_language_model.yaml'))
lm.load('workdir/pretrain-language-model/pretrain-language-model.pth')
lm = lm.eval()
word = input('Target: ')
itos = ['<null>'] + list(string.ascii_lowercase + '1234567890')
stoi = {s: i for i, s in enumerate(itos)}
max_len = 25
target = [torch.as_tensor([stoi[c] for c in word]), torch.arange(max_len + 1)]
target = pad_sequence(target, batch_first=True, padding_value=0)[:1] # exclude dummy target
lengths = torch.as_tensor([len(word) + 1])
tgt = F.one_hot(target, len(itos)).float()
print(target, target.shape, lengths, lengths.shape)
res = lm(tgt, lengths)
pred = res['logits'].argmax(-1)
print(pred)
decoded = ''.join([itos[i] for i in pred.squeeze()])
decoded = decoded[:decoded.find('<null>')]
print(decoded)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment