Created
February 12, 2021 10:48
-
-
Save Kensuke-Mitsuzawa/b4132b202f013724ffefb8623ae017cf to your computer and use it in GitHub Desktop.
散布図と重みを同時に表示するためのコード
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas | |
import random | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as mpatches | |
import matplotlib | |
import typing | |
def fix_min_max(x_tensor: np.ndarray, | |
slicing_x: typing.Tuple[typing.Optional[int], ...], | |
slicing_y: typing.Tuple[typing.Optional[int], ...]) -> typing.Tuple[typing.List[float], typing.List[float]]: | |
# 最大値、最小値を決定する関数 | |
# 時系列全体での最小値と最大値を決定する | |
history_x_lim = [np.min(x_tensor[slicing_x].flatten()), | |
np.max(x_tensor[slicing_x].flatten())] | |
history_y_lim = [np.min(x_tensor[slicing_y].flatten()), | |
np.max(x_tensor[slicing_y].flatten())] | |
if mark_value[0] < history_x_lim[0]: | |
history_x_lim[0] = mark_value[0] | |
elif history_x_lim[1] < mark_value[1]: | |
history_x_lim[1] = mark_value[1] | |
if mark_value[1] < history_y_lim[0]: | |
history_y_lim[0] = mark_value[1] | |
elif history_y_lim[1] < mark_value[1]: | |
history_y_lim[1] = mark_value[1] | |
return history_x_lim, history_y_lim | |
# 多次元配列をつくる | |
x_tensor = np.random.uniform(size=(30, 100, 2), low=1.0, high=20) | |
# 重み表示用の配列 | |
weight_tensor = np.random.uniform(size=(30, 100, 1), low=0.0, high=10) | |
# | |
mark_value = [2.5, 5.5] | |
if len(x_tensor) > 10: | |
seq_plot_range = [0] | |
seq_plot_range += random.sample(population=range(1, len(x_tensor) - 2), k=8) | |
seq_plot_range.append(len(x_tensor) - 1) | |
n_plot_max = 10 | |
else: | |
seq_plot_range = list(range(0, len(x_tensor))) | |
n_plot_max = len(x_tensor) | |
# end if | |
# 10*2のカラム図を用意する | |
fig, axes = plt.subplots(nrows=n_plot_max, ncols=2, figsize=(3 * n_plot_max, 15)) | |
history_x_lim, history_y_lim = fix_min_max(x_tensor, slicing_x=(None, None, 0), slicing_y=(None, None, 1)) | |
weight_x_lim, weight_y_lim = fix_min_max(weight_tensor, slicing_x=(None, None, 0), slicing_y=(None, None, 1)) | |
color_palette = ['blue', 'red'] | |
for axe_i, n_iter in enumerate(seq_plot_range): | |
df_x = pandas.DataFrame(x_tensor[n_iter], columns=['x', 'y']) | |
df_x['label'] = 'points' | |
axes[axe_i, 0].set(xlim=history_x_lim, ylim=history_y_lim) | |
axes[axe_i, 0].scatter(df_x['x'], df_x['y'], c=df_x.index, cmap='Blues') | |
axes[axe_i, 0].scatter(mark_value[0], mark_value[1], color='red') | |
axes[axe_i, 0].set_title(f'N = {n_iter}') | |
# wの可視化 | |
if n_iter == 0: | |
# 空白データを描く。なぜならN=0に対応する重みがない(という状況だから) | |
axes[axe_i, 1].bar(np.zeros(len(x_tensor[n_iter])), np.zeros(len(x_tensor[n_iter]))) | |
elif n_iter < len(x_tensor): | |
df_weight = pandas.DataFrame(weight_tensor[n_iter - 1], columns=['y']) | |
# generating colors | |
norm = matplotlib.colors.Normalize(vmin=df_weight.index.min(), vmax=df_weight.index.max()) | |
cmap = matplotlib.cm.ScalarMappable(norm=norm, cmap=matplotlib.cm.Blues) | |
colors = [cmap.to_rgba(i + 1) for i in df_weight.index] | |
axes[axe_i, 1].set(xlim=weight_x_lim, ylim=weight_y_lim) | |
axes[axe_i, 1].bar(df_weight.index, df_weight['y'], tick_label=df_weight.index, color=colors) | |
# end if | |
blue_patch = mpatches.Patch(color='blue', label='candidate') | |
red_patch = mpatches.Patch(color='red', label='oracle') | |
fig.legend(handles=[red_patch, blue_patch]) | |
fig.suptitle('Left: x at time_t. Right: w at time_t', fontsize=15) | |
fig.savefig('plot_practice.png') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment