Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yrevar/ebe97208c09ca3e9274e8d29ce357b21 to your computer and use it in GitHub Desktop.
Save yrevar/ebe97208c09ca3e9274e8d29ce357b21 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from matplotlib import cm as cm, pyplot as plt, gridspec as gridspec\n",
"from matplotlib.patches import ConnectionPatch\n",
"%matplotlib inline\n",
"\n",
"class MatplotlibGridDisplay:\n",
" \n",
" def __init__(self, rows, cols):\n",
" self.rows, self.cols = rows, cols\n",
" self.axes_order = {}\n",
" \n",
" @staticmethod\n",
" def _prepare_axis(ax):\n",
" \n",
" ax.set_xticks([])\n",
" ax.set_yticks([])\n",
" for sp in ax.spines.values():\n",
" sp.set_visible(False)\n",
" if ax.is_first_row():\n",
" ax.spines['top'].set_visible(True)\n",
" if ax.is_last_row():\n",
" ax.spines['bottom'].set_visible(True)\n",
" if ax.is_first_col():\n",
" ax.spines['left'].set_visible(True)\n",
" if ax.is_last_col():\n",
" ax.spines['right'].set_visible(True)\n",
" \n",
" return ax\n",
" \n",
" def _xy_to_rowcol(self, x, y):\n",
" \"\"\"Converts (x, y) to (row, col).\n",
"\n",
" \"\"\"\n",
" return self.rows - y, x - 1\n",
"\n",
" def _rowcol_to_xy(self, row, col):\n",
" \"\"\"Converts (row, col) to (x, y).\n",
"\n",
" \"\"\"\n",
" return col + 1, self.rows - row\n",
" \n",
" def connect_axes(self, fig, ax1, ax2, order=\"forward\"):\n",
" \n",
" axis_center = (0., 0.)\n",
" con = ConnectionPatch(xyA=axis_center, xyB=axis_center, \n",
" coordsA=\"data\", coordsB=\"data\",\n",
" axesA=ax1, axesB=ax2, color=\"red\", \n",
" mutation_scale=40, arrowstyle=\"->\", \n",
" shrinkB=5, shrinkA=5)\n",
" ax1.add_artist(con)\n",
" con = ConnectionPatch(xyA=axis_center, xyB=axis_center, \n",
" coordsA=\"data\", coordsB=\"data\",\n",
" axesA=ax2, axesB=ax1, color=\"red\", \n",
" mutation_scale=40, arrowstyle=\"<-\",\n",
" shrinkB=5, shrinkA=5)\n",
" ax2.add_artist(con)\n",
" \n",
" ax1.plot(*axis_center,'ro',markersize=10)\n",
" ax2.plot(*axis_center,'ro',markersize=10)\n",
" \n",
" def add_trajectory(self, fig, axes_grid, traj):\n",
" \n",
" x_list, y_list = tuple(zip(*traj)) # [(x, y), ..] -> [x, ...], [y, ...]\n",
" for idx in range(len(x_list)-1):\n",
" ax1 = axes_grid[(x_list[idx], y_list[idx])]\n",
" ax2 = axes_grid[(x_list[idx+1], y_list[idx+1])]\n",
" ax1.set_zorder(-2*idx+1)\n",
" ax2.set_zorder(-2*idx)\n",
" self.connect_axes(fig, ax1, ax2)\n",
" \n",
" def add_trajectories(self, fig, axes_grid, traj_lst):\n",
" \n",
" if traj_lst is not None:\n",
" for traj in traj_lst:\n",
" self.add_trajectory(fig, axes_grid, traj)\n",
" \n",
" def render(self, data, phi_shape, traj_lst=None,\n",
" interpolation=\"None\", cmap=cm.viridis, vmin=None, vmax=None):\n",
" \n",
" H, W, D = data.shape\n",
" # Setup axes grid\n",
" fig = plt.figure(figsize=(W*2, H*2))\n",
" gs = gridspec.GridSpec(H, W)\n",
" gs.update(wspace=0., hspace=0., left = 0., right = 1., bottom = 0., top = 1.)\n",
" axes_grid = {}\n",
" \n",
" for row in range(H):\n",
" for col in range(W):\n",
" ax = plt.Subplot(fig, gs[row, col])\n",
" ax.imshow(data[row, col].reshape(*phi_shape), vmin=vmin, vmax=vmax)\n",
" fig.add_subplot(self._prepare_axis(ax))\n",
" axes_grid[self._rowcol_to_xy(row, col)] = ax\n",
" \n",
" self.add_trajectories(fig, axes_grid, traj_lst)\n",
" return fig, ax"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(<Figure size 288x288 with 4 Axes>,\n",
" <matplotlib.axes._subplots.AxesSubplot at 0x113dcc978>)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAATUAAAE1CAYAAACGH3cEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADJFJREFUeJzt3X+slYV5wPHnICLzKjh+y0aVi9w5WAoqsajZssVGsF0shVrXdcYlZiPRZaMUNocGwWydXa01XXUjWf+wq2sbF0vbKcg2rY1xjEkKXcPsvaDQbnIRyYSK/ObdH+1RQPD8uPee99yHz+cvOHkPecJ9+d7nvec9h0pRFAGQxZCyBwDoT6IGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCpDGzl4WOW8Ynh0DNQsJLe/41AMHTWq7DEYpA7/5H9eL4pibK3jGora8OiID1Sub34qzmpb/3J22SMwiG1ftGRHPce5/ARSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETVoQNfOnRHHj5c9Bu9haNkDwGDy1YdXxQWHDsXf/+ZvxIM3zokYYi9oN74i0IDZK+6JR3/9uviDZ5+LLX92dyx+co3Nrc2IGjTg+NCh8dmbfjum3/8X4tamKkVR1H3wiMqo4gOV6wdwnMFnRvFa3Bmb4tLY9/Zj22NEPBwzY3NlXImTtZ+tX5hd9gj9bsjRo7H0qbXx+997PopKpeHL0tndPbHiidXR1bvr7ce6J4yPFfPnxfquqQM19qC0fdGSjUVRzKp1nKj1wSeLLXFbbImIiMoJj1f/Rh+NafFYZVrL52pXGaNW1Uzc/mjtuli8dl1EnP78eXDuDfGluTcM3NCDTL1Rc/nZpBnFa3FbbIlKnHxCxs9/X4mI22JLzChea/1wtFyjl6Wzu3ti8dp173n+LF67LmZ397Rg+lxErUl3xqa6jrujzuPIod64rXhidV1/3r11Hsc73NLRpEtj37u+w56qEhGTY19MLf6vFSO1vekvbix7hJZ6aeLFsezjC+JDmzbHwn97NhY+8914cub7Y/Gtn4yu3l11nT+/csLP2qiPqLXAoji7/jGfyaRv7C97hPJUIs49dizmbfx+LPvY/LKnSU3UWuDOygfLHqEtbP1c3hcKzuSWF/49/vTJNTHiwNFYfdWVcc/N8+Pg8OFlj5WaqDVpe4yoeQlaRMQrMaJVI9FG3onZgfjOFVfEPTfPj7dOiFn3hPE1L0GLiPjRhPEDPms2otakh2NmfC6+V/O4R2JmC6ahXdSKWdWK+fPiHx9ZVfPPWzl/3kCMmZqoNWlzZVw8WkyreZ+aG3DPDvXGrGp919R4cO4NNe9TcwNu40StDx6rTIsfFmPijtgUk094R8ErMSIe8Y6Cs0KjMTvRl+beEC92To57n1h90qucP5owPlZ6R0HTRK2PNlfGxZ3FB2N1rI5HYmY8VekseyRaoC8xO9H6rqlx411L49sPfCGW3XJz/HDSLw/AtGcXUesHc2J7DIkiFkRPPFVMjqjUugOJwWrBf2yIZd/+5z7HjIEjan00tDgen4iXYmd0RCWKmB07Y31MLHssBsgfPvNcPHf55WLWxkStj+bE9tgRF8ZFcSiejUlxa2yJ9cXFtrWk5vz50rJHoAbv/eyD6pb21fjZJ3H8IMbGOT/f1oByiFofVLe0/66MjoiIIirxDzEtbo0tEQ18pBPQf0StSaduaVUvxETbGpRI1Jp06pZWVVRsa1AmUWvCmba0KtsalEfUmnCmLa3KtgblEbUG1drSqmxrUA5Ra1CtLa3KtgblELUG1LulVdnWoPVErQH1bmlVtjVoPVGrU6NbWpVtDVpL1OrU6JZWZVuD1hK1OnXGG/FoTG/quS/ExNgTvxAXxaF+ngo4lU/pqNPfVK5s+rlFpRLL47p+nAY4E5sakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqQxt5OC3zjsUb9xyzUDNMqhNeuxbMfqacfGfnbPKHqVtffmmVWWP0Lam3fVq3D/x67HrphFlj9K2rl8Ul9RzXEOb2nkXjGpuGoC+21HPQS4/gVREDUhF1IBURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVEQNWmDUd9+MIQePl/b8s4moQQtcfs9r8atLept67vBXDseM21+N81492s9T5SRq0ALbPjU6xq19M87Z13iYpi3tjbc6h8WBzmEDMFk+ogYtsGvByDjyi+fE5ctea+h5w185HCO/fzBe+qtxAzRZPqIGLdJz15iGt7XqlrZ31vkDOFkuogYt0ui2ZktrjqhBCzWyrdnSmiNq0EL1bmu2tOaJGrRYPduaLa15ogYtVmtbs6X1jahBCd5rW7Ol9Y2oQQnOtK3Z0vpO1KAkp9vWbGl9J2pQklO3tXN3H7Gl9QNRa3OT9u6KYUcPlz0GA6S6rUVRxC99fZ8trR+IWpua2bstHl99f/zTt/46rurdVvY4DJDqtlY5EnH+9iO2tH4wtOwBONnM3m1x9/rH4337dscPxl4aiz9ye/xk5Niyx2IA9dw1JqYv3RUHx51jS+sHotYmxOzstWvByLjs/tfjpRW+3v1B1Proqp09sWTDN2P4sSNx3/Nfi/ue/1psGzk+Hrj6o7Hx4qk1ny9mZ7eLXtgfXSt3x7A9x2LGHb0R0Rv7pw6L7nvHxhvXdpQ93qAkan1w++Z1sXDz0xERUTnh8Sl7d8Xf/svfxaoZc+LLM2447XPFjEu+uCc6H9oTESefPx09h+OK3/vfeHnR6Njxx6PLGW4QE7UmXbWzJxZufvqkk7Gq+tjCzU/HpnGTT9rYxIyIn21onQ/tec/zp/OhPbF31nAbW4NErUlLNnyzruM+vWF1/O5HlooZJ+laubu+4+7bHRvWilojRK1JU/buOu132RNVIuKyvb3x5OMrYuyBn8aPLxwTX7zyw7Hn/BExfc+OmL5nRytGbRsX/tfBskdoC8fPrURHz+G6zp+ObvcoNkrUWmDcgZ/GsajE+P1vxMLN68oepzRH7nZbZEREcW6tnNEXotYCy6/7RPzJxu/ERYf2xzOT3h+fueZj8eaws+9+pM+uWFX2CG3jtzq7yx4hLd86m7Rt5PgoahxTRMTWkRNizZRZMffjK2Pltb8TV+7aFv/6jeXxmee+EhccfqsVo9KG9k8dVtf5s7/Lf4vXKFFr0gNXf7Su4z5/9by3fy1uVHXfW9+LRN3LvZjUKFFr0saLp8aqGXOiiHjXd9zqY6tmzDntDbjixhvXdsTLi0a/5/nz8qLRbudoQqUoai3B7+gYM6mY9uFPDeA4g89VO3vi0xtWx2V7e99+bOvICfH5q+fV9Y6CiIgbt734zs/c3pf3Z25+pvZuF72wP7ru233Sq5z7u4ZF93LvKDjV9Z3dG4uimFXrOFFrI9njJmr0Rb1Rc/nZRk5/Wfqoz1ODBohaGzoxbr/2+o99nho0wH1qbWzNlFmxZkrNbRs4gU0NSEXUgFREDUhF1IBURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVEQNSEXUgFQqRVHUf3ClsjsidgzcOCR3STh/aN4lRVGMrXVQQ1EDaHcuP4FURA1IRdSAVEQNSEXUgFREDUhF1IBURA1IRdSAVP4f7QCqD6ofNpAAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 288x288 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"m = MatplotlibGridDisplay(2, 2)\n",
"m.render(\n",
" data = np.array([[0., 0.5], [0.3, 0.9]])[:,:,np.newaxis],\n",
" phi_shape=(1,1), traj_lst=[[(1,1), (1,2), (2,2), (2,1), (1,1)]],\n",
" vmin=0., vmax=1.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:irl]",
"language": "python",
"name": "conda-env-irl-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment