Last active
August 4, 2019 16:56
-
-
Save samarth-agrawal-86/90476a428c64059956ba5050298892f6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from string import punctuation | |
def tokenize_review(test_review): | |
test_review = test_review.lower() # lowercase | |
# get rid of punctuation | |
test_text = ''.join([c for c in test_review if c not in punctuation]) | |
# splitting by spaces | |
test_words = test_text.split() | |
# tokens | |
test_ints = [] | |
test_ints.append([vocab_to_int[word] for word in test_words]) | |
return test_ints | |
# test code and generate tokenized review | |
test_ints = tokenize_review(test_review_neg) | |
print(test_ints) | |
# test sequence padding | |
seq_length=200 | |
features = pad_features(test_ints, seq_length) | |
print(features) | |
# test conversion to tensor and pass into your model | |
feature_tensor = torch.from_numpy(features) | |
print(feature_tensor.size()) | |
def predict(net, test_review, sequence_length=200): | |
net.eval() | |
# tokenize review | |
test_ints = tokenize_review(test_review) | |
# pad tokenized sequence | |
seq_length=sequence_length | |
features = pad_features(test_ints, seq_length) | |
# convert to tensor to pass into your model | |
feature_tensor = torch.from_numpy(features) | |
batch_size = feature_tensor.size(0) | |
# initialize hidden state | |
h = net.init_hidden(batch_size) | |
if(train_on_gpu): | |
feature_tensor = feature_tensor.cuda() | |
# get the output from the model | |
output, h = net(feature_tensor, h) | |
# convert output probabilities to predicted class (0 or 1) | |
pred = torch.round(output.squeeze()) | |
# printing output value, before rounding | |
print('Prediction value, pre-rounding: {:.6f}'.format(output.item())) | |
# print custom response | |
if(pred.item()==1): | |
print("Positive review detected!") | |
else: | |
print("Negative review detected.") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment