Skip to content

Instantly share code, notes, and snippets.

@weakish
Last active November 14, 2018 09:03
Show Gist options
  • Save weakish/a97f656c454c8188be1340b050d145b6 to your computer and use it in GitHub Desktop.
Save weakish/a97f656c454c8188be1340b050d145b6 to your computer and use it in GitHub Desktop.
Keras_LSTM_TPU.ipynb (train on k80)
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Keras_LSTM_TPU.ipynb",
"version": "0.3.2",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/weakish/a97f656c454c8188be1340b050d145b6/keras_lstm_tpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"metadata": {
"id": "CB43mV-TD1vb",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# Tutorial - [How to train Keras model x20 times faster with TPU for free](https://www.dlology.com/blog/how-to-train-keras-model-x20-times-faster-with-tpu-for-free/)"
]
},
{
"metadata": {
"id": "ya06BE0ZU526",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.datasets import imdb\n",
"from tensorflow.keras.preprocessing import sequence\n",
"from tensorflow.python.keras.layers import Input, LSTM, Bidirectional, Dense, Embedding"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "_uSZchXTVOHr",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
},
"outputId": "b64c0d3e-554d-418a-918e-9e7d8eaa276a"
},
"cell_type": "code",
"source": [
"\n",
"# Number of words to consider as features\n",
"max_features = 10000\n",
"# Cut texts after this number of words (among top max_features most common words)\n",
"maxlen = 500\n",
"\n",
"# Load data\n",
"(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)\n",
"\n",
"# Reverse sequences\n",
"x_train = [x[::-1] for x in x_train]\n",
"x_test = [x[::-1] for x in x_test]\n",
"\n",
"# Pad sequences\n",
"x_train = sequence.pad_sequences(x_train, maxlen=maxlen)\n",
"x_test = sequence.pad_sequences(x_test, maxlen=maxlen)"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz\n",
"17465344/17464789 [==============================] - 0s 0us/step\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "p35nSfjbVVBE",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"def make_model(batch_size=None):\n",
" source = Input(shape=(maxlen,), batch_size=batch_size, dtype=tf.int32, name='Input')\n",
" embedding = Embedding(input_dim=max_features, output_dim=128, name='Embedding')(source)\n",
" # lstm = Bidirectional(LSTM(32, name = 'LSTM'), name='Bidirectional')(embedding)\n",
" lstm = LSTM(32, name = 'LSTM')(embedding)\n",
" predicted_var = Dense(1, activation='sigmoid', name='Output')(lstm)\n",
" model = tf.keras.Model(inputs=[source], outputs=[predicted_var])\n",
" model.compile(\n",
" optimizer=tf.train.RMSPropOptimizer(learning_rate=0.01),\n",
" loss='binary_crossentropy',\n",
" metrics=['acc'])\n",
" return model\n",
" "
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "bivVZS0jZhxg",
"colab_type": "code",
"outputId": "a4491799-1844-493f-d1a5-37256e5c9f6a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 278
}
},
"cell_type": "code",
"source": [
"tf.keras.backend.clear_session()\n",
"model = make_model(batch_size = 128)\n",
"model.summary()"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"Input (InputLayer) (128, 500) 0 \n",
"_________________________________________________________________\n",
"Embedding (Embedding) (128, 500, 128) 1280000 \n",
"_________________________________________________________________\n",
"LSTM (LSTM) (128, 32) 20608 \n",
"_________________________________________________________________\n",
"Output (Dense) (128, 1) 33 \n",
"=================================================================\n",
"Total params: 1,300,641\n",
"Trainable params: 1,300,641\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "km1iz0HD--7M",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 278
},
"outputId": "a59909a9-474a-4e4f-f7f3-810052569455"
},
"cell_type": "code",
"source": [
"variable_model = make_model(batch_size = None)\n",
"variable_model.summary()"
],
"execution_count": 21,
"outputs": [
{
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"Input (InputLayer) (None, 500) 0 \n",
"_________________________________________________________________\n",
"Embedding (Embedding) (None, 500, 128) 1280000 \n",
"_________________________________________________________________\n",
"LSTM (LSTM) (None, 32) 20608 \n",
"_________________________________________________________________\n",
"Output (Dense) (None, 1) 33 \n",
"=================================================================\n",
"Total params: 1,300,641\n",
"Trainable params: 1,300,641\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "XlSm1vd5bteH",
"colab_type": "code",
"outputId": "c0bc1fe8-ec6a-4582-853b-9cb52c08d24f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 746
}
},
"cell_type": "code",
"source": [
"import time\n",
"start_time = time.time()\n",
"\n",
"history = variable_model.fit(x_train, y_train,\n",
" batch_size = 128,\n",
" epochs=20,\n",
" validation_split=0.2)\n",
"\n",
"print(\"--- %s seconds ---\" % (time.time() - start_time))\n"
],
"execution_count": 22,
"outputs": [
{
"output_type": "stream",
"text": [
"Train on 20000 samples, validate on 5000 samples\n",
"Epoch 1/20\n",
"20000/20000 [==============================] - 305s 15ms/step - loss: 0.6927 - acc: 0.5116 - val_loss: 0.6822 - val_acc: 0.5774\n",
"Epoch 2/20\n",
"20000/20000 [==============================] - 313s 16ms/step - loss: 0.6282 - acc: 0.6599 - val_loss: 0.6214 - val_acc: 0.6306\n",
"Epoch 3/20\n",
"20000/20000 [==============================] - 305s 15ms/step - loss: 0.4379 - acc: 0.8089 - val_loss: 0.3874 - val_acc: 0.8482\n",
"Epoch 4/20\n",
"20000/20000 [==============================] - 301s 15ms/step - loss: 0.3218 - acc: 0.8712 - val_loss: 0.3600 - val_acc: 0.8486\n",
"Epoch 5/20\n",
"20000/20000 [==============================] - 297s 15ms/step - loss: 0.2678 - acc: 0.8940 - val_loss: 0.3046 - val_acc: 0.8818\n",
"Epoch 6/20\n",
"20000/20000 [==============================] - 296s 15ms/step - loss: 0.2022 - acc: 0.9233 - val_loss: 0.3462 - val_acc: 0.8606\n",
"Epoch 7/20\n",
"20000/20000 [==============================] - 295s 15ms/step - loss: 0.1492 - acc: 0.9451 - val_loss: 0.3575 - val_acc: 0.8718\n",
"Epoch 8/20\n",
"20000/20000 [==============================] - 294s 15ms/step - loss: 0.1039 - acc: 0.9617 - val_loss: 0.3718 - val_acc: 0.8784\n",
"Epoch 9/20\n",
"20000/20000 [==============================] - 296s 15ms/step - loss: 0.0710 - acc: 0.9771 - val_loss: 0.4218 - val_acc: 0.8756\n",
"Epoch 10/20\n",
"20000/20000 [==============================] - 299s 15ms/step - loss: 0.0486 - acc: 0.9845 - val_loss: 0.4529 - val_acc: 0.8786\n",
"Epoch 11/20\n",
"20000/20000 [==============================] - 302s 15ms/step - loss: 0.0395 - acc: 0.9869 - val_loss: 0.5026 - val_acc: 0.8666\n",
"Epoch 12/20\n",
"20000/20000 [==============================] - 299s 15ms/step - loss: 0.0331 - acc: 0.9894 - val_loss: 0.5990 - val_acc: 0.8588\n",
"Epoch 13/20\n",
"20000/20000 [==============================] - 300s 15ms/step - loss: 0.0237 - acc: 0.9928 - val_loss: 0.6152 - val_acc: 0.8684\n",
"Epoch 14/20\n",
"20000/20000 [==============================] - 301s 15ms/step - loss: 0.0153 - acc: 0.9950 - val_loss: 0.6993 - val_acc: 0.8676\n",
"Epoch 15/20\n",
"20000/20000 [==============================] - 299s 15ms/step - loss: 0.0151 - acc: 0.9949 - val_loss: 0.7196 - val_acc: 0.8732\n",
"Epoch 16/20\n",
"20000/20000 [==============================] - 300s 15ms/step - loss: 0.0118 - acc: 0.9964 - val_loss: 0.7559 - val_acc: 0.8736\n",
"Epoch 17/20\n",
"20000/20000 [==============================] - 299s 15ms/step - loss: 0.0081 - acc: 0.9973 - val_loss: 0.7615 - val_acc: 0.8642\n",
"Epoch 18/20\n",
"20000/20000 [==============================] - 300s 15ms/step - loss: 0.0098 - acc: 0.9968 - val_loss: 0.7584 - val_acc: 0.8616\n",
"Epoch 19/20\n",
"20000/20000 [==============================] - 300s 15ms/step - loss: 0.0103 - acc: 0.9964 - val_loss: 0.8302 - val_acc: 0.8684\n",
"Epoch 20/20\n",
"20000/20000 [==============================] - 300s 15ms/step - loss: 0.0083 - acc: 0.9974 - val_loss: 0.8563 - val_acc: 0.8636\n",
"--- 6004.776646375656 seconds ---\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "uAB1CDE3W0Lo",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
},
"outputId": "efd387b5-aebc-424b-921d-489fcc4cedc9"
},
"cell_type": "code",
"source": [
"variable_model.evaluate(x_test, y_test)"
],
"execution_count": 23,
"outputs": [
{
"output_type": "stream",
"text": [
"25000/25000 [==============================] - 499s 20ms/step\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[0.8893412690807879, 0.85656]"
]
},
"metadata": {
"tags": []
},
"execution_count": 23
}
]
},
{
"metadata": {
"id": "apwRGvwWDnau",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment