Skip to content

Instantly share code, notes, and snippets.

@morus12
Created April 5, 2016 14:15
Show Gist options
  • Save morus12/846eedb8c90423ecd8b28c2440be6d6d to your computer and use it in GitHub Desktop.
Save morus12/846eedb8c90423ecd8b28c2440be6d6d to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
"""
Simple example using LSTM recurrent neural network to classify IMDB
sentiment dataset.
References:
- Long Short Term Memory, Sepp Hochreiter & Jurgen Schmidhuber, Neural
Computation 9(8): 1735-1780, 1997.
- Andrew L. Maas, Raymond E. Daly, Peter T. Pham, Dan Huang, Andrew Y. Ng,
and Christopher Potts. (2011). Learning Word Vectors for Sentiment
Analysis. The 49th Annual Meeting of the Association for Computational
Linguistics (ACL 2011).
Links:
- http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
- http://ai.stanford.edu/~amaas/data/sentiment/
"""
from __future__ import division, print_function, absolute_import
import tflearn
from tflearn.data_utils import to_categorical, pad_sequences
from tflearn.datasets import imdb
# IMDB Dataset loading
train, val, test = imdb.load_data(path='imdb.pkl', maxlen=200,
n_words=20000)
trainX, trainY = train
valX, valY = val
testX, testY = test
# Data preprocessing
# Sequence padding
trainX = pad_sequences(trainX, maxlen=50, value=0.)
testX = pad_sequences(testX, maxlen=50, value=0.)
# Converting labels to binary vectors
trainY = to_categorical(trainY, nb_classes=2)
valY = to_categorical(valY, nb_classes=2)
testY = to_categorical(testY, nb_classes=2)
# Network building
net = tflearn.input_data([None, 50])
net = tflearn.embedding(net, input_dim=20000, output_dim=128)
net = tflearn.lstm(net, 128)
net = tflearn.dropout(net, 0.5)
net = tflearn.fully_connected(net, 2, activation='softmax')
net = tflearn.regression(net, optimizer='adam',
loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, clip_gradients=0., tensorboard_verbose=0)
model.fit(trainX, trainY, validation_set=(testX, testY), show_metric=True,
batch_size=128, n_epoch=1)
print(model.predict(pad_sequences([[1,2,3]], maxlen=50, value=0.)))
model.save('./dump')
model.load('./dump')
print(model.predict(pad_sequences([[1,2,3]], maxlen=50, value=0.)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment