Created
July 4, 2017 03:12
-
-
Save zrbecker/6173ac01ed30be4eea9cc96e21f4896f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Training a Neural Network to Compute 'XOR' in scikit-learn\n", | |
"First we import some libraries." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import sklearn.neural_network" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"There are only four cases. Certainly it seems like overkill to train a neural network to compute such a simple function, but since it is simple, it should be fairly easy to do (so I thought)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"xs = np.array([\n", | |
" 0, 0,\n", | |
" 0, 1,\n", | |
" 1, 0,\n", | |
" 1, 1\n", | |
"]).reshape(4, 2)\n", | |
"\n", | |
"ys = np.array([0, 1, 1, 0]).reshape(4,)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"I found a way to handcode a neural network for 'XOR' with 4 neurons in the first hidden layer, and 2 neurons in the second hidden layer. I want to see if scikit-learn can find me that neural network or another like it." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"MLPClassifier(activation='logistic', alpha=0.0001, batch_size='auto',\n", | |
" beta_1=0.9, beta_2=0.999, early_stopping=False, epsilon=1e-08,\n", | |
" hidden_layer_sizes=(4, 2), learning_rate='constant',\n", | |
" learning_rate_init=0.001, max_iter=10000, momentum=0.9,\n", | |
" nesterovs_momentum=True, power_t=0.5, random_state=None,\n", | |
" shuffle=True, solver='adam', tol=0.0001, validation_fraction=0.1,\n", | |
" verbose=False, warm_start=False)" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = sklearn.neural_network.MLPClassifier(\n", | |
" activation='logistic', max_iter=10000, hidden_layer_sizes=(4,2))\n", | |
"model.fit(xs, ys)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"However, it does not seem like scikit-learn's MLPClassifier can converge to a correct solution." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"score: 0.5\n", | |
"predictions: [0 0 0 0]\n", | |
"expected: [0 1 1 0]\n" | |
] | |
} | |
], | |
"source": [ | |
"print('score:', model.score(xs, ys))\n", | |
"print('predictions:', model.predict(xs))\n", | |
"print('expected:', np.array([0, 1, 1, 0]))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Here is a manually configured neural network that will compute 'XOR'. Why can't scikit learn find this model?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0\n", | |
"1\n", | |
"1\n", | |
"0\n" | |
] | |
} | |
], | |
"source": [ | |
"theta1 = np.array([\n", | |
" 20, 0,\n", | |
" -20, 0,\n", | |
" 0, 20,\n", | |
" 0, -20,\n", | |
"]).reshape(4, 2)\n", | |
"\n", | |
"theta2 = np.array([\n", | |
" 20, 0, 0, 20,\n", | |
" 0, 20, 20, 0,\n", | |
"]).reshape(2, 4)\n", | |
"\n", | |
"theta3 = np.array([20, 20,]).reshape(1, 2)\n", | |
"\n", | |
"beta1 = np.array([-10, 10, -10, 10])\n", | |
"beta2 = np.array([-30, -30])\n", | |
"beta3 = np.array([-10])\n", | |
"\n", | |
"def sigmoid(z):\n", | |
" return 1 / (1 + np.exp(-z))\n", | |
"\n", | |
"def evaluate(x):\n", | |
" a1 = sigmoid(theta1 @ x + beta1)\n", | |
" a2 = sigmoid(theta2 @ a1 + beta2)\n", | |
" a3 = sigmoid(theta3 @ a2 + beta3)\n", | |
" return 1 if a3[0] >= 0.5 else 0\n", | |
"\n", | |
"print(evaluate([0, 0]))\n", | |
"print(evaluate([1, 0]))\n", | |
"print(evaluate([0, 1]))\n", | |
"print(evaluate([1, 1]))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment