Skip to content

Instantly share code, notes, and snippets.

@yrevar
Created August 9, 2019 20:15
Show Gist options
  • Save yrevar/765cf3456af4119c9fcabf8667dadb4f to your computer and use it in GitHub Desktop.
Save yrevar/765cf3456af4119c9fcabf8667dadb4f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from matplotlib import cm as cm, pyplot as plt, gridspec as gridspec\n",
"from matplotlib.patches import ConnectionPatch\n",
"%matplotlib inline\n",
"\n",
"class MatplotlibGridDisplay:\n",
" \n",
" def __init__(self, rows, cols):\n",
" self.rows, self.cols = rows, cols\n",
" self.axes_order = {}\n",
" \n",
" @staticmethod\n",
" def _prepare_axis(ax):\n",
" \n",
" ax.set_xticks([])\n",
" ax.set_yticks([])\n",
" for sp in ax.spines.values():\n",
" sp.set_visible(False)\n",
" if ax.is_first_row():\n",
" ax.spines['top'].set_visible(True)\n",
" if ax.is_last_row():\n",
" ax.spines['bottom'].set_visible(True)\n",
" if ax.is_first_col():\n",
" ax.spines['left'].set_visible(True)\n",
" if ax.is_last_col():\n",
" ax.spines['right'].set_visible(True)\n",
" \n",
" return ax\n",
" \n",
" def _xy_to_rowcol(self, x, y):\n",
" \"\"\"Converts (x, y) to (row, col).\n",
"\n",
" \"\"\"\n",
" return self.rows - y, x - 1\n",
"\n",
" def _rowcol_to_xy(self, row, col):\n",
" \"\"\"Converts (row, col) to (x, y).\n",
"\n",
" \"\"\"\n",
" return col + 1, self.rows - row\n",
" \n",
" def connect_axes(self, fig, ax1, ax2, order=\"forward\"):\n",
" \n",
" axis_center = (0., 0.)\n",
" if order == \"forward\":\n",
" con = ConnectionPatch(xyA=axis_center, xyB=axis_center, \n",
" coordsA=\"data\", coordsB=\"data\",\n",
" axesA=ax1, axesB=ax2, color=\"red\", \n",
" mutation_scale=40, arrowstyle=\"->\", \n",
" shrinkB=5)\n",
" ax1.add_artist(con)\n",
" else:\n",
" con = ConnectionPatch(xyA=axis_center, xyB=axis_center, \n",
" coordsA=\"data\", coordsB=\"data\",\n",
" axesA=ax2, axesB=ax1, color=\"red\", \n",
" mutation_scale=40, arrowstyle=\"<-\",\n",
" shrinkB=5)\n",
" ax2.add_artist(con)\n",
" \n",
" ax1.plot(*axis_center,'ro',markersize=10)\n",
" ax2.plot(*axis_center,'ro',markersize=10)\n",
" \n",
" def add_trajectory(self, fig, axes_grid, traj):\n",
" \n",
" x_list, y_list = tuple(zip(*traj)) # [(x, y), ..] -> [x, ...], [y, ...]\n",
" for idx in range(len(x_list)-1):\n",
" ax1 = axes_grid[(x_list[idx], y_list[idx])]\n",
" ax2 = axes_grid[(x_list[idx+1], y_list[idx+1])]\n",
" ax1.set_zorder(-2*idx+1)\n",
" ax2.set_zorder(-2*idx)\n",
" self.connect_axes(fig, ax1, ax2)\n",
" \n",
" def add_trajectories(self, fig, axes_grid, traj_lst):\n",
" \n",
" if traj_lst is not None:\n",
" for traj in traj_lst:\n",
" self.add_trajectory(fig, axes_grid, traj)\n",
" \n",
" def render(self, data, phi_shape, traj_lst=None,\n",
" interpolation=\"None\", cmap=cm.viridis, vmin=None, vmax=None):\n",
" \n",
" H, W, D = data.shape\n",
" # Setup axes grid\n",
" fig = plt.figure(figsize=(W*2, H*2))\n",
" gs = gridspec.GridSpec(H, W)\n",
" gs.update(wspace=0., hspace=0., left = 0., right = 1., bottom = 0., top = 1.)\n",
" axes_grid = {}\n",
" \n",
" for row in range(H):\n",
" for col in range(W):\n",
" ax = plt.Subplot(fig, gs[row, col])\n",
" ax.imshow(data[row, col].reshape(*phi_shape), vmin=vmin, vmax=vmax)\n",
" fig.add_subplot(self._prepare_axis(ax))\n",
" axes_grid[self._rowcol_to_xy(row, col)] = ax\n",
" \n",
" self.add_trajectories(fig, axes_grid, traj_lst)\n",
" return fig, ax"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(<Figure size 288x288 with 4 Axes>,\n",
" <matplotlib.axes._subplots.AxesSubplot at 0x113c27a20>)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAATUAAAE1CAYAAACGH3cEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAC81JREFUeJzt3X+M1/V9wPHXV2947iZMDgGtlIJybrRyGgmCS2ytirRbFmpxXbJ1zhhnoouelCV0poJNtqaLVrbWTrKRVdPOrbOItjXcubX/LI1xI4GKVg/kl9Xy69DjvEVx8Nkf7uRQ5L7f7/H9wesej7/0m/eXvAIfn/f6fr7fL5aKogiALE5r9AAAJ5OoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpBKSyWHx5XOKFqjrVazkNxg29vRMnFio8fgFHXolV/uL4rinJHOVRS11miLy0tXVz8VY9rWv5rf6BE4he3oWraznHNefgKpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBpUYNLAQERRNHoMTkDUoAJf+5d/ix/evyqu2fy8uDUpUYMK/NnNfxrfXHhN3PVUt7g1qZZGDwCnkuK006JnzsXx9Cc+Htdufj7ueqo77lzfE3+7aGH8+8dnR5RKjR5xzBO1Ueos9sbtsTE+Fgffe2xHjI8H45LYVJrcwMmopZMVt/m9W2Ll2nXRsXvPe4/1Tp0SK69fHM90zKrV+KmVigpW5/GlicXlpatrOM6p5Y+KF+LGeCEiIoZfwkO/ow/H7PheaXbd52pWWx+Y3+gRaqZ05Ehcu/n5uHP903HktFJZcfvz9T2xdH3Pu88f9vjQ9fONRQvjW4sW1m7oU8yOrmUbiqKYO9I599Sq1FnsjRvjhSjFsRdk/P+/lyLixnghOou99R+Ouhva3H5vWVdZ99zm926Jpet7Tnj9LF3fE/N7t9Rh+lxErUq3x8ayzt1W5jlyKDduK9euK+vXW1HmOY5yT61KH4uDH/gJ+36liJgRB2NW8Xo9Rmp6ra/8stEj1NVrZ58dy//whrh868vx5Sd+GMuf/FGs+eSV8ejvLIiO3XvKun4uGnavjfKIWh10xYZGj9AU3v7XFxs9QsMUEfHR/X3xlXVPxmPzRrwtxCiIWh3cXrqm0SM0ha3L8r5R8GEu3b4j7ux+OiYODsaKJdfHY/Pmxjst/rOrJb+7VdoR40d8CVpExPYYX6+RaCJDMbtwz5548Npr4pZ5Nx0Ts96pU0Z8CVpExEtTp9R81my8UVClB+OSss59u8xz5HDp9h3xnYf+Ib75yHeje87FcdXdy+PRK+Z/YDtbef3isn69e8s8x1E2tSptKk2Oh4vZI35OzQdwx4aRNrP3e6ZjVnxj0cIRP6fmA7iVE7VR+F5pdmwuJsVtsTFmDPtGwfYYH9/2jYIxodKYDfetRQvjv2fOiBVr1x3zLudLU6fEvb5RUDXfKKBuMn2j4P0xG+0bAE/e90D85RduiM3Tzj+JU+ZS7jcKbGpQgTm7dsXSp7qr2syoD38aUIEbnvmv6J5zsZg1MX8qUIGv/MHnGz0CI/CRDiAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSaank8P+c8Xa88YUFtZrllPbIjx6Iv16wJF5sn9boUZrWmt9f3egRmtb0f9wf91y5NgYubm30KE3r6q6YXs65ija1M35jYnXTAIzeznIOefkJpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpQB7++7VCU/rdo2PPHElGDOpixqi/Oe7S/que2vH44LluyK1oOHjnJU+UkalAHu24+O6Y/dCBKb1cepo+ueT32Ljor3pl4eg0my0fUoA4GOlvjzd86I877/sGKntfy+uH4yD+/ETtvm1ijyfIRNaiT7Xe0V7ytDW1pb53/azWcLBdRgzqpdFuzpVVH1KCOKtnWbGnVETWoo3K3NVta9UQN6qycbc2WVj1RgzobaVuzpY2OqEEDnGhbs6WNjqhBA3zYtmZLGz1RgwY53rZmSxs9UYMGef+2dvpBW9rJIGpNbsJbb8bpRw43egxqZGhbiyNFTH1iwJZ2EohakzpvoC/u/tn347F1X4/zB/oaPQ41MrStjdt/OCZ3v2lLOwlErckMxezhH6+KvjPPiiWLl8fOCZMbPRY1tP2O9mjdezgOLDjTlnYStDR6AN513kBf3PTcf8Sndj0XP7joiliyeHn0t7Y1eizqYKCzNfZf1RbbuiY1epQURG2ULvvVllj27ONxQf+eeOTHqyIi4uUJU+K+eZ+LDefOGvH5Yja2/ebPBqPj3n3RtuVQtP90MCIiBmeNi94V58QbV7gOqiFqo3Dzpp64dVN3RESUhj1+Qf+e+PunH4rVndfFms6Fx32umDH97/pi5qp375cOv37athyKS//41djW1R4772hvzHCnMFGr0mW/2hK3buo+5mIcMvTYrZu6Y+PkGcdsbGJGxLsb2sxVfSe8fmau6ov+ua02tgp5o6BKy559vKxzX3p2XUQc/w2Ahy79jKCNUR337ivv3FfLO8dRNrUqXdC/57g/ZYcrRcSF/bvjb376TzF399b4yfQ58eVPfjHeHHdmnDt4IM4dPFCPUZvGWc+91egRmkbblkNlXT9tvYfqMU4qolYHV72yOfadOT4u6ns1Og681uhxGuYjO20d1J6o1cGffLYrbvl5T3QceDUe/sSn44lZl8eh08fe55G+vnJ1o0doGlfN7G30CGm5p1allydMiZH+17JFRGydMDV+MWlaLP30zfEXn7opFrz2Uqx9/Gtxw4v/GeMOv1OPUWlCg7PGlXX9DHaMq8c4qYhale6b97myzt0/b/F7/yxuDOldcU555+4p7xxHiVqVNpw7K1Z3XhdFxAd+4g49trrzuuN+AFfceOOKttjW1X7C62dbV7uPc1TBPbVRWNO5MDZOnhFfenZdXNi/+73Ht06YGvfPWzziNwqG4vbb+1+JW37eEzdu/smYvuc21uy8oz3657ZGx1f3HfMu52DHuOi9xzcKqlUqipFe2R/VNmlaMft376rhOGPbUNyyvqHgjQJG4+qZvRuKopg70jkvP5vIh70s9fepQflErQkNj9sle7f7+9SgAu6pNbFfTJoWd1/5xUaPAacUmxqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqZSKoij/cKm0LyJ21m4ckpserh+qN70oinNGOlRR1ACanZefQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAq/wd9PCIrVuNL8gAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 288x288 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"m = MatplotlibGridDisplay(2, 2)\n",
"m.render(\n",
" data = np.array([[0., 0.5], [0.3, 0.9]])[:,:,np.newaxis],\n",
" phi_shape=(1,1), traj_lst=[[(1,1), (1,2), (2,2), (2,1), (1,1)]],\n",
" vmin=0., vmax=1.)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:irl]",
"language": "python",
"name": "conda-env-irl-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@yrevar
Copy link
Author

yrevar commented Aug 9, 2019

A workaround. One limitation is that it requires shrinkA = shrinkB.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment