Skip to content

Instantly share code, notes, and snippets.

@diego898
Created October 30, 2020 23:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save diego898/13659ce1aac5c5f5de3c8c4d449a84a6 to your computer and use it in GitHub Desktop.
Save diego898/13659ce1aac5c5f5de3c8c4d449a84a6 to your computer and use it in GitHub Desktop.
misc plotting code for our NN discussion
fig = plt.figure(figsize=(30,10))
# loss/acc
ax11 = plt.subplot(241)
ax11.plot(range(len(epoch_loss)),np.array(epoch_loss));
ax11.set_title('Epoch loss')
ax21 = plt.subplot(245)
ax21.plot(range(len(epoch_acc)),np.array(epoch_acc));
ax21.set_title('Epoch acc')
# space embeddings
ax12 = plt.subplot(242,projection='3d')
ax13 = plt.subplot(243,projection='3d')
ax14 = plt.subplot(244)
# H1
ax12.plot(
np.ravel(H1_blue[:,0]), np.ravel(H1_blue[:,1]), np.ravel(H1_blue[:,2]),
'bo', label='class: blue circle', alpha=0.5)
ax12.plot(
np.ravel(H1_red[:,0]), np.ravel(H1_red[:,1]), np.ravel(H1_red[:,2]),
'r*', label='class: red star', alpha=0.5)
ax12.set_xlabel('$h_1$', fontsize=20);ax12.set_ylabel('$h_2$', fontsize=20);ax12.set_zlabel('$h_3$', fontsize=20);
# ax12.view_init(elev=30, azim=-200)
ax12.set_title('H1 space');
# H2
ax13.plot(
np.ravel(H2_blue[:,0]), np.ravel(H2_blue[:,1]), np.ravel(H2_blue[:,2]),
'bo', label='class: blue circle', alpha=0.5)
ax13.plot(
np.ravel(H2_red[:,0]), np.ravel(H2_red[:,1]), np.ravel(H2_red[:,2]),
'r*', label='class: red star', alpha=0.5)
ax13.set_xlabel('$h_1$', fontsize=20);ax13.set_ylabel('$h_2$', fontsize=20);ax13.set_zlabel('$h_3$', fontsize=20);
# ax13.view_init(elev=30, azim=-200)
ax13.set_title('H2 space');
# output layer
ax14.scatter(output_red,np.zeros_like(output_red),color='red');
ax14.scatter(output_blue,np.zeros_like(output_blue),color='blue');
ax14.axes.yaxis.set_visible(False);
ax14.set_title('Output space');
# activations
ax22 = plt.subplot(246)
sharey = None;
# sharey = ax22
ax23 = plt.subplot(247,sharey=sharey)
ax24 = plt.subplot(248,sharey=sharey)
ax22.plot((np.array(act_info[-3]).T)[0,:],label='Node 1');
ax22.plot((np.array(act_info[-3]).T)[1,:],label='Node 2');
ax22.plot((np.array(act_info[-3]).T)[2,:],label='Node 3');
ax22.legend();
ax22.set_title('H1 Activations');
ax23.plot((np.array(act_info[-2]).T)[0,:],label='Node 1');
ax23.plot((np.array(act_info[-2]).T)[1,:],label='Node 2');
ax23.plot((np.array(act_info[-2]).T)[2,:],label='Node 3');
ax23.legend();
ax23.set_title('H2 Activations');
ax24.plot(np.array(np.squeeze(act_info[-1])),color='red');
ax24.set_ylim([0,1]);
ax24.set_title('Output Activations');
# red and blue activations lists generated during training
# must modify `.SGD` to generate these (not shown here)
act_red = rb_info[0]
act_blue = rb_info[1]
fig = plt.figure(figsize=(30,7))
ax1 = plt.subplot(141)
ax1.set_xlim([0, len(act_red)])
epoch_loss_min = np.array(epoch_loss).min()
epoch_loss_max = np.array(epoch_loss).max()
ax1.set_ylim([epoch_loss_min*0.999, epoch_loss_max*1.001])
ax1t = ax1.twinx()
ax1t.set_ylim([0,1.1])
ax2 = plt.subplot(142,projection='3d')
ax3 = plt.subplot(143,projection='3d')
ax4 = plt.subplot(144)
# unpack data:
curr_act_red = act_red[-1]
H1_red = np.array(curr_act_red[-3].T)
H2_red = np.array(curr_act_red[-2].T)
output_red = np.array(curr_act_red[-1])
curr_act_blue = act_blue[-1]
H1_blue = np.array(curr_act_blue[-3].T)
H2_blue = np.array(curr_act_blue[-2].T)
output_blue = np.array(curr_act_blue[-1])
# stack data for plotting
H1 = np.vstack((H1_blue,H1_red))
H2 = np.vstack((H2_blue,H2_red))
output = np.squeeze(np.hstack((output_blue,output_red)))
join_labels = np.array([0]*500 + [1]*500)
# loss/acc
e_loss, = ax1.plot(epoch_loss,color='red')
ax1.set_ylabel('loss',color='red')
ax1.tick_params(axis='y', labelcolor='red')
ax1.set_title('Loss vs Acc')
e_acc, = ax1t.plot(epoch_acc,color='green')
ax1t.set_ylabel('acc',color='green')
ax1t.tick_params(axis='y', labelcolor='green')
# H1
ax2.set_xlabel('$h_1$', fontsize=20);
ax2.set_ylabel('$h_2$', fontsize=20);
ax2.set_zlabel('$h_3$', fontsize=20);
ax2.set_title('H1 space');
h1_scat= ax2.scatter(H1[:,0],H1[:,1],H1[:,2], alpha=0.5,cmap='seismic',c=join_labels)
# H2
ax3.set_xlabel('$h_1$', fontsize=20);
ax3.set_ylabel('$h_2$', fontsize=20);
ax3.set_zlabel('$h_3$', fontsize=20);
ax3.set_title('H2 space');
h2_scat = ax3.scatter(H2[:,0],H2[:,1],H2[:,2], alpha=0.5, cmap='seismic',c=join_labels)
# output
ax4.axes.yaxis.set_visible(False);
ax4.set_title('Output space');
output_scat = ax4.scatter(output,np.zeros_like(output),cmap='seismic',c=join_labels)
# animation function
def update_graph(i):
curr_act_red = act_red[i]
H1_red = np.array(curr_act_red[-3].T)
H2_red = np.array(curr_act_red[-2].T)
output_red = np.array(curr_act_red[-1])
curr_act_blue = act_blue[i]
H1_blue = np.array(curr_act_blue[-3].T)
H2_blue = np.array(curr_act_blue[-2].T)
output_blue = np.array(curr_act_blue[-1])
# stack data for plotting
H1 = np.vstack((H1_blue,H1_red))
H2 = np.vstack((H2_blue,H2_red))
output_c = np.squeeze(np.hstack((output_blue,output_red)))
# loss/acc
e_loss.set_data(range(i),epoch_loss[0:i*50:50])
e_acc.set_data(range(i),epoch_acc[0:i*50:50])
# h1 space
h1_scat._offsets3d = (H1[:,0],H1[:,1],H1[:,2])
h1_scat.set_array(np.array([0]*500 + [1]*500))
ax2.view_init(elev=30, azim=i*2)
# h2 space
h2_scat._offsets3d = (H2[:,0],H2[:,1],H2[:,2])
h2_scat.set_array(np.array([0]*500 + [1]*500))
ax3.view_init(elev=30, azim=i*2)
# output space
output_scat.set_offsets(np.vstack((output_c,np.zeros_like(output_c))).T)
output_scat.set_array(np.array([0]*500 + [1]*500))
# create animation
anim = matplotlib.animation.FuncAnimation(fig, update_graph,
frames=range(len(act_red)),
interval=40, blit=False)
# show in notebook
from IPython.display import HTML
HTML(anim.to_html5_video())
# save to file
#anim.save('slow_training.gif')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment