Skip to content

Instantly share code, notes, and snippets.

@eisenjulian
Last active April 9, 2018 11:48
Show Gist options
  • Save eisenjulian/4e73fde6303396d9b3b3daf7ff303c39 to your computer and use it in GitHub Desktop.
Save eisenjulian/4e73fde6303396d9b3b3daf7ff303c39 to your computer and use it in GitHub Desktop.
def text_to_index(sentence):
# Remove punctuation characters except for the apostrophe
translator = str.maketrans('', '', string.punctuation.replace("'", ''))
tokens = sentence.translate(translator).lower().split()
return np.array([1] + [word_index[t] if t in word_index else 2 for t in tokens])
def print_predictions(sentences, classifier):
indexes = [text_to_index(sentence) for sentence in sentences]
x = sequence.pad_sequences(indexes,
maxlen=sentence_size,
padding='post',
value=0)
length = np.array([min(len(x), sentence_size) for x in indexes])
predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": x, "len": length}, shuffle=False)
predictions = [p['logistic'][0] for p in classifier.predict(input_fn=predict_input_fn)]
print(predictions)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment