Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active February 5, 2021 13:52
Show Gist options
  • Save crowsonkb/eb24067fb22a1604c40498011ea17fad to your computer and use it in GitHub Desktop.
Save crowsonkb/eb24067fb22a1604c40498011ea17fad to your computer and use it in GitHub Desktop.
Trains IMLE on the MNIST dataset.
"""Trains IMLE on the MNIST dataset."""
import torch
from torch import optim, nn
from torch.nn import functional as F
from torch.utils import data
from torchvision import datasets, transforms, utils
from torchvision.transforms import functional as TF
from tqdm import tqdm
BATCH_SIZE = 25
BIG_BATCH_SIZE = 100
EPOCHS = 100
LATENT_SIZE = 32
class ConvBlock(nn.Sequential):
def __init__(self, c_in, c_out):
super().__init__(
nn.Conv2d(c_in, c_out, 3, padding=1),
nn.ReLU(inplace=True),
)
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
torch.manual_seed(0)
tf = transforms.ToTensor()
train_set = datasets.MNIST('data/mnist', download=True, transform=tf)
train_dl = data.DataLoader(train_set, BIG_BATCH_SIZE, shuffle=True,
num_workers=1, pin_memory=True)
model = nn.Sequential(
nn.Linear(LATENT_SIZE, 16 * 7 * 7),
nn.Unflatten(-1, (16, 7, 7)),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2),
ConvBlock(16, 16),
ConvBlock(16, 8),
nn.Upsample(scale_factor=2),
ConvBlock(8, 8),
nn.Conv2d(8, 1, 3, padding=1),
nn.Sigmoid(),
).to(device)
print('Parameters:', sum(map(lambda x: x.numel(), model.parameters())))
def crit(x, z):
out = model(z).unsqueeze(0) - x.unsqueeze(1)
out = out.pow(2).mean([2, 3, 4])
return out.min(1).values.mean()
opt = optim.Adam(model.parameters(), lr=1e-3)
def train():
with tqdm(total=len(train_set), unit='samples', dynamic_ncols=True) as pbar:
model.train()
losses = []
i = 0
for x, _ in train_dl:
x = x.to(device, non_blocking=True)
for j in range(BIG_BATCH_SIZE // BATCH_SIZE):
i += 1
z = torch.randn([BATCH_SIZE, LATENT_SIZE], device=device)
opt.zero_grad()
loss = crit(x, z)
losses.append(loss.item())
loss.backward()
opt.step()
pbar.update(len(z))
if i % 50 == 0:
tqdm.write(f'{i * BATCH_SIZE} {sum(losses[-50:]) / 50:g}')
@torch.no_grad()
@torch.random.fork_rng()
def demo():
model.eval()
z = torch.randn([100, LATENT_SIZE], device=device)
grid = utils.make_grid(model(z), 10).cpu()
TF.to_pil_image(grid).save('demo.png')
print('Wrote examples to demo.png.')
try:
for epoch in range(1, EPOCHS + 1):
print('Epoch', epoch)
train()
demo()
except KeyboardInterrupt:
pass
torch.save(model.state_dict(), 'mnist_imle.pth')
print('Wrote trained model to mnist_imle.pth.')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment