Last active
March 20, 2021 12:28
-
-
Save PurpleBooth/921706b2ab46d233997fa4dd76ddaef7 to your computer and use it in GitHub Desktop.
Hacking about with the IMDB Keras data set. Mostly from https://www.tensorflow.org/tutorials/keras/basic_text_classification
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import, division, print_function, unicode_literals | |
import tensorflow as tf | |
from tensorflow import keras | |
import matplotlib.pyplot as plt | |
import numpy as np | |
print(tf.__version__) | |
imdb = keras.datasets.imdb | |
(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000) | |
print("Training entries: {}, labels: {}".format(len(train_data), len(train_labels))) | |
print(train_data[0]) | |
len(train_data[0]), len(train_data[1]) | |
# A dictionary mapping words to an integer index | |
word_index = imdb.get_word_index() | |
# The first indices are reserved | |
word_index = {k: (v + 3) for k, v in word_index.items()} | |
word_index["<PAD>"] = 0 | |
word_index["<START>"] = 1 | |
word_index["<UNK>"] = 2 # unknown | |
word_index["<UNUSED>"] = 3 | |
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()]) | |
def decode_review(text): | |
return ' '.join([reverse_word_index.get(i, '?') for i in text]) | |
decode_review(train_data[0]) | |
train_data = keras.preprocessing.sequence.pad_sequences(train_data, | |
value=word_index["<PAD>"], | |
padding='post', | |
maxlen=256) | |
test_data = keras.preprocessing.sequence.pad_sequences(test_data, | |
value=word_index["<PAD>"], | |
padding='post', | |
maxlen=256) | |
len(train_data[0]), len(train_data[1]) | |
print(train_data[0]) | |
# input shape is the vocabulary count used for the movie reviews (10,000 words) | |
vocab_size = 10000 | |
model = keras.Sequential() | |
model.add(keras.layers.Embedding(vocab_size, 16)) | |
model.add(keras.layers.GlobalAveragePooling1D()) | |
model.add(keras.layers.Dense(16, activation=tf.nn.relu)) | |
model.add(keras.layers.Dense(1, activation=tf.nn.sigmoid)) | |
model.summary() | |
model.compile(optimizer='adam', | |
loss='binary_crossentropy', | |
metrics=['acc']) | |
x_val = train_data[:10000] | |
partial_x_train = train_data[10000:] | |
y_val = train_labels[:10000] | |
partial_y_train = train_labels[10000:] | |
history = model.fit(partial_x_train, | |
partial_y_train, | |
epochs=40, | |
batch_size=512, | |
validation_data=(x_val, y_val), | |
verbose=1) | |
results = model.evaluate(test_data, test_labels) | |
print(results) | |
model.save("model.h5") | |
history_dict = history.history | |
history_dict.keys() | |
acc = history_dict['acc'] | |
val_acc = history_dict['val_acc'] | |
loss = history_dict['loss'] | |
val_loss = history_dict['val_loss'] | |
epochs = range(1, len(acc) + 1) | |
# "bo" is for "blue dot" | |
plt.plot(epochs, loss, 'bo', label='Training loss') | |
# b is for "solid blue line" | |
plt.plot(epochs, val_loss, 'b', label='Validation loss') | |
plt.title('Training and validation loss') | |
plt.xlabel('Epochs') | |
plt.ylabel('Loss') | |
plt.legend() | |
plt.show() | |
plt.clf() # clear figure | |
plt.plot(epochs, acc, 'bo', label='Training acc') | |
plt.plot(epochs, val_acc, 'b', label='Validation acc') | |
plt.title('Training and validation accuracy') | |
plt.xlabel('Epochs') | |
plt.ylabel('Accuracy') | |
plt.legend() | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[[source]] | |
name = "pypi" | |
url = "https://pypi.org/simple" | |
verify_ssl = true | |
[dev-packages] | |
[packages] | |
keras = "*" | |
tensorflow = "*" | |
tf_nightly = "*" | |
matplotlib = "*" | |
nltk = "*" | |
[requires] | |
python_version = "3.7" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import, division, print_function, unicode_literals | |
import tensorflow as tf | |
from tensorflow import keras | |
from keras.datasets import imdb | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from keras.preprocessing import sequence | |
import nltk | |
from nltk.tokenize import word_tokenize | |
nltk.download('punkt') | |
print(tf.__version__) | |
word2index = imdb.get_word_index() | |
test = [] | |
text = "awful film do not see" | |
for word in word_tokenize(text): | |
if word in word2index: | |
test.append(word2index[word]) | |
else: | |
test.append(2) | |
test = sequence.pad_sequences([test], maxlen=256, | |
value=0, | |
padding='post', ) | |
model = keras.models.load_model("model.h5") | |
predictions_single = model.predict(test) | |
print(predictions_single) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment