Skip to content

Instantly share code, notes, and snippets.

@niazangels
Created May 27, 2017 07:44
Show Gist options
  • Save niazangels/372248eb0a5aa8163bffe200f76f67a5 to your computer and use it in GitHub Desktop.
Save niazangels/372248eb0a5aa8163bffe200f76f67a5 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Lesson 6"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from theano.sandbox import cuda\n",
"cuda.use('gpu0')"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using Theano backend.\n"
]
}
],
"source": [
"%matplotlib inline\n",
"import utils; reload(utils)\n",
"from utils import *\n",
"from __future__ import division, print_function"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"corpus length: 600901\n"
]
}
],
"source": [
"path = get_file('nietzsche.txt', origin=\"https://s3.amazonaws.com/text-datasets/nietzsche.txt\")\n",
"# path = get_file('sherlock.txt', origin=\"https://sherlock-holm.es/stories/plain-text/cano.txt\")\n",
"\n",
"text = open(path).read()\n",
"\n",
"print('corpus length:', len(text))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"total chars: 86\n"
]
}
],
"source": [
"chars = sorted(list(set(text)))\n",
"vocab_size = len(chars)+1\n",
"print('total chars:', vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"chars.insert(0, \"\\0\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"'\\n !\"\\'(),-.0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxyz'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"''.join(chars[1:-6])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"char_indices = dict((c, i) for i, c in enumerate(chars))\n",
"indices_char = dict((i, c) for i, c in enumerate(chars))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"idx = [char_indices[c] for c in text]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"'PREFACE\\n\\n\\nSUPPOSING that Truth is a woman--what then? Is there not ground\\nfor suspecting that all philosophers, in so far as they have been\\ndogmatists, have failed to understand women--that the terrib'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text[:200]"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# 3 Char Model"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"skip = 3\n",
"\n",
"c1_data = [idx[i+0] for i in range(0, len(idx)-1-skip, skip)]\n",
"c2_data = [idx[i+1] for i in range(0, len(idx)-1-skip, skip)]\n",
"c3_data = [idx[i+2] for i in range(0, len(idx)-1-skip, skip)]\n",
"c4_data = [idx[i+3] for i in range(0, len(idx)-1-skip, skip)]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"(200299,)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x1 = np.stack(c1_data)\n",
"x1.shape"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"(200299,)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Alternatively\n",
"xx1 = np.array(c1_data)\n",
"xx1.shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"<br>Our inputs"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"x1 = np.stack(c1_data)\n",
"x2 = np.stack(c2_data)\n",
"x3 = np.stack(c3_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"And output"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"y = np.stack(c4_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"Let's check them out"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"(array([40, 30, 29, 1]), array([42, 25, 1, 43]), array([29, 27, 1, 45]))"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x1[:4], x2[:4], x3[:4]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"array([30, 29, 1, 40])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y[:4]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"((200299,), (200299,), (200299,), (200299,))"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x1.shape, x2.shape, x3.shape, y.shape"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"n_fac = 42"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"def embedding_input(name, n_in, n_out):\n",
" inp = Input(shape=(1,), dtype='int64', name=name)\n",
" emb = Embedding(n_in, n_out, input_length=1)(inp)\n",
" return inp, Flatten()(emb)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"c1_in, c1 = embedding_input('c1', vocab_size, n_fac)\n",
"c2_in, c2 = embedding_input('c2', vocab_size, n_fac)\n",
"c3_in, c3 = embedding_input('c3', vocab_size, n_fac)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"### 3 Char Model"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"n_hidden = 256"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"dense_in = Dense(n_hidden, activation='relu')"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"dense_hidden = Dense(n_hidden, activation='relu')\n",
"# dense_hidden = Dense(n_hidden, activation='tanh')"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"dense_out = Dense(vocab_size, activation='softmax')"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"c1_hidden = dense_in(c1)\n",
"hidden_2 = dense_hidden(c1_hidden)\n",
"\n",
"c2_dense = dense_in(c2)\n",
"c2_hidden = merge([c2_dense, hidden_2])\n",
"hidden_3 = dense_hidden(c2_hidden)\n",
"\n",
"c3_dense = dense_in(c3)\n",
"c3_hidden = merge([c3_dense, hidden_2])\n",
"\n",
"c4_out = dense_out(c3_hidden)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"model = Model([c1_in, c2_in, c3_in], c4_out)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam())"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"# model.optimizer.lr=0.1 # Took 37s - yielding a whopping loss: 11.4472\n",
"model.optimizer.lr=0.00001"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/4\n",
"26s - loss: 3.7297\n",
"Epoch 2/4\n",
"26s - loss: 3.0560\n",
"Epoch 3/4\n",
"27s - loss: 2.9648\n",
"Epoch 4/4\n",
"25s - loss: 2.8690\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fe71a428690>"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit([x1, x2, x3], y, batch_size=64, nb_epoch=4, verbose=2)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"def get_next(inp):\n",
" idxs = [char_indices[c] for c in inp]\n",
" arrs = [np.stack([i]) for i in idxs]\n",
" p = model.predict(arrs)\n",
" i = np.argmax(p)\n",
" return chars[i]"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"'n'"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_next('phi')"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"' '"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_next('the')"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"'e'"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_next(' th')"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"Let's try to show the top 10 predictions and their probabilities."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"def get_next_prob(inp):\n",
" idxs = [char_indices[c] for c in inp]\n",
" arrs = [np.array(i)[np.newaxis] for i in idxs]\n",
"\n",
" p = model.predict(arrs)\n",
" p_flat = np.squeeze(p)\n",
"\n",
" p_sorted = np.argsort(p)[0][::-1]\n",
" p_sorted_flat = np.squeeze(p_sorted)\n",
" \n",
"# i = np.argmax(p)\n",
" top_preds = list(np.array(chars)[p_sorted_flat][:10])\n",
" top_probs = list(np.array(p_flat)[p_sorted_flat][:10])\n",
" return (top_preds, top_probs)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"(['n', 't', 's', 'o', 'r', 'l', 'a', ' ', 'e', 'i'],\n",
" [0.15723342,\n",
" 0.11152012,\n",
" 0.097292215,\n",
" 0.060666647,\n",
" 0.057199709,\n",
" 0.053245645,\n",
" 0.047986366,\n",
" 0.04335808,\n",
" 0.042990349,\n",
" 0.039894816])"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_next_prob('phi')"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"(['e', ' ', 'i', 'a', 'o', 'h', 'r', 's', 't', 'l'],\n",
" [0.3403981,\n",
" 0.18252306,\n",
" 0.065477088,\n",
" 0.051549714,\n",
" 0.037764899,\n",
" 0.034287382,\n",
" 0.033996776,\n",
" 0.030441631,\n",
" 0.027145658,\n",
" 0.023002297])"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_next_prob(' th')"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"(['n', 't', ' ', 'r', 's', 'o', 'a', 'l', 'e', 'i'],\n",
" [0.12739956,\n",
" 0.078190982,\n",
" 0.077934235,\n",
" 0.07212282,\n",
" 0.070048511,\n",
" 0.050404333,\n",
" 0.050341681,\n",
" 0.045512386,\n",
" 0.038728461,\n",
" 0.036639702])"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_next_prob('nno')"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"Let's do a running prediction for 500 characters"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"def running_prediction(seed, nb_chars=500):\n",
" running_pred = seed\n",
" for i in range(nb_chars):\n",
" pred = get_next(seed)\n",
" running_pred += pred\n",
" seed = seed[1:] + pred\n",
" return running_pred\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"' can the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the th'"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"running_prediction(' ca')"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"'esthe the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the t'"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"running_prediction('est')"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"'care the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the th'"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"running_prediction('car')"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/4\n",
"20s - loss: 2.7869\n",
"Epoch 2/4\n",
"26s - loss: 2.7278\n",
"Epoch 3/4\n",
"23s - loss: 2.6866\n",
"Epoch 4/4\n",
"27s - loss: 2.6565\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fe707a36850>"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit([x1, x2, x3], y, batch_size=64, nb_epoch=4, verbose=2)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/4\n",
"22s - loss: 2.6331\n",
"Epoch 2/4\n",
"20s - loss: 2.6143\n",
"Epoch 3/4\n",
"20s - loss: 2.5985\n",
"Epoch 4/4\n",
"22s - loss: 2.5850\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fe707a36f10>"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.optimizer.lr = 0.01\n",
"model.fit([x1, x2, x3], y, batch_size=64, nb_epoch=4, verbose=2)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"'care the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the th'"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"running_prediction('car')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment