Created
November 8, 2017 19:03
-
-
Save fabianp/50597e6c629ca4d6ec1781292fcdf60b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Populating the interactive namespace from numpy and matplotlib\n" | |
] | |
} | |
], | |
"source": [ | |
"%pylab inline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# suppose we want to minimize the quadratic loss function\n", | |
"# 1/n ||A x - b||^2\n", | |
"n_samples, n_features = 100, 10\n", | |
"A = np.random.randn(n_samples, n_features)\n", | |
"b = np.random.randn(n_samples)\n", | |
"def partial_grad(x):\n", | |
" i = np.random.randint(n_samples)\n", | |
" return i, A[i] * (A[i].dot(x) - b[i])\n", | |
"\n", | |
"def saga(pg, x, step_size, n_samples, max_iter=n_samples * 1000):\n", | |
" memory_gradients = np.zeros((n_samples, n_features))\n", | |
" for _ in range(max_iter):\n", | |
" i, cur_grad = pg(x)\n", | |
" x -= step_size * (cur_grad - memory_gradients[i] + memory_gradients.mean(0))\n", | |
" memory_gradients[i] = cur_grad\n", | |
" return x\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 0.082629 , 0.07154017, 0.15254069, -0.03220931, -0.03776542,\n", | |
" -0.09992818, 0.03298484, 0.11354239, 0.09035411, -0.13045454])" | |
] | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"saga(partial_grad, np.zeros(n_features), 0.01, n_samples)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 0.082629 , 0.07154017, 0.15254069, -0.03220931, -0.03776542,\n", | |
" -0.09992818, 0.03298484, 0.11354239, 0.09035411, -0.13045454])" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# sanity check, to see if solutions coincide\n", | |
"from scipy import linalg\n", | |
"linalg.lstsq(A, b)[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment