Skip to content

Instantly share code, notes, and snippets.

@joonas-yoon
Created May 31, 2022 10:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save joonas-yoon/ef70b190b6e72a0f920f8af0ff9d6939 to your computer and use it in GitHub Desktop.
Save joonas-yoon/ef70b190b6e72a0f920f8af0ff9d6939 to your computer and use it in GitHub Desktop.
MNIST AutoEncoder
class AutoEncoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, padding=1, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1, stride=2),
nn.ReLU(),
nn.Conv2d(64, 2, 3, padding=1, stride=1),
nn.ReLU(),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(2, 64, 3, padding=1, stride=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 64, 3, padding=1, output_padding=1, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 3, padding=1, output_padding=1, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 3, padding=1),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
import io
import imageio
import numpy as np
from PIL import Image
from torchvision.transforms.transforms import Lambda
from time import sleep
model = AutoEncoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
keyframes = []
model.train()
total = len(train_loader)
for idx, (inputs, outputs) in tqdm_nb(enumerate(train_loader), total=total):
inputs = inputs.to(device)
y_hat = model(inputs).to(device)
# print(inputs)
# print(y_hat)
loss = criterion(y_hat, inputs)
# print(loss.detach().cpu().item())
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
if idx % 10 == 0:
lab = {}
for (i, l) in enumerate(outputs):
label = l.item()
if not label in lab:
lab[label] = i
if len(lab) >= 10:
break
images = torch.empty((1, 28, 28)).to(device)
for label, i in sorted(lab.items()):
images = torch.row_stack([images, inputs[i].view(-1, 28, 28)])
images = torch.row_stack([images, y_hat[i].view(-1, 28, 28)])
images = images[1:].detach().cpu().numpy()
imgrid(images, 2, 10)
img_buf = io.BytesIO()
plt.savefig(img_buf, format='png')
keyframes.append(np.array(plt.gcf().canvas.renderer._renderer))
img_buf.close()
plt.show()
clear_output(wait=True)
imageio.mimsave('./mnist.gif', keyframes)
@joonas-yoon
Copy link
Author

mnist

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment