Skip to content

Instantly share code, notes, and snippets.

@vikhyat
Last active April 21, 2024 05:30
Show Gist options
  • Save vikhyat/61e12b126ce6b098ae7700262204e479 to your computer and use it in GitHub Desktop.
Save vikhyat/61e12b126ce6b098ae7700262204e479 to your computer and use it in GitHub Desktop.
Activation Visualization
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
act, filename = nn.GELU, 'gelu_training.gif'
# Step 1: Generate Data
x = torch.linspace(-2*np.pi, 2*np.pi, 100).view(-1, 1)
y = torch.sin(x)
# Step 2: Define the MLP Model
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.fc1 = nn.Linear(1, 16)
self.act = act()
self.fc2 = nn.Linear(16, 1)
def forward(self, x):
x = self.act(self.fc1(x))
return self.fc2(x)
model = MLP()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
# Step 3: Train the Model and Yield Results for Animation
def train(num_epochs=5000, interval=100):
for epoch in range(num_epochs):
optimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()
if epoch % interval == 0:
yield output.detach(), epoch
# Step 4: Create Animation
fig, ax = plt.subplots()
ax.set_xlim(-2 * np.pi, 2 * np.pi)
ax.set_ylim(-2, 2)
line, = ax.plot(x, y, 'r', label='True Function')
line2, = ax.plot(x, y, 'b', label=f'MLP Approximation')
text = ax.text(0.05, 0.95, '', transform=ax.transAxes)
ax.legend()
ax.set_title(f"{model.act} MLP")
def update(frame):
output, epoch = frame
line2.set_ydata(output.numpy())
text.set_text(f'Step: {epoch}')
return line2, text
ani = FuncAnimation(fig, update, frames=train(1200, 10), blit=True)
# Save as GIF
ani.save(filename, writer='imagemagick', fps=10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment