Skip to content

Instantly share code, notes, and snippets.

@tomtung
Created July 16, 2017 22:22
Show Gist options
  • Save tomtung/c030219cdb731ad67be00cb049b5dc22 to your computer and use it in GitHub Desktop.
Save tomtung/c030219cdb731ad67be00cb049b5dc22 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"import keras\n",
"import numpy as np\n",
"import random"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"N_DIGITS = 5\n",
"INPUT_LEN = N_DIGITS * 2 + 1\n",
"OUTPUT_LEN = N_DIGITS + 1\n",
"\n",
"CHARS = list(' 1234567890+')\n",
"CHAR_TO_INDEX = {\n",
" c: i\n",
" for i, c in enumerate(CHARS)\n",
"}\n",
"\n",
"TRAIN_DATA_SIZE = 600000\n",
"TEST_DATA_SIZE = 100000"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data Generation"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"A random number: 3\n"
]
}
],
"source": [
"def generate_random_number():\n",
" return random.randrange(0, 10 ** random.randint(1, N_DIGITS))\n",
"\n",
"print('A random number: {}'.format(generate_random_number()))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(8, 27988)\n"
]
}
],
"source": [
"def generate_addend_pair():\n",
" return generate_random_number(), generate_random_number()\n",
"\n",
"print(generate_addend_pair())"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"An example: ('12+345 ', '357 ')\n"
]
}
],
"source": [
"def generate_str_example(x, y):\n",
" input_str = '{}+{}'.format(x, y)\n",
" output_str = str(x + y)\n",
" \n",
" input_format_str = '{{:{}}}'.format(INPUT_LEN)\n",
" input_str = input_format_str.format(input_str)\n",
" \n",
" output_format_str = '{{:{}}}'.format(OUTPUT_LEN)\n",
" output_str = output_format_str.format(output_str)\n",
" \n",
" return input_str, output_str\n",
"\n",
"print('An example: {}'.format(generate_str_example(12, 345)))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[(11, 12), (6, 12)]\n"
]
}
],
"source": [
"def generate_example(x, y):\n",
" input_str, output_str = generate_str_example(x, y)\n",
"\n",
" input_ = np.zeros((INPUT_LEN, len(CHARS)))\n",
" for i, c in enumerate(input_str):\n",
" index = CHAR_TO_INDEX[c]\n",
" input_[i, index] = 1\n",
"\n",
" output = np.zeros((OUTPUT_LEN, len(CHARS)))\n",
" for i, c in enumerate(output_str):\n",
" index = CHAR_TO_INDEX[c]\n",
" output[i, index] = 1\n",
"\n",
" return input_, output\n",
"\n",
"print([array.shape for array in generate_example(12, 345)])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training_x shape: (600000, 11, 12)\n",
"training_y shape: (600000, 6, 12)\n",
"testing_x shape: (100000, 11, 12)\n",
"testing_y shape: (100000, 6, 12)\n"
]
}
],
"source": [
"def generate_examples(n_train, n_test):\n",
" n_examples = n_train + n_test\n",
" \n",
" addend_pairs = set()\n",
" while len(addend_pairs) < n_examples:\n",
" addend_pairs.add(generate_addend_pair())\n",
" \n",
" inputs, outputs = zip(*[\n",
" generate_example(x, y)\n",
" for x, y in addend_pairs\n",
" ])\n",
" \n",
" return np.array(inputs[:n_train]), np.array(outputs[:n_train]), np.array(inputs[n_train:]), np.array(outputs[n_train:])\n",
"\n",
"training_x, training_y, testing_x, testing_y = generate_examples(TRAIN_DATA_SIZE, TEST_DATA_SIZE)\n",
"print('training_x shape:', training_x.shape)\n",
"print('training_y shape:', training_y.shape)\n",
"print('testing_x shape:', testing_x.shape)\n",
"print('testing_y shape:', testing_y.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"HIDDEN_SIZE = 128\n",
"BATCH_SIZE = 128\n",
"MAX_N_EPOCS = 1000"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"model = keras.models.Sequential([\n",
" keras.layers.wrappers.Bidirectional(\n",
" keras.layers.recurrent.LSTM(HIDDEN_SIZE),\n",
" input_shape=(INPUT_LEN, len(CHARS))\n",
" ),\n",
" keras.layers.core.RepeatVector(OUTPUT_LEN),\n",
" keras.layers.recurrent.LSTM(HIDDEN_SIZE, return_sequences=True),\n",
" keras.layers.wrappers.TimeDistributed(\n",
" keras.layers.Dense(len(CHARS), activation='softmax')\n",
" ),\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"bidirectional_1 (Bidirection (None, 256) 144384 \n",
"_________________________________________________________________\n",
"repeat_vector_1 (RepeatVecto (None, 6, 256) 0 \n",
"_________________________________________________________________\n",
"lstm_2 (LSTM) (None, 6, 128) 197120 \n",
"_________________________________________________________________\n",
"time_distributed_1 (TimeDist (None, 6, 12) 1548 \n",
"=================================================================\n",
"Total params: 343,052\n",
"Trainable params: 343,052\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 480000 samples, validate on 120000 samples\n",
"Epoch 1/1000\n",
"157s - loss: 1.1147 - acc: 0.5823 - val_loss: 0.5931 - val_acc: 0.8039\n",
"Epoch 2/1000\n",
"146s - loss: 0.4039 - acc: 0.8568 - val_loss: 0.3069 - val_acc: 0.8849\n",
"Epoch 3/1000\n",
"142s - loss: 0.1948 - acc: 0.9316 - val_loss: 0.1232 - val_acc: 0.9608\n",
"Epoch 4/1000\n",
"142s - loss: 0.0904 - acc: 0.9708 - val_loss: 0.0693 - val_acc: 0.9774\n",
"Epoch 5/1000\n",
"142s - loss: 0.0559 - acc: 0.9816 - val_loss: 0.0418 - val_acc: 0.9864\n",
"Epoch 6/1000\n",
"142s - loss: 0.0373 - acc: 0.9880 - val_loss: 0.0304 - val_acc: 0.9900\n",
"Epoch 7/1000\n",
"142s - loss: 0.0280 - acc: 0.9911 - val_loss: 0.0195 - val_acc: 0.9941\n",
"Epoch 8/1000\n",
"142s - loss: 0.0210 - acc: 0.9935 - val_loss: 0.0241 - val_acc: 0.9919\n",
"Epoch 9/1000\n",
"142s - loss: 0.0181 - acc: 0.9946 - val_loss: 0.0103 - val_acc: 0.9970\n",
"Epoch 10/1000\n",
"141s - loss: 0.0137 - acc: 0.9960 - val_loss: 0.0108 - val_acc: 0.9967\n",
"Epoch 11/1000\n",
"141s - loss: 0.0143 - acc: 0.9959 - val_loss: 0.0148 - val_acc: 0.9958\n",
"Epoch 12/1000\n",
"141s - loss: 0.0114 - acc: 0.9968 - val_loss: 0.0046 - val_acc: 0.9988\n",
"Epoch 13/1000\n",
"141s - loss: 0.0110 - acc: 0.9968 - val_loss: 0.0067 - val_acc: 0.9980\n",
"Epoch 14/1000\n",
"141s - loss: 0.0076 - acc: 0.9978 - val_loss: 0.0048 - val_acc: 0.9986\n",
"Epoch 15/1000\n",
"141s - loss: 0.0093 - acc: 0.9975 - val_loss: 0.0081 - val_acc: 0.9975\n",
"Epoch 16/1000\n",
"141s - loss: 0.0073 - acc: 0.9979 - val_loss: 0.0080 - val_acc: 0.9975\n",
"Epoch 17/1000\n",
"141s - loss: 0.0059 - acc: 0.9983 - val_loss: 0.0043 - val_acc: 0.9987\n",
"Epoch 18/1000\n",
"141s - loss: 0.0068 - acc: 0.9981 - val_loss: 0.0046 - val_acc: 0.9986\n",
"Epoch 19/1000\n",
"141s - loss: 0.0058 - acc: 0.9984 - val_loss: 0.0044 - val_acc: 0.9987\n",
"Epoch 20/1000\n",
"141s - loss: 0.0064 - acc: 0.9982 - val_loss: 0.0094 - val_acc: 0.9972\n",
"Epoch 21/1000\n",
"141s - loss: 0.0053 - acc: 0.9985 - val_loss: 0.0039 - val_acc: 0.9989\n",
"Epoch 22/1000\n",
"141s - loss: 0.0041 - acc: 0.9989 - val_loss: 0.0042 - val_acc: 0.9987\n",
"Epoch 23/1000\n",
"141s - loss: 0.0050 - acc: 0.9986 - val_loss: 0.0172 - val_acc: 0.9949\n",
"Epoch 24/1000\n",
"141s - loss: 0.0038 - acc: 0.9989 - val_loss: 0.0033 - val_acc: 0.9990\n",
"Epoch 25/1000\n",
"141s - loss: 0.0051 - acc: 0.9987 - val_loss: 0.0020 - val_acc: 0.9995\n",
"Epoch 26/1000\n",
"142s - loss: 0.0042 - acc: 0.9988 - val_loss: 0.0023 - val_acc: 0.9994\n",
"Epoch 27/1000\n",
"141s - loss: 0.0044 - acc: 0.9988 - val_loss: 0.0018 - val_acc: 0.9995\n",
"Epoch 28/1000\n",
"141s - loss: 0.0032 - acc: 0.9991 - val_loss: 0.0029 - val_acc: 0.9992\n",
"Epoch 29/1000\n",
"144s - loss: 0.0042 - acc: 0.9988 - val_loss: 0.0085 - val_acc: 0.9974\n",
"Epoch 30/1000\n",
"144s - loss: 0.0033 - acc: 0.9991 - val_loss: 0.0019 - val_acc: 0.9995\n",
"Epoch 31/1000\n",
"157s - loss: 0.0039 - acc: 0.9990 - val_loss: 0.0014 - val_acc: 0.9997\n",
"Epoch 32/1000\n",
"154s - loss: 0.0028 - acc: 0.9992 - val_loss: 0.0033 - val_acc: 0.9991\n",
"Epoch 33/1000\n",
"150s - loss: 0.0031 - acc: 0.9992 - val_loss: 0.0013 - val_acc: 0.9997\n",
"Epoch 34/1000\n",
"154s - loss: 0.0028 - acc: 0.9992 - val_loss: 0.0024 - val_acc: 0.9993\n",
"Epoch 35/1000\n",
"152s - loss: 0.0032 - acc: 0.9992 - val_loss: 0.0038 - val_acc: 0.9988\n",
"Epoch 36/1000\n",
"157s - loss: 0.0026 - acc: 0.9993 - val_loss: 0.0037 - val_acc: 0.9989\n",
"Epoch 37/1000\n",
"152s - loss: 0.0024 - acc: 0.9993 - val_loss: 0.0016 - val_acc: 0.9996\n",
"Epoch 38/1000\n",
"153s - loss: 0.0031 - acc: 0.9992 - val_loss: 0.0024 - val_acc: 0.9992\n",
"Epoch 39/1000\n",
"155s - loss: 0.0025 - acc: 0.9993 - val_loss: 0.0031 - val_acc: 0.9990\n",
"Epoch 00038: early stopping\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x1a2ad1d6e48>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(\n",
" training_x, training_y,\n",
" batch_size=BATCH_SIZE,\n",
" epochs=MAX_N_EPOCS,\n",
" verbose=2,\n",
" validation_split=.2,\n",
" callbacks=[\n",
" keras.callbacks.EarlyStopping(patience=5, verbose=2),\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100000/100000 [==============================] - 32s \n",
"\n",
"Test loss: 0.0030354137425270163\n",
"Test acc: 0.9990433450508117\n"
]
}
],
"source": [
"metrics_vals = model.evaluate(testing_x, testing_y)\n",
"\n",
"print('')\n",
"for metric_name, metric_val in zip(model.metrics_names, metrics_vals):\n",
" print('Test {}: {}'.format(metric_name, metric_val))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model In Action"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def neural_addition(x, y):\n",
" input_, _ = generate_example(x, y)\n",
" output_ = model.predict_on_batch(np.array([input_]))[0]\n",
" indices = np.argmax(output_, axis=1)\n",
" return ''.join(CHARS[index] for index in indices)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"163 + 0 = \"163 \" (correct)\n",
"96 + 453 = \"549 \" (correct)\n",
"69 + 557 = \"626 \" (correct)\n",
"7721 + 98 = \"7819 \" (correct)\n",
"5112 + 79646 = \"84758 \" (correct)\n",
"493 + 43044 = \"43537 \" (correct)\n",
"51 + 489 = \"540 \" (correct)\n",
"84628 + 3457 = \"88085 \" (correct)\n",
"1 + 2236 = \"2237 \" (correct)\n",
"0 + 4622 = \"4622 \" (correct)\n",
"67 + 0 = \"67 \" (correct)\n",
"90642 + 68 = \"90710 \" (correct)\n",
"6 + 6 = \"12 \" (correct)\n",
"38973 + 23 = \"38996 \" (correct)\n",
"4 + 5945 = \"5949 \" (correct)\n",
"155 + 321 = \"476 \" (correct)\n",
"4987 + 2805 = \"7792 \" (correct)\n",
"70001 + 8 = \"70009 \" (correct)\n",
"1085 + 36 = \"1121 \" (correct)\n",
"13 + 2969 = \"2982 \" (correct)\n"
]
}
],
"source": [
"for _ in range(20):\n",
" x, y = generate_addend_pair()\n",
" expected = x + y\n",
" result = neural_addition(x, y)\n",
" if result.strip() == str(expected):\n",
" print('{} + {} = \"{}\" (correct)'.format(x, y, result))\n",
" else:\n",
" print('{} + {} = \"{}\" (incorrect, should be {})'.format(x, y, result, expected))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment