Skip to content

Instantly share code, notes, and snippets.

@fogside
Created March 19, 2017 11:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fogside/af7a896ede5229da173cd4c79e44b343 to your computer and use it in GitHub Desktop.
Save fogside/af7a896ede5229da173cd4c79e44b343 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"deletable": true,
"editable": true,
"id": "8tQJd2YSCfWR"
},
"source": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"deletable": true,
"editable": true,
"id": "D7tqLMoKF6uq"
},
"source": [
"Deep Learning\n",
"=============\n",
"\n",
"Assignment 6\n",
"------------\n",
"\n",
"After training a skip-gram model in `5_word2vec.ipynb`, the goal of this notebook is to train a LSTM character model over [Text8](http://mattmahoney.net/dc/textdata) data."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"cellView": "both",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"collapsed": true,
"deletable": true,
"editable": true,
"id": "MvEblsgEXxrd"
},
"outputs": [],
"source": [
"# These are all the modules we'll be using later. Make sure you can import them\n",
"# before proceeding further.\n",
"from __future__ import print_function\n",
"import os\n",
"import numpy as np\n",
"import random\n",
"import string\n",
"import tensorflow as tf\n",
"import zipfile\n",
"from six.moves import range\n",
"from six.moves.urllib.request import urlretrieve"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"cellView": "both",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"output_extras": [
{
"item_id": 1
}
]
},
"colab_type": "code",
"collapsed": false,
"deletable": true,
"editable": true,
"executionInfo": {
"elapsed": 5993,
"status": "ok",
"timestamp": 1445965582896,
"user": {
"color": "#1FA15D",
"displayName": "Vincent Vanhoucke",
"isAnonymous": false,
"isMe": true,
"permissionId": "05076109866853157986",
"photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg",
"sessionId": "6f6f07b359200c46",
"userId": "102167687554210253930"
},
"user_tz": 420
},
"id": "RJ-o3UBUFtCw",
"outputId": "d530534e-0791-4a94-ca6d-1c8f1b908a9e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found and verified text8.zip\n"
]
}
],
"source": [
"url = 'http://mattmahoney.net/dc/'\n",
"\n",
"def maybe_download(filename, expected_bytes):\n",
" \"\"\"Download a file if not present, and make sure it's the right size.\"\"\"\n",
" if not os.path.exists(filename):\n",
" filename, _ = urlretrieve(url + filename, filename)\n",
" statinfo = os.stat(filename)\n",
" if statinfo.st_size == expected_bytes:\n",
" print('Found and verified %s' % filename)\n",
" else:\n",
" print(statinfo.st_size)\n",
" raise Exception(\n",
" 'Failed to verify ' + filename + '. Can you get to it with a browser?')\n",
" return filename\n",
"\n",
"filename = maybe_download('text8.zip', 31344016)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"cellView": "both",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"output_extras": [
{
"item_id": 1
}
]
},
"colab_type": "code",
"collapsed": false,
"deletable": true,
"editable": true,
"executionInfo": {
"elapsed": 5982,
"status": "ok",
"timestamp": 1445965582916,
"user": {
"color": "#1FA15D",
"displayName": "Vincent Vanhoucke",
"isAnonymous": false,
"isMe": true,
"permissionId": "05076109866853157986",
"photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg",
"sessionId": "6f6f07b359200c46",
"userId": "102167687554210253930"
},
"user_tz": 420
},
"id": "Mvf09fjugFU_",
"outputId": "8f75db58-3862-404b-a0c3-799380597390"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Data size 100000000\n"
]
}
],
"source": [
"def read_data(filename):\n",
" with zipfile.ZipFile(filename) as f:\n",
" name = f.namelist()[0]\n",
" data = tf.compat.as_str(f.read(name))\n",
" return data\n",
" \n",
"text = read_data(filename)\n",
"print('Data size %d' % len(text))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"deletable": true,
"editable": true,
"id": "ga2CYACE-ghb"
},
"source": [
"Create a small validation set."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"cellView": "both",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"output_extras": [
{
"item_id": 1
}
]
},
"colab_type": "code",
"collapsed": false,
"deletable": true,
"editable": true,
"executionInfo": {
"elapsed": 6184,
"status": "ok",
"timestamp": 1445965583138,
"user": {
"color": "#1FA15D",
"displayName": "Vincent Vanhoucke",
"isAnonymous": false,
"isMe": true,
"permissionId": "05076109866853157986",
"photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg",
"sessionId": "6f6f07b359200c46",
"userId": "102167687554210253930"
},
"user_tz": 420
},
"id": "w-oBpfFG-j43",
"outputId": "bdb96002-d021-4379-f6de-a977924f0d02"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"99999000 ons anarchists advocate social relations based upon voluntary as\n",
"1000 anarchism originated as a term of abuse first used against earl\n"
]
}
],
"source": [
"valid_size = 1000\n",
"valid_text = text[:valid_size]\n",
"train_text = text[valid_size:]\n",
"train_size = len(train_text)\n",
"print(train_size, train_text[:64])\n",
"print(valid_size, valid_text[:64])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"deletable": true,
"editable": true,
"id": "Zdw6i4F8glpp"
},
"source": [
"Utility functions to map characters to vocabulary IDs and back."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"cellView": "both",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"output_extras": [
{
"item_id": 1
}
]
},
"colab_type": "code",
"collapsed": false,
"deletable": true,
"editable": true,
"executionInfo": {
"elapsed": 6276,
"status": "ok",
"timestamp": 1445965583249,
"user": {
"color": "#1FA15D",
"displayName": "Vincent Vanhoucke",
"isAnonymous": false,
"isMe": true,
"permissionId": "05076109866853157986",
"photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg",
"sessionId": "6f6f07b359200c46",
"userId": "102167687554210253930"
},
"user_tz": 420
},
"id": "gAL1EECXeZsD",
"outputId": "88fc9032-feb9-45ff-a9a0-a26759cc1f2e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Unexpected character: ï\n",
"1 26 0 0\n",
"a z \n"
]
}
],
"source": [
"vocabulary_size = len(string.ascii_lowercase) + 1 # [a-z] + ' '\n",
"first_letter = ord(string.ascii_lowercase[0])\n",
"\n",
"def char2id(char):\n",
" if char in string.ascii_lowercase:\n",
" return ord(char) - first_letter + 1\n",
" elif char == ' ':\n",
" return 0\n",
" else:\n",
" print('Unexpected character: %s' % char)\n",
" return 0\n",
" \n",
"def id2char(dictid):\n",
" if dictid > 0:\n",
" return chr(dictid + first_letter - 1)\n",
" else:\n",
" return ' '\n",
"\n",
"print(char2id('a'), char2id('z'), char2id(' '), char2id('ï'))\n",
"print(id2char(1), id2char(26), id2char(0))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"deletable": true,
"editable": true,
"id": "lFwoyygOmWsL"
},
"source": [
"Function to generate a training batch for the LSTM model."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"cellView": "both",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"output_extras": [
{
"item_id": 1
}
]
},
"colab_type": "code",
"collapsed": false,
"deletable": true,
"editable": true,
"executionInfo": {
"elapsed": 6473,
"status": "ok",
"timestamp": 1445965583467,
"user": {
"color": "#1FA15D",
"displayName": "Vincent Vanhoucke",
"isAnonymous": false,
"isMe": true,
"permissionId": "05076109866853157986",
"photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg",
"sessionId": "6f6f07b359200c46",
"userId": "102167687554210253930"
},
"user_tz": 420
},
"id": "d9wMtjy5hCj9",
"outputId": "3dd79c80-454a-4be0-8b71-4a4a357b3367"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['ons anarchi', 'when milita', 'lleria arch', ' abbeys and', 'married urr', 'hel and ric', 'y and litur', 'ay opened f', 'tion from t', 'migration t', 'new york ot', 'he boeing s', 'e listed wi', 'eber has pr', 'o be made t', 'yer who rec', 'ore signifi', 'a fierce cr', ' two six ei', 'aristotle s', 'ity can be ', ' and intrac', 'tion of the', 'dy to pass ', 'f certain d', 'at it will ', 'e convince ', 'ent told hi', 'ampaign and', 'rver side s', 'ious texts ', 'o capitaliz', 'a duplicate', 'gh ann es d', 'ine january', 'ross zero t', 'cal theorie', 'ast instanc', ' dimensiona', 'most holy m', 't s support', 'u is still ', 'e oscillati', 'o eight sub', 'of italy la', 's the tower', 'klahoma pre', 'erprise lin', 'ws becomes ', 'et in a naz', 'the fabian ', 'etchy to re', ' sharman ne', 'ised empero', 'ting in pol', 'd neo latin', 'th risky ri', 'encyclopedi', 'fense the a', 'duating fro', 'treet grid ', 'ations more', 'appeal of d', 'si have mad']\n",
"['ists advoca', 'ary governm', 'hes nationa', 'd monasteri', 'raca prince', 'chard baer ', 'rgical lang', 'for passeng', 'the nationa', 'took place ', 'ther well k', 'seven six s', 'ith a gloss', 'robably bee', 'to recogniz', 'ceived the ', 'icant than ', 'ritic of th', 'ight in sig', 's uncaused ', ' lost as in', 'cellular ic', 'e size of t', ' him a stic', 'drugs confu', ' take to co', ' the priest', 'im to name ', 'd barred at', 'standard fo', ' such as es', 'ze on the g', 'e of the or', 'd hiver one', 'y eight mar', 'the lead ch', 'es classica', 'ce the non ', 'al analysis', 'mormons bel', 't or at lea', ' disagreed ', 'ing system ', 'btypes base', 'anguages th', 'r commissio', 'ess one nin', 'nux suse li', ' the first ', 'zi concentr', ' society ne', 'elatively s', 'etworks sha', 'or hirohito', 'litical ini', 'n most of t', 'iskerdoo ri', 'ic overview', 'air compone', 'om acnm acc', ' centerline', 'e than any ', 'devotional ', 'de such dev']\n",
"[' a']\n",
"['an']\n"
]
}
],
"source": [
"batch_size=64\n",
"num_unrollings=10\n",
"\n",
"class BatchGenerator(object):\n",
" def __init__(self, text, batch_size, num_unrollings):\n",
" self._text = text\n",
" self._text_size = len(text)\n",
" self._batch_size = batch_size\n",
" self._num_unrollings = num_unrollings\n",
" segment = self._text_size // batch_size\n",
" self._cursor = [ offset * segment for offset in range(batch_size)]\n",
" self._last_batch = self._next_batch()\n",
" \n",
" def _next_batch(self):\n",
" \"\"\"Generate a single batch from the current cursor position in the data.\"\"\"\n",
" batch = np.zeros(shape=(self._batch_size, vocabulary_size), dtype=np.float)\n",
" for b in range(self._batch_size):\n",
" batch[b, char2id(self._text[self._cursor[b]])] = 1.0\n",
" self._cursor[b] = (self._cursor[b] + 1) % self._text_size\n",
" return batch\n",
" \n",
" def next(self):\n",
" \"\"\"Generate the next array of batches from the data. The array consists of\n",
" the last batch of the previous array, followed by num_unrollings new ones.\n",
" \"\"\"\n",
" batches = [self._last_batch]\n",
" for step in range(self._num_unrollings):\n",
" batches.append(self._next_batch())\n",
" self._last_batch = batches[-1]\n",
" return batches\n",
"\n",
"def characters(probabilities):\n",
" \"\"\"Turn a 1-hot encoding or a probability distribution over the possible\n",
" characters back into its (most likely) character representation.\"\"\"\n",
" return [id2char(c) for c in np.argmax(probabilities, 1)]\n",
"\n",
"def batches2string(batches):\n",
" \"\"\"Convert a sequence of batches back into their (most likely) string\n",
" representation.\"\"\"\n",
" s = [''] * batches[0].shape[0]\n",
" for b in batches:\n",
" s = [''.join(x) for x in zip(s, characters(b))]\n",
" return s\n",
"\n",
"train_batches = BatchGenerator(train_text, batch_size, num_unrollings)\n",
"valid_batches = BatchGenerator(valid_text, 1, 1)\n",
"\n",
"print(batches2string(train_batches.next()))\n",
"print(batches2string(train_batches.next()))\n",
"print(batches2string(valid_batches.next()))\n",
"print(batches2string(valid_batches.next()))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"cellView": "both",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"collapsed": true,
"deletable": true,
"editable": true,
"id": "KyVd8FxT5QBc"
},
"outputs": [],
"source": [
"def logprob(predictions, labels):\n",
" \"\"\"Log-probability of the true labels in a predicted batch.\"\"\"\n",
" predictions[predictions < 1e-10] = 1e-10\n",
" return np.sum(np.multiply(labels, -np.log(predictions))) / labels.shape[0]\n",
"\n",
"def sample_distribution(distribution):\n",
" \"\"\"Sample one element from a distribution assumed to be an array of normalized\n",
" probabilities.\n",
" \"\"\"\n",
" r = random.uniform(0, 1)\n",
" s = 0\n",
" for i in range(len(distribution)):\n",
" s += distribution[i]\n",
" if s >= r:\n",
" return i\n",
" return len(distribution) - 1\n",
"\n",
"def sample(prediction):\n",
" \"\"\"Turn a (column) prediction into 1-hot encoded samples.\"\"\"\n",
" p = np.zeros(shape=[1, vocabulary_size], dtype=np.float)\n",
" p[0, sample_distribution(prediction[0])] = 1.0\n",
" return p\n",
"\n",
"def random_distribution():\n",
" \"\"\"Generate a random column of probabilities.\"\"\"\n",
" b = np.random.uniform(0.0, 1.0, size=[1, vocabulary_size])\n",
" return b/np.sum(b, 1)[:,None]"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"deletable": true,
"editable": true,
"id": "K8f67YXaDr4C"
},
"source": [
"Simple LSTM Model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "both",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"collapsed": true,
"deletable": true,
"editable": true,
"id": "Q5rxZK6RDuGe"
},
"outputs": [],
"source": [
"num_nodes = 64\n",
"\n",
"graph = tf.Graph()\n",
"with graph.as_default():\n",
" \n",
" # Parameters:\n",
" # Input gate: input, previous output, and bias.\n",
" ix = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))\n",
" im = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))\n",
" ib = tf.Variable(tf.zeros([1, num_nodes]))\n",
" # Forget gate: input, previous output, and bias.\n",
" fx = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))\n",
" fm = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))\n",
" fb = tf.Variable(tf.zeros([1, num_nodes]))\n",
" # Memory cell: input, state and bias. \n",
" cx = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))\n",
" cm = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))\n",
" cb = tf.Variable(tf.zeros([1, num_nodes]))\n",
" # Output gate: input, previous output, and bias.\n",
" ox = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))\n",
" om = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))\n",
" ob = tf.Variable(tf.zeros([1, num_nodes]))\n",
" # Variables saving state across unrollings.\n",
" saved_output = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)\n",
" saved_state = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)\n",
" # Classifier weights and biases.\n",
" w = tf.Variable(tf.truncated_normal([num_nodes, vocabulary_size], -0.1, 0.1))\n",
" b = tf.Variable(tf.zeros([vocabulary_size]))\n",
" \n",
" # Definition of the cell computation.\n",
" def lstm_cell(i, o, state):\n",
" \"\"\"Create a LSTM cell. See e.g.: http://arxiv.org/pdf/1402.1128v1.pdf\n",
" Note that in this formulation, we omit the various connections between the\n",
" previous state and the gates.\"\"\"\n",
" input_gate = tf.sigmoid(tf.matmul(i, ix) + tf.matmul(o, im) + ib)\n",
" forget_gate = tf.sigmoid(tf.matmul(i, fx) + tf.matmul(o, fm) + fb)\n",
" update = tf.matmul(i, cx) + tf.matmul(o, cm) + cb\n",
" state = forget_gate * state + input_gate * tf.tanh(update)\n",
" output_gate = tf.sigmoid(tf.matmul(i, ox) + tf.matmul(o, om) + ob)\n",
" return output_gate * tf.tanh(state), state\n",
" \n",
"\n",
" # Input data.\n",
" train_data = list()\n",
" for _ in range(num_unrollings + 1):\n",
" train_data.append(\n",
" tf.placeholder(tf.float32, shape=[batch_size,vocabulary_size]))\n",
" train_inputs = train_data[:num_unrollings]\n",
" train_labels = train_data[1:] # labels are inputs shifted by one time step.\n",
"\n",
" # Unrolled LSTM loop.\n",
" outputs = list()\n",
" output = saved_output\n",
" state = saved_state\n",
" for i in train_inputs:\n",
" output, state = lstm_cell(i, output, state)\n",
" outputs.append(output)\n",
"\n",
" # State saving across unrollings.\n",
" with tf.control_dependencies([saved_output.assign(output),\n",
" saved_state.assign(state)]):\n",
" # Classifier.\n",
" logits = tf.nn.xw_plus_b(tf.concat(outputs, 0), w, b)\n",
" loss = tf.reduce_mean(\n",
" tf.nn.softmax_cross_entropy_with_logits(\n",
" labels=tf.concat(train_labels, 0), logits=logits))\n",
"\n",
" # Optimizer.\n",
" global_step = tf.Variable(0)\n",
" #### tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=False, name=None)\n",
" #### decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)\n",
" \n",
" learning_rate = tf.train.exponential_decay(\n",
" 10.0, global_step, 5000, 0.1, staircase=True) ## so strange to use global_step == 0..\n",
" optimizer = tf.train.GradientDescentOptimizer(learning_rate)\n",
" gradients, v = zip(*optimizer.compute_gradients(loss))\n",
" gradients, _ = tf.clip_by_global_norm(gradients, 1.25)\n",
" optimizer = optimizer.apply_gradients(\n",
" zip(gradients, v), global_step=global_step)\n",
"\n",
" # Predictions.\n",
" train_prediction = tf.nn.softmax(logits)\n",
" \n",
" # Sampling and validation eval: batch 1, no unrolling.\n",
" sample_input = tf.placeholder(tf.float32, shape=[1, vocabulary_size])\n",
" saved_sample_output = tf.Variable(tf.zeros([1, num_nodes]))\n",
" saved_sample_state = tf.Variable(tf.zeros([1, num_nodes]))\n",
" reset_sample_state = tf.group(\n",
" saved_sample_output.assign(tf.zeros([1, num_nodes])),\n",
" saved_sample_state.assign(tf.zeros([1, num_nodes])))\n",
" sample_output, sample_state = lstm_cell(\n",
" sample_input, saved_sample_output, saved_sample_state)\n",
" with tf.control_dependencies([saved_sample_output.assign(sample_output),\n",
" saved_sample_state.assign(sample_state)]):\n",
" sample_prediction = tf.nn.softmax(tf.nn.xw_plus_b(sample_output, w, b))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "both",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"output_extras": [
{
"item_id": 41
},
{
"item_id": 80
},
{
"item_id": 126
},
{
"item_id": 144
}
]
},
"colab_type": "code",
"collapsed": false,
"deletable": true,
"editable": true,
"executionInfo": {
"elapsed": 199909,
"status": "ok",
"timestamp": 1445965877333,
"user": {
"color": "#1FA15D",
"displayName": "Vincent Vanhoucke",
"isAnonymous": false,
"isMe": true,
"permissionId": "05076109866853157986",
"photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg",
"sessionId": "6f6f07b359200c46",
"userId": "102167687554210253930"
},
"user_tz": 420
},
"id": "RD9zQCZTEaEm",
"outputId": "5e868466-2532-4545-ce35-b403cf5d9de6"
},
"outputs": [],
"source": [
"num_steps = 7001\n",
"summary_frequency = 100\n",
"\n",
"with tf.Session(graph=graph) as session:\n",
" tf.global_variables_initializer().run()\n",
" print('Initialized')\n",
" mean_loss = 0\n",
" for step in range(num_steps):\n",
" batches = train_batches.next()\n",
" feed_dict = dict()\n",
" for i in range(num_unrollings + 1):\n",
" feed_dict[train_data[i]] = batches[i]\n",
" _, l, predictions, lr = session.run(\n",
" [optimizer, loss, train_prediction, learning_rate], feed_dict=feed_dict)\n",
" mean_loss += l\n",
" if step % summary_frequency == 0:\n",
" if step > 0:\n",
" mean_loss = mean_loss / summary_frequency\n",
" # The mean loss is an estimate of the loss over the last few batches.\n",
" print(\n",
" 'Average loss at step %d: %f learning rate: %f' % (step, mean_loss, lr))\n",
" mean_loss = 0\n",
" labels = np.concatenate(list(batches)[1:])\n",
" print('Minibatch perplexity: %.2f' % float(\n",
" np.exp(logprob(predictions, labels))))\n",
" if step % (summary_frequency * 10) == 0:\n",
" # Generate some samples.\n",
" print('=' * 80)\n",
" for _ in range(5):\n",
" feed = sample(random_distribution())\n",
" sentence = characters(feed)[0]\n",
" reset_sample_state.run()\n",
" for _ in range(79):\n",
" prediction = sample_prediction.eval({sample_input: feed})\n",
" feed = sample(prediction)\n",
" sentence += characters(feed)[0]\n",
" print(sentence)\n",
" print('=' * 80)\n",
" # Measure validation set perplexity.\n",
" reset_sample_state.run()\n",
" valid_logprob = 0\n",
" for _ in range(valid_size):\n",
" b = valid_batches.next()\n",
" predictions = sample_prediction.eval({sample_input: b[0]})\n",
" valid_logprob = valid_logprob + logprob(predictions, b[1])\n",
" print('Validation set perplexity: %.2f' % float(np.exp(\n",
" valid_logprob / valid_size)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"deletable": true,
"editable": true,
"id": "pl4vtmFfa5nn"
},
"source": [
"---\n",
"Problem 1\n",
"---------\n",
"\n",
"You might have noticed that the definition of the LSTM cell involves 4 matrix multiplications with the input, and 4 matrix multiplications with the output. Simplify the expression by using a single matrix multiply for each, and variables that are 4 times larger.\n",
"\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"num_nodes = 64\n",
"\n",
"graph = tf.Graph()\n",
"with graph.as_default():\n",
" \n",
" # Parameters:\n",
" # Input gate: input, previous output, and bias.\n",
" in_mtx = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes*4], -0.1, 0.1))\n",
" out_mtx = tf.Variable(tf.truncated_normal([num_nodes, num_nodes*4], -0.1, 0.1))\n",
" b_vec = tf.Variable(tf.zeros([1, num_nodes*4]))\n",
" \n",
" # Variables saving state across unrollings.\n",
" saved_output = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)\n",
" saved_state = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)\n",
" # Classifier weights and biases.\n",
" w = tf.Variable(tf.truncated_normal([num_nodes, vocabulary_size], -0.1, 0.1))\n",
" b = tf.Variable(tf.zeros([vocabulary_size]))\n",
" \n",
" # Definition of the cell computation.\n",
" def lstm_cell(i, o, state):\n",
" \"\"\"Create a LSTM cell. See e.g.: http://arxiv.org/pdf/1402.1128v1.pdf\n",
" Note that in this formulation, we omit the various connections between the\n",
" previous state and the gates.\"\"\"\n",
" 'in product_tmp consequentially: input_gate, forget_gate, update, output_gate'\n",
" product_tmp = tf.matmul(i, in_mtx) + tf.matmul(o, out_mtx) + b_vec\n",
" input_gate, forget_gate, output_gate, update = tf.split(product_tmp, num_or_size_splits=4, axis=1)\n",
" input_gate = tf.sigmoid(input_gate)\n",
" forget_gate = tf.sigmoid(forget_gate)\n",
" output_gate = tf.sigmoid(output_gate)\n",
" state = forget_gate * state + input_gate * tf.tanh(update)\n",
" return output_gate * tf.tanh(state), state\n",
"\n",
" # Input data.\n",
" train_data = list()\n",
" for _ in range(num_unrollings + 1):\n",
" train_data.append(\n",
" tf.placeholder(tf.float32, shape=[batch_size,vocabulary_size]))\n",
" train_inputs = train_data[:num_unrollings]\n",
" train_labels = train_data[1:] # labels are inputs shifted by one time step.\n",
"\n",
" # Unrolled LSTM loop.\n",
" outputs = list()\n",
" output = saved_output\n",
" state = saved_state\n",
" for i in train_inputs:\n",
" output, state = lstm_cell(i, output, state)\n",
" outputs.append(output)\n",
"\n",
" # State saving across unrollings.\n",
" with tf.control_dependencies([saved_output.assign(output),\n",
" saved_state.assign(state)]):\n",
" # Classifier.\n",
" logits = tf.nn.xw_plus_b(tf.concat(outputs, 0), w, b)\n",
" loss = tf.reduce_mean(\n",
" tf.nn.softmax_cross_entropy_with_logits(\n",
" labels=tf.concat(train_labels, 0), logits=logits))\n",
"\n",
" # Optimizer.\n",
" global_step = tf.Variable(0)\n",
" #### tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=False, name=None)\n",
" #### decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)\n",
" \n",
" learning_rate = tf.train.exponential_decay(\n",
" 10.0, global_step, 5000, 0.1, staircase=True) ## so strange to use global_step == 0..\n",
" optimizer = tf.train.GradientDescentOptimizer(learning_rate)\n",
" gradients, v = zip(*optimizer.compute_gradients(loss))\n",
" gradients, _ = tf.clip_by_global_norm(gradients, 1.25)\n",
" optimizer = optimizer.apply_gradients(\n",
" zip(gradients, v), global_step=global_step)\n",
"\n",
" # Predictions.\n",
" train_prediction = tf.nn.softmax(logits)\n",
" \n",
" # Sampling and validation eval: batch 1, no unrolling.\n",
" sample_input = tf.placeholder(tf.float32, shape=[1, vocabulary_size])\n",
" saved_sample_output = tf.Variable(tf.zeros([1, num_nodes]))\n",
" saved_sample_state = tf.Variable(tf.zeros([1, num_nodes]))\n",
" reset_sample_state = tf.group(\n",
" saved_sample_output.assign(tf.zeros([1, num_nodes])),\n",
" saved_sample_state.assign(tf.zeros([1, num_nodes])))\n",
" sample_output, sample_state = lstm_cell(\n",
" sample_input, saved_sample_output, saved_sample_state)\n",
" with tf.control_dependencies([saved_sample_output.assign(sample_output),\n",
" saved_sample_state.assign(sample_state)]):\n",
" sample_prediction = tf.nn.softmax(tf.nn.xw_plus_b(sample_output, w, b))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Initialized\n",
"Average loss at step 0: 3.297306 learning rate: 10.000000\n",
"Minibatch perplexity: 27.04\n",
"================================================================================\n",
"en atts zabxa fdlf m masoiychlvaweyakpz htkthexkfyxs pekxqpe hwesdoaavatgxoubce\n",
"l ahyradd lne imwv kmpdfck rv bo pijjm o xjminu ei tf xdqcfsqmizkoyldyw q sda\n",
"ynh dok yqofdyrchthfc gue gaxnk o x yyf tkm seasaatezni l lkgydalogcoefnhxa\n",
"fdpsstxms ewlizxegaaq j xsuuks rs rrl zicfpnoi cietfjllkvqtgd ej vtikpuddx j \n",
"qtefwicfieqmcgkhpgnvuafjih i leb en ptpynkpttgtghifx jizdpywcxugveiccksihthqgql\n",
"================================================================================\n",
"Validation set perplexity: 20.19\n",
"Average loss at step 100: 2.593243 learning rate: 10.000000\n",
"Minibatch perplexity: 10.56\n",
"Validation set perplexity: 9.87\n",
"Average loss at step 200: 2.239796 learning rate: 10.000000\n",
"Minibatch perplexity: 8.50\n",
"Validation set perplexity: 8.37\n",
"Average loss at step 300: 2.093813 learning rate: 10.000000\n",
"Minibatch perplexity: 7.34\n",
"Validation set perplexity: 8.18\n",
"Average loss at step 400: 1.997686 learning rate: 10.000000\n",
"Minibatch perplexity: 7.42\n",
"Validation set perplexity: 7.74\n",
"Average loss at step 500: 1.931805 learning rate: 10.000000\n",
"Minibatch perplexity: 6.64\n",
"Validation set perplexity: 7.18\n",
"Average loss at step 600: 1.907364 learning rate: 10.000000\n",
"Minibatch perplexity: 6.28\n",
"Validation set perplexity: 6.98\n",
"Average loss at step 700: 1.857062 learning rate: 10.000000\n",
"Minibatch perplexity: 6.41\n",
"Validation set perplexity: 6.71\n",
"Average loss at step 800: 1.813501 learning rate: 10.000000\n",
"Minibatch perplexity: 6.02\n",
"Validation set perplexity: 6.52\n",
"Average loss at step 900: 1.825168 learning rate: 10.000000\n",
"Minibatch perplexity: 6.87\n",
"Validation set perplexity: 6.21\n",
"Average loss at step 1000: 1.823168 learning rate: 10.000000\n",
"Minibatch perplexity: 5.71\n",
"================================================================================\n",
"ter where sake zero his and injevated bribk with assmald birle eight his stral s\n",
"ne lacces dicter centers it a mearions on auconsied min or groken of seprigritil\n",
"uts comused psksta to make s and diffelice plut word for productl at and trests \n",
"ed the frien of the fave the has beol orter treee wort itwessed dusucso vare pe \n",
"n berch to veqice is the seefiverder secment id use six eriver or the live memoc\n",
"================================================================================\n",
"Validation set perplexity: 6.14\n",
"Average loss at step 1100: 1.775161 learning rate: 10.000000\n",
"Minibatch perplexity: 5.53\n",
"Validation set perplexity: 5.88\n",
"Average loss at step 1200: 1.754109 learning rate: 10.000000\n",
"Minibatch perplexity: 5.02\n",
"Validation set perplexity: 5.63\n",
"Average loss at step 1300: 1.726914 learning rate: 10.000000\n",
"Minibatch perplexity: 5.57\n",
"Validation set perplexity: 5.61\n",
"Average loss at step 1400: 1.745069 learning rate: 10.000000\n",
"Minibatch perplexity: 6.06\n",
"Validation set perplexity: 5.53\n",
"Average loss at step 1500: 1.737422 learning rate: 10.000000\n",
"Minibatch perplexity: 4.76\n",
"Validation set perplexity: 5.39\n",
"Average loss at step 1600: 1.742193 learning rate: 10.000000\n",
"Minibatch perplexity: 5.73\n",
"Validation set perplexity: 5.43\n",
"Average loss at step 1700: 1.709692 learning rate: 10.000000\n",
"Minibatch perplexity: 5.49\n",
"Validation set perplexity: 5.32\n",
"Average loss at step 1800: 1.674778 learning rate: 10.000000\n",
"Minibatch perplexity: 5.19\n",
"Validation set perplexity: 5.14\n",
"Average loss at step 1900: 1.643579 learning rate: 10.000000\n",
"Minibatch perplexity: 5.08\n",
"Validation set perplexity: 5.11\n",
"Average loss at step 2000: 1.691691 learning rate: 10.000000\n",
"Minibatch perplexity: 5.66\n",
"================================================================================\n",
"janked three propes on angorical unfairntical is of they four the s leadent of \n",
"ver unled and oct such doders aboud state and nine sevont rcage wordress coign t\n",
"latesion od birloge from planial concluse in one nine four and agepoics on prous\n",
"x phigk many one nine eight aust pres desponumets brou nasce world byongrone spa\n",
"oppentary his popular was it he reveatations were opirdous c a finst preside s \n",
"================================================================================\n",
"Validation set perplexity: 5.09\n",
"Average loss at step 2100: 1.683549 learning rate: 10.000000\n",
"Minibatch perplexity: 5.06\n",
"Validation set perplexity: 4.84\n",
"Average loss at step 2200: 1.679229 learning rate: 10.000000\n",
"Minibatch perplexity: 6.25\n",
"Validation set perplexity: 5.00\n",
"Average loss at step 2300: 1.635878 learning rate: 10.000000\n",
"Minibatch perplexity: 5.02\n",
"Validation set perplexity: 4.80\n",
"Average loss at step 2400: 1.655348 learning rate: 10.000000\n",
"Minibatch perplexity: 4.88\n",
"Validation set perplexity: 4.70\n",
"Average loss at step 2500: 1.679176 learning rate: 10.000000\n",
"Minibatch perplexity: 5.20\n",
"Validation set perplexity: 4.63\n",
"Average loss at step 2600: 1.651054 learning rate: 10.000000\n",
"Minibatch perplexity: 5.52\n",
"Validation set perplexity: 4.59\n",
"Average loss at step 2700: 1.658052 learning rate: 10.000000\n",
"Minibatch perplexity: 4.59\n",
"Validation set perplexity: 4.61\n",
"Average loss at step 2800: 1.650663 learning rate: 10.000000\n",
"Minibatch perplexity: 5.72\n",
"Validation set perplexity: 4.58\n",
"Average loss at step 2900: 1.651291 learning rate: 10.000000\n",
"Minibatch perplexity: 5.71\n",
"Validation set perplexity: 4.58\n",
"Average loss at step 3000: 1.650964 learning rate: 10.000000\n",
"Minibatch perplexity: 4.96\n",
"================================================================================\n",
"will as with form consides obs monsk as as bllibagation was amogratles known sef\n",
"graphicity the latin are poblict terponomation ently readon filh bo mahaments of\n",
"fff one nine eigho three s carlican subjoce ruboy of eight in some fawn corrante\n",
"jorsbers his he dissust be tass he stant play of lacrable boy dused posiofied di\n",
"ther in the iny ciunt avtwoke and to villiem two weltwa the fsictame kassic comp\n",
"================================================================================\n",
"Validation set perplexity: 4.56\n",
"Average loss at step 3100: 1.629150 learning rate: 10.000000\n",
"Minibatch perplexity: 5.78\n",
"Validation set perplexity: 4.45\n",
"Average loss at step 3200: 1.648842 learning rate: 10.000000\n",
"Minibatch perplexity: 5.85\n",
"Validation set perplexity: 4.54\n",
"Average loss at step 3300: 1.637304 learning rate: 10.000000\n",
"Minibatch perplexity: 5.10\n",
"Validation set perplexity: 4.50\n",
"Average loss at step 3400: 1.664272 learning rate: 10.000000\n",
"Minibatch perplexity: 5.51\n",
"Validation set perplexity: 4.49\n",
"Average loss at step 3500: 1.656060 learning rate: 10.000000\n",
"Minibatch perplexity: 5.60\n",
"Validation set perplexity: 4.56\n",
"Average loss at step 3600: 1.665079 learning rate: 10.000000\n",
"Minibatch perplexity: 4.43\n",
"Validation set perplexity: 4.50\n",
"Average loss at step 3700: 1.644942 learning rate: 10.000000\n",
"Minibatch perplexity: 5.09\n",
"Validation set perplexity: 4.47\n",
"Average loss at step 3800: 1.640834 learning rate: 10.000000\n",
"Minibatch perplexity: 5.63\n",
"Validation set perplexity: 4.60\n",
"Average loss at step 3900: 1.635533 learning rate: 10.000000\n",
"Minibatch perplexity: 5.13\n",
"Validation set perplexity: 4.54\n",
"Average loss at step 4000: 1.654129 learning rate: 10.000000\n",
"Minibatch perplexity: 4.72\n",
"================================================================================\n",
"emote yrach nol beliphint speikf to gene reganed termatifally propeetreds day em\n",
"foremaghishone sluss were regamedlam chromosa clamed dibecsifled have any bryzed\n",
"jewism usicial populated words of historical roman paying or japreces unherial c\n",
"reposanke in ornentrie names and gentaltisnochine dechacistick only periovegs ha\n",
"ques shar palegal the repti avody genezate with cultanks the larke theoress with\n",
"================================================================================\n",
"Validation set perplexity: 4.53\n",
"Average loss at step 4100: 1.633774 learning rate: 10.000000\n",
"Minibatch perplexity: 5.34\n",
"Validation set perplexity: 4.64\n",
"Average loss at step 4200: 1.639684 learning rate: 10.000000\n",
"Minibatch perplexity: 5.40\n",
"Validation set perplexity: 4.41\n",
"Average loss at step 4300: 1.615680 learning rate: 10.000000\n",
"Minibatch perplexity: 5.13\n",
"Validation set perplexity: 4.49\n",
"Average loss at step 4400: 1.608433 learning rate: 10.000000\n",
"Minibatch perplexity: 4.94\n",
"Validation set perplexity: 4.34\n",
"Average loss at step 4500: 1.614877 learning rate: 10.000000\n",
"Minibatch perplexity: 5.24\n",
"Validation set perplexity: 4.51\n",
"Average loss at step 4600: 1.614870 learning rate: 10.000000\n",
"Minibatch perplexity: 5.10\n",
"Validation set perplexity: 4.52\n",
"Average loss at step 4700: 1.625532 learning rate: 10.000000\n",
"Minibatch perplexity: 5.29\n",
"Validation set perplexity: 4.45\n",
"Average loss at step 4800: 1.631490 learning rate: 10.000000\n",
"Minibatch perplexity: 4.34\n",
"Validation set perplexity: 4.41\n",
"Average loss at step 4900: 1.635893 learning rate: 10.000000\n",
"Minibatch perplexity: 5.08\n",
"Validation set perplexity: 4.60\n",
"Average loss at step 5000: 1.607185 learning rate: 1.000000\n",
"Minibatch perplexity: 4.56\n",
"================================================================================\n",
"us atte reablow where mally in his s bracin s quickens has that as the has paten\n",
"gulogision one kibels the a canor according a a deleains as his sene of had sexp\n",
"s american that quart wish it can mostro duas entryot montement actore games of \n",
"port and nas heard humbanding sinques pstherated by down beatk seady coaland mat\n",
"winano age divikaly two zers it with numal creations lew hyperted bas of usain i\n",
"================================================================================\n",
"Validation set perplexity: 4.61\n",
"Average loss at step 5100: 1.607503 learning rate: 1.000000\n",
"Minibatch perplexity: 5.01\n",
"Validation set perplexity: 4.42\n",
"Average loss at step 5200: 1.593635 learning rate: 1.000000\n",
"Minibatch perplexity: 4.75\n",
"Validation set perplexity: 4.35\n",
"Average loss at step 5300: 1.579082 learning rate: 1.000000\n",
"Minibatch perplexity: 4.48\n",
"Validation set perplexity: 4.37\n",
"Average loss at step 5400: 1.581546 learning rate: 1.000000\n",
"Minibatch perplexity: 5.01\n",
"Validation set perplexity: 4.34\n",
"Average loss at step 5500: 1.568294 learning rate: 1.000000\n",
"Minibatch perplexity: 4.90\n",
"Validation set perplexity: 4.31\n",
"Average loss at step 5600: 1.582512 learning rate: 1.000000\n",
"Minibatch perplexity: 4.84\n",
"Validation set perplexity: 4.30\n",
"Average loss at step 5700: 1.569742 learning rate: 1.000000\n",
"Minibatch perplexity: 4.52\n",
"Validation set perplexity: 4.32\n",
"Average loss at step 5800: 1.577648 learning rate: 1.000000\n",
"Minibatch perplexity: 4.95\n",
"Validation set perplexity: 4.30\n",
"Average loss at step 5900: 1.574286 learning rate: 1.000000\n",
"Minibatch perplexity: 4.93\n",
"Validation set perplexity: 4.30\n",
"Average loss at step 6000: 1.550640 learning rate: 1.000000\n",
"Minibatch perplexity: 5.05\n",
"================================================================================\n",
"ine tellsco by by arabasit insussenter demagreted were problem of aga pen a one \n",
"ote the a duslic such the compubroyicion sole more he anobising long produced ho\n",
"b regension of the les one one six he mytwlatter present aport on the bs the qua\n",
"le lawrid similuting dodich b one nate to the our by pointed founder about gauge\n",
"de but statq music has a has a one nine all the can zooa with excetsle s instrag\n",
"================================================================================\n",
"Validation set perplexity: 4.30\n",
"Average loss at step 6100: 1.565242 learning rate: 1.000000\n",
"Minibatch perplexity: 5.05\n",
"Validation set perplexity: 4.26\n",
"Average loss at step 6200: 1.532255 learning rate: 1.000000\n",
"Minibatch perplexity: 4.76\n",
"Validation set perplexity: 4.26\n",
"Average loss at step 6300: 1.549893 learning rate: 1.000000\n",
"Minibatch perplexity: 5.05\n",
"Validation set perplexity: 4.24\n",
"Average loss at step 6400: 1.541046 learning rate: 1.000000\n",
"Minibatch perplexity: 4.48\n",
"Validation set perplexity: 4.24\n",
"Average loss at step 6500: 1.558128 learning rate: 1.000000\n",
"Minibatch perplexity: 4.69\n",
"Validation set perplexity: 4.23\n",
"Average loss at step 6600: 1.592973 learning rate: 1.000000\n",
"Minibatch perplexity: 4.79\n",
"Validation set perplexity: 4.23\n",
"Average loss at step 6700: 1.578549 learning rate: 1.000000\n",
"Minibatch perplexity: 5.05\n",
"Validation set perplexity: 4.25\n",
"Average loss at step 6800: 1.601822 learning rate: 1.000000\n",
"Minibatch perplexity: 4.69\n",
"Validation set perplexity: 4.24\n",
"Average loss at step 6900: 1.579455 learning rate: 1.000000\n",
"Minibatch perplexity: 4.79\n",
"Validation set perplexity: 4.26\n",
"Average loss at step 7000: 1.573577 learning rate: 1.000000\n",
"Minibatch perplexity: 4.89\n",
"================================================================================\n",
"x shoming the entrocuder leh when icoco of a curviel to plete dee posided the fa\n",
"chnorious versasticals manizen azerbi selters this rines bandus by antist the po\n",
"ar chroneementary laster were evinism myrey in the smopered yerops but one nine \n",
"chism of consiple wherehames clalimity of relained to shum namio stangiamearnum \n",
"y beconsows and his in the vertel sometent of the queed brath work time law is b\n",
"================================================================================\n",
"Validation set perplexity: 4.25\n"
]
}
],
"source": [
"num_steps = 7001\n",
"summary_frequency = 100\n",
"\n",
"with tf.Session(graph=graph) as session:\n",
" tf.global_variables_initializer().run()\n",
" print('Initialized')\n",
" mean_loss = 0\n",
" for step in range(num_steps):\n",
" batches = train_batches.next()\n",
" feed_dict = dict()\n",
" for i in range(num_unrollings + 1):\n",
" feed_dict[train_data[i]] = batches[i]\n",
" _, l, predictions, lr = session.run(\n",
" [optimizer, loss, train_prediction, learning_rate], feed_dict=feed_dict)\n",
" mean_loss += l\n",
" if step % summary_frequency == 0:\n",
" if step > 0:\n",
" mean_loss = mean_loss / summary_frequency\n",
" # The mean loss is an estimate of the loss over the last few batches.\n",
" print(\n",
" 'Average loss at step %d: %f learning rate: %f' % (step, mean_loss, lr))\n",
" mean_loss = 0\n",
" labels = np.concatenate(list(batches)[1:])\n",
" print('Minibatch perplexity: %.2f' % float(\n",
" np.exp(logprob(predictions, labels))))\n",
" if step % (summary_frequency * 10) == 0:\n",
" # Generate some samples.\n",
" print('=' * 80)\n",
" for _ in range(5):\n",
" feed = sample(random_distribution())\n",
" sentence = characters(feed)[0]\n",
" reset_sample_state.run()\n",
" for _ in range(79):\n",
" prediction = sample_prediction.eval({sample_input: feed})\n",
" feed = sample(prediction)\n",
" sentence += characters(feed)[0]\n",
" print(sentence)\n",
" print('=' * 80)\n",
" # Measure validation set perplexity.\n",
" reset_sample_state.run()\n",
" valid_logprob = 0\n",
" for _ in range(valid_size):\n",
" b = valid_batches.next()\n",
" predictions = sample_prediction.eval({sample_input: b[0]})\n",
" valid_logprob = valid_logprob + logprob(predictions, b[1])\n",
" print('Validation set perplexity: %.2f' % float(np.exp(\n",
" valid_logprob / valid_size)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"deletable": true,
"editable": true,
"id": "4eErTCTybtph"
},
"source": [
"---\n",
"Problem 2\n",
"---------\n",
"\n",
"We want to train a LSTM over bigrams, that is pairs of consecutive characters like 'ab' instead of single characters like 'a'. Since the number of possible bigrams is large, feeding them directly to the LSTM using 1-hot encodings will lead to a very sparse representation that is very wasteful computationally.\n",
"\n",
"a- Introduce an embedding lookup on the inputs, and feed the embeddings to the LSTM cell instead of the inputs themselves.\n",
"\n",
"b- Write a bigram-based LSTM, modeled on the character LSTM above.\n",
"\n",
"c- Introduce Dropout. For best practices on how to use Dropout in LSTMs, refer to this [article](http://arxiv.org/abs/1409.2329).\n",
"\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"num_nodes = 64\n",
"embedding_size = 128\n",
"keep_prob = 0.2 # The probability that each element is kept;\n",
" # the same for input and output;\n",
"\n",
"graph = tf.Graph()\n",
"with graph.as_default():\n",
" \n",
" # Parameters:\n",
" # Input gate: input, previous output, and bias.\n",
" in_mtx = tf.Variable(tf.truncated_normal([embedding_size, num_nodes*4], -0.1, 0.1))\n",
" out_mtx = tf.Variable(tf.truncated_normal([num_nodes, num_nodes*4], -0.1, 0.1))\n",
" b_vec = tf.Variable(tf.zeros([1, num_nodes*4]))\n",
" \n",
" # Variables saving state across unrollings.\n",
" saved_output = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)\n",
" saved_state = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)\n",
" # Classifier weights and biases.\n",
" w = tf.Variable(tf.truncated_normal([num_nodes, vocabulary_size], -0.1, 0.1))\n",
" b = tf.Variable(tf.zeros([vocabulary_size]))\n",
" embeddings_mtx = tf.Variable(tf.truncated_normal([vocabulary_size*vocabulary_size, embedding_size], -0.1, 0.1), trainable=True)\n",
" \n",
" # Definition of the cell computation.\n",
" def lstm_cell(i, o, state):\n",
" \"\"\"Create a LSTM cell. See e.g.: http://arxiv.org/pdf/1402.1128v1.pdf\n",
" Note that in this formulation, we omit the various connections between the\n",
" previous state and the gates.\"\"\"\n",
" 'in product_tmp consequentially: input_gate, forget_gate, update, output_gate'\n",
" product_tmp = tf.matmul(i, in_mtx) + tf.matmul(o, out_mtx) + b_vec\n",
" input_gate, forget_gate, output_gate, update = tf.split(product_tmp, num_or_size_splits=4, axis=1)\n",
" input_gate = tf.sigmoid(input_gate)\n",
" forget_gate = tf.sigmoid(forget_gate)\n",
" output_gate = tf.sigmoid(output_gate)\n",
" state = forget_gate * state + input_gate * tf.tanh(update)\n",
" output = output_gate * tf.tanh(state)\n",
" return output, state\n",
"\n",
" # Input data.\n",
" train_data = list()\n",
" for _ in range(num_unrollings + 1):\n",
" train_data.append(\n",
" tf.placeholder(tf.float32, shape=[batch_size, vocabulary_size]))\n",
" train_tmp = train_data[:num_unrollings]\n",
" train_inputs = zip(train_tmp[:-1], train_tmp[1:]) #creating bigrams\n",
" train_labels = train_data[2:] # labels are inputs shifted by one time step.\n",
"\n",
" # Unrolled LSTM loop.\n",
" outputs = list()\n",
" output_dropouted = saved_output\n",
" state = saved_state\n",
" for i in train_inputs:\n",
" input_idx = tf.argmax(i[0], dimension=1)*vocabulary_size + tf.argmax(i[1], dimension=1)\n",
" current_input = tf.nn.embedding_lookup(embeddings_mtx, input_idx)\n",
" input_dropouted = tf.nn.dropout(current_input, keep_prob)\n",
" output, state = lstm_cell(input_dropouted, output_dropouted, state)\n",
" output_dropouted = tf.nn.dropout(output, keep_prob)\n",
" outputs.append(output)\n",
"\n",
" # State saving across unrollings.\n",
" with tf.control_dependencies([saved_output.assign(output),\n",
" saved_state.assign(state)]):\n",
" # Classifier.\n",
" logits = tf.nn.xw_plus_b(tf.concat(outputs, 0), w, b)\n",
" loss = tf.reduce_mean(\n",
" tf.nn.softmax_cross_entropy_with_logits(\n",
" labels=tf.concat(train_labels, 0), logits=logits))\n",
"\n",
" # Optimizer.\n",
" global_step = tf.Variable(0)\n",
" #### tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=False, name=None)\n",
" #### decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)\n",
" \n",
" learning_rate = tf.train.exponential_decay(\n",
" 10.0, global_step, 5000, 0.1, staircase=True) ## so strange to use global_step == 0..\n",
" optimizer = tf.train.GradientDescentOptimizer(learning_rate)\n",
" gradients, v = zip(*optimizer.compute_gradients(loss))\n",
" gradients, _ = tf.clip_by_global_norm(gradients, 1.25)\n",
" optimizer = optimizer.apply_gradients(\n",
" zip(gradients, v), global_step=global_step)\n",
"\n",
" # Predictions.\n",
" train_prediction = tf.nn.softmax(logits)\n",
" \n",
" # Sampling and validation eval: batch 1, no unrolling.\n",
" sample_input = [tf.placeholder(tf.float32, shape=[1, vocabulary_size]),\n",
" tf.placeholder(tf.float32, shape=[1, vocabulary_size])]\n",
" bigrams_idx = tf.argmax(sample_input[0], dimension=1)*vocabulary_size + tf.argmax(sample_input[1], dimension=1)\n",
" sample_embeddings = tf.nn.embedding_lookup(embeddings_mtx, bigrams_idx)\n",
" saved_sample_output = tf.Variable(tf.zeros([1, num_nodes]))\n",
" saved_sample_state = tf.Variable(tf.zeros([1, num_nodes]))\n",
" reset_sample_state = tf.group(\n",
" saved_sample_output.assign(tf.zeros([1, num_nodes])),\n",
" saved_sample_state.assign(tf.zeros([1, num_nodes])))\n",
" sample_output, sample_state = lstm_cell(\n",
" sample_embeddings, saved_sample_output, saved_sample_state)\n",
" with tf.control_dependencies([saved_sample_output.assign(sample_output),\n",
" saved_sample_state.assign(sample_state)]):\n",
" sample_prediction = tf.nn.softmax(tf.nn.xw_plus_b(sample_output, w, b))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Initialized\n",
"Average loss at step 0: 3.318331 learning rate: 10.000000\n",
"Minibatch perplexity: 27.61\n",
"================================================================================\n",
"uo p v z s p b n c a m x c o s g l b i z n eed z p f jes b f p e i h \n",
"fb l c u k p nv kenoz teq its k y c r y g ieg m ttq u inx v b n a t g rem v \n",
"pl s m w k n x i e xey t t t cew q f a z e p qet p w aer f f y \n",
"fi g z ey c s y g o k y en xe s c i g o c v e a eee k i r f a o r j \n",
"an e i y q d o e l a s r c u e k h m zei o c o s f je o h d v t d j bey \n",
"================================================================================\n",
"Validation set perplexity: 79.23\n",
"Average loss at step 100: 2.824483 learning rate: 10.000000\n",
"Minibatch perplexity: 13.67\n",
"Validation set perplexity: 11.29\n",
"Average loss at step 200: 2.376306 learning rate: 10.000000\n",
"Minibatch perplexity: 10.26\n",
"Validation set perplexity: 9.64\n",
"Average loss at step 300: 2.265620 learning rate: 10.000000\n",
"Minibatch perplexity: 8.66\n",
"Validation set perplexity: 9.06\n",
"Average loss at step 400: 2.208451 learning rate: 10.000000\n",
"Minibatch perplexity: 8.11\n",
"Validation set perplexity: 8.87\n",
"Average loss at step 500: 2.220826 learning rate: 10.000000\n",
"Minibatch perplexity: 8.32\n",
"Validation set perplexity: 8.37\n",
"Average loss at step 600: 2.166269 learning rate: 10.000000\n",
"Minibatch perplexity: 8.75\n",
"Validation set perplexity: 8.27\n",
"Average loss at step 700: 2.157839 learning rate: 10.000000\n",
"Minibatch perplexity: 9.19\n",
"Validation set perplexity: 8.19\n",
"Average loss at step 800: 2.142931 learning rate: 10.000000\n",
"Minibatch perplexity: 9.06\n",
"Validation set perplexity: 8.22\n",
"Average loss at step 900: 2.146872 learning rate: 10.000000\n",
"Minibatch perplexity: 6.82\n",
"Validation set perplexity: 8.19\n",
"Average loss at step 1000: 2.097487 learning rate: 10.000000\n",
"Minibatch perplexity: 8.56\n",
"================================================================================\n",
"xkir taftkers dive zero zerally all abhlainteaaught or anatherryour caliwuse nin \n",
"es to two du ala glems in wentio the creapfall ude sweigwa argame capy looriritio\n",
"mir the eight popecerse to pageady ssion repaton it hostreses aterge amite tftaen\n",
" on wel elequasicit is the of eight s bek mruch the and pectiond eight rough the \n",
"ywric off mosa nom of three a recribaccuce u emeulcne eight whis ace of withe li\n",
"================================================================================\n",
"Validation set perplexity: 8.20\n",
"Average loss at step 1100: 2.076751 learning rate: 10.000000\n",
"Minibatch perplexity: 8.33\n",
"Validation set perplexity: 8.36\n",
"Average loss at step 1200: 2.110654 learning rate: 10.000000\n",
"Minibatch perplexity: 8.35\n",
"Validation set perplexity: 8.21\n",
"Average loss at step 1300: 2.101799 learning rate: 10.000000\n",
"Minibatch perplexity: 7.95\n",
"Validation set perplexity: 8.12\n",
"Average loss at step 1400: 2.081900 learning rate: 10.000000\n",
"Minibatch perplexity: 7.60\n",
"Validation set perplexity: 8.07\n",
"Average loss at step 1500: 2.092468 learning rate: 10.000000\n",
"Minibatch perplexity: 8.71\n",
"Validation set perplexity: 8.06\n",
"Average loss at step 1600: 2.067980 learning rate: 10.000000\n",
"Minibatch perplexity: 8.55\n",
"Validation set perplexity: 8.22\n",
"Average loss at step 1700: 2.087291 learning rate: 10.000000\n",
"Minibatch perplexity: 8.73\n",
"Validation set perplexity: 8.24\n",
"Average loss at step 1800: 2.072904 learning rate: 10.000000\n",
"Minibatch perplexity: 8.08\n",
"Validation set perplexity: 8.01\n",
"Average loss at step 1900: 2.069282 learning rate: 10.000000\n",
"Minibatch perplexity: 7.61\n",
"Validation set perplexity: 7.94\n",
"Average loss at step 2000: 2.080945 learning rate: 10.000000\n",
"Minibatch perplexity: 8.16\n",
"================================================================================\n",
"hcatatal one on tynts of the frenked one stencunct was esacter stoversahuld kqcs \n",
"mnive matene sentrans cast your the dephirar the ple empeme leach da forde a plal\n",
"fhneme s thee mandmy mews to memicaded seight zeroecate yeap its canited howpeigo\n",
"jnation mist heral mosee buth on bey irshow our the extere her the motle amrever \n",
"uy pledimenal hadpollase of te of the cuse mettata as the a use unity of is of a\n",
"================================================================================\n",
"Validation set perplexity: 7.99\n",
"Average loss at step 2100: 2.072666 learning rate: 10.000000\n",
"Minibatch perplexity: 7.60\n",
"Validation set perplexity: 8.21\n",
"Average loss at step 2200: 2.050317 learning rate: 10.000000\n",
"Minibatch perplexity: 8.46\n",
"Validation set perplexity: 8.02\n",
"Average loss at step 2300: 2.066610 learning rate: 10.000000\n",
"Minibatch perplexity: 7.70\n",
"Validation set perplexity: 7.92\n",
"Average loss at step 2400: 2.061169 learning rate: 10.000000\n",
"Minibatch perplexity: 8.92\n",
"Validation set perplexity: 8.10\n",
"Average loss at step 2500: 2.079461 learning rate: 10.000000\n",
"Minibatch perplexity: 8.37\n",
"Validation set perplexity: 8.14\n",
"Average loss at step 2600: 2.065772 learning rate: 10.000000\n",
"Minibatch perplexity: 7.20\n",
"Validation set perplexity: 8.02\n",
"Average loss at step 2700: 2.083565 learning rate: 10.000000\n",
"Minibatch perplexity: 7.62\n",
"Validation set perplexity: 7.93\n",
"Average loss at step 2800: 2.050979 learning rate: 10.000000\n",
"Minibatch perplexity: 7.45\n",
"Validation set perplexity: 7.74\n",
"Average loss at step 2900: 2.047830 learning rate: 10.000000\n",
"Minibatch perplexity: 7.46\n",
"Validation set perplexity: 7.94\n",
"Average loss at step 3000: 2.065596 learning rate: 10.000000\n",
"Minibatch perplexity: 8.26\n",
"================================================================================\n",
"krairton marchen and ampredoms brociareaan grestermr fyor conce six thro zero zer\n",
" famsean whiclabe of devenera polowhise is action mishumulal ivs the for eighhar \n",
"lve is carres mist in rocpshere tal exampury ususanciethe one by ire is of sox cs\n",
"hcriskition aras st horisling of and modix deven pout ags tix zero zero zero zero\n",
"xbentation native wor cithnhich and sternal bute by the sonnived watkal iftory an\n",
"================================================================================\n",
"Validation set perplexity: 7.87\n",
"Average loss at step 3100: 2.053678 learning rate: 10.000000\n",
"Minibatch perplexity: 7.91\n",
"Validation set perplexity: 7.96\n",
"Average loss at step 3200: 2.066419 learning rate: 10.000000\n",
"Minibatch perplexity: 7.86\n",
"Validation set perplexity: 7.91\n",
"Average loss at step 3300: 2.037368 learning rate: 10.000000\n",
"Minibatch perplexity: 7.39\n",
"Validation set perplexity: 7.88\n",
"Average loss at step 3400: 2.047022 learning rate: 10.000000\n",
"Minibatch perplexity: 7.19\n",
"Validation set perplexity: 8.01\n",
"Average loss at step 3500: 2.037646 learning rate: 10.000000\n",
"Minibatch perplexity: 8.57\n",
"Validation set perplexity: 8.05\n",
"Average loss at step 3600: 2.052229 learning rate: 10.000000\n",
"Minibatch perplexity: 8.03\n",
"Validation set perplexity: 7.77\n",
"Average loss at step 3700: 2.045109 learning rate: 10.000000\n",
"Minibatch perplexity: 7.19\n",
"Validation set perplexity: 7.84\n",
"Average loss at step 3800: 2.034123 learning rate: 10.000000\n",
"Minibatch perplexity: 7.39\n",
"Validation set perplexity: 7.79\n",
"Average loss at step 3900: 2.039331 learning rate: 10.000000\n",
"Minibatch perplexity: 8.23\n",
"Validation set perplexity: 7.75\n",
"Average loss at step 4000: 2.034395 learning rate: 10.000000\n",
"Minibatch perplexity: 7.75\n",
"================================================================================\n",
"johning paalsorianica of from iric foll learst an that thur thrree oparive swonos\n",
"cging dixm homitial a intresicar sr iss sadidi wontiorken the ping was wanto the \n",
"qxsultali weipaws eight severo and in incrowat dual wilitingeniway conbainkcdat d\n",
"w ald rover une gix nocsquampitina narimers orgide hat ght licreble press payer u\n",
"xzmr red of timply recepromer is weroped jass a nonal five reasasikisa vic to nin\n",
"================================================================================\n",
"Validation set perplexity: 7.87\n",
"Average loss at step 4100: 2.039861 learning rate: 10.000000\n",
"Minibatch perplexity: 7.56\n",
"Validation set perplexity: 8.02\n",
"Average loss at step 4200: 2.030850 learning rate: 10.000000\n",
"Minibatch perplexity: 7.58\n",
"Validation set perplexity: 7.97\n",
"Average loss at step 4300: 2.012156 learning rate: 10.000000\n",
"Minibatch perplexity: 7.46\n",
"Validation set perplexity: 7.92\n",
"Average loss at step 4400: 2.032554 learning rate: 10.000000\n",
"Minibatch perplexity: 7.97\n",
"Validation set perplexity: 7.99\n",
"Average loss at step 4500: 2.044585 learning rate: 10.000000\n",
"Minibatch perplexity: 7.63\n",
"Validation set perplexity: 8.04\n",
"Average loss at step 4600: 2.049579 learning rate: 10.000000\n",
"Minibatch perplexity: 7.88\n",
"Validation set perplexity: 7.91\n",
"Average loss at step 4700: 2.028848 learning rate: 10.000000\n",
"Minibatch perplexity: 7.97\n",
"Validation set perplexity: 7.86\n",
"Average loss at step 4800: 2.016727 learning rate: 10.000000\n",
"Minibatch perplexity: 8.78\n",
"Validation set perplexity: 7.93\n",
"Average loss at step 4900: 2.033456 learning rate: 10.000000\n",
"Minibatch perplexity: 7.64\n",
"Validation set perplexity: 8.19\n",
"Average loss at step 5000: 2.041955 learning rate: 1.000000\n",
"Minibatch perplexity: 7.90\n",
"================================================================================\n",
"zturn two zero one nine eight svallowelvanrn ardionayistrack three nine faall his\n",
"uzero leny six sycontrationsing of spolascounes comporrefadate frob eight thrisca\n",
"fe oterely the spepon by afc sayea one five a be of and ponss lyuling three piest\n",
"tjuva of lededs in jamemme which is crenke plimsor from dunds vollecric erace dis\n",
"his comling rallo ins he vamo and eopita if pive sude fa u plitecfriey a the mefa\n",
"================================================================================\n",
"Validation set perplexity: 7.95\n",
"Average loss at step 5100: 2.049792 learning rate: 1.000000\n",
"Minibatch perplexity: 7.88\n",
"Validation set perplexity: 7.95\n",
"Average loss at step 5200: 2.045247 learning rate: 1.000000\n",
"Minibatch perplexity: 7.62\n",
"Validation set perplexity: 7.84\n",
"Average loss at step 5300: 2.027454 learning rate: 1.000000\n",
"Minibatch perplexity: 6.98\n",
"Validation set perplexity: 7.80\n",
"Average loss at step 5400: 2.015190 learning rate: 1.000000\n",
"Minibatch perplexity: 7.22\n",
"Validation set perplexity: 7.82\n",
"Average loss at step 5500: 2.010666 learning rate: 1.000000\n",
"Minibatch perplexity: 6.88\n",
"Validation set perplexity: 7.83\n",
"Average loss at step 5600: 2.032782 learning rate: 1.000000\n",
"Minibatch perplexity: 8.98\n",
"Validation set perplexity: 7.82\n",
"Average loss at step 5700: 1.992788 learning rate: 1.000000\n",
"Minibatch perplexity: 6.38\n",
"Validation set perplexity: 7.86\n",
"Average loss at step 5800: 1.996959 learning rate: 1.000000\n",
"Minibatch perplexity: 7.57\n",
"Validation set perplexity: 7.82\n",
"Average loss at step 5900: 2.017644 learning rate: 1.000000\n",
"Minibatch perplexity: 7.46\n",
"Validation set perplexity: 7.78\n",
"Average loss at step 6000: 1.988973 learning rate: 1.000000\n",
"Minibatch perplexity: 6.87\n",
"================================================================================\n",
"fnoing heoned seven peattanion one prol one zero eight the und and ifson ofterns \n",
"wqold in emhrean ade deverke n of the three or osneen giour probirj hit dure of o\n",
"lxd id a the by hesnebsion tenconnek of three five seven fferai one five that mal\n",
"pcene five five timationiter the gint revent is capaa be lovnald two seangtheb it\n",
"bwa boling one nissle hithmat pah the prover with tralost theprecan the nof one n\n",
"================================================================================\n",
"Validation set perplexity: 7.75\n",
"Average loss at step 6100: 1.998669 learning rate: 1.000000\n",
"Minibatch perplexity: 7.04\n",
"Validation set perplexity: 7.76\n",
"Average loss at step 6200: 2.016346 learning rate: 1.000000\n",
"Minibatch perplexity: 7.01\n",
"Validation set perplexity: 7.82\n",
"Average loss at step 6300: 2.012220 learning rate: 1.000000\n",
"Minibatch perplexity: 7.59\n",
"Validation set perplexity: 7.81\n",
"Average loss at step 6400: 2.030418 learning rate: 1.000000\n",
"Minibatch perplexity: 8.16\n",
"Validation set perplexity: 7.80\n",
"Average loss at step 6500: 2.030910 learning rate: 1.000000\n",
"Minibatch perplexity: 7.42\n",
"Validation set perplexity: 7.80\n",
"Average loss at step 6600: 2.013080 learning rate: 1.000000\n",
"Minibatch perplexity: 7.97\n",
"Validation set perplexity: 7.77\n",
"Average loss at step 6700: 2.013069 learning rate: 1.000000\n",
"Minibatch perplexity: 7.76\n",
"Validation set perplexity: 7.74\n",
"Average loss at step 6800: 1.988502 learning rate: 1.000000\n",
"Minibatch perplexity: 6.48\n",
"Validation set perplexity: 7.75\n",
"Average loss at step 6900: 1.977013 learning rate: 1.000000\n",
"Minibatch perplexity: 7.64\n",
"Validation set perplexity: 7.70\n",
"Average loss at step 7000: 1.985839 learning rate: 1.000000\n",
"Minibatch perplexity: 7.26\n",
"================================================================================\n",
"planoters gian rinsade of vate hop thippulre wil of and perth if the kining of in\n",
"know co of thesi s precischerst of refaind as froguth swith oth of s titkh ins ge\n",
"which reading baschra clucichs knogrogoleator thow int and frembers is se gausitr\n",
"ybseed the eown handu liny offia inters refq s gis fsmital govinaters wor the man\n",
"b in dcasch one the eogrowerse wom the me and intil this imperno ducted chen rize\n",
"================================================================================\n",
"Validation set perplexity: 7.71\n"
]
}
],
"source": [
"import collections\n",
"num_steps = 7001\n",
"summary_frequency = 100\n",
"\n",
"valid_batches = BatchGenerator(valid_text, 1, 2)\n",
"\n",
"with tf.Session(graph=graph) as session:\n",
" tf.global_variables_initializer().run()\n",
" print('Initialized')\n",
" mean_loss = 0\n",
" for step in range(num_steps):\n",
" batches = train_batches.next()\n",
" feed_dict = dict()\n",
" for i in range(num_unrollings + 1):\n",
" feed_dict[train_data[i]] = batches[i]\n",
" _, l, predictions, lr = session.run(\n",
" [optimizer, loss, train_prediction, learning_rate], feed_dict=feed_dict)\n",
" mean_loss += l\n",
" if step % summary_frequency == 0:\n",
" if step > 0:\n",
" mean_loss = mean_loss / summary_frequency\n",
" # The mean loss is an estimate of the loss over the last few batches.\n",
" print(\n",
" 'Average loss at step %d: %f learning rate: %f' % (step, mean_loss, lr))\n",
" mean_loss = 0\n",
" labels = np.concatenate(list(batches)[2:]) ## because of bigrams\n",
" print('Minibatch perplexity: %.2f' % float(\n",
" np.exp(logprob(predictions, labels))))\n",
" if step % (summary_frequency * 10) == 0:\n",
" \n",
" # Generate some samples.\n",
" print('=' * 80)\n",
" for _ in range(5):\n",
"# feed = (sample(random_distribution()), sample(random_distribution()))\n",
" feed = collections.deque(maxlen=2)\n",
" for _ in range(2):\n",
" feed.append(sample(random_distribution()))\n",
" sentence = characters(feed[0])[0]+characters(feed[1])[0]\n",
" reset_sample_state.run()\n",
" for _ in range(79):\n",
" prediction = sample_prediction.eval({sample_input[0]: feed[0], sample_input[1]: feed[1]})\n",
" feed.append(sample(prediction)) #the first value will be replaced by the next\n",
" sentence += characters(feed[1])[0]\n",
" print(sentence)\n",
" print('=' * 80)\n",
" # Measure validation set perplexity.\n",
" reset_sample_state.run()\n",
" valid_logprob = 0\n",
" for _ in range(valid_size):\n",
" b = valid_batches.next()\n",
" predictions = sample_prediction.eval({sample_input[0]: b[0], sample_input[1]: b[1]})\n",
" valid_logprob = valid_logprob + logprob(predictions, b[2])\n",
" print('Validation set perplexity: %.2f' % float(np.exp(\n",
" valid_logprob / valid_size)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"deletable": true,
"editable": true,
"id": "Y5tapX3kpcqZ"
},
"source": [
"---\n",
"Problem 3\n",
"---------\n",
"\n",
"(difficult!)\n",
"\n",
"Write a sequence-to-sequence LSTM which mirrors all the words in a sentence. For example, if your input is:\n",
"\n",
" the quick brown fox\n",
" \n",
"the model should attempt to output:\n",
"\n",
" eht kciuq nworb xof\n",
" \n",
"Refer to the lecture on how to put together a sequence-to-sequence model, as well as [this article](http://arxiv.org/abs/1409.3215) for best practices.\n",
"\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"default_view": {},
"name": "6_lstm.ipynb",
"provenance": [],
"version": "0.3.2",
"views": {}
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment