Skip to content

Instantly share code, notes, and snippets.

@pbloem
Created June 17, 2019 19:24
Show Gist options
  • Save pbloem/d7a26ef09b9c699ee57a667dbf534a01 to your computer and use it in GitHub Desktop.
Save pbloem/d7a26ef09b9c699ee57a667dbf534a01 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.autograd import Variable
from torchvision.transforms import CenterCrop, ToTensor, Compose, Lambda, Resize, Grayscale
from torchvision.datasets import coco
from torch.distributions import Beta
from torch.optim import Adam
import tqdm, sys
EPSILON = 0.01
EPOCHS = 5
BATCH = 256
C, H, W = 3, 32, 32
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH,
shuffle=True, num_workers=2)
# instead of a model, we just create a parameter matrix the size of an image
# (ideally this converges to the mean image of the data)
model = nn.Parameter(torch.zeros(C * 2, H, W))
optimizer = Adam([model], lr = 0.001)
for e in range(EPOCHS):
for i, (data, _) in enumerate(trainloader):
b, c, h, w = data.size()
input = Variable(data)
# expand the "model" to fit the data
result = model[None, :, :, :].expand(b, 2*C, H, W)
# activation
result = F.softplus(result)
m = Beta(result[:, :3, :, :], result[:, 3:, :, :])
data = data * (1 - 2*EPSILON) + EPSILON
loss = - m.log_prob(data).mean()
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(loss.item())
print(model.data[0, :5, :5])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment