Skip to content

Instantly share code, notes, and snippets.

@aravindpai aravindpai/inference.py
Last active Jan 27, 2020

Embed
What would you like to do?
#load weights
path='/content/saved_weights.pt'
model.load_state_dict(torch.load(path));
model.eval();
#inference
import spacy
nlp = spacy.load('en')
def predict(model, sentence):
tokenized = [tok.text for tok in nlp.tokenizer(sentence)] #tokenize the sentence
indexed = [TEXT.vocab.stoi[t] for t in tokenized] #convert to integer sequence
length = [len(indexed)] #compute no. of words
tensor = torch.LongTensor(indexed).to(device) #convert to tensor
tensor = tensor.unsqueeze(1).T #reshape in form of batch,no. of words
length_tensor = torch.LongTensor(length) #convert to tensor
prediction = model(tensor, length_tensor) #prediction
return prediction.item()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.