Skip to content

Instantly share code, notes, and snippets.

@ahrzb
Last active March 22, 2020 14:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ahrzb/03a804ce3b74d7173b18771b250d2114 to your computer and use it in GitHub Desktop.
Save ahrzb/03a804ce3b74d7173b18771b250d2114 to your computer and use it in GitHub Desktop.
import torch
from torch import nn, autograd, optim
from torch.utils.tensorboard import SummaryWriter
import tqdm#.notebook as tqdm
import torchvision.datasets
import torchvision.transforms
class WGanGpLoss(nn.Module):
def __init__(self, critic, gp_lambda=10, align=False):
super().__init__()
self.critic_net = critic
self._gp_lambda = gp_lambda
self._align = align
def critic_loss(self, fake, real):
assert real.size()[1:] == fake.size()[1:]
if self._align:
real, fake = self.align(real, fake)
assert len(real) == len(fake)
fake = self.interpolate(real, fake)
c_real = self.critic(real)
c_fake = self.critic(fake)
wasserstein_loss = c_fake.mean() - c_real.mean()
gradient_penalty = self.gradient_penalty(c_fake, fake)
loss = wasserstein_loss + self._gp_lambda * gradient_penalty
return loss, (wasserstein_loss, gradient_penalty)
def align(self, real, fake):
n = max(len(real), len(fake))
indices = fake.new_empty(n, dtype=torch.long)
torch.arange(n, out=indices)
real = real[indices % len(real)]
fake = fake[indices % len(fake)]
return real, fake
def interpolate(self, real, fake):
assert real.size() == fake.size()
dims = [len(real)] + (real.ndim - 1) * [1]
epsilon = fake.new(*dims)
torch.rand(*dims, out=epsilon)
fake = epsilon * real + (1 - epsilon) * fake
return fake
def critic(self, x):
c = self.critic_net(x)
return c[:, 0] if c.size() == (len(c), 1) else c
def gradient_penalty(self, c_fake, fake):
[grad] = autograd.grad(c_fake.unbind(0), [fake], retain_graph=True)
gradient_norm = grad.norm(2, dim=[1, 2, 3])
gradient_penalty = (gradient_norm - 1)**2
assert gradient_penalty.size() == (len(c_fake),)
gradient_penalty = gradient_penalty.mean()
return gradient_penalty
def generator_loss(self, fake):
c_fake = self.critic(fake)
loss = -c_fake.mean()
return loss
class TensorDeque(nn.Module):
def __init__(self, capacity, *size, dtype=torch.float):
super().__init__()
self.capacity = capacity
self.buffer = nn.Parameter(
data=torch.empty(0, *size, dtype=dtype),
requires_grad=False
)
def forward(self, batch):
assert len(batch) <= self.capacity
t = torch.cat([
self.buffer,
batch,
])
t = t[-self.capacity:]
self.buffer.data = t
return t
class Critic(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.utils.weight_norm(nn.Linear(28*28, 1024)),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.Linear(1024, 100)),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.Linear(100, 100)),
)
def forward(self, x):
N = len(x)
x = x.view(N, 28*28)
x = self.net(x).mean(dim=1)
return x.view(N)
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.utils.weight_norm(nn.Linear(128, 1024)),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.Linear(1024, 128)),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.Linear(128, 28*28, bias=False)),
nn.Sigmoid(),
)
def forward(self, z):
N = len(z)
assert z.size() == (N, 128)
fake = self.net(z)
fake = fake.view(N, 1, 28, 28)
return fake
mnist = torchvision.datasets.MNIST(
'./mnist',
train=True,
download=True,
transform=torchvision.transforms.ToTensor()
)
critic = Critic().cuda()
generator = Generator().cuda()
critic_opt = optim.Adam(critic.parameters(), lr=0.01)
generator_opt = optim.Adam(generator.parameters(), lr=0.01)
loss = WGanGpLoss(critic, align=True).cuda()
fake_dq = TensorDeque(128, 1, 28, 28).cuda()
dataset = torch.utils.data.DataLoader(
mnist,
batch_size=32,
shuffle=True,
pin_memory=True
)
it = 0
summary_writer = SummaryWriter(comment="mnist")
for i in tqdm.trange(100):
for real, ـ in tqdm.tqdm(dataset):
it += 1
real = real.cuda()
[N, C, W, H] = real.size()
assert [C, W, H] == [1, 28, 28]
latent = real.new(N, 128)
torch.randn(N, 128, out=latent)
if it % 6 == 5:
generator_opt.zero_grad()
fake = generator(latent)
lgen = loss.generator_loss(fake)
lgen.backward()
generator_opt.step()
summary_writer.add_scalar("loss.generator", lgen, global_step=it)
else:
critic_opt.zero_grad()
fake = generator(latent)
lcri, (w_loss, gp_loss) = loss.critic_loss(fake_dq(fake), real)
lcri.backward()
critic_opt.step()
summary_writer.add_scalar("loss.wasserstein", w_loss, global_step=it)
summary_writer.add_scalar("loss.gradient_penalty", gp_loss, global_step=it)
summary_writer.add_scalar("loss.critic", lcri, global_step=it)
latent = real.new(10, 128)
torch.randn(10, 128, out=latent)
fake = generator(latent).detach().cpu()
summary_writer.add_image(f"image.fake", torch.cat(fake.unbind(0), dim=2), global_step=it)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment