Skip to content

Instantly share code, notes, and snippets.

@Kensuke-Mitsuzawa
Created February 12, 2021 10:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Kensuke-Mitsuzawa/b4132b202f013724ffefb8623ae017cf to your computer and use it in GitHub Desktop.
Save Kensuke-Mitsuzawa/b4132b202f013724ffefb8623ae017cf to your computer and use it in GitHub Desktop.
散布図と重みを同時に表示するためのコード
import numpy as np
import pandas
import random
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib
import typing
def fix_min_max(x_tensor: np.ndarray,
slicing_x: typing.Tuple[typing.Optional[int], ...],
slicing_y: typing.Tuple[typing.Optional[int], ...]) -> typing.Tuple[typing.List[float], typing.List[float]]:
# 最大値、最小値を決定する関数
# 時系列全体での最小値と最大値を決定する
history_x_lim = [np.min(x_tensor[slicing_x].flatten()),
np.max(x_tensor[slicing_x].flatten())]
history_y_lim = [np.min(x_tensor[slicing_y].flatten()),
np.max(x_tensor[slicing_y].flatten())]
if mark_value[0] < history_x_lim[0]:
history_x_lim[0] = mark_value[0]
elif history_x_lim[1] < mark_value[1]:
history_x_lim[1] = mark_value[1]
if mark_value[1] < history_y_lim[0]:
history_y_lim[0] = mark_value[1]
elif history_y_lim[1] < mark_value[1]:
history_y_lim[1] = mark_value[1]
return history_x_lim, history_y_lim
# 多次元配列をつくる
x_tensor = np.random.uniform(size=(30, 100, 2), low=1.0, high=20)
# 重み表示用の配列
weight_tensor = np.random.uniform(size=(30, 100, 1), low=0.0, high=10)
#
mark_value = [2.5, 5.5]
if len(x_tensor) > 10:
seq_plot_range = [0]
seq_plot_range += random.sample(population=range(1, len(x_tensor) - 2), k=8)
seq_plot_range.append(len(x_tensor) - 1)
n_plot_max = 10
else:
seq_plot_range = list(range(0, len(x_tensor)))
n_plot_max = len(x_tensor)
# end if
# 10*2のカラム図を用意する
fig, axes = plt.subplots(nrows=n_plot_max, ncols=2, figsize=(3 * n_plot_max, 15))
history_x_lim, history_y_lim = fix_min_max(x_tensor, slicing_x=(None, None, 0), slicing_y=(None, None, 1))
weight_x_lim, weight_y_lim = fix_min_max(weight_tensor, slicing_x=(None, None, 0), slicing_y=(None, None, 1))
color_palette = ['blue', 'red']
for axe_i, n_iter in enumerate(seq_plot_range):
df_x = pandas.DataFrame(x_tensor[n_iter], columns=['x', 'y'])
df_x['label'] = 'points'
axes[axe_i, 0].set(xlim=history_x_lim, ylim=history_y_lim)
axes[axe_i, 0].scatter(df_x['x'], df_x['y'], c=df_x.index, cmap='Blues')
axes[axe_i, 0].scatter(mark_value[0], mark_value[1], color='red')
axes[axe_i, 0].set_title(f'N = {n_iter}')
# wの可視化
if n_iter == 0:
# 空白データを描く。なぜならN=0に対応する重みがない(という状況だから)
axes[axe_i, 1].bar(np.zeros(len(x_tensor[n_iter])), np.zeros(len(x_tensor[n_iter])))
elif n_iter < len(x_tensor):
df_weight = pandas.DataFrame(weight_tensor[n_iter - 1], columns=['y'])
# generating colors
norm = matplotlib.colors.Normalize(vmin=df_weight.index.min(), vmax=df_weight.index.max())
cmap = matplotlib.cm.ScalarMappable(norm=norm, cmap=matplotlib.cm.Blues)
colors = [cmap.to_rgba(i + 1) for i in df_weight.index]
axes[axe_i, 1].set(xlim=weight_x_lim, ylim=weight_y_lim)
axes[axe_i, 1].bar(df_weight.index, df_weight['y'], tick_label=df_weight.index, color=colors)
# end if
blue_patch = mpatches.Patch(color='blue', label='candidate')
red_patch = mpatches.Patch(color='red', label='oracle')
fig.legend(handles=[red_patch, blue_patch])
fig.suptitle('Left: x at time_t. Right: w at time_t', fontsize=15)
fig.savefig('plot_practice.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment