Skip to content

Instantly share code, notes, and snippets.

@chck
Last active September 18, 2018 11:35
Show Gist options
  • Save chck/294f06e2032f70d7303fc528564dd17f to your computer and use it in GitHub Desktop.
Save chck/294f06e2032f70d7303fc528564dd17f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"https://github.com/tbennun/keras-bucketed-sequence.git\n",
"Bucketing technique for NLP by Keras\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.layers import Input, LSTM, Dense\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"from absl import app, flags\n",
"import numpy as np\n",
"\n",
"from bucketed_sequence import BucketedSequence"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"UNK = -1.0\n",
"batch_size = 64\n",
"epochs = 20\n",
"lstm_units = 100\n",
"dense_breadth = 32\n",
"\n",
"dataset_size = 10000\n",
"val_size = 1000\n",
"seqlen_mean = 50\n",
"seqlen_stddev = 200\n",
"\n",
"buckets = 10"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[-1.00000000e+000, -1.00000000e+000, -1.00000000e+000,\n",
" -1.00000000e+000, -1.00000000e+000, 0.00000000e+000,\n",
" 1.73060038e-077, 2.23111331e-314, 2.23093651e-314,\n",
" 2.23093663e-314, 2.23077450e-314, 2.23080756e-314,\n",
" 0.00000000e+000, 2.15575018e-314, 7.14433543e-309]])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def pad(seqs, maxlen):\n",
" # Note: prepends data\n",
" padded = np.array(pad_sequences(seqs, maxlen=maxlen, value=UNK, dtype=seqs[0].dtype))\n",
" return np.vstack([np.expand_dims(x, axis=0) for x in padded])\n",
"\n",
"pad(seqs=np.empty((1,10)), maxlen=15)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((1000, 594, 1), (1000,), (1000,))"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def gen_dataset(set_size):\n",
" sequence_lendths = np.random.normal(loc=seqlen_mean, scale=seqlen_stddev, size=set_size).astype(np.int32)\n",
" max_length = np.max(sequence_lendths)\n",
" # Clamp range to start from three elements\n",
" sequence_lendths = np.clip(sequence_lendths, 3, max_length)\n",
" \n",
" # Generate random sequences\n",
" seq_x = [np.random.uniform(1.0, 50.0, sl) for sl in sequence_lendths]\n",
" seq_y = np.array([seq[2] for seq in seq_x], dtype=np.float32)\n",
" \n",
" # Pad sequences\n",
" padded_x = pad(seq_x, max_length)\n",
" padded_x = np.reshape(padded_x, (len(sequence_lendths), max_length, 1))\n",
" \n",
" # Return dataset\n",
" return padded_x, seq_y, sequence_lendths\n",
"\n",
"a, b, c = gen_dataset(set_size=val_size)\n",
"a.shape, b.shape, c.shape"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<svg height=\"264pt\" viewBox=\"0.00 0.00 112.25 264.00\" width=\"112pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 260)\">\n",
"<title>G</title>\n",
"<polygon fill=\"#ffffff\" points=\"-4,4 -4,-260 108.252,-260 108.252,4 -4,4\" stroke=\"transparent\"/>\n",
"<!-- 4719619544 -->\n",
"<g class=\"node\" id=\"node1\">\n",
"<title>4719619544</title>\n",
"<polygon fill=\"none\" points=\"3.8896,-219.5 3.8896,-255.5 100.3623,-255.5 100.3623,-219.5 3.8896,-219.5\" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"52.126\" y=\"-233.3\">in: InputLayer</text>\n",
"</g>\n",
"<!-- 4719354656 -->\n",
"<g class=\"node\" id=\"node2\">\n",
"<title>4719354656</title>\n",
"<polygon fill=\"none\" points=\"9.7036,-146.5 9.7036,-182.5 94.5483,-182.5 94.5483,-146.5 9.7036,-146.5\" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"52.126\" y=\"-160.3\">lstm: LSTM</text>\n",
"</g>\n",
"<!-- 4719619544&#45;&gt;4719354656 -->\n",
"<g class=\"edge\" id=\"edge1\">\n",
"<title>4719619544-&gt;4719354656</title>\n",
"<path d=\"M52.126,-219.4551C52.126,-211.3828 52.126,-201.6764 52.126,-192.6817\" fill=\"none\" stroke=\"#000000\"/>\n",
"<polygon fill=\"#000000\" points=\"55.6261,-192.5903 52.126,-182.5904 48.6261,-192.5904 55.6261,-192.5903\" stroke=\"#000000\"/>\n",
"</g>\n",
"<!-- 4719394376 -->\n",
"<g class=\"node\" id=\"node3\">\n",
"<title>4719394376</title>\n",
"<polygon fill=\"none\" points=\"0,-73.5 0,-109.5 104.252,-109.5 104.252,-73.5 0,-73.5\" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"52.126\" y=\"-87.3\">dense_6: Dense</text>\n",
"</g>\n",
"<!-- 4719354656&#45;&gt;4719394376 -->\n",
"<g class=\"edge\" id=\"edge2\">\n",
"<title>4719354656-&gt;4719394376</title>\n",
"<path d=\"M52.126,-146.4551C52.126,-138.3828 52.126,-128.6764 52.126,-119.6817\" fill=\"none\" stroke=\"#000000\"/>\n",
"<polygon fill=\"#000000\" points=\"55.6261,-119.5903 52.126,-109.5904 48.6261,-119.5904 55.6261,-119.5903\" stroke=\"#000000\"/>\n",
"</g>\n",
"<!-- 4719489208 -->\n",
"<g class=\"node\" id=\"node4\">\n",
"<title>4719489208</title>\n",
"<polygon fill=\"none\" points=\"0,-.5 0,-36.5 104.252,-36.5 104.252,-.5 0,-.5\" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"52.126\" y=\"-14.3\">dense_7: Dense</text>\n",
"</g>\n",
"<!-- 4719394376&#45;&gt;4719489208 -->\n",
"<g class=\"edge\" id=\"edge3\">\n",
"<title>4719394376-&gt;4719489208</title>\n",
"<path d=\"M52.126,-73.4551C52.126,-65.3828 52.126,-55.6764 52.126,-46.6817\" fill=\"none\" stroke=\"#000000\"/>\n",
"<polygon fill=\"#000000\" points=\"55.6261,-46.5903 52.126,-36.5904 48.6261,-46.5904 55.6261,-46.5903\" stroke=\"#000000\"/>\n",
"</g>\n",
"</g>\n",
"</svg>"
],
"text/plain": [
"<IPython.core.display.SVG object>"
]
},
"execution_count": 77,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def build_model():\n",
" # Set up a single network (LSTM + Dense))\n",
" inp = Input(shape=(None, 1), dtype=\"float32\", name=\"in\")\n",
" lstm = LSTM(lstm_units, return_sequences=False, name=\"lstm\")(inp)\n",
" dense = Dense(dense_breadth, kernel_initializer='normal', activation='relu')(lstm)\n",
" outputs = Dense(1, kernel_initializer='normal')(dense)\n",
" return Model(inputs=inp, outputs=outputs)\n",
"\n",
"model = build_model()\n",
"model.compile(optimizer=\"adam\", loss=\"mean_squared_error\", metrics=['mae'])\n",
"\n",
"from IPython.display import SVG\n",
"from tensorflow.python.keras.utils.vis_utils import model_to_dot\n",
"\n",
"SVG(model_to_dot(model).create(prog='dot', format='svg'))"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"# Generate Dataset\n",
"x_train, y_train, len_train = gen_dataset(dataset_size)\n",
"x_val, y_val, len_val = gen_dataset(dataset_size)"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training with 10 non-empty buckets\n",
"Training with 10 non-empty buckets\n",
"Epoch 1/20\n",
"163/163 [==============================] - 35s 218ms/step - loss: 195.1114 - mean_absolute_error: 10.3456 - val_loss: 118.8506 - val_mean_absolute_error: 7.5571\n",
"Epoch 2/20\n",
"163/163 [==============================] - 36s 219ms/step - loss: 120.9951 - mean_absolute_error: 7.6279 - val_loss: 121.1408 - val_mean_absolute_error: 7.4981\n",
"Epoch 3/20\n",
"163/163 [==============================] - 37s 230ms/step - loss: 122.0102 - mean_absolute_error: 7.6568 - val_loss: 118.5691 - val_mean_absolute_error: 7.4961\n",
"Epoch 4/20\n",
"163/163 [==============================] - 37s 229ms/step - loss: 124.0299 - mean_absolute_error: 7.7591 - val_loss: 122.0257 - val_mean_absolute_error: 7.5287\n",
"Epoch 5/20\n",
"163/163 [==============================] - 39s 239ms/step - loss: 124.0682 - mean_absolute_error: 7.8174 - val_loss: 119.2843 - val_mean_absolute_error: 7.7904\n",
"Epoch 6/20\n",
"163/163 [==============================] - 37s 229ms/step - loss: 119.0774 - mean_absolute_error: 7.5436 - val_loss: 118.4170 - val_mean_absolute_error: 7.4253\n",
"Epoch 7/20\n",
"163/163 [==============================] - 44s 271ms/step - loss: 121.2110 - mean_absolute_error: 7.6321 - val_loss: 118.0953 - val_mean_absolute_error: 7.5249\n",
"Epoch 8/20\n",
"163/163 [==============================] - 43s 263ms/step - loss: 121.8157 - mean_absolute_error: 7.7071 - val_loss: 117.6462 - val_mean_absolute_error: 7.3396\n",
"Epoch 9/20\n",
"163/163 [==============================] - 41s 252ms/step - loss: 115.4723 - mean_absolute_error: 7.5024 - val_loss: 91.7777 - val_mean_absolute_error: 6.6707\n",
"Epoch 10/20\n",
"163/163 [==============================] - 47s 286ms/step - loss: 53.6427 - mean_absolute_error: 5.0593 - val_loss: 15.4274 - val_mean_absolute_error: 2.6593\n",
"Epoch 11/20\n",
"163/163 [==============================] - 37s 227ms/step - loss: 11.2502 - mean_absolute_error: 2.2828 - val_loss: 5.5709 - val_mean_absolute_error: 1.3830\n",
"Epoch 12/20\n",
"163/163 [==============================] - 36s 224ms/step - loss: 8.2588 - mean_absolute_error: 1.8053 - val_loss: 3.8581 - val_mean_absolute_error: 1.1801\n",
"Epoch 13/20\n",
"163/163 [==============================] - 36s 224ms/step - loss: 5.0152 - mean_absolute_error: 1.2891 - val_loss: 13.0625 - val_mean_absolute_error: 2.6455\n",
"Epoch 14/20\n",
"163/163 [==============================] - 36s 224ms/step - loss: 5.0283 - mean_absolute_error: 1.4202 - val_loss: 2.6965 - val_mean_absolute_error: 0.9099\n",
"Epoch 15/20\n",
"163/163 [==============================] - 36s 224ms/step - loss: 3.2636 - mean_absolute_error: 1.0700 - val_loss: 2.5475 - val_mean_absolute_error: 0.9087\n",
"Epoch 16/20\n",
"163/163 [==============================] - 37s 226ms/step - loss: 2.5395 - mean_absolute_error: 0.9430 - val_loss: 1.8564 - val_mean_absolute_error: 0.7822\n",
"Epoch 17/20\n",
"163/163 [==============================] - 38s 231ms/step - loss: 2.4931 - mean_absolute_error: 0.9311 - val_loss: 1.3757 - val_mean_absolute_error: 0.6267\n",
"Epoch 18/20\n",
"163/163 [==============================] - 37s 230ms/step - loss: 2.2431 - mean_absolute_error: 0.8908 - val_loss: 1.2690 - val_mean_absolute_error: 0.6079\n",
"Epoch 19/20\n",
"163/163 [==============================] - 37s 229ms/step - loss: 3.3244 - mean_absolute_error: 1.1000 - val_loss: 1.8255 - val_mean_absolute_error: 0.8506\n",
"Epoch 20/20\n",
"163/163 [==============================] - 36s 223ms/step - loss: 1.4627 - mean_absolute_error: 0.7011 - val_loss: 1.1646 - val_mean_absolute_error: 0.5915\n"
]
}
],
"source": [
"if buckets > 0:\n",
" # Create Sequence objects\n",
" train_generator = BucketedSequence(buckets, batch_size, len_train, x_train, y_train)\n",
" val_generator = BucketedSequence(buckets, batch_size, len_val, x_val, y_val)\n",
" \n",
" model.fit_generator(train_generator, \n",
" epochs=epochs, \n",
" validation_data=val_generator, \n",
" shuffle=True, \n",
" verbose=True)\n",
" \n",
"else:\n",
" model.fit(x=x_train, y=y_train, \n",
" epochs=epochs, \n",
" validation_data=(x_val, y_val), \n",
" batch_size=batch_size, \n",
" verbose=True, \n",
" shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment