Skip to content

Instantly share code, notes, and snippets.

@liyu1981
Last active April 12, 2023 08:54
Show Gist options
  • Save liyu1981/caf28bd5d34343d2759233597962b122 to your computer and use it in GitHub Desktop.
Save liyu1981/caf28bd5d34343d2759233597962b122 to your computer and use it in GitHub Desktop.
A simple function to animate the training process in google colab with PyTorch
# Usage:
# just copy this function into your google colab notebook, then use it
# a working example can be seen in this demo notebook:
# https://colab.research.google.com/drive/19Ni0EfOExQmTcrFZh3Z6DxsOM25KULqN?usp=sharing
import matplotlib.pyplot as plt
from matplotlib import animation, rc
def make_animate_train(train_step_fn, animate_setup_fn):
"""
return a train function which returns animation of training visulization. The
function returned has spec:
def train(net, x, y, optimizer, loss_func, iterations=100)
and it will return an animation object can be shown in colab notebook
:param train_step_fn a function with spec:
def train_step(i, plot_state, net, x, y, optimizer, loss_func), and
Return:
tuple (loss_value, plot_state).
Params:
i - int value of current iteration number
plot_state - tuple of (ax, plot_elem1, plot_elem2, ...), e.g., (ax, line)
x - torch tensor, data
y - torch tensor, result
optimizer - torch optimizer
loss_func - torch loss func
:param animate_setup_fn a function with spec:
def animate_setup(ax)
Return:
tuple (ax, plot_elem1, plot_elem2, ...), e.g., (ax, line)
Params:
ax - instance of Axie returned by plt.subplots()
"""
def train(net, x, y, optimizer, loss_func, iterations=100):
fig, ax = plt.subplots()
plt.close()
plot_state = animate_setup_fn(ax)
def set_plot_state(s):
plot_state = s
def anim_init():
return plot_state[1:]
def anim_frame(i):
# train model when genearte each frame
loss, updated_plot_state = train_step_fn(i, plot_state, net, x, y, optimizer, loss_func)
print("\riteration %d done with loss=%f." % (i, loss), end='')
set_plot_state(updated_plot_state)
return updated_plot_state[1:]
anim = animation.FuncAnimation(fig, anim_frame, init_func=anim_init,
frames=iterations, interval=100, blit=True)
# below is the part which makes it work on Colab
rc('animation', html='jshtml')
return anim
return train
@liyu1981
Copy link
Author

liyu1981 commented Apr 12, 2023

usage is like

def train_step(i, plot_state, net, x, y, optimizer, loss_func):
  prediction = net(x)
  loss = loss_func(prediction, y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  # now update the plot
  ax, line = plot_state
  line.set_data(x.data.numpy(), prediction.data.numpy())
  ax.texts[-1].set_text('Loss=%.4f' % loss.data.numpy())

  return (loss, (ax,line))
net = Net(n_feature =1, n_hidden=10, n_output=1)
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss()

def animate_setup(ax):
  ax.set_xlim((-1.5, 1.5))
  ax.set_ylim((-0.2, 1.4))

  ax.scatter(x.data.numpy(), y.data.numpy())
  line, = ax.plot([],[], lw=5, color='red')
  ax.text(0.5, 0, 'Loss=%.4f' % 0, fontdict={'size': 20, 'color': 'red'})
  
  return (ax, line)

anim_train = make_animate_train(train_step, animate_setup)
anim = anim_train(net, x, y, optimizer, loss_func, iterations=200)

anim

@liyu1981
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment