Last active
November 14, 2018 09:03
-
-
Save weakish/a97f656c454c8188be1340b050d145b6 to your computer and use it in GitHub Desktop.
Keras_LSTM_TPU.ipynb (train on k80)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Keras_LSTM_TPU.ipynb", | |
"version": "0.3.2", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/weakish/a97f656c454c8188be1340b050d145b6/keras_lstm_tpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "CB43mV-TD1vb", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"# Tutorial - [How to train Keras model x20 times faster with TPU for free](https://www.dlology.com/blog/how-to-train-keras-model-x20-times-faster-with-tpu-for-free/)" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "ya06BE0ZU526", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"import tensorflow as tf\n", | |
"from tensorflow.keras.datasets import imdb\n", | |
"from tensorflow.keras.preprocessing import sequence\n", | |
"from tensorflow.python.keras.layers import Input, LSTM, Bidirectional, Dense, Embedding" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "_uSZchXTVOHr", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 52 | |
}, | |
"outputId": "b64c0d3e-554d-418a-918e-9e7d8eaa276a" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"\n", | |
"# Number of words to consider as features\n", | |
"max_features = 10000\n", | |
"# Cut texts after this number of words (among top max_features most common words)\n", | |
"maxlen = 500\n", | |
"\n", | |
"# Load data\n", | |
"(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)\n", | |
"\n", | |
"# Reverse sequences\n", | |
"x_train = [x[::-1] for x in x_train]\n", | |
"x_test = [x[::-1] for x in x_test]\n", | |
"\n", | |
"# Pad sequences\n", | |
"x_train = sequence.pad_sequences(x_train, maxlen=maxlen)\n", | |
"x_test = sequence.pad_sequences(x_test, maxlen=maxlen)" | |
], | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz\n", | |
"17465344/17464789 [==============================] - 0s 0us/step\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "p35nSfjbVVBE", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"def make_model(batch_size=None):\n", | |
" source = Input(shape=(maxlen,), batch_size=batch_size, dtype=tf.int32, name='Input')\n", | |
" embedding = Embedding(input_dim=max_features, output_dim=128, name='Embedding')(source)\n", | |
" # lstm = Bidirectional(LSTM(32, name = 'LSTM'), name='Bidirectional')(embedding)\n", | |
" lstm = LSTM(32, name = 'LSTM')(embedding)\n", | |
" predicted_var = Dense(1, activation='sigmoid', name='Output')(lstm)\n", | |
" model = tf.keras.Model(inputs=[source], outputs=[predicted_var])\n", | |
" model.compile(\n", | |
" optimizer=tf.train.RMSPropOptimizer(learning_rate=0.01),\n", | |
" loss='binary_crossentropy',\n", | |
" metrics=['acc'])\n", | |
" return model\n", | |
" " | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "bivVZS0jZhxg", | |
"colab_type": "code", | |
"outputId": "a4491799-1844-493f-d1a5-37256e5c9f6a", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 278 | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"tf.keras.backend.clear_session()\n", | |
"model = make_model(batch_size = 128)\n", | |
"model.summary()" | |
], | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"Input (InputLayer) (128, 500) 0 \n", | |
"_________________________________________________________________\n", | |
"Embedding (Embedding) (128, 500, 128) 1280000 \n", | |
"_________________________________________________________________\n", | |
"LSTM (LSTM) (128, 32) 20608 \n", | |
"_________________________________________________________________\n", | |
"Output (Dense) (128, 1) 33 \n", | |
"=================================================================\n", | |
"Total params: 1,300,641\n", | |
"Trainable params: 1,300,641\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "km1iz0HD--7M", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 278 | |
}, | |
"outputId": "a59909a9-474a-4e4f-f7f3-810052569455" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"variable_model = make_model(batch_size = None)\n", | |
"variable_model.summary()" | |
], | |
"execution_count": 21, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"Input (InputLayer) (None, 500) 0 \n", | |
"_________________________________________________________________\n", | |
"Embedding (Embedding) (None, 500, 128) 1280000 \n", | |
"_________________________________________________________________\n", | |
"LSTM (LSTM) (None, 32) 20608 \n", | |
"_________________________________________________________________\n", | |
"Output (Dense) (None, 1) 33 \n", | |
"=================================================================\n", | |
"Total params: 1,300,641\n", | |
"Trainable params: 1,300,641\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "XlSm1vd5bteH", | |
"colab_type": "code", | |
"outputId": "c0bc1fe8-ec6a-4582-853b-9cb52c08d24f", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 746 | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"import time\n", | |
"start_time = time.time()\n", | |
"\n", | |
"history = variable_model.fit(x_train, y_train,\n", | |
" batch_size = 128,\n", | |
" epochs=20,\n", | |
" validation_split=0.2)\n", | |
"\n", | |
"print(\"--- %s seconds ---\" % (time.time() - start_time))\n" | |
], | |
"execution_count": 22, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Train on 20000 samples, validate on 5000 samples\n", | |
"Epoch 1/20\n", | |
"20000/20000 [==============================] - 305s 15ms/step - loss: 0.6927 - acc: 0.5116 - val_loss: 0.6822 - val_acc: 0.5774\n", | |
"Epoch 2/20\n", | |
"20000/20000 [==============================] - 313s 16ms/step - loss: 0.6282 - acc: 0.6599 - val_loss: 0.6214 - val_acc: 0.6306\n", | |
"Epoch 3/20\n", | |
"20000/20000 [==============================] - 305s 15ms/step - loss: 0.4379 - acc: 0.8089 - val_loss: 0.3874 - val_acc: 0.8482\n", | |
"Epoch 4/20\n", | |
"20000/20000 [==============================] - 301s 15ms/step - loss: 0.3218 - acc: 0.8712 - val_loss: 0.3600 - val_acc: 0.8486\n", | |
"Epoch 5/20\n", | |
"20000/20000 [==============================] - 297s 15ms/step - loss: 0.2678 - acc: 0.8940 - val_loss: 0.3046 - val_acc: 0.8818\n", | |
"Epoch 6/20\n", | |
"20000/20000 [==============================] - 296s 15ms/step - loss: 0.2022 - acc: 0.9233 - val_loss: 0.3462 - val_acc: 0.8606\n", | |
"Epoch 7/20\n", | |
"20000/20000 [==============================] - 295s 15ms/step - loss: 0.1492 - acc: 0.9451 - val_loss: 0.3575 - val_acc: 0.8718\n", | |
"Epoch 8/20\n", | |
"20000/20000 [==============================] - 294s 15ms/step - loss: 0.1039 - acc: 0.9617 - val_loss: 0.3718 - val_acc: 0.8784\n", | |
"Epoch 9/20\n", | |
"20000/20000 [==============================] - 296s 15ms/step - loss: 0.0710 - acc: 0.9771 - val_loss: 0.4218 - val_acc: 0.8756\n", | |
"Epoch 10/20\n", | |
"20000/20000 [==============================] - 299s 15ms/step - loss: 0.0486 - acc: 0.9845 - val_loss: 0.4529 - val_acc: 0.8786\n", | |
"Epoch 11/20\n", | |
"20000/20000 [==============================] - 302s 15ms/step - loss: 0.0395 - acc: 0.9869 - val_loss: 0.5026 - val_acc: 0.8666\n", | |
"Epoch 12/20\n", | |
"20000/20000 [==============================] - 299s 15ms/step - loss: 0.0331 - acc: 0.9894 - val_loss: 0.5990 - val_acc: 0.8588\n", | |
"Epoch 13/20\n", | |
"20000/20000 [==============================] - 300s 15ms/step - loss: 0.0237 - acc: 0.9928 - val_loss: 0.6152 - val_acc: 0.8684\n", | |
"Epoch 14/20\n", | |
"20000/20000 [==============================] - 301s 15ms/step - loss: 0.0153 - acc: 0.9950 - val_loss: 0.6993 - val_acc: 0.8676\n", | |
"Epoch 15/20\n", | |
"20000/20000 [==============================] - 299s 15ms/step - loss: 0.0151 - acc: 0.9949 - val_loss: 0.7196 - val_acc: 0.8732\n", | |
"Epoch 16/20\n", | |
"20000/20000 [==============================] - 300s 15ms/step - loss: 0.0118 - acc: 0.9964 - val_loss: 0.7559 - val_acc: 0.8736\n", | |
"Epoch 17/20\n", | |
"20000/20000 [==============================] - 299s 15ms/step - loss: 0.0081 - acc: 0.9973 - val_loss: 0.7615 - val_acc: 0.8642\n", | |
"Epoch 18/20\n", | |
"20000/20000 [==============================] - 300s 15ms/step - loss: 0.0098 - acc: 0.9968 - val_loss: 0.7584 - val_acc: 0.8616\n", | |
"Epoch 19/20\n", | |
"20000/20000 [==============================] - 300s 15ms/step - loss: 0.0103 - acc: 0.9964 - val_loss: 0.8302 - val_acc: 0.8684\n", | |
"Epoch 20/20\n", | |
"20000/20000 [==============================] - 300s 15ms/step - loss: 0.0083 - acc: 0.9974 - val_loss: 0.8563 - val_acc: 0.8636\n", | |
"--- 6004.776646375656 seconds ---\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "uAB1CDE3W0Lo", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 52 | |
}, | |
"outputId": "efd387b5-aebc-424b-921d-489fcc4cedc9" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"variable_model.evaluate(x_test, y_test)" | |
], | |
"execution_count": 23, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"25000/25000 [==============================] - 499s 20ms/step\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[0.8893412690807879, 0.85656]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 23 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "apwRGvwWDnau", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment