Skip to content

Instantly share code, notes, and snippets.

@georgehc
Created November 27, 2019 04:07
Show Gist options
  • Save georgehc/1a3244060ce278c0158ae7d70209d325 to your computer and use it in GitHub Desktop.
Save georgehc/1a3244060ce278c0158ae7d70209d325 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Recurrent Neural Networks Demo\n",
"\n",
"Author: Runshan Fu (Fall 2017 95-865 TA), George H. Chen\n",
"\n",
"In this demo, we will implement RNN models for sentiment analysis on IMDB reviews. We will start from the original review texts and predict the sentiment (positive or negative) for each review. This demo is borrowed from the book *Deep Learning with Python* by Francois Chollet and also uses code from user `mdaoust` in [this stackoverflow post](https://stackoverflow.com/questions/42821330/restore-original-text-from-keras-s-imdb-dataset)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the dataset\n",
"We directly load the data as lists of intergers from keras. We restrict the movie reviews to the top 2000 most common words, and make all the reviews exactly 200 words (truncating as needed or padding with a special padding character; note that truncating works by removing the *start* of the review and keeping only the last 200 words; padding works by adding the special padding character to the *start* of the review)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"from keras.datasets import imdb\n",
"from keras.preprocessing import sequence\n",
"\n",
"# load the dataset and only keep the top words (most frequently occurring)\n",
"vocab_size = 2000\n",
"INDEX_FROM = 2 # for dealing with some special characters; this is a technical thing related to the IMDB dataset\n",
"(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=vocab_size, index_from=INDEX_FROM)\n",
"\n",
"word_to_idx = imdb.get_word_index()\n",
"word_to_idx = {word:(word_idx + INDEX_FROM) for word, word_idx in word_to_idx.items()}\n",
"word_to_idx['<PAD>'] = 0\n",
"word_to_idx['<START>'] = 1\n",
"word_to_idx['<UNK>'] = 2\n",
"\n",
"idx_to_word = {word_idx:word for word, word_idx in word_to_idx.items()}\n",
"\n",
"# turn the lists of integers into a 2D integer tensor of shape `(samples, maxlen)`\n",
"x_train = sequence.pad_sequences(x_train, maxlen=200)\n",
"x_test = sequence.pad_sequences(x_test, maxlen=200)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(25000, 200)\n"
]
}
],
"source": [
"print(x_train.shape)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 4 24 99 42 837 111 49 669 2 8 34 479 283 4\n",
" 149 3 171 111 166 2 335 384 38 3 171 2 1110 16\n",
" 545 37 12 446 3 191 49 15 5 146 2 18 13 21\n",
" 3 1919 2 468 3 21 70 86 11 15 42 529 37 75\n",
" 14 12 1246 3 21 16 514 16 11 15 625 17 2 4\n",
" 61 385 11 7 315 7 105 4 3 2 2 15 479 65\n",
" 2 32 3 129 11 15 37 618 4 24 123 50 35 134\n",
" 47 24 1414 32 5 21 11 214 27 76 51 4 13 406\n",
" 15 81 2 7 3 106 116 2 14 255 3 2 6 2\n",
" 4 722 35 70 42 529 475 25 399 316 45 6 3 2\n",
" 1028 12 103 87 3 380 14 296 97 31 2 55 25 140\n",
" 5 193 2 17 3 225 21 20 133 475 25 479 4 143\n",
" 29 2 17 50 35 27 223 91 24 103 3 225 64 15\n",
" 37 1333 87 11 15 282 4 15 2 112 102 31 14 15\n",
" 2 18 177 31]\n"
]
}
],
"source": [
"print(x_train[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here is an example of a review with positive sentiment (this particular review has been truncated, so we only see the last 200 words of the review):"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"and you could just imagine being there robert <UNK> is an amazing actor and now the same being director <UNK> father came from the same <UNK> island as myself so i loved the fact there was a real <UNK> with this film the witty <UNK> throughout the film were great it was just brilliant so much that i bought the film as soon as it was released for <UNK> and would recommend it to everyone to watch and the <UNK> <UNK> was amazing really <UNK> at the end it was so sad and you know what they say if you cry at a film it must have been good and this definitely was also <UNK> to the two little <UNK> that played the <UNK> of <UNK> and paul they were just brilliant children are often left out of the <UNK> list i think because the stars that play them all <UNK> up are such a big <UNK> for the whole film but these children are amazing and should be <UNK> for what they have done don't you think the whole story was so lovely because it was true and was <UNK> life after all that was <UNK> with us all\n"
]
}
],
"source": [
"print(' '.join(idx_to_word[idx] for idx in x_train[0]))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here is an example of a review with negative sentiment (this particular review has been padded at the beginning; note that when we can see the starting of a review, the first token of the actual review is a special token `<START>`):"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <START> in a far away <UNK> is a planet called <UNK> it's <UNK> people <UNK> <UNK> but the dog people <UNK> war upon these <UNK> loving people and they have no choice but to go to earth and <UNK> people up for food this is one of the <UNK> f <UNK> <UNK> ideas for a movie i've seen leave it to <UNK> <UNK> to make a movie more <UNK> than the already low standard he set in previous films it's like he <UNK> playing in a <UNK> game of <UNK> how low can he go the only <UNK> in the <UNK> are us the viewer mr <UNK> and his silly little <UNK> <UNK> actually has people who still buy this crap br br my grade f br br dvd <UNK> commentary by <UNK> <UNK> the story behind the making of 9 and a half minutes <UNK> minutes 15 seconds of behind the scenes footage <UNK> <UNK> <UNK> and <UNK> for the <UNK> <UNK> girl in gold <UNK> the <UNK> <UNK> ten violent women featuring nudity blood <UNK> of the she <UNK> the <UNK> <UNK>\n"
]
}
],
"source": [
"print(' '.join(idx_to_word[idx] for idx in x_train[-3]))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train[-3]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use pre-trained word embeddings\n",
"We use GloVe embeddings instead of learning our own task-specific word embedding. First download the pre-computed embeddings from 2014 English Wikipedia on https://nlp.stanford.edu/projects/glove/ (specifically the one with 6 billion tokens, `globe.6B.zip`). Unzip it (so that `glove.6B.100d.txt` is located in the directory `./glove/`)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We first create a dictionary that maps each English word to its corresponding 100-dimensional GloVe embedding."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 400000 word vectors.\n"
]
}
],
"source": [
"word_to_embedding = {}\n",
"\n",
"# we will use the 100-dimensional embedding vectors\n",
"with open(\"./glove/glove.6B.100d.txt\") as f:\n",
" # each row represents a word vector\n",
" for line in f:\n",
" values = line.split()\n",
" # the first part is word\n",
" word = values[0]\n",
" # the rest of the values form the embedding vector\n",
" embedding = np.asarray(values[1:], dtype='float32')\n",
" word_to_embedding[word] = embedding\n",
"\n",
"print('Found %s word vectors.' % len(word_to_embedding))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we create an embedding matrix, where the i-th row holds the GloVe embedding for the i-th top word (except where i=0 is the special padding token `<PAD>`, i=1 is the special `<START>` token, and i=2 is the special `<UNK>` token; for these special cases the embedding vector is left as all zeros)."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"embedding_dim = 100\n",
"\n",
"embedding_matrix = np.zeros((vocab_size, embedding_dim))\n",
"for idx in range(vocab_size):\n",
" word = idx_to_word[idx]\n",
" if word in word_to_embedding:\n",
" embedding_matrix[idx] = word_to_embedding[word]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Feedforward network with embeddings"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This first neural net is *not* a recurrent neural net. It does not do anything special to account for time series structure. This net is meant to be a baseline that we compare a recurrent neural net against. To make the comparison somewhat fair, in both cases, the last two layers have the same output dimensions: the second-to-last layer is a Dense layer with 64 neurons and `relu` activation, and the last layer is a Dense layer with 1 neuron and `sigmoid` activation."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_1\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"embedding_1 (Embedding) (None, 200, 100) 200000 \n",
"_________________________________________________________________\n",
"flatten_1 (Flatten) (None, 20000) 0 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 64) 1280064 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 1) 65 \n",
"=================================================================\n",
"Total params: 1,480,129\n",
"Trainable params: 1,480,129\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"from keras.models import Sequential\n",
"from keras.layers import Embedding, Flatten, Dense\n",
"# initialize the model\n",
"feedforward_model = Sequential()\n",
"feedforward_model.add(Embedding(vocab_size, embedding_dim, input_length=200))\n",
"feedforward_model.add(Flatten())\n",
"feedforward_model.add(Dense(64, activation='relu'))\n",
"feedforward_model.add(Dense(1, activation='sigmoid'))\n",
"feedforward_model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_1\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"embedding_1 (Embedding) (None, 200, 100) 200000 \n",
"_________________________________________________________________\n",
"flatten_1 (Flatten) (None, 20000) 0 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 64) 1280064 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 1) 65 \n",
"=================================================================\n",
"Total params: 1,480,129\n",
"Trainable params: 1,280,129\n",
"Non-trainable params: 200,000\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"# load the GloVe embeddings in the model\n",
"feedforward_model.layers[0].set_weights([embedding_matrix])\n",
"# set the embedding layer to be not trainable, so the weights do not change during the training\n",
"feedforward_model.layers[0].trainable = False\n",
"\n",
"feedforward_model.summary() # the summary changes after we turn off training for the 0-th layer (note the last line \"Non-trainable params\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 20000 samples, validate on 5000 samples\n",
"Epoch 1/10\n",
"20000/20000 [==============================] - 14s 686us/step - loss: 0.6309 - acc: 0.6500 - val_loss: 0.5866 - val_acc: 0.6856\n",
"Epoch 2/10\n",
"20000/20000 [==============================] - 13s 651us/step - loss: 0.4821 - acc: 0.7722 - val_loss: 0.5610 - val_acc: 0.7232\n",
"Epoch 3/10\n",
"20000/20000 [==============================] - 13s 660us/step - loss: 0.4004 - acc: 0.8179 - val_loss: 0.5807 - val_acc: 0.7176\n",
"Epoch 4/10\n",
"20000/20000 [==============================] - 13s 673us/step - loss: 0.3227 - acc: 0.8597 - val_loss: 0.6451 - val_acc: 0.7126\n",
"Epoch 5/10\n",
"20000/20000 [==============================] - 13s 653us/step - loss: 0.2503 - acc: 0.8947 - val_loss: 0.7348 - val_acc: 0.7100\n",
"Epoch 6/10\n",
"20000/20000 [==============================] - 13s 629us/step - loss: 0.1787 - acc: 0.9286 - val_loss: 0.8210 - val_acc: 0.7106\n",
"Epoch 7/10\n",
"20000/20000 [==============================] - 14s 686us/step - loss: 0.1028 - acc: 0.9650 - val_loss: 0.9893 - val_acc: 0.7102\n",
"Epoch 8/10\n",
"20000/20000 [==============================] - 13s 674us/step - loss: 0.0681 - acc: 0.9772 - val_loss: 1.1859 - val_acc: 0.7082\n",
"Epoch 9/10\n",
"20000/20000 [==============================] - 13s 653us/step - loss: 0.0470 - acc: 0.9856 - val_loss: 1.3526 - val_acc: 0.7100\n",
"Epoch 10/10\n",
"20000/20000 [==============================] - 13s 642us/step - loss: 0.0588 - acc: 0.9783 - val_loss: 1.4186 - val_acc: 0.7048\n"
]
}
],
"source": [
"# compile and train the model\n",
"feedforward_model.compile(optimizer='adam',\n",
" loss='binary_crossentropy',\n",
" metrics=['acc'])\n",
"\n",
"history = feedforward_model.fit(x_train, y_train,\n",
" validation_split=0.2,\n",
" epochs=10,\n",
" batch_size=32)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x649c8ffd0>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAEICAYAAABRSj9aAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3de3xV5Z3v8c+XuwEUhFgVhFBrvSGXGKGOaLVaRafV1joVinO8VGm1OFOtp1Vxjr5oqZ1ab53xOKWOtp2mUo6OSju2Tm299KYlVKSCo1IFjCBExBtBufg7f6yVsBN2kh3YyU5Wvu/Xa732Ws961tq/vXbyW2s/z7ooIjAzs+zqVeoAzMysYznRm5llnBO9mVnGOdGbmWWcE72ZWcY50ZuZZZwTfQ8kqbekdySNKmbdUpL0IUlFP1dY0kmSVuZMPyfp2ELq7sJ73SHp6l1d3qwlfUodgLVN0js5k2XAe8D2dPoLEVHdnvVFxHZgULHr9gQRcXAx1iPpQuCciDg+Z90XFmPdZs050XcDEdGYaNMjxgsj4uGW6kvqExHbOiM2s7b477H03HSTAZK+Iemnku6W9DZwjqSjJT0h6Q1JayV9V1LftH4fSSGpIp3+cTr/F5LelvRHSWPaWzedf6qk5yW9KelfJP1e0nktxF1IjF+QtELSRknfzVm2t6SbJW2Q9Fdgaivb5xpJ85uV3SbppnT8QknPpp/nr+nRdkvrqpV0fDpeJuk/0tiWAUfmed8X0/Uuk3R6Wn4E8K/AsWmz2Gs52/a6nOW/mH72DZLul7RfIdumPdu5IR5JD0t6XdKrkr6a8z7/lG6TtyTVSNo/XzOZpN81fM/p9nw8fZ/XgWskHSTpkfSzvJZut71ylh+dfsa6dP6tkgakMR+aU28/SfWShrX0eS2PiPDQjQZgJXBSs7JvAFuAT5LsvPcAjgImk/xq+yDwPDArrd8HCKAinf4x8BpQBfQFfgr8eBfq7gO8DZyRzrsc2Aqc18JnKSTGB4C9gArg9YbPDswClgEjgWHA48mfc973+SDwDjAwZ93rgap0+pNpHQEfAzYD49J5JwErc9ZVCxyfjn8HeBQYCowGljer+1lgv/Q7+VwawwfSeRcCjzaL88fAden4yWmME4ABwP8FflPItmnndt4LWAf8I9Af2BOYlM67CngaOCj9DBOAvYEPNd/WwO8avuf0s20DLgZ6k/w9fhg4EeiX/p38HvhOzud5Jt2eA9P6x6Tz5gFzc97nK8B9pf4/7G5DyQPw0M4vrOVE/5s2lrsC+H/peL7k/W85dU8HntmFuhcAv82ZJ2AtLST6AmP8SM78/wSuSMcfJ2nCaph3WvPk02zdTwCfS8dPBZ5vpe7PgS+l460l+tW53wVwSW7dPOt9BvjbdLytRP9D4Js58/Yk6ZcZ2da2aed2/nugpoV6f22It1l5IYn+xTZiOAtYlI4fC7wK9M5T7xjgJUDp9BLgzGL/X2V9cNNNdrycOyHpEEn/lf4UfwuYAwxvZflXc8brab0DtqW6++fGEcl/Zm1LKykwxoLeC1jVSrwAPwGmp+OfAxo7sCV9QtKTadPFGyRH061tqwb7tRaDpPMkPZ02P7wBHFLgeiH5fI3ri4i3gI3AiJw6BX1nbWznA4AVLcRwAEmy3xXN/x73lbRA0itpDD9oFsPKSDr+m4iI35P8OpgiaSwwCvivXYypx3Kiz47mpxZ+j+QI8kMRsSfwf0iOsDvSWpIjTgAkiaaJqbndiXEtSYJo0Nbpnz8FTpI0kqRp6SdpjHsA9wDXkzSrDAH+u8A4Xm0pBkkfBG4nab4Ylq73f3LW29apoGtImoMa1jeYpInolQLiaq617fwycGALy7U0b1MaU1lO2b7N6jT/fP9McrbYEWkM5zWLYbSk3i3E8SPgHJJfHwsi4r0W6lkLnOizazDwJrAp7cz6Qie858+BSkmflNSHpN23vINiXAB8WdKItGPua61Vjoh1JM0LdwHPRcQL6az+JO3GdcB2SZ8gaUsuNIarJQ1Rcp3BrJx5g0iSXR3JPu9CkiP6BuuAkbmdos3cDXxe0jhJ/Ul2RL+NiBZ/IbWite28EBglaZakfpL2lDQpnXcH8A1JByoxQdLeJDu4V0k6/XtLmknOTqmVGDYBb0o6gKT5qMEfgQ3AN5V0cO8h6Zic+f9B0tTzOZKkb+3kRJ9dXwHOJekc/R7JEW2HSpPp2cBNJP+4BwJPkRzJFTvG24FfA38BFpEclbflJyRt7j/JifkN4DLgPpIOzbNIdliFuJbkl8VK4BfkJKGIWAp8F/hTWucQ4MmcZX8FvACsk5TbBNOw/C9JmljuS5cfBcwoMK7mWtzOEfEm8HHgMySdv88DH01n3wDcT7Kd3yLpGB2QNsldBFxN0jH/oWafLZ9rgUkkO5yFwL05MWwDPgEcSnJ0v5rke2iYv5Lke94SEX9o52c3dnRwmBVd+lN8DXBWRPy21PFY9yXpRyQdvNeVOpbuyBdMWVFJmkryU/xdktPztpEc1ZrtkrS/4wzgiFLH0l256caKbQrwIslP+qnAp9x5ZrtK0vUk5/J/MyJWlzqe7spNN2ZmGdfmEb2kOyWtl/RMC/OVXuq8QtJSSZU5886V9EI6nFvMwM3MrDBtHtFLOo7k0u0fRcTYPPNPAy4luTJxMnBrRExOT8OqIblUPoDFwJERsbG19xs+fHhUVFTswkcxM+u5Fi9e/FpE5D2duc3O2Ih4XOkNrVpwBslOIIAn0nOK9wOOB34VEa8DSPoVSZvt3a29X0VFBTU1NW2FZWZmOSS1eHV4MTpjR9D0cufatKyl8nwBzkzvjFdTV1dXhJDMzKxBMRJ9vkvFo5XynQsj5kVEVURUlZe3diGlmZm1VzESfS1N7/cxkuQimZbKzcysExXjgqmFwCwlD3aYDLwZEWslPURy74qhab2TSS6gabetW7dSW1vLu+++W4RwraMMGDCAkSNH0rdvS7dvMbNSaDPRS7qbpGN1uKRakntW9AWIiH8DHiQ542YFya1Sz0/nvS7p6yT3IQGY09Ax2161tbUMHjyYiooKkhsiWlcTEWzYsIHa2lrGjBnT9gJm1mkKOetmehvzA/hSC/PuBO7ctdB2ePfdd53kuzhJDBs2DHemW3dSXQ2zZ8Pq1TBqFMydCzN29dZxXVi3udeNk3zX5+/IupPqapg5E+rrk+lVq5JpyF6y971uzKxHmj17R5JvUF+flGeNE30BNmzYwIQJE5gwYQL77rsvI0aMaJzesmVLQes4//zzee6551qtc9ttt1FdXd1qHTMrjtUt3CKtpfLurNs03bRHsdvdhg0bxpIlSwC47rrrGDRoEFdccUWTOo0P4e2Vf9951113tfk+X/pS3q4OM+sAo0YlzTX5yrMmc0f0De1uq1ZBxI52t444UF6xYgVjx47li1/8IpWVlaxdu5aZM2dSVVXF4Ycfzpw5cxrrTpkyhSVLlrBt2zaGDBnClVdeyfjx4zn66KNZv349ANdccw233HJLY/0rr7ySSZMmcfDBB/OHPyQP1tm0aROf+cxnGD9+PNOnT6eqqqpxJ5Tr2muv5aijjmqMr+GeRs8//zwf+9jHGD9+PJWVlaxcuRKAb37zmxxxxBGMHz+e2Vn87WpdSnU1VFRAr17Jayl+yM6dC2VlTcvKypLyztbh26PhSLSrDEceeWQ0t3z58p3KWjJ6dESS4psOo0cXvIpWXXvttXHDDTdERMQLL7wQkuJPf/pT4/wNGzZERMTWrVtjypQpsWzZsoiIOOaYY+Kpp56KrVu3BhAPPvhgRERcdtllcf3110dExOzZs+Pmm29urP/Vr341IiIeeOCBOOWUUyIi4vrrr49LLrkkIiKWLFkSvXr1iqeeemqnOBvieP/992PatGmN71dZWRkLFy6MiIjNmzfHpk2bYuHChTFlypSor69vsuyuaM93ZT3Tj38cUVbW9P+zrCwpL0Uso0dHSMlrqWIoxvYAaqKFvJq5I/rObnc78MADOeqooxqn7777biorK6msrOTZZ59l+fLlOy2zxx57cOqppwJw5JFHNh5VN3fmmWfuVOd3v/sd06ZNA2D8+PEcfvjheZf99a9/zaRJkxg/fjyPPfYYy5YtY+PGjbz22mt88pOfBJILnMrKynj44Ye54IIL2GOPPQDYe++9278hzArUlTpBZ8yAlSvh/feT11KcbdMZ2yNzbfSd3e42cODAxvEXXniBW2+9lT/96U8MGTKEc845J+/VvP369Wsc7927N9u2bcu77v79++9UJwp4UEx9fT2zZs3iz3/+MyNGjOCaa65pjCPfKZAR4VMjrdP0pE7QQnTG9sjcEX0p293eeustBg8ezJ577snatWt56KGHiv4eU6ZMYcGCBQD85S9/yfuLYfPmzfTq1Yvhw4fz9ttvc++99wIwdOhQhg8fzs9+9jMguRCtvr6ek08+mX//939n8+bNALz++i5dwGxWkJYOurLYCVqIztgemUv0M2bAvHkwejRIyeu8eZ3zk6yyspLDDjuMsWPHctFFF3HMMccU/T0uvfRSXnnlFcaNG8eNN97I2LFj2WuvvZrUGTZsGOeeey5jx47l05/+NJMnT26cV11dzY033si4ceOYMmUKdXV1fOITn2Dq1KlUVVUxYcIEbr755qLHbdagK3WCdgWdsj1aarwv1bC7nbFZt3Xr1ti8eXNERDz//PNRUVERW7duLXFUO/i7skJ0hU7QrqQY24NWOmMz10afde+88w4nnngi27ZtIyL43ve+R58+/hqtcF3h/i4zZmTvNgO7o6O3hzNENzNkyBAWL15c6jCsm+pJ93exHTLXRm9mLetKpzZa53GiN+tBfGpjz+REb9aD+NTGnsmJ3qwH8amNPZMTfQGOP/74nS5+uuWWW7jkkktaXW7QoEEArFmzhrPOOqvFddfU1LS6nltuuYX6nIbV0047jTfeeKOQ0M2aKOV1JlY6TvQFmD59OvPnz29SNn/+fKZPb/Upi432339/7rnnnl1+/+aJ/sEHH2TIkCG7vD7r2brC/V2scznRF+Css87i5z//Oe+99x4AK1euZM2aNUyZMqXxvPbKykqOOOIIHnjggZ2WX7lyJWPHjgWS2xNMmzaNcePGcfbZZzfedgDg4osvbrzF8bXXXgvAd7/7XdasWcMJJ5zACSecAEBFRQWvvfYaADfddBNjx45l7Nixjbc4XrlyJYceeigXXXQRhx9+OCeffHKT92nws5/9jMmTJzNx4kROOukk1q1bByTn6p9//vkcccQRjBs3rvEWCr/85S+prKxk/PjxnHjiiUXZtmbW8Qo6j17SVOBWoDdwR0R8q9n80SQPAS8HXgfOiYjadN524C9p1dURcfruBPzlL0Oe26/vlgkTIM2ReQ0bNoxJkybxy1/+kjPOOIP58+dz9tlnI4kBAwZw3333seeee/Laa6/xkY98hNNPP73Fm4TdfvvtlJWVsXTpUpYuXUplZWXjvLlz57L33nuzfft2TjzxRJYuXco//MM/cNNNN/HII48wfPjwJutavHgxd911F08++SQRweTJk/noRz/K0KFDeeGFF7j77rv5/ve/z2c/+1nuvfdezjnnnCbLT5kyhSeeeAJJ3HHHHXz729/mxhtv5Otf/zp77bUXf/lL8rVt3LiRuro6LrroIh5//HHGjBnj++GYdSNtHtFL6g3cBpwKHAZMl3RYs2rfAX4UEeOAOcD1OfM2R8SEdNitJF9Kuc03uc02EcHVV1/NuHHjOOmkk3jllVcaj4zzefzxxxsT7rhx4xg3blzjvAULFlBZWcnEiRNZtmxZ3huW5frd737Hpz/9aQYOHMigQYM488wz+e1vfwvAmDFjmDBhAtDyrZBra2s55ZRTOOKII7jhhhtYtmwZAA8//HCTp10NHTqUJ554guOOO44xY8YAvpWxWXdSyBH9JGBFRLwIIGk+cAaQm4UOAy5Lxx8B7i9mkLlaO/LuSJ/61Ke4/PLL+fOf/8zmzZsbj8Srq6upq6tj8eLF9O3bl4qKiry3Js6V72j/pZde4jvf+Q6LFi1i6NChnHfeeW2uJ1q5ZXHDLY4huc1xvqabSy+9lMsvv5zTTz+dRx99lOuuu65xvc1jzFdmZt1DIW30I4CXc6Zr07JcTwOfScc/DQyWNCydHiCpRtITkj6V7w0kzUzr1NTV1bUj/M4zaNAgjj/+eC644IImnbBvvvkm++yzD3379uWRRx5hVb6b4ec47rjjGh8A/swzz7B06VIgucXxwIED2WuvvVi3bh2/+MUvGpcZPHgwb7/9dt513X///dTX17Np0ybuu+8+jj322II/05tvvsmIEclX+cMf/rCx/OSTT+Zf//VfG6c3btzI0UcfzWOPPcZLL70E+FbGZt1JIYk+32Fc80PJK4CPSnoK+CjwCtDwNI1REVEFfA64RdKBO60sYl5EVEVEVXl5eeHRd7Lp06fz9NNPNz7hCWDGjBnU1NRQVVVFdXU1hxxySKvruPjii3nnnXcYN24c3/72t5k0aRKQPC1q4sSJHH744VxwwQVNbnE8c+ZMTj311MbO2AaVlZWcd955TJo0icmTJ3PhhRcyceLEgj/Pddddx9/93d9x7LHHNmn/v+aaa9i4cSNjx45l/PjxPPLII5SXlzNv3jzOPPNMxo8fz9lnn13w+5hZaam1n/8Ako4GrouIU9LpqwAi4voW6g8C/iciRuaZ9wPg5xHR4rmGVVVV0fy88meffZZDDz209U9iXYK/K7PSkLQ4PajeSSFH9IuAgySNkdQPmAYsbPYGwyU1rOsqkjNwkDRUUv+GOsAxNG3bNzOzDtZmoo+IbcAs4CHgWWBBRCyTNEdSw1k0xwPPSXoe+ADQcEH1oUCNpKdJOmm/FRFO9GZmnaig8+gj4kHgwWZl/ydn/B5gp+aYiPgDcMRuxtiwLp/10cW11QxoZqXRLa6MHTBgABs2bHAi6cIigg0bNjBgwIBSh2JmzXSLJ0yNHDmS2tpauuqpl5YYMGAAI0fu1AdvZiXWLRJ93759G6/INDOz9ukWTTdmZrbrnOjNOkl1NVRUQK9eyWt6gbRZh+sWTTdm3V11NcycuePB3KtWJdPg+8Fbx/MRvVknmD17R5JvUF+flJt1NCd6s06wenX7ys2KyYnerBOMGtW+crNicqI36wRz50JZWdOysrKk3KyjOdGbdYIZM2DePBg9GqTkdd48d8Ra5/BZN2adZMYMJ3YrDR/Rm5llnBO9mVnGOdGbmWWcE72ZWcY50ZuZZZwTvZlZxjnRm5llnBO9mVnGFZToJU2V9JykFZKuzDN/tKRfS1oq6VFJI3PmnSvphXQ4t5jBm5lZ29pM9JJ6A7cBpwKHAdMlHdas2neAH0XEOGAOcH267N7AtcBkYBJwraShxQvfzMzaUsgR/SRgRUS8GBFbgPnAGc3qHAb8Oh1/JGf+KcCvIuL1iNgI/AqYuvthm5lZoQpJ9COAl3Oma9OyXE8Dn0nHPw0MljSswGXNzKwDFZLolacsmk1fAXxU0lPAR4FXgG0FLoukmZJqJNXU1dUVEJJZ4fysVuvpCkn0tcABOdMjgTW5FSJiTUScGRETgdlp2ZuFLJvWnRcRVRFRVV5e3s6PYNayhme1rloFETue1epkbz1JIYl+EXCQpDGS+gHTgIW5FSQNl9SwrquAO9Pxh4CTJQ1NO2FPTsvMOoWf1WpWQKKPiG3ALJIE/SywICKWSZoj6fS02vHAc5KeBz4AzE2XfR34OsnOYhEwJy0z6xR+VqsZKGKnJvOSqqqqipqamlKHYRlRUZE01zQ3ejSsXNnZ0Zh1HEmLI6Iq3zxfGWuZ5me1mjnRW8b5Wa1mfmas9QB+Vqv1dD6iNzPLOCd6M7OMc6I3M8s4J3ozs4xzojczyzgnejOzjHOiNzPLOCd6M7OMc6I3M8s4J3ozs4xzojczyzgnejOzjHOitw7jZ7WadQ2+e6V1iIZntTY8xq/hWa3gO0madTYf0VuH8LNazboOJ3rrEH5Wq1nX4URvHWLUqPaVm1nHcaK3DuFntZp1HQUleklTJT0naYWkK/PMHyXpEUlPSVoq6bS0vELSZklL0uHfiv0BrGvys1rNuo42z7qR1Bu4Dfg4UAsskrQwIpbnVLsGWBARt0s6DHgQqEjn/TUiJhQ3bOsO/KxWs66hkCP6ScCKiHgxIrYA84EzmtUJYM90fC9gTfFCNDOz3VFIoh8BvJwzXZuW5boOOEdSLcnR/KU588akTTqPSTo23xtImimpRlJNXV1d4dGbmVmbCkn0ylMWzaanAz+IiJHAacB/SOoFrAVGRcRE4HLgJ5L2bLYsETEvIqoioqq8vLx9n8DMzFpVSKKvBQ7ImR7Jzk0znwcWAETEH4EBwPCIeC8iNqTli4G/Ah/e3aDNzKxwhST6RcBBksZI6gdMAxY2q7MaOBFA0qEkib5OUnnamYukDwIHAS8WK3gzM2tbm2fdRMQ2SbOAh4DewJ0RsUzSHKAmIhYCXwG+L+kykmad8yIiJB0HzJG0DdgOfDEiXu+wT2NmZjtRRPPm9tKqqqqKmpqaUodhZtatSFocEVX55vnKWDOzjHOiNzPLOCd6M7OMc6I3M8s4J3ozs4xzojczyzgnejOzjHOiNzPLOCd6M7OMc6I3M8s4J/oMqq6Gigro1St5ra4udURmVkpt3tTMupfqapg5E+rrk+lVq5Jp8GP9zHoqH9FnzOzZO5J8g/r6pNzMeiYn+oxZvbp95WaWfU70GTNqVPvKzSz7nOgzZu5cKCtrWlZWlpSbWc/kRJ8xM2bAvHkwejRIyeu8ee6INevJfNZNBs2Y4cRuZjv4iN7MLOOc6M3MMq6gRC9pqqTnJK2QdGWe+aMkPSLpKUlLJZ2WM++qdLnnJJ1SzODNzKxtbbbRS+oN3AZ8HKgFFklaGBHLc6pdAyyIiNslHQY8CFSk49OAw4H9gYclfTgithf7g5iZWX6FHNFPAlZExIsRsQWYD5zRrE4Ae6bjewFr0vEzgPkR8V5EvASsSNdnZmadpJBEPwJ4OWe6Ni3LdR1wjqRakqP5S9uxrJmZdaBCEr3ylEWz6enADyJiJHAa8B+SehW4LJJmSqqRVFNXV1dASGZmVqhCEn0tcEDO9Eh2NM00+DywACAi/ggMAIYXuCwRMS8iqiKiqry8vPDozcysTYUk+kXAQZLGSOpH0rm6sFmd1cCJAJIOJUn0dWm9aZL6SxoDHAT8qVjBm5lZ29o86yYitkmaBTwE9AbujIhlkuYANRGxEPgK8H1Jl5E0zZwXEQEsk7QAWA5sA77kM27MzDqXknzcdVRVVUVNTU2pwzAz61YkLY6IqnzzfGWsmVnGOdGbmWWcE72ZWcY50ZuZZZwTvZlZxjnRm5llnBO9mVnGOdGbmWWcE30RVVdDRQX06pW8VleXOiIzMz8cvGiqq2HmTKivT6ZXrUqmwQ/qNrPS8hF9kcyevSPJN6ivT8rNzErJib5IVq9uX7mZWWdxoi+SUaPaV25m1lmc6Itk7lwoK2taVlaWlJuZlZITfZHMmAHz5sHo0SAlr/PmuSPWzErPZ90U0YwZTuxm1vX4iN7MLOOc6M3MMs6J3sws45zozcwyzonezCzjCkr0kqZKek7SCklX5pl/s6Ql6fC8pDdy5m3PmbewmMGbmVnb2jy9UlJv4Dbg40AtsEjSwohY3lAnIi7LqX8pMDFnFZsjYkLxQjYzs/Yo5Ih+ErAiIl6MiC3AfOCMVupPB+4uRnBmZrb7Ckn0I4CXc6Zr07KdSBoNjAF+k1M8QFKNpCckfaqF5WamdWrq6uoKDN3MzApRSKJXnrJooe404J6I2J5TNioiqoDPAbdIOnCnlUXMi4iqiKgqLy8vICQzMytUIYm+FjggZ3oksKaFutNo1mwTEWvS1xeBR2nafm9mZh2skES/CDhI0hhJ/UiS+U5nz0g6GBgK/DGnbKik/un4cOAYYHnzZc3MrOO0edZNRGyTNAt4COgN3BkRyyTNAWoioiHpTwfmR0Rus86hwPckvU+yU/lW7tk6ZmbW8dQ0L5deVVVV1NTUlDoMM7NuRdLitD90J74y1sws45zozcwyzonezCzjnOjNzDLOid7MLOOc6M3MMs6J3sws45zozcwyzonezCzj2rwFghVu61a44Qa46SYYOBD23Rf226/lYZ99oI+/ATPrYE4zRbJkCZx/fvJ62mkwbBisXQsrVsBvfwuvv77zMr16QXl56zuD/fZLdhgDBnT+ZzKzbHCi303vvQdf/zr88z8nyf3ee+HMM/PXe/XVZFi7Nv+wZAmsWwfvv7/z8kOHtrwTyJ0ePBiU7wkCZtZjOdHvhiefhAsugOXL4X/9L7j5Zth77/x1+/eH0aOToTXbt0NdXcs7g7Vrk18Ia9fCli07L19WtnPz0Ac+kLw2H/dOwaxncKLfBfX18E//BLfcAvvvDw8+CKeeWpx19+6dHKXvuy9MbOURLRGwcWPrO4Snn4b16+GNN/KvY8CA/DuA3PGG6eHD3Z9g1l35X7edHnsMPv95+Otf4QtfgG9/G/bcs/PjkJJfD3vvDYcf3nrdLVuSXwnr1iWJf/36puPr18OaNUnT0fr1SadyPsOGtbxDaL5zGDTIvxbMugon+gK9/TZ87Wtw++3wwQ/Cb34DJ5xQ6qgK068fjBiRDG2JSH4B5O4E8u0gGnYKLf1a2GOPHUl/2LAk8Q8c2PQ1X1m+eQMGeKdhtjuc6Avw0EMwcya8/DJcdlnS+TpwYKmj6hhS0vE7dCgcfHDb9d97L/m10NpOYcMGWL0a3nkHNm1KXt97r/CYevVKtnd7dxDNX/v0SfpAtm9POrw7YzxfmZQMvXo1HZqX7UqdQpfp3TvpN9pjj/zDgAHJa+/eu/631NVEJL9W331350FKDkqGD8/WZ27gRN+KjRvh8svhBz+AQw6B3/8ejj661FF1Lf37w8iRydAe27btSPrNXwste+MNqK1tWvbuux3zOXdVQ1Lt1WvHKyRJPyJ5zR26mr59W94ZNN8p7Eq9fv2SpsV8ybdh2Ly59fntGdp6oF7DKc8f+EDTYd99dy4rL+8+OwUn+hbcfz9cfHFytHr11Unnq89lL7GiVi0AAAjJSURBVJ4+fWCvvZKhmLZvT5J+vh3E9u07Em5u8u3I8fY2OUU03QHk2xk0L9uVOtu3J7+qNm8ufGhIug3Dpk3w2mv5627fXtzvtUGfPsn/YUvDwIFJU2FuWcNOprVh+/Ydv0JffTV5XbcOXnghmc53ACHl3ynk2zGUl5f2ZAYn+mbq6uDSS+GnP4Xx4+G//gsqK0sdlRWqd++kc7wUHeTFkNus051t3dryDiJ32LIl+VVYSFLu3780yTIi6aNrSP7NdwYNw4oVyevmzTuvQ0qahdraKey/f/JabAVtNklTgVuB3sAdEfGtZvNvBhq6JsuAfSJiSDrvXOCadN43IuKHxQi82CJg/vwkyb/1VtIO/7WvJT9dzax9+vZNhu66w80l7Th4OOig1utGJL8eW9shvPoq/PGPyXh9fdPljzwSamqK/xnaTPSSegO3AR8HaoFFkhZGxPKGOhFxWU79S4GJ6fjewLVAFRDA4nTZjUX9FLtpzZqkmWbhQjjqKLjrrrZPWTQza05KLkQcPBg+9KG26zfsFBp2CP37d0xchRzRTwJWRMSLAJLmA2cAy1uoP50kuQOcAvwqIl5Pl/0VMBW4e3eCLpaIJKlffnnSXnnDDfDlL/vCIDPrHA1nih14YMe+TyEtgSOAl3Oma9OynUgaDYwBftOeZSXNlFQjqaaurq6QuHfbypVwyinJxU/jxsHSpXDFFU7yZpY9hST6fOcNtHSS0jTgnoho6HMvaNmImBcRVRFRVV5eXkBIu+799+G222DsWPjDH5LxRx9tu+3NzKy7KiTR1wIH5EyPBNa0UHcaTZtl2rNsh3v+eTj+eJg1C/7mb+CZZ+CSS7r/GQ5mZq0pJMUtAg6SNEZSP5JkvrB5JUkHA0OBP+YUPwScLGmopKHAyWlZp9q2LWl/Hz8+aaK5887kateKis6OxMys87XZIh0R2yTNIknQvYE7I2KZpDlATUQ0JP3pwPyIHdeeRcTrkr5OsrMAmNPQMdtZnnkmuZXwokVw+unJvWr2378zIzAzKy1FW9cEd7KqqqqoKcKJpFu2wLe+Bd/4RnL15b/8C5x9tm+OZWbZJGlxRFTlm5fJc0wWL06O4pcuhWnT4LvfTS5BNjPriTLVDfnuu3DVVTB5cnLfivvug7vvdpI3s54tM0f0L72UPOXpueeSh3TfeGNyq10zs54uM4l+//2Tq8tuvTW5EMrMzBKZSfT9+yd3mjQzs6Yy1UZvZmY7c6I3M8s4J3ozs4xzojczyzgnejOzjHOiNzPLOCd6M7OMc6I3M8s4J3ozs4xzojczyzgnejOzjHOiNzPLOCd6M7OMc6I3M8s4J3ozs4xzojczy7iCEr2kqZKek7RC0pUt1PmspOWSlkn6SU75dklL0mFhsQJvrroaKiqgV6/ktbq6o97JzKx7afMJU5J6A7cBHwdqgUWSFkbE8pw6BwFXAcdExEZJ++SsYnNETChy3E1UV8PMmVBfn0yvWpVMA8yY0ZHvbGbW9RVyRD8JWBERL0bEFmA+cEazOhcBt0XERoCIWF/cMFs3e/aOJN+gvj4pNzPr6QpJ9COAl3Oma9OyXB8GPizp95KekDQ1Z94ASTVp+afyvYGkmWmdmrq6unZ9AIDVq9tXbmbWkxSS6JWnLJpN9wEOAo4HpgN3SBqSzhsVEVXA54BbJB2408oi5kVEVURUlZeXFxx8g1Gj2lduZtaTFJLoa4EDcqZHAmvy1HkgIrZGxEvAcySJn4hYk76+CDwKTNzNmHcydy6UlTUtKytLys3MerpCEv0i4CBJYyT1A6YBzc+euR84AUDScJKmnBclDZXUP6f8GGA5RTZjBsybB6NHg5S8zpvnjlgzMyjgrJuI2CZpFvAQ0Bu4MyKWSZoD1ETEwnTeyZKWA9uB/x0RGyT9DfA9Se+T7FS+lXu2TjHNmOHEbmaWjyKaN7eXVlVVVdTU1JQ6DDOzbkXS4rQ/dCe+MtbMLOOc6M3MMs6J3sws45zozcwyrst1xkqqA1aVOo7dNBx4rdRBdCHeHk15e+zgbdHU7myP0RGR94rTLpfos0BSTUu93z2Rt0dT3h47eFs01VHbw003ZmYZ50RvZpZxTvQdY16pA+hivD2a8vbYwduiqQ7ZHm6jNzPLOB/Rm5llnBO9mVnGOdEXkaQDJD0i6dn0Ien/WOqYSk1Sb0lPSfp5qWMpNUlDJN0j6X/Sv5GjSx1TKUm6LP0/eUbS3ZIGlDqmziTpTknrJT2TU7a3pF9JeiF9HVqM93KiL65twFci4lDgI8CXJB1W4phK7R+BZ0sdRBdxK/DLiDgEGE8P3i6SRgD/AFRFxFiSW6BPK21Une4HwNRmZVcCv46Ig4Bfp9O7zYm+iCJibUT8OR1/m+QfufnzdXsMSSOBvwXuKHUspSZpT+A44N8BImJLRLxR2qhKrg+wh6Q+QBk7P7ku0yLiceD1ZsVnAD9Mx38I5H3Odns50XcQSRUkj018srSRlNQtwFeB90sdSBfwQaAOuCttyrpD0sBSB1UqEfEK8B1gNbAWeDMi/ru0UXUJH4iItZAcOAL7FGOlTvQdQNIg4F7gyxHxVqnjKQVJnwDWR8TiUsfSRfQBKoHbI2IisIki/SzvjtK25zOAMcD+wEBJ55Q2quxyoi8ySX1Jknx1RPxnqeMpoWOA0yWtBOYDH5P049KGVFK1QG1ENPzCu4ck8fdUJwEvRURdRGwF/hP4mxLH1BWsk7QfQPq6vhgrdaIvIkkiaYN9NiJuKnU8pRQRV0XEyIioIOlk+01E9Ngjtoh4FXhZ0sFp0YlAhzw/uZtYDXxEUln6f3MiPbhzOsdC4Nx0/FzggWKstM2Hg1u7HAP8PfAXSUvSsqsj4sESxmRdx6VAtaR+wIvA+SWOp2Qi4klJ9wB/Jjlb7Sl62O0QJN0NHA8Ml1QLXAt8C1gg6fMkO8O/K8p7+RYIZmbZ5qYbM7OMc6I3M8s4J3ozs4xzojczyzgnejOzjHOiNzPLOCd6M7OM+/8O09w9zWxswwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# plot the accuracy rates for each epoch on training and validation data\n",
"acc = history.history['acc']\n",
"val_acc = history.history['val_acc']\n",
"epochs = range(1, len(acc) + 1)\n",
"plt.plot(epochs, acc, 'bo', label='Training acc')\n",
"plt.plot(epochs, val_acc, 'b', label='Validation acc')\n",
"plt.title('Training and validation accuracy')\n",
"plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## LSTM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we use an LSTM recurrent neural net. If you want to use a different kind of RNN such as `SimpleRNN` or `GRU`, simply replace `LSTM` with `SimpleRNN` or `GRU` (both in importing the layer and in adding the layer to the model)."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_2\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"embedding_2 (Embedding) (None, 200, 100) 200000 \n",
"_________________________________________________________________\n",
"lstm_1 (LSTM) (None, 64) 42240 \n",
"_________________________________________________________________\n",
"dense_3 (Dense) (None, 1) 65 \n",
"=================================================================\n",
"Total params: 242,305\n",
"Trainable params: 42,305\n",
"Non-trainable params: 200,000\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"from keras.models import Sequential\n",
"from keras.layers import Embedding, LSTM, Dense\n",
"rnn_model = Sequential()\n",
"rnn_model.add(Embedding(vocab_size, embedding_dim, input_length=200))\n",
"rnn_model.add(LSTM(64))\n",
"rnn_model.add(Dense(1, activation='sigmoid'))\n",
"\n",
"# load the GloVe embeddings in the model\n",
"rnn_model.layers[0].set_weights([embedding_matrix])\n",
"rnn_model.layers[0].trainable = False\n",
"\n",
"rnn_model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 20000 samples, validate on 5000 samples\n",
"Epoch 1/10\n",
"20000/20000 [==============================] - 86s 4ms/step - loss: 0.5707 - acc: 0.7016 - val_loss: 0.5012 - val_acc: 0.7534\n",
"Epoch 2/10\n",
"20000/20000 [==============================] - 85s 4ms/step - loss: 0.4279 - acc: 0.8048 - val_loss: 0.4189 - val_acc: 0.8126\n",
"Epoch 3/10\n",
"20000/20000 [==============================] - 83s 4ms/step - loss: 0.3690 - acc: 0.8366 - val_loss: 0.3670 - val_acc: 0.8394\n",
"Epoch 4/10\n",
"20000/20000 [==============================] - 85s 4ms/step - loss: 0.3420 - acc: 0.8520 - val_loss: 0.3361 - val_acc: 0.8522\n",
"Epoch 5/10\n",
"20000/20000 [==============================] - 87s 4ms/step - loss: 0.3173 - acc: 0.8627 - val_loss: 0.3382 - val_acc: 0.8510\n",
"Epoch 6/10\n",
"20000/20000 [==============================] - 83s 4ms/step - loss: 0.2980 - acc: 0.8733 - val_loss: 0.3764 - val_acc: 0.8352\n",
"Epoch 7/10\n",
"20000/20000 [==============================] - 86s 4ms/step - loss: 0.2806 - acc: 0.8801 - val_loss: 0.3622 - val_acc: 0.8526\n",
"Epoch 8/10\n",
"20000/20000 [==============================] - 82s 4ms/step - loss: 0.2617 - acc: 0.8915 - val_loss: 0.3367 - val_acc: 0.8544\n",
"Epoch 9/10\n",
"20000/20000 [==============================] - 83s 4ms/step - loss: 0.2462 - acc: 0.8951 - val_loss: 0.3822 - val_acc: 0.8280\n",
"Epoch 10/10\n",
"20000/20000 [==============================] - 83s 4ms/step - loss: 0.2273 - acc: 0.9056 - val_loss: 0.3363 - val_acc: 0.8638\n"
]
}
],
"source": [
"# compile and train the model\n",
"rnn_model.compile(optimizer='adam',\n",
" loss='binary_crossentropy',\n",
" metrics=['acc'])\n",
"history = rnn_model.fit(x_train, y_train,\n",
" validation_split=0.2,\n",
" epochs=10,\n",
" batch_size=32)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# plot the accuracy rates for each epoch on training and validation data\n",
"acc = history.history['acc']\n",
"val_acc = history.history['val_acc']\n",
"epochs = range(1, len(acc) + 1)\n",
"plt.plot(epochs, acc, 'bo', label='Training acc')\n",
"plt.plot(epochs, val_acc, 'b', label='Validation acc')\n",
"plt.title('Training and validation accuracy')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Finally evaluate on test data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now compare the test set raw classification accuracies of the feedforward neural net vs the LSTM model. Keep in mind that we have set both of these models up so that right before the final logistic regression classification layer, we are representing each review as a feature vector of length 64. The LSTM model learns a much better 64-dimensional feature space to use (as evidenced by its dramatically higher prediction accuracy on the test set)."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"25000/25000 [==============================] - 2s 85us/step\n",
"Test accuracy: 0.700439989566803\n"
]
}
],
"source": [
"test_loss, test_acc = feedforward_model.evaluate(x_test, y_test)\n",
"print('Test accuracy:', test_acc)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"25000/25000 [==============================] - 22s 868us/step\n",
"Test accuracy: 0.8664799928665161\n"
]
}
],
"source": [
"test_loss, test_acc = rnn_model.evaluate(x_test, y_test)\n",
"print('Test accuracy:', test_acc)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To get predictions, we use the `predict_classes` function (note that there is a `predict` function that outputs whatever is the neural net's final output, which in this case is the probability of positive sentiment per test example, since the final layer is a Dense layer with 1 neuron and sigmoid activation)."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"predicted_labels = rnn_model.predict_classes(x_test)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0],\n",
" [1],\n",
" [1],\n",
" ...,\n",
" [0],\n",
" [0],\n",
" [1]], dtype=int32)"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predicted_labels"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.86648"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(predicted_labels.flatten() == y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In case you're wondering how the predicted labels are computed, we can simply look at the raw neural net outputs (which are probabilities) and threshold at probability 0.5 (i.e., declare every test example with probability at least 0.5 of having positive sentiment to be in the positive sentiment class and declare all other test examples to have negative sentiment)."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"test_set_predicted_probs = rnn_model.predict(x_test)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.00192991],\n",
" [0.9505013 ],\n",
" [0.7850129 ],\n",
" ...,\n",
" [0.00563908],\n",
" [0.06085095],\n",
" [0.7172326 ]], dtype=float32)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_set_predicted_probs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The code below computes the test set accuracy of `rnn_model`. Note that `test_set_predicted_probs >= .5` converts the test set predicted probabilities into actual classifications (1 if the predicted probability is greater than .5 and 0 otherwise). Flattening is needed since `test_set_predicted_probs >= .5` is actually a 2D array."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.86648"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean((test_set_predicted_probs >= .5).flatten() == y_test)"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment