Last active
June 25, 2020 04:50
-
-
Save kngwyu/76a091d02b1971dbaf7afb172a616f91 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, Tuple | |
import numpy as np | |
from matplotlib import colors as mc | |
from matplotlib import pyplot as plt | |
def _make_colormap(colors: Dict[int, str], name: str) -> mc.LinearSegmentedColormap: | |
z = np.sort(list(colors.keys())) | |
min_z = min(z) | |
x0 = (z - min_z) / (max(z) - min_z) | |
rgb = [mc.to_rgb(colors[zi]) for zi in z] | |
cmap_dict = dict( | |
red=[(x0[i], c[0], c[0]) for i, c in enumerate(rgb)], | |
green=[(x0[i], c[1], c[1]) for i, c in enumerate(rgb)], | |
blue=[(x0[i], c[2], c[2]) for i, c in enumerate(rgb)], | |
) | |
return mc.LinearSegmentedColormap(name, cmap_dict) | |
DEFAULT_VALUE_CM = _make_colormap({0: "xkcd:scarlet", 1: "w", 2: "xkcd:green"}, "value") | |
class ValueHeatMap: | |
def __init__( | |
self, | |
data_shape: Tuple[int, int], | |
nrows: int = 1, | |
ncols: int = 1, | |
name: str = "Value Function", | |
cmap: str = DEFAULT_VALUE_CM, | |
vmin: float = -1.0, | |
vmax: float = 1.0, | |
) -> None: | |
self.fig = plt.figure(name) | |
self.data_shape = data_shape | |
cmap = plt.get_cmap(cmap) | |
self.imgs = [] | |
def label_str(i: int) -> str: | |
if nrows == 0 and ncols == 0: | |
return "" | |
else: | |
return f"Value {i}" | |
dummy = np.zeros(data_shape) | |
for i in range(nrows * ncols): | |
ax = self.fig.add_subplot(nrows, ncols, i + 1) | |
img = ax.imshow( | |
dummy, cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest" | |
) | |
cbar = ax.figure.colorbar(img, ax=ax) | |
cbar.ax.set_ylabel("", rotation=-90, va="bottom") | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
ax.set_xlabel(label_str(i)) | |
ax.set_aspect("equal") | |
self.imgs.append(img) | |
self.fig.tight_layout() | |
self.fig.canvas.draw() | |
def update(self, data, index: int = 0) -> None: | |
self.imgs[index].set_data(data.reshape(self.data_shape)) | |
def draw(self) -> None: | |
self.fig.canvas.draw() | |
if __name__ == "__main__": | |
DATA_SHAPE = (10, 10) | |
heatmap = ValueHeatMap(DATA_SHAPE, 2, 2) | |
for i in range(4): | |
heatmap.update(np.random.uniform(-1, 1, DATA_SHAPE), i) | |
heatmap.draw() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment