Skip to content

Instantly share code, notes, and snippets.

@heisters
Created July 13, 2015 17:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save heisters/0964024cabbde2bb21df to your computer and use it in GitHub Desktop.
Save heisters/0964024cabbde2bb21df to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from __future__ import print_function\n",
"from keras.models import Sequential\n",
"from keras.layers.core import Dense, Activation, Dropout\n",
"from keras.layers.recurrent import LSTM\n",
"from keras.datasets.data_utils import get_file\n",
"import numpy as np\n",
"import random, sys"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"corpus length: 600901\n"
]
}
],
"source": [
"path = get_file('nietzsche.txt', origin=\"https://s3.amazonaws.com/text-datasets/nietzsche.txt\")\n",
"text = open(path).read().lower()\n",
"print('corpus length:', len(text))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"total chars: 59\n",
"nb sequences: 200294\n"
]
}
],
"source": [
"chars = set(text)\n",
"print('total chars:', len(chars))\n",
"char_indices = dict((c, i) for i, c in enumerate(chars))\n",
"indices_char = dict((i, c) for i, c in enumerate(chars))\n",
"\n",
"# cut the text in semi-redundant sequences of maxlen characters\n",
"maxlen = 20\n",
"step = 3\n",
"sentences = []\n",
"next_chars = []\n",
"for i in range(0, len(text) - maxlen, step):\n",
" sentences.append(text[i : i + maxlen])\n",
" next_chars.append(text[i + maxlen])\n",
"print('nb sequences:', len(sentences))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Vectorization...\n",
"Done.\n"
]
}
],
"source": [
"print('Vectorization...')\n",
"X = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)\n",
"y = np.zeros((len(sentences), len(chars)), dtype=np.bool)\n",
"for i, sentence in enumerate(sentences):\n",
" for t, char in enumerate(sentence):\n",
" X[i, t, char_indices[char]] = 1\n",
" y[i, char_indices[next_chars[i]]] = 1\n",
"print(\"Done.\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Build model...\n"
]
}
],
"source": [
"# build the model: 2 stacked LSTM\n",
"print('Build model...')\n",
"model = Sequential()\n",
"model.add(LSTM(len(chars), 512, return_sequences=True))\n",
"model.add(Dropout(0.2))\n",
"model.add(LSTM(512, 512, return_sequences=False))\n",
"model.add(Dropout(0.2))\n",
"model.add(Dense(512, len(chars)))\n",
"model.add(Activation('softmax'))\n",
"\n",
"model.compile(loss='categorical_crossentropy', optimizer='rmsprop')\n",
"print(\"Done.\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# helper function to sample an index from a probability array\n",
"def sample(a, diversity=0.75):\n",
" if random.random() > diversity:\n",
" return np.argmax(a)\n",
" while 1:\n",
" i = random.randint(0, len(a)-1)\n",
" if a[i] > random.random():\n",
" return i"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"--------------------------------------------------\n",
"Iteration 1\n",
"Epoch 0\n",
"200294/200294 [==============================] - 734s - loss: 3.0183 \n",
"\n",
"----- diversity: 0.2\n",
"----- Generating with seed: \"at which is termed \"\"\n",
"at which is termed \"oe ee oe ae to ae to ae toe to ae te ab ael ae to ae to ae ta ae to ae ties ae to ae ton ae to ae tro he ae to ad coe ae to ae ne ae to ae to ae to ae�k ae to ae to ae to ace ae te ae to an dei ai ae to as re ae ae tue ae tose to ae tof ae to ae to ae to ae to ae to ai ar oe ae hed ae to oe ve toe aptrei ae to aerg fe ae iy ar ae te ys ae ae to ae ua aeg ae torcte aeu ae de aes -te ae to ae to ae \n",
"\n",
"----- diversity: 0.4\n",
"----- Generating with seed: \"at which is termed \"\"\n",
"at which is termed \"o ne aif aa toe to ae tie to te oea ae to ak aeo ae oe ae ae to a\n",
"e as oer otte aertty aee tom af fe ae ae to ae coa leg tweh am ae to ae thf ae to oen ae ae poe ates an aa oa te ae to ae ton ae ta een ao at a ar aee to aeel to ae tor fe ati ae to aa to ae to aet aen ae to re se ad aedr ae to uer af ae toa tor an tte ae rar rd ace ae to ad ie an ae toc poe als al ieu io aes,ts,y ae, ho l\"e ae tpe\n",
"\n",
"----- diversity: 0.6\n",
"----- Generating with seed: \"at which is termed \"\"\n",
"at which is termed \"ou datyrtte as adtte sry an tia te ae eee an as ated a- le nd an mte ae be ooutte ae wital ae \n",
"os a oar aeb hdrish ae touoir as ai aev ae to rhtec\n",
"te tda av \"mi ae, oor fse cro aos taidc ao dsalgl eurd oe ton ee wre of as sh ld aa as ae\n",
" aa to ai ai atlr! to iiad to aev ay ae ty ab aengr ym be; at: ii te aee ac oeb se ae ae, ae re oelstiw lenen ae, ogr oe fe ao ifenan oe opfe te ar aayik yn aers\n",
"\n",
"----- diversity: 0.8\n",
"----- Generating with seed: \"at which is termed \"\"\n",
"at which is termed \"ntvxt,\n",
"d te eele\n",
" sans oips,wn de asc-e hil,t- nhe toy eh aet ae ,d ati a he eettrhrotfe to mf anl inl\n",
"rykl ioy ours aeegup aeir- fii aef ro am mse ahmur- rkets is lbm tfginr�e fotoenc' aes,isgf tale aase tot\n",
"e tain aei to on haed fifm se ai sitfalaao., ti he iie ued rvedpcpusdru,wtmeurn fe eeuiegivd chelo t oh we co nia rio \"a ae uoe et\n",
"ermetactoy oifrjst aan lile nirun flo neli bthcanst p oa t\n",
"\n",
"--------------------------------------------------\n",
"Iteration 2\n",
"Epoch 0\n",
" 6400/200294 [..............................] - ETA: 808s - loss: 2.8582"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-12-737ec135ece6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-'\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m50\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Iteration'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0miteration\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnb_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mstart_index\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mmaxlen\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/ian/Projects/wordgraph/python/lib/python2.7/site-packages/keras/models.pyc\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, batch_size, nb_epoch, verbose, callbacks, validation_split, validation_data, shuffle, show_accuracy)\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0mbatch_logs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'accuracy'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0macc\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mins\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 188\u001b[0m \u001b[0mbatch_logs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'loss'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/ian/Projects/wordgraph/python/lib/python2.7/site-packages/theano/compile/function_module.pyc\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 593\u001b[0m \u001b[0mt0_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 594\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 595\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 596\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 597\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'position_of_error'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/ian/Projects/wordgraph/python/lib/python2.7/site-packages/theano/scan_module/scan_op.pyc\u001b[0m in \u001b[0;36mrval\u001b[0;34m(p, i, o, n, allow_gc)\u001b[0m\n\u001b[1;32m 670\u001b[0m def rval(p=p, i=node_input_storage, o=node_output_storage, n=node,\n\u001b[1;32m 671\u001b[0m allow_gc=allow_gc):\n\u001b[0;32m--> 672\u001b[0;31m \u001b[0mr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 673\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mo\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 674\u001b[0m \u001b[0mcompute_map\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/ian/Projects/wordgraph/python/lib/python2.7/site-packages/theano/scan_module/scan_op.pyc\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(node, args, outs)\u001b[0m\n\u001b[1;32m 659\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 660\u001b[0m \u001b[0mouts\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 661\u001b[0;31m self, node)\n\u001b[0m\u001b[1;32m 662\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mImportError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtheano\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgof\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMissingGXX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 663\u001b[0m \u001b[0mp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"for iteration in range(1, 60):\n",
" print()\n",
" print('-' * 50)\n",
" print('Iteration', iteration)\n",
" model.fit(X, y, batch_size=128, nb_epoch=1)\n",
"\n",
" start_index = random.randint(0, len(text) - maxlen - 1)\n",
"\n",
" for diversity in [0.2, 0.4, 0.6, 0.8]:\n",
" print()\n",
" print('----- diversity:', diversity)\n",
"\n",
" generated = ''\n",
" sentence = text[start_index : start_index + maxlen]\n",
" generated += sentence\n",
" print('----- Generating with seed: \"' + sentence + '\"')\n",
" sys.stdout.write(generated)\n",
"\n",
" for iteration in range(400):\n",
" x = np.zeros((1, maxlen, len(chars)))\n",
" for t, char in enumerate(sentence):\n",
" x[0, t, char_indices[char]] = 1.\n",
"\n",
" preds = model.predict(x, verbose=0)[0]\n",
" next_index = sample(preds, diversity)\n",
" next_char = indices_char[next_index]\n",
"\n",
" generated += next_char\n",
" sentence = sentence[1:] + next_char\n",
"\n",
" sys.stdout.write(next_char)\n",
" sys.stdout.flush()\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment