Skip to content

Instantly share code, notes, and snippets.

@rizar
Created September 18, 2018 15:03
Show Gist options
  • Save rizar/1aff32d3d914a6ea8308338c28e66280 to your computer and use it in GitHub Desktop.
Save rizar/1aff32d3d914a6ea8308338c28e66280 to your computer and use it in GitHub Desktop.
alphas
def plot_key_alphas(data, softmax=True, fixed_axis=True):
f, axes = pyplot.subplots(1, 3)
f.set_size_inches((15, 2))
for k in range(3):
arr = numpy.array(data[k])
arr = arr[:, [4, 5, 7]]
if softmax:
arr = numpy.exp(arr) / numpy.exp(arr).sum(axis=1)[:, None]
else:
arr = arr / arr.sum(axis=1)[:, None]
for i in range(arr.shape[1]):
axes[k].plot(arr[:, i])
axes[k].legend(['X', 'R', 'Y'])
if fixed_axis:
axes[k].set_xlim(0, 5000)
axes[k].set_ylim(0, 1.0)
axes[k].set_xlabel('alpha_{}'.format(k))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment