Skip to content

Instantly share code, notes, and snippets.

@aravindpai
Last active January 27, 2020 13:55
Show Gist options
  • Save aravindpai/4af719d02d3802794c2c15b1ad09222c to your computer and use it in GitHub Desktop.
Save aravindpai/4af719d02d3802794c2c15b1ad09222c to your computer and use it in GitHub Desktop.
#load weights
path='/content/saved_weights.pt'
model.load_state_dict(torch.load(path));
model.eval();
#inference
import spacy
nlp = spacy.load('en')
def predict(model, sentence):
tokenized = [tok.text for tok in nlp.tokenizer(sentence)] #tokenize the sentence
indexed = [TEXT.vocab.stoi[t] for t in tokenized] #convert to integer sequence
length = [len(indexed)] #compute no. of words
tensor = torch.LongTensor(indexed).to(device) #convert to tensor
tensor = tensor.unsqueeze(1).T #reshape in form of batch,no. of words
length_tensor = torch.LongTensor(length) #convert to tensor
prediction = model(tensor, length_tensor) #prediction
return prediction.item()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment