Created
July 17, 2019 03:53
-
-
Save tocom242242/0333f56160d8e55777c284b28dc1dc7b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib.ticker import MaxNLocator | |
def softmax_selection(tau, values): | |
""" | |
softmax 行動選択 | |
""" | |
sum_exp_values = sum([np.exp(v/tau) for v in values]) # softmax選択の分母の計算 | |
p = [np.exp(v/tau)/sum_exp_values for v in values] # 確率分布の生成 | |
action = np.random.choice(np.arange(len(values)), p=p) # 確率分布pに従ってランダムで選択 | |
return action | |
nb_steps = 1000 | |
values = [0.3, 1.0, 0.5] # 各選択肢の価値。例:最も左の選択肢の価値は0.3 | |
tau = 0.4 # 温度パラメータ | |
results = [] | |
# 複数回行動選択 | |
for _ in range(nb_steps): | |
selected_action = softmax_selection(tau, values) | |
results.append(selected_action) | |
# ヒストグラムのプロット | |
fig = plt.figure() | |
ax = fig.add_subplot(1,1,1) | |
ax.xaxis.set_major_locator(MaxNLocator(integer=True)) # x軸のメモリを整数に | |
ax.set_ylim(0, 1000) | |
ax.hist(results) | |
plt.savefig("result.jpg") | |
plt.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment