Skip to content

Instantly share code, notes, and snippets.

@kersulis
Last active November 20, 2020 11:36
Show Gist options
  • Save kersulis/007cf0c14e4375a70235 to your computer and use it in GitHub Desktop.
Save kersulis/007cf0c14e4375a70235 to your computer and use it in GitHub Desktop.
Interactive Eigenimages. Live version on Binder: http://mybinder.org/repo/kersulis/SVD-eigenimages
import ipywidgets as ipw
from pylab import *
def classify_image(tst,trn,k):
""" Input:
tst: a n x T vector corresponding to T images that are to be classified
trn: n x m x 10 matrix containing m training samples for each digit
k: integer between 1 and n
"""
n,T = tst.shape
m = trn.shape[1]
# compute projection matrices:
P = zeros((n,n,10))
for i in range(10):
U,S,V = svd(trn[:,:,i])
U1 = U[:,0:k]
P[:,:,i] = dot(U1,U1.T)
# find errors:
sqerr = zeros((10,T))
for i in range(10):
sqerr[i,:] = sum((tst - dot(squeeze(P[:,:,i]),tst))**2,0)
return argmin(sqerr,0)
def linear_combo(a1,a2,a3,digit,trn):
U,S,V = svd(trn[:,:,digit])
U1 = U[:,0]
U2 = U[:,1]
U3 = U[:,2]
v = a1*U1 + a2*U2 + a3*U3
return v
def vec2mat(v):
# shift to zero:
v = (v - min(v))
# scale to 1:
v = v/max(v)
# shift to [-0.5,0.5]:
v = v - 0.5
return reshape(v,(16,16))
def return_widgets():
c1=ipw.FloatSlider(
value=1,
min=0.1,
max=1,
step=0.1,
width=50,
height=10)
c2=ipw.FloatSlider(
value=0.0,
min=-0.5,
max=0.5,
step=0.1,
width=50,
height=10)
c3=ipw.FloatSlider(
value=0.0,
min=-0.5,
max=0.5,
step=0.1,
width=50,
height=10,
padding=0)
d = ipw.Dropdown(
options=[str(i) for i in range(10)],
value='0',
description='Digit:')
return c1,c2,c3,d
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## EECS 453/551\n",
"# Eigenimages\n",
"\n",
"What can SVD tell us about the way people write digits?\n",
"___"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Before we begin\n",
"\n",
"If you are new to the Jupyter notebook interface, take the tour by clicking Help -> User Interface Tour. The most important thing to know is that you can run a code cell (like the one below) by clicking on it and pressing Ctrl+Enter.\n",
"\n",
"Run the code cell below to load the Python code and data we need."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [],
"source": [
"# import code for plotting, widgets, etc.\n",
"%pylab inline\n",
"from scipy.io import loadmat\n",
"import ipywidgets as ipw\n",
"\n",
"# contains classify_image and related fcns:\n",
"from eigenimages import *\n",
"\n",
"# load data\n",
"trn = loadmat(\"TRAIN_DIGITS.mat\")[\"TRAIN_DIGITS\"]\n",
"testdata = loadmat(\"TEST_DIGITS.mat\")\n",
"tst = testdata[\"TEST_DIGITS\"]\n",
"labels = testdata[\"TEST_DIGIT_LABELS\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Comparing digit predictions with labels\n",
"\n",
"This section should be familiar. Run the following cell, then use the slider to scroll through the test dataset and see which digits are correctly classified.\n",
"\n",
"The predictions were made using a `classify_image` function just like the one you wrote."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [],
"source": [
"# number of components to take:\n",
"k = 10\n",
"\n",
"# classify all images:\n",
"predictions = classify_image(tst,trn,k)\n",
"\n",
"fig1 = figure(figsize(6,6))\n",
"ax1 = fig1.add_subplot(111)\n",
"ax1.axis('off')\n",
"set_cmap('gray_r')\n",
"im1 = ax1.matshow(vec2mat(tst[:,0]))\n",
"ttl = ax1.text(0,-1,\"Predicted: \",size=16)\n",
"def on_change(val):\n",
" i = [val]\n",
" test_image = tst[:,i]\n",
" correct_label = labels[0,i]\n",
" which_digit = predictions[i]\n",
" im1.set_data(vec2mat(tst[:,i]))\n",
" ttl.set_text(\"Predicted: \" + \n",
" str(which_digit[0]) + \n",
" \" Actual: \" + \n",
" str(correct_label[0]))\n",
" ttl.set_color('black' if which_digit[0] == \n",
" correct_label[0] else 'red')\n",
" fig1.canvas.draw()\n",
" return fig1\n",
"\n",
"n,T = tst.shape\n",
"ipw.interact(on_change, val=ipw.IntSlider(\n",
" min=0,\n",
" max=T-1,\n",
" description=\"Digit index:\",\n",
" width=200))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Interactive Eigenimages\n",
"\n",
"We know SVD can do better than mean-based classification. Why is this? What insight do we gain by taking the SVD over a set of images instead of just using average images?\n",
"\n",
"Run the following cell to generate an interactive figure. The top row of plots shows the first three left singular vectors for a particular digit: $U[:,1]$, $U[:,2]$, and $U[:,3]$. The bottom plot shows the linear combination $ a_1 U[:,1] + a_2 U[:,2] + a_3 U[:,3].$ Think of $U[:,1]$ as the \"base image\" and $U[:,2]$ & $U[:,3]$ as the two most common deviations from the base image. By adding and subtracting $U[:,2]$ and $U[:,3]$ through the coefficients $a_2$ and $a_3$, we are modifying the base image by adding and subtracting pixels.\n",
"\n",
"Set \"Digit\" to 0 and play with the sliders. What does this tell you about the way people write \"0\"?\n",
"\n",
"*Note: you can drag a slider or use the arrow keys to change its value.*"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [],
"source": [
"# load widgets:\n",
"c1,c2,c3,d = return_widgets()\n",
"\n",
"# group widgets together:\n",
"container = ipw.Box(\n",
" children=[d,ipw.HBox(children=[c1,c2,c3],width=150)])\n",
"\n",
"# specify interaction behavior\n",
"n,T = tst.shape\n",
"ncomps = 3\n",
"Uvecs = zeros((n,10,ncomps))\n",
"for i in range(10):\n",
" U,S,V = svd(trn[:,:,i])\n",
" Uvecs[:,i,:] = U[:,range(3)]\n",
" \n",
"fig2 = figure(figsize(8,8))\n",
"set_cmap('bwr')\n",
"ax21 = subplot2grid((3,2), (0,0),colspan=2)\n",
"ax21.axis('off')\n",
"ax21.text(6,-1,\"u1\",size=16)\n",
"ax21.text(27,-1,\"u2\",size=16)\n",
"ax21.text(49,-1,\"u3\",size=16)\n",
"\n",
"ax22 = subplot2grid((3,2), (1,0),colspan=2,rowspan=2)\n",
"ax22.axis('off')\n",
"ax22.text(2,17,\"a1*u1 + a2*u2 + a3*u3\",size=16)\n",
"\n",
"z = zeros((16,16))\n",
"ws = zeros((16,5))\n",
"im21 = ax21.matshow(hstack((z,ws,z,ws,z)),vmin=-0.5,vmax=0.5)\n",
"\n",
"lc = vec2mat(linear_combo(1.0,0.0,0.0,0,trn))\n",
"im22 = ax22.matshow(lc,vmin=-0.5,vmax=0.5)\n",
"\n",
"def on_change(a1,a2,a3,digit):\n",
" digit = int(digit)\n",
" \n",
" v1,v2,v3 = [vec2mat(Uvecs[:,digit,i]) for i in range(3)]\n",
" im21.set_data(hstack((v1,ws,v2,ws,v3)))\n",
" \n",
" v = vec2mat(linear_combo(a1,a2,a3,digit,trn))\n",
" im22.set_data(v)\n",
" \n",
" fig2.canvas.draw()\n",
" return(fig2)\n",
"\n",
"w = ipw.interactive(on_change,a1=c1,a2=c2,a3=c3,digit=d)\n",
"container"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot first three vectors for each digit\n",
"\n",
"Run the following cell to see the first three left singular vectors for all ten digits.\n",
"\n",
"Now save the figure, [print it][1], and hang it in your room. (optional)\n",
"\n",
"[1]: http://www.itcs.umich.edu/sites/printing/poster.php"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [],
"source": [
"fig3 = figure(figsize(17,4))\n",
"set_cmap('bwr')\n",
"for i in arange(3)+1:\n",
" for j in arange(10)+1:\n",
" ax3 = fig3.add_subplot(3,10,(i-1)*10 + j)\n",
" v = Uvecs[:,j-1,i-1]\n",
" ax3.matshow(reshape(v,(16,16)))\n",
" ax3.axis('off')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Singular value \"knee\"\n",
"\n",
"In class we plotted $P_{correct}$ versus $k$ and found that $P_{correct}$ was highest around $k=11$. Why did accuracy decrease when we moved away from this value? In general, prediction accuracy is highest when we capture the most signal and the least noise, and we can use singular value magnitudes to distinguish the two.\n",
"\n",
"Run the cell below to plot singular value magnitudes for the training set of a particular digit. Use the top slider to vary the digit. Use the bottom slider to set a cutoff value for $k$ and compute the fraction \n",
"\n",
"$$\\frac{\\text{sum}(S[1:k])}{\\text{sum}(S)}.$$\n",
"\n",
"A couple things to think about:\n",
"\n",
"* How many points \"break away\" from the smooth (lower-right) portion of the plot?\n",
"* What fraction of the typical 16x16 image of a digit is signal?\n",
"* Why is there such a dramatic separation between $S[1]$ and $S[2]$ for the digit \"1\"?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [],
"source": [
"fig4 = figure(figsize(8,8))\n",
"ax4 = fig4.add_subplot(1,1,1)\n",
"ax4.set_xlabel(\"index\")\n",
"ax4.set_ylabel(\"singular value magnitude\")\n",
"ax4.axis([-2,258,0,250])\n",
"U,S,V = svd(trn[:,:,0])\n",
"line, = ax4.plot([10.5,10.5],[0,250])\n",
"pts, = ax4.plot(S,lw=0,marker=\"o\",c=\"k\",markersize=4)\n",
"ttl41 = ax4.text(100,255,\"Digit: \",size=16)\n",
"ttl42 = ax4.text(50,230,\"sum(S<cutoff)/sum(S): \",size=14)\n",
"def int_change(digit,cutoff):\n",
" U,S,V = svd(trn[:,:,digit])\n",
" line.set_xdata([cutoff,cutoff])\n",
" pts.set_ydata(S)\n",
" ttl41.set_text(\"Digit: \" + str(digit))\n",
" pct = round(100*sum(S[range(int(cutoff))])/sum(S),1)\n",
" ttl42.set_text(str(pct) + \n",
" \"% of sum(S) is captured in first\\n\" +\n",
" str(round(100*cutoff/256,1)) + \n",
" \"% of components\")\n",
" fig4.canvas.draw()\n",
" return fig4\n",
"\n",
"ipw.interact(\n",
" int_change,\n",
" digit=(0,9),\n",
" cutoff=ipw.FloatSlider(\n",
" min=0,\n",
" max=256,\n",
" step=1,\n",
" value=10))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment