Skip to content

Instantly share code, notes, and snippets.

@yiidtw
Created October 18, 2018 15:25
Show Gist options
  • Save yiidtw/088b20b07fe9d2562101800968cfb3ac to your computer and use it in GitHub Desktop.
Save yiidtw/088b20b07fe9d2562101800968cfb3ac to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "keras_imdb_lstm.ipynb",
"version": "0.3.2",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "TPU"
},
"cells": [
{
"metadata": {
"id": "QDWBkKp_4Ii7",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "9272cfc0-4987-4d09-f96a-6de683d8a476"
},
"cell_type": "code",
"source": [
"# fork from https://github.com/keras-team/keras/blob/master/examples/imdb_lstm.py\n",
"'''Trains an LSTM model on the IMDB sentiment classification task.\n",
"The dataset is actually too small for LSTM to be of any advantage\n",
"compared to simpler, much faster methods such as TF-IDF + LogReg.\n",
"# Notes\n",
"- RNNs are tricky. Choice of batch size is important,\n",
"choice of loss and optimizer is critical, etc.\n",
"Some configurations won't converge.\n",
"- LSTM loss decrease patterns during training can be quite different\n",
"from what you see with CNNs/MLPs/etc.\n",
"'''\n",
"from __future__ import print_function\n",
"\n",
"from keras.preprocessing import sequence\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Embedding\n",
"from keras.layers import LSTM\n",
"from keras.datasets import imdb"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
],
"name": "stderr"
}
]
},
{
"metadata": {
"id": "33l0XIii5g-n",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"outputId": "0323da4b-ecde-4921-b3e2-01e0832b13ef"
},
"cell_type": "code",
"source": [
"max_features = 20000\n",
"# cut texts after this number of words (among top max_features most common words)\n",
"maxlen = 80\n",
"batch_size = 32\n",
"\n",
"print('Loading data...')\n",
"(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)\n",
"print(len(x_train), 'train sequences')\n",
"print(len(x_test), 'test sequences')"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"Loading data...\n",
"25000 train sequences\n",
"25000 test sequences\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "z4uMELzs50zN",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"outputId": "25727658-a079-4074-b21b-f6d21f2245b5"
},
"cell_type": "code",
"source": [
"print('Pad sequences (samples x time)')\n",
"x_train = sequence.pad_sequences(x_train, maxlen=maxlen)\n",
"x_test = sequence.pad_sequences(x_test, maxlen=maxlen)\n",
"print('x_train shape:', x_train.shape)\n",
"print('x_test shape:', x_test.shape)"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"Pad sequences (samples x time)\n",
"x_train shape: (25000, 80)\n",
"x_test shape: (25000, 80)\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "MzdFSDcG5kP4",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "e82f032e-e9d7-48fb-9551-efd5ed28bdb1"
},
"cell_type": "code",
"source": [
"print('Build model...')\n",
"model = Sequential()\n",
"model.add(Embedding(max_features, 128))\n",
"model.add(LSTM(128, dropout=0.2, recurrent_dropout=0.2))\n",
"model.add(Dense(1, activation='sigmoid'))\n",
"\n",
"# try using different optimizers and different optimizer configs\n",
"model.compile(loss='binary_crossentropy',\n",
" optimizer='adam',\n",
" metrics=['accuracy'])"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"Build model...\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "Wc531CLR5qxI",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 598
},
"outputId": "2bf7474b-4372-42c1-a1b3-17605ee0b8a7"
},
"cell_type": "code",
"source": [
"print('Train...')\n",
"model.fit(x_train, y_train,\n",
" batch_size=batch_size,\n",
" epochs=15,\n",
" validation_data=(x_test, y_test))"
],
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"text": [
"Train...\n",
"Train on 25000 samples, validate on 25000 samples\n",
"Epoch 1/15\n",
"25000/25000 [==============================] - 159s 6ms/step - loss: 0.4603 - acc: 0.7816 - val_loss: 0.4377 - val_acc: 0.8022\n",
"Epoch 2/15\n",
"25000/25000 [==============================] - 157s 6ms/step - loss: 0.3028 - acc: 0.8764 - val_loss: 0.3759 - val_acc: 0.8344\n",
"Epoch 3/15\n",
"25000/25000 [==============================] - 157s 6ms/step - loss: 0.2229 - acc: 0.9121 - val_loss: 0.4648 - val_acc: 0.8110\n",
"Epoch 4/15\n",
"25000/25000 [==============================] - 157s 6ms/step - loss: 0.1566 - acc: 0.9412 - val_loss: 0.4597 - val_acc: 0.8304\n",
"Epoch 5/15\n",
"25000/25000 [==============================] - 157s 6ms/step - loss: 0.1090 - acc: 0.9596 - val_loss: 0.6787 - val_acc: 0.8198\n",
"Epoch 6/15\n",
"25000/25000 [==============================] - 156s 6ms/step - loss: 0.0852 - acc: 0.9690 - val_loss: 0.6086 - val_acc: 0.8231\n",
"Epoch 7/15\n",
"25000/25000 [==============================] - 157s 6ms/step - loss: 0.0598 - acc: 0.9792 - val_loss: 0.7145 - val_acc: 0.8192\n",
"Epoch 8/15\n",
"25000/25000 [==============================] - 157s 6ms/step - loss: 0.0435 - acc: 0.9860 - val_loss: 0.8142 - val_acc: 0.8016\n",
"Epoch 9/15\n",
"25000/25000 [==============================] - 157s 6ms/step - loss: 0.0350 - acc: 0.9877 - val_loss: 0.9446 - val_acc: 0.8176\n",
"Epoch 10/15\n",
"25000/25000 [==============================] - 156s 6ms/step - loss: 0.0277 - acc: 0.9914 - val_loss: 0.8735 - val_acc: 0.8134\n",
"Epoch 11/15\n",
"25000/25000 [==============================] - 157s 6ms/step - loss: 0.0230 - acc: 0.9926 - val_loss: 0.9833 - val_acc: 0.8079\n",
"Epoch 12/15\n",
"25000/25000 [==============================] - 156s 6ms/step - loss: 0.0179 - acc: 0.9949 - val_loss: 1.0808 - val_acc: 0.8093\n",
"Epoch 13/15\n",
"25000/25000 [==============================] - 157s 6ms/step - loss: 0.0108 - acc: 0.9968 - val_loss: 1.1897 - val_acc: 0.8026\n",
"Epoch 14/15\n",
"25000/25000 [==============================] - 156s 6ms/step - loss: 0.0087 - acc: 0.9972 - val_loss: 1.0943 - val_acc: 0.8121\n",
"Epoch 15/15\n",
"25000/25000 [==============================] - 157s 6ms/step - loss: 0.0152 - acc: 0.9951 - val_loss: 1.1105 - val_acc: 0.8165\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f3acfe67208>"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"metadata": {
"id": "iSUcTbCK6u8p",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"outputId": "15c9ee51-815a-4d47-a9f6-cec061f75af5"
},
"cell_type": "code",
"source": [
"score, acc = model.evaluate(x_test, y_test,\n",
" batch_size=batch_size)\n",
"print('Test score:', score)\n",
"print('Test accuracy:', acc)"
],
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": [
"25000/25000 [==============================] - 26s 1ms/step\n",
"Test score: 1.1104922486197948\n",
"Test accuracy: 0.81648\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "BOg7XIUtD3xz",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment