Skip to content

Instantly share code, notes, and snippets.

@Kautenja
Last active September 4, 2017 15:35
Show Gist options
  • Save Kautenja/8161314e6563c8581dbd52ab3c53981f to your computer and use it in GitHub Desktop.
Save Kautenja/8161314e6563c8581dbd52ab3c53981f to your computer and use it in GitHub Desktop.
Solving the XOR function using a deep net in `keras`.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Solving XOR With A Deep Net\n",
"\n",
"$$\\veebar(x, y) = x'y + xy'$$"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style>\n",
" .dataframe thead tr:only-child th {\n",
" text-align: right;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: left;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>$x$</th>\n",
" <th>$y$</th>\n",
" <th>XOR</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" $x$ $y$ XOR\n",
"0 0 0 0\n",
"1 0 1 1\n",
"2 1 0 1\n",
"3 1 1 0"
]
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from pandas import DataFrame\n",
"# generate a design matrix for XOR\n",
"truthtable = [[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 0]]\n",
"columns = ['$x$', '$y$', 'XOR']\n",
"df = DataFrame(truthtable, columns=columns)\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"X = df[['$x$', '$y$']]\n",
"Y = df[['XOR']]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Network Design\n",
"\n",
"![Network Design](network.png)"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# reproducable random number seeds\n",
"## keras\n",
"from numpy.random import seed\n",
"seed(100)\n",
"## TensorFlow\n",
"from tensorflow import set_random_seed\n",
"set_random_seed(100)"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"from keras.models import Sequential\n",
"from keras.layers import Dense"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# build the network in the graph above\n",
"model = Sequential()\n",
"model.name = 'XOR'\n",
"model.add(Dense(2, input_dim=2, activation='tanh'))\n",
"model.add(Dense(2, activation='tanh'))\n",
"model.add(Dense(1, activation='sigmoid'))"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"dense_22 (Dense) (None, 2) 6 \n",
"_________________________________________________________________\n",
"dense_23 (Dense) (None, 2) 6 \n",
"_________________________________________________________________\n",
"dense_24 (Dense) (None, 1) 3 \n",
"=================================================================\n",
"Total params: 15\n",
"Trainable params: 15\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
"source": [
"# fit the model, 4600 epochs was the minimum needed to reach a loss of ~0.020\n",
"# turning off logging speeds the process up a _lot_.\n",
"_ = model.fit(X.values, Y.values, epochs=4600, batch_size=4, verbose=False)"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4/4 [==============================] - 0s\n"
]
},
{
"data": {
"text/plain": [
"[('loss', 0.0084844892844557762), ('acc', 1.0)]"
]
},
"execution_count": 91,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# evaluate the score and output some metrics\n",
"scores = model.evaluate(X.values, Y.values)\n",
"[(model.metrics_names[i], scores[i]) for i in range(0, len(scores))]"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.00854679],\n",
" [ 0.99831319],\n",
" [ 0.98393595],\n",
" [ 0.00744388]], dtype=float32)"
]
},
"execution_count": 92,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# manually check the predictions for all inputs\n",
"model.predict(X.values)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment