Created
July 16, 2017 22:22
-
-
Save tomtung/c030219cdb731ad67be00cb049b5dc22 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using TensorFlow backend.\n" | |
] | |
} | |
], | |
"source": [ | |
"import keras\n", | |
"import numpy as np\n", | |
"import random" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"N_DIGITS = 5\n", | |
"INPUT_LEN = N_DIGITS * 2 + 1\n", | |
"OUTPUT_LEN = N_DIGITS + 1\n", | |
"\n", | |
"CHARS = list(' 1234567890+')\n", | |
"CHAR_TO_INDEX = {\n", | |
" c: i\n", | |
" for i, c in enumerate(CHARS)\n", | |
"}\n", | |
"\n", | |
"TRAIN_DATA_SIZE = 600000\n", | |
"TEST_DATA_SIZE = 100000" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Data Generation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"A random number: 3\n" | |
] | |
} | |
], | |
"source": [ | |
"def generate_random_number():\n", | |
" return random.randrange(0, 10 ** random.randint(1, N_DIGITS))\n", | |
"\n", | |
"print('A random number: {}'.format(generate_random_number()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(8, 27988)\n" | |
] | |
} | |
], | |
"source": [ | |
"def generate_addend_pair():\n", | |
" return generate_random_number(), generate_random_number()\n", | |
"\n", | |
"print(generate_addend_pair())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"An example: ('12+345 ', '357 ')\n" | |
] | |
} | |
], | |
"source": [ | |
"def generate_str_example(x, y):\n", | |
" input_str = '{}+{}'.format(x, y)\n", | |
" output_str = str(x + y)\n", | |
" \n", | |
" input_format_str = '{{:{}}}'.format(INPUT_LEN)\n", | |
" input_str = input_format_str.format(input_str)\n", | |
" \n", | |
" output_format_str = '{{:{}}}'.format(OUTPUT_LEN)\n", | |
" output_str = output_format_str.format(output_str)\n", | |
" \n", | |
" return input_str, output_str\n", | |
"\n", | |
"print('An example: {}'.format(generate_str_example(12, 345)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[(11, 12), (6, 12)]\n" | |
] | |
} | |
], | |
"source": [ | |
"def generate_example(x, y):\n", | |
" input_str, output_str = generate_str_example(x, y)\n", | |
"\n", | |
" input_ = np.zeros((INPUT_LEN, len(CHARS)))\n", | |
" for i, c in enumerate(input_str):\n", | |
" index = CHAR_TO_INDEX[c]\n", | |
" input_[i, index] = 1\n", | |
"\n", | |
" output = np.zeros((OUTPUT_LEN, len(CHARS)))\n", | |
" for i, c in enumerate(output_str):\n", | |
" index = CHAR_TO_INDEX[c]\n", | |
" output[i, index] = 1\n", | |
"\n", | |
" return input_, output\n", | |
"\n", | |
"print([array.shape for array in generate_example(12, 345)])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"training_x shape: (600000, 11, 12)\n", | |
"training_y shape: (600000, 6, 12)\n", | |
"testing_x shape: (100000, 11, 12)\n", | |
"testing_y shape: (100000, 6, 12)\n" | |
] | |
} | |
], | |
"source": [ | |
"def generate_examples(n_train, n_test):\n", | |
" n_examples = n_train + n_test\n", | |
" \n", | |
" addend_pairs = set()\n", | |
" while len(addend_pairs) < n_examples:\n", | |
" addend_pairs.add(generate_addend_pair())\n", | |
" \n", | |
" inputs, outputs = zip(*[\n", | |
" generate_example(x, y)\n", | |
" for x, y in addend_pairs\n", | |
" ])\n", | |
" \n", | |
" return np.array(inputs[:n_train]), np.array(outputs[:n_train]), np.array(inputs[n_train:]), np.array(outputs[n_train:])\n", | |
"\n", | |
"training_x, training_y, testing_x, testing_y = generate_examples(TRAIN_DATA_SIZE, TEST_DATA_SIZE)\n", | |
"print('training_x shape:', training_x.shape)\n", | |
"print('training_y shape:', training_y.shape)\n", | |
"print('testing_x shape:', testing_x.shape)\n", | |
"print('testing_y shape:', testing_y.shape)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"HIDDEN_SIZE = 128\n", | |
"BATCH_SIZE = 128\n", | |
"MAX_N_EPOCS = 1000" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model = keras.models.Sequential([\n", | |
" keras.layers.wrappers.Bidirectional(\n", | |
" keras.layers.recurrent.LSTM(HIDDEN_SIZE),\n", | |
" input_shape=(INPUT_LEN, len(CHARS))\n", | |
" ),\n", | |
" keras.layers.core.RepeatVector(OUTPUT_LEN),\n", | |
" keras.layers.recurrent.LSTM(HIDDEN_SIZE, return_sequences=True),\n", | |
" keras.layers.wrappers.TimeDistributed(\n", | |
" keras.layers.Dense(len(CHARS), activation='softmax')\n", | |
" ),\n", | |
"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"bidirectional_1 (Bidirection (None, 256) 144384 \n", | |
"_________________________________________________________________\n", | |
"repeat_vector_1 (RepeatVecto (None, 6, 256) 0 \n", | |
"_________________________________________________________________\n", | |
"lstm_2 (LSTM) (None, 6, 128) 197120 \n", | |
"_________________________________________________________________\n", | |
"time_distributed_1 (TimeDist (None, 6, 12) 1548 \n", | |
"=================================================================\n", | |
"Total params: 343,052\n", | |
"Trainable params: 343,052\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"model.summary()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Train on 480000 samples, validate on 120000 samples\n", | |
"Epoch 1/1000\n", | |
"157s - loss: 1.1147 - acc: 0.5823 - val_loss: 0.5931 - val_acc: 0.8039\n", | |
"Epoch 2/1000\n", | |
"146s - loss: 0.4039 - acc: 0.8568 - val_loss: 0.3069 - val_acc: 0.8849\n", | |
"Epoch 3/1000\n", | |
"142s - loss: 0.1948 - acc: 0.9316 - val_loss: 0.1232 - val_acc: 0.9608\n", | |
"Epoch 4/1000\n", | |
"142s - loss: 0.0904 - acc: 0.9708 - val_loss: 0.0693 - val_acc: 0.9774\n", | |
"Epoch 5/1000\n", | |
"142s - loss: 0.0559 - acc: 0.9816 - val_loss: 0.0418 - val_acc: 0.9864\n", | |
"Epoch 6/1000\n", | |
"142s - loss: 0.0373 - acc: 0.9880 - val_loss: 0.0304 - val_acc: 0.9900\n", | |
"Epoch 7/1000\n", | |
"142s - loss: 0.0280 - acc: 0.9911 - val_loss: 0.0195 - val_acc: 0.9941\n", | |
"Epoch 8/1000\n", | |
"142s - loss: 0.0210 - acc: 0.9935 - val_loss: 0.0241 - val_acc: 0.9919\n", | |
"Epoch 9/1000\n", | |
"142s - loss: 0.0181 - acc: 0.9946 - val_loss: 0.0103 - val_acc: 0.9970\n", | |
"Epoch 10/1000\n", | |
"141s - loss: 0.0137 - acc: 0.9960 - val_loss: 0.0108 - val_acc: 0.9967\n", | |
"Epoch 11/1000\n", | |
"141s - loss: 0.0143 - acc: 0.9959 - val_loss: 0.0148 - val_acc: 0.9958\n", | |
"Epoch 12/1000\n", | |
"141s - loss: 0.0114 - acc: 0.9968 - val_loss: 0.0046 - val_acc: 0.9988\n", | |
"Epoch 13/1000\n", | |
"141s - loss: 0.0110 - acc: 0.9968 - val_loss: 0.0067 - val_acc: 0.9980\n", | |
"Epoch 14/1000\n", | |
"141s - loss: 0.0076 - acc: 0.9978 - val_loss: 0.0048 - val_acc: 0.9986\n", | |
"Epoch 15/1000\n", | |
"141s - loss: 0.0093 - acc: 0.9975 - val_loss: 0.0081 - val_acc: 0.9975\n", | |
"Epoch 16/1000\n", | |
"141s - loss: 0.0073 - acc: 0.9979 - val_loss: 0.0080 - val_acc: 0.9975\n", | |
"Epoch 17/1000\n", | |
"141s - loss: 0.0059 - acc: 0.9983 - val_loss: 0.0043 - val_acc: 0.9987\n", | |
"Epoch 18/1000\n", | |
"141s - loss: 0.0068 - acc: 0.9981 - val_loss: 0.0046 - val_acc: 0.9986\n", | |
"Epoch 19/1000\n", | |
"141s - loss: 0.0058 - acc: 0.9984 - val_loss: 0.0044 - val_acc: 0.9987\n", | |
"Epoch 20/1000\n", | |
"141s - loss: 0.0064 - acc: 0.9982 - val_loss: 0.0094 - val_acc: 0.9972\n", | |
"Epoch 21/1000\n", | |
"141s - loss: 0.0053 - acc: 0.9985 - val_loss: 0.0039 - val_acc: 0.9989\n", | |
"Epoch 22/1000\n", | |
"141s - loss: 0.0041 - acc: 0.9989 - val_loss: 0.0042 - val_acc: 0.9987\n", | |
"Epoch 23/1000\n", | |
"141s - loss: 0.0050 - acc: 0.9986 - val_loss: 0.0172 - val_acc: 0.9949\n", | |
"Epoch 24/1000\n", | |
"141s - loss: 0.0038 - acc: 0.9989 - val_loss: 0.0033 - val_acc: 0.9990\n", | |
"Epoch 25/1000\n", | |
"141s - loss: 0.0051 - acc: 0.9987 - val_loss: 0.0020 - val_acc: 0.9995\n", | |
"Epoch 26/1000\n", | |
"142s - loss: 0.0042 - acc: 0.9988 - val_loss: 0.0023 - val_acc: 0.9994\n", | |
"Epoch 27/1000\n", | |
"141s - loss: 0.0044 - acc: 0.9988 - val_loss: 0.0018 - val_acc: 0.9995\n", | |
"Epoch 28/1000\n", | |
"141s - loss: 0.0032 - acc: 0.9991 - val_loss: 0.0029 - val_acc: 0.9992\n", | |
"Epoch 29/1000\n", | |
"144s - loss: 0.0042 - acc: 0.9988 - val_loss: 0.0085 - val_acc: 0.9974\n", | |
"Epoch 30/1000\n", | |
"144s - loss: 0.0033 - acc: 0.9991 - val_loss: 0.0019 - val_acc: 0.9995\n", | |
"Epoch 31/1000\n", | |
"157s - loss: 0.0039 - acc: 0.9990 - val_loss: 0.0014 - val_acc: 0.9997\n", | |
"Epoch 32/1000\n", | |
"154s - loss: 0.0028 - acc: 0.9992 - val_loss: 0.0033 - val_acc: 0.9991\n", | |
"Epoch 33/1000\n", | |
"150s - loss: 0.0031 - acc: 0.9992 - val_loss: 0.0013 - val_acc: 0.9997\n", | |
"Epoch 34/1000\n", | |
"154s - loss: 0.0028 - acc: 0.9992 - val_loss: 0.0024 - val_acc: 0.9993\n", | |
"Epoch 35/1000\n", | |
"152s - loss: 0.0032 - acc: 0.9992 - val_loss: 0.0038 - val_acc: 0.9988\n", | |
"Epoch 36/1000\n", | |
"157s - loss: 0.0026 - acc: 0.9993 - val_loss: 0.0037 - val_acc: 0.9989\n", | |
"Epoch 37/1000\n", | |
"152s - loss: 0.0024 - acc: 0.9993 - val_loss: 0.0016 - val_acc: 0.9996\n", | |
"Epoch 38/1000\n", | |
"153s - loss: 0.0031 - acc: 0.9992 - val_loss: 0.0024 - val_acc: 0.9992\n", | |
"Epoch 39/1000\n", | |
"155s - loss: 0.0025 - acc: 0.9993 - val_loss: 0.0031 - val_acc: 0.9990\n", | |
"Epoch 00038: early stopping\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<keras.callbacks.History at 0x1a2ad1d6e48>" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.fit(\n", | |
" training_x, training_y,\n", | |
" batch_size=BATCH_SIZE,\n", | |
" epochs=MAX_N_EPOCS,\n", | |
" verbose=2,\n", | |
" validation_split=.2,\n", | |
" callbacks=[\n", | |
" keras.callbacks.EarlyStopping(patience=5, verbose=2),\n", | |
" ],\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"100000/100000 [==============================] - 32s \n", | |
"\n", | |
"Test loss: 0.0030354137425270163\n", | |
"Test acc: 0.9990433450508117\n" | |
] | |
} | |
], | |
"source": [ | |
"metrics_vals = model.evaluate(testing_x, testing_y)\n", | |
"\n", | |
"print('')\n", | |
"for metric_name, metric_val in zip(model.metrics_names, metrics_vals):\n", | |
" print('Test {}: {}'.format(metric_name, metric_val))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Model In Action" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def neural_addition(x, y):\n", | |
" input_, _ = generate_example(x, y)\n", | |
" output_ = model.predict_on_batch(np.array([input_]))[0]\n", | |
" indices = np.argmax(output_, axis=1)\n", | |
" return ''.join(CHARS[index] for index in indices)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"163 + 0 = \"163 \" (correct)\n", | |
"96 + 453 = \"549 \" (correct)\n", | |
"69 + 557 = \"626 \" (correct)\n", | |
"7721 + 98 = \"7819 \" (correct)\n", | |
"5112 + 79646 = \"84758 \" (correct)\n", | |
"493 + 43044 = \"43537 \" (correct)\n", | |
"51 + 489 = \"540 \" (correct)\n", | |
"84628 + 3457 = \"88085 \" (correct)\n", | |
"1 + 2236 = \"2237 \" (correct)\n", | |
"0 + 4622 = \"4622 \" (correct)\n", | |
"67 + 0 = \"67 \" (correct)\n", | |
"90642 + 68 = \"90710 \" (correct)\n", | |
"6 + 6 = \"12 \" (correct)\n", | |
"38973 + 23 = \"38996 \" (correct)\n", | |
"4 + 5945 = \"5949 \" (correct)\n", | |
"155 + 321 = \"476 \" (correct)\n", | |
"4987 + 2805 = \"7792 \" (correct)\n", | |
"70001 + 8 = \"70009 \" (correct)\n", | |
"1085 + 36 = \"1121 \" (correct)\n", | |
"13 + 2969 = \"2982 \" (correct)\n" | |
] | |
} | |
], | |
"source": [ | |
"for _ in range(20):\n", | |
" x, y = generate_addend_pair()\n", | |
" expected = x + y\n", | |
" result = neural_addition(x, y)\n", | |
" if result.strip() == str(expected):\n", | |
" print('{} + {} = \"{}\" (correct)'.format(x, y, result))\n", | |
" else:\n", | |
" print('{} + {} = \"{}\" (incorrect, should be {})'.format(x, y, result, expected))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment