Skip to content

Instantly share code, notes, and snippets.

@jaganadhg
Created October 9, 2013 14:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jaganadhg/6901864 to your computer and use it in GitHub Desktop.
Save jaganadhg/6901864 to your computer and use it in GitHub Desktop.
Geometric illustration of the SVD
{
"metadata": {
"name": "Geometric illustration of the SVD"
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "code",
"collapsed": false,
"input": [
"#http://scipy-central.org/item/11/1/geometric-illustration-of-the-svd\n",
"\"\"\"Illustrate the SVD geometrically.\n",
"\n",
":author: Stefan van der Walt\n",
":date: 2006\n",
"\n",
"\"\"\"\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import CirclePolygon\n",
"import copy\n",
"\n",
"class CirclePoint(object):\n",
" \"\"\"\n",
" Draggable arrow on a circle.\n",
"\n",
" Clicks within epsilon pixels of arrow head grab the arrow.\n",
" \"\"\"\n",
"\n",
" def __init__(self,epsilon=10):\n",
" \"\"\"Initialize circle in given axss.\"\"\"\n",
" axes = plt.gca()\n",
" circ = CirclePolygon((0,0), 1., resolution=200)\n",
" circ.set_fill(False)\n",
" circ.set_edgecolor('b')\n",
" axes.add_patch(circ)\n",
"\n",
" canvas = circ.figure.canvas\n",
" canvas.mpl_connect('button_press_event', self.button_press_callback)\n",
" canvas.mpl_connect('button_release_event', self.button_release_callback)\n",
" canvas.mpl_connect('motion_notify_event', self.motion_notify_callback)\n",
"\n",
" self.epsilon = epsilon\n",
" self.circ = circ\n",
" self.arrow = None # Created by set_angle\n",
" self.canvas = circ.figure.canvas\n",
" self.external_hook = None\n",
" self.axes = axes\n",
" self.nr_pts = 0\n",
"\n",
" self.arrow_colour = {'default': 'r', 'selected': 'g'}\n",
" self.arrow_mode = 'default'\n",
"\n",
" self.set_angle(np.pi / 4.)\n",
"\n",
" def get_angle(self):\n",
" \"\"\"Return the angle of the arrow.\"\"\"\n",
" return self.__angle\n",
"\n",
" def set_angle(self, theta):\n",
" \"\"\"Point the arrow in the given direction.\"\"\"\n",
" self.__angle = theta\n",
"\n",
" self.update_arrow()\n",
"\n",
" angle = property(fget=get_angle, fset=set_angle,\n",
" doc=\"Angle of the arrow.\")\n",
"\n",
" def update_arrow(self):\n",
" \"\"\"Redraw the canvas.\"\"\"\n",
"\n",
" # Create a new arrow, and remove the previous one\n",
" if self.arrow:\n",
" self.axes.artists.remove(self.arrow)\n",
" ex,ey = self.pos\n",
" self.arrow = plt.arrow(0, 0, ex, ey, width=0.01,\n",
" length_includes_head=True)\n",
"\n",
" ac = self.arrow_colour[self.arrow_mode]\n",
" self.arrow.set_edgecolor(ac)\n",
" self.arrow.set_facecolor(ac)\n",
" self.canvas.draw()\n",
"\n",
" if self.external_hook:\n",
" self.external_hook()\n",
"\n",
" @property\n",
" def pos(self):\n",
" \"\"\"Return position of arrow tip (x,y).\"\"\"\n",
" a = self.angle\n",
" return (np.cos(a), np.sin(a))\n",
"\n",
" def button_press_callback(self, event):\n",
" \"\"\"Called when a mouse button is pressed.\"\"\"\n",
" if event.inaxes == None: return\n",
" if event.button != 1: return\n",
"\n",
" # translate graph coordinates to pixel coordinate\n",
" transf = self.circ.get_transform()\n",
"\n",
" x, y = transf.transform(self.pos)\n",
" ex, ey = event.x, event.y\n",
" if np.sqrt((ex - x)**2 + (ey - y)**2) > self.epsilon:\n",
" return\n",
"\n",
" # Arrow selected\n",
" self.arrow_mode = 'selected'\n",
"\n",
" verts = self.circ.get_verts()\n",
" verts_tf = transf.transform(verts)\n",
" cxt = verts_tf[:, 0]\n",
" cyt = verts_tf[:, 1]\n",
" d = np.sqrt((cxt - ex)**2 + (cyt - ey)**2)\n",
"\n",
" self.update_arrow()\n",
"\n",
" def button_release_callback(self, event):\n",
" \"\"\"Called when a mouse button is released.\"\"\"\n",
" self.arrow_mode = 'default'\n",
" self.update_arrow()\n",
"\n",
" def motion_notify_callback(self, event):\n",
" \"\"\"Called on mouse movement.\"\"\"\n",
" if self.arrow_mode == 'selected':\n",
" transf = self.circ.get_transform()\n",
" xt,yt = transf.inverted().transform((event.x, event.y))\n",
" self.angle = -np.arctan2(xt, yt) + np.pi/2\n",
" self.update_arrow()\n",
"\n",
"\n",
"class SVD_Geometry:\n",
" def __init__(self,M):\n",
" fig = plt.figure()\n",
"\n",
" cp = CirclePoint()\n",
" cp.external_hook = self.plot_tf\n",
" ax = cp.axes\n",
"\n",
" axis_max = np.array([1.5, 1.5])\n",
" axis_min = np.array([-1.5, 1.5])\n",
"\n",
" U,S,Vt = np.linalg.svd(M)\n",
" V = Vt.transpose()\n",
"\n",
" eig_vecs = np.vstack([(U*S).transpose(),V.transpose()])\n",
" colours = ['m','m','c','c']\n",
" labels = ['U', '', 'V', '']\n",
" for i,ev in enumerate(eig_vecs):\n",
" a = plt.arrow(0, 0, *ev, **{'width': 0.01,\n",
" 'length_includes_head': True})\n",
" a.set_edgecolor(colours[i])\n",
" a.set_facecolor(colours[i])\n",
" plt.text(*(tuple(ev / 1.5) + tuple([labels[i]])))\n",
"\n",
" idx = (ev > axis_max)\n",
" axis_max[idx] = ev[idx]\n",
" idx = (ev < axis_min)\n",
" axis_min[idx] = ev[idx]\n",
"\n",
" self.M = M\n",
" self.cp = cp\n",
" self.add_patch = plt.gca().add_patch\n",
"\n",
" # Maximum axis dimension is extent of bounding box\n",
" am = max(axis_max.max(), abs(axis_min.min()))\n",
" plt.axis('equal')\n",
" plt.axis([-1.5*am, 1.5*am, -1.5*am, 1.5*am])\n",
" plt.ylabel(\"Geometrical Illustration of the SVD\",\n",
" fontsize=14).set_weight('bold')\n",
" plt.title(\"A unit vector, indicated by the red arrow, is multiplied \"\n",
" \"by\\n%s\\n to form a blue dot.\" % self.M, fontsize=10)\n",
" plt.xlabel(\"Click and drag the tip of \"\n",
" \"the red arrow.\").set_style(\"italic\")\n",
" plt.show()\n",
"\n",
" def plot_tf(self):\n",
" self.add_patch(plt.Circle(np.dot(self.M, np.array(self.cp.pos)),\n",
" 0.01))\n",
"\n",
"demo = SVD_Geometry(np.array([[0.7, 1.4], [1.2, 0.1]]))"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 1
},
{
"cell_type": "code",
"collapsed": false,
"input": [],
"language": "python",
"metadata": {},
"outputs": []
}
],
"metadata": {}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment