Created
April 4, 2016 19:05
-
-
Save rjpower/13ae9db09a29d4e5ee0844d399500f3b to your computer and use it in GitHub Desktop.
Grid search over a number of topologies.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using TensorFlow backend.\n" | |
] | |
} | |
], | |
"source": [ | |
"import json\n", | |
"import numpy as np\n", | |
"\n", | |
"from keras.models import Sequential, Graph\n", | |
"from keras.preprocessing import text\n", | |
"from keras.preprocessing.text import Tokenizer\n", | |
"from keras.layers.embeddings import Embedding\n", | |
"from keras.layers.core import Dense, Activation, Dropout, TimeDistributedDense\n", | |
"from keras.layers.recurrent import LSTM\n", | |
"\n", | |
"SEQ_LENGTH = 64\n", | |
"BATCH_SIZE = 16\n", | |
"HIDDEN_SIZE = 4\n", | |
"NUM_WORDS = 10000\n", | |
"\n", | |
"examples = [json.loads(line) for line in open('/tmp/topic-sample.json').readlines() if line]\n", | |
"abstracts = [ e['abstract'].encode('utf8') for e in examples ]\n", | |
"tokenizer = Tokenizer(nb_words=NUM_WORDS, split=' ')\n", | |
"tokenizer.fit_on_texts(abstracts)\n", | |
"\n", | |
"vocab_size = tokenizer.nb_words" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"def build_model(layers):\n", | |
" model = Sequential()\n", | |
" for i, l in enumerate(layers):\n", | |
" model.add(Dense(l, input_dim=NUM_WORDS if i == 0 else None))\n", | |
" model.add(Activation('softmax'))\n", | |
" model.compile(loss='categorical_crossentropy', optimizer='adagrad')\n", | |
" return model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"X = tokenizer.texts_to_matrix(abstracts)\n", | |
"Y = np.zeros((len(abstracts), 2))\n", | |
"for i, ex in enumerate(examples):\n", | |
" if ex['topic'] == 'compsci': Y[i,0] = 1\n", | |
" else: Y[i, 1] = 1\n", | |
" \n", | |
"shuffle_idx = np.arange(len(Y))\n", | |
"np.random.shuffle(shuffle_idx)\n", | |
"X = X[shuffle_idx]\n", | |
"Y = Y[shuffle_idx]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"results = {}\n", | |
"\n", | |
"for a in [0, 4, 8, 16, 32, 64]:\n", | |
" for b in [0, 4, 8, 16, 32, 64]:\n", | |
" layers = [a, b, 2]\n", | |
" layers = [l for l in layers if l > 0]\n", | |
" model = build_model(layers)\n", | |
" val_acc = model.fit(X, Y, batch_size=BATCH_SIZE, show_accuracy=True,\n", | |
" verbose=0, nb_epoch=1, validation_split=0.1).history['val_acc']\n", | |
" print layers, val_acc[0]\n", | |
" results[(a, b)] = val_acc" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[((32, 4), 0.84562996594778661),\n", | |
" ((64, 32), 0.84903518728717364),\n", | |
" ((0, 16), 0.85357548240635639),\n", | |
" ((0, 8), 0.85584562996594782),\n", | |
" ((16, 32), 0.85698070374574342),\n", | |
" ((32, 32), 0.85811577752553914),\n", | |
" ((8, 16), 0.85925085130533485),\n", | |
" ((16, 4), 0.85925085130533485),\n", | |
" ((4, 64), 0.85925085130533485),\n", | |
" ((64, 64), 0.86038592508513057),\n", | |
" ((4, 32), 0.86038592508513057),\n", | |
" ((0, 4), 0.86152099886492617),\n", | |
" ((32, 16), 0.86265607264472188),\n", | |
" ((4, 8), 0.86265607264472188),\n", | |
" ((32, 0), 0.86265607264472188),\n", | |
" ((0, 32), 0.8637911464245176),\n", | |
" ((0, 0), 0.8637911464245176),\n", | |
" ((16, 8), 0.8637911464245176),\n", | |
" ((8, 32), 0.86492622020431331),\n", | |
" ((16, 16), 0.86492622020431331),\n", | |
" ((4, 0), 0.86606129398410892),\n", | |
" ((64, 16), 0.86606129398410892),\n", | |
" ((16, 64), 0.86606129398410892),\n", | |
" ((8, 64), 0.86606129398410892),\n", | |
" ((4, 16), 0.86606129398410892),\n", | |
" ((8, 8), 0.86606129398410892),\n", | |
" ((32, 64), 0.86719636776390463),\n", | |
" ((4, 4), 0.86719636776390463),\n", | |
" ((64, 0), 0.86833144154370034),\n", | |
" ((64, 8), 0.86833144154370034),\n", | |
" ((64, 4), 0.86833144154370034),\n", | |
" ((8, 4), 0.86946651532349606),\n", | |
" ((8, 0), 0.87060158910329166),\n", | |
" ((0, 64), 0.87287173666288309),\n", | |
" ((32, 8), 0.87287173666288309),\n", | |
" ((16, 0), 0.88195232690124858)]" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sorted(results.items(), key=lambda kv: kv[1])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment