Skip to content

Instantly share code, notes, and snippets.

@kngwyu
Last active June 25, 2020 04:50
Show Gist options
  • Save kngwyu/76a091d02b1971dbaf7afb172a616f91 to your computer and use it in GitHub Desktop.
Save kngwyu/76a091d02b1971dbaf7afb172a616f91 to your computer and use it in GitHub Desktop.
from typing import Dict, Tuple
import numpy as np
from matplotlib import colors as mc
from matplotlib import pyplot as plt
def _make_colormap(colors: Dict[int, str], name: str) -> mc.LinearSegmentedColormap:
z = np.sort(list(colors.keys()))
min_z = min(z)
x0 = (z - min_z) / (max(z) - min_z)
rgb = [mc.to_rgb(colors[zi]) for zi in z]
cmap_dict = dict(
red=[(x0[i], c[0], c[0]) for i, c in enumerate(rgb)],
green=[(x0[i], c[1], c[1]) for i, c in enumerate(rgb)],
blue=[(x0[i], c[2], c[2]) for i, c in enumerate(rgb)],
)
return mc.LinearSegmentedColormap(name, cmap_dict)
DEFAULT_VALUE_CM = _make_colormap({0: "xkcd:scarlet", 1: "w", 2: "xkcd:green"}, "value")
class ValueHeatMap:
def __init__(
self,
data_shape: Tuple[int, int],
nrows: int = 1,
ncols: int = 1,
name: str = "Value Function",
cmap: str = DEFAULT_VALUE_CM,
vmin: float = -1.0,
vmax: float = 1.0,
) -> None:
self.fig = plt.figure(name)
self.data_shape = data_shape
cmap = plt.get_cmap(cmap)
self.imgs = []
def label_str(i: int) -> str:
if nrows == 0 and ncols == 0:
return ""
else:
return f"Value {i}"
dummy = np.zeros(data_shape)
for i in range(nrows * ncols):
ax = self.fig.add_subplot(nrows, ncols, i + 1)
img = ax.imshow(
dummy, cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest"
)
cbar = ax.figure.colorbar(img, ax=ax)
cbar.ax.set_ylabel("", rotation=-90, va="bottom")
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel(label_str(i))
ax.set_aspect("equal")
self.imgs.append(img)
self.fig.tight_layout()
self.fig.canvas.draw()
def update(self, data, index: int = 0) -> None:
self.imgs[index].set_data(data.reshape(self.data_shape))
def draw(self) -> None:
self.fig.canvas.draw()
if __name__ == "__main__":
DATA_SHAPE = (10, 10)
heatmap = ValueHeatMap(DATA_SHAPE, 2, 2)
for i in range(4):
heatmap.update(np.random.uniform(-1, 1, DATA_SHAPE), i)
heatmap.draw()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment