Skip to content

Instantly share code, notes, and snippets.

@zrbecker
Created July 4, 2017 03:12
Show Gist options
  • Save zrbecker/6173ac01ed30be4eea9cc96e21f4896f to your computer and use it in GitHub Desktop.
Save zrbecker/6173ac01ed30be4eea9cc96e21f4896f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training a Neural Network to Compute 'XOR' in scikit-learn\n",
"First we import some libraries."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"import sklearn.neural_network"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are only four cases. Certainly it seems like overkill to train a neural network to compute such a simple function, but since it is simple, it should be fairly easy to do (so I thought)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"xs = np.array([\n",
" 0, 0,\n",
" 0, 1,\n",
" 1, 0,\n",
" 1, 1\n",
"]).reshape(4, 2)\n",
"\n",
"ys = np.array([0, 1, 1, 0]).reshape(4,)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I found a way to handcode a neural network for 'XOR' with 4 neurons in the first hidden layer, and 2 neurons in the second hidden layer. I want to see if scikit-learn can find me that neural network or another like it."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MLPClassifier(activation='logistic', alpha=0.0001, batch_size='auto',\n",
" beta_1=0.9, beta_2=0.999, early_stopping=False, epsilon=1e-08,\n",
" hidden_layer_sizes=(4, 2), learning_rate='constant',\n",
" learning_rate_init=0.001, max_iter=10000, momentum=0.9,\n",
" nesterovs_momentum=True, power_t=0.5, random_state=None,\n",
" shuffle=True, solver='adam', tol=0.0001, validation_fraction=0.1,\n",
" verbose=False, warm_start=False)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = sklearn.neural_network.MLPClassifier(\n",
" activation='logistic', max_iter=10000, hidden_layer_sizes=(4,2))\n",
"model.fit(xs, ys)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, it does not seem like scikit-learn's MLPClassifier can converge to a correct solution."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"score: 0.5\n",
"predictions: [0 0 0 0]\n",
"expected: [0 1 1 0]\n"
]
}
],
"source": [
"print('score:', model.score(xs, ys))\n",
"print('predictions:', model.predict(xs))\n",
"print('expected:', np.array([0, 1, 1, 0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here is a manually configured neural network that will compute 'XOR'. Why can't scikit learn find this model?"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"1\n",
"1\n",
"0\n"
]
}
],
"source": [
"theta1 = np.array([\n",
" 20, 0,\n",
" -20, 0,\n",
" 0, 20,\n",
" 0, -20,\n",
"]).reshape(4, 2)\n",
"\n",
"theta2 = np.array([\n",
" 20, 0, 0, 20,\n",
" 0, 20, 20, 0,\n",
"]).reshape(2, 4)\n",
"\n",
"theta3 = np.array([20, 20,]).reshape(1, 2)\n",
"\n",
"beta1 = np.array([-10, 10, -10, 10])\n",
"beta2 = np.array([-30, -30])\n",
"beta3 = np.array([-10])\n",
"\n",
"def sigmoid(z):\n",
" return 1 / (1 + np.exp(-z))\n",
"\n",
"def evaluate(x):\n",
" a1 = sigmoid(theta1 @ x + beta1)\n",
" a2 = sigmoid(theta2 @ a1 + beta2)\n",
" a3 = sigmoid(theta3 @ a2 + beta3)\n",
" return 1 if a3[0] >= 0.5 else 0\n",
"\n",
"print(evaluate([0, 0]))\n",
"print(evaluate([1, 0]))\n",
"print(evaluate([0, 1]))\n",
"print(evaluate([1, 1]))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment