Skip to content

Instantly share code, notes, and snippets.

@apaszke
Last active January 26, 2017 23:37
Show Gist options
  • Save apaszke/38a258ba9b585d9b411ebe98f5d4f997 to your computer and use it in GitHub Desktop.
Save apaszke/38a258ba9b585d9b411ebe98f5d4f997 to your computer and use it in GitHub Desktop.
#
# Copyright (c) Alex J. Champandard, 2017.
#
import PIL.Image
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
dtype = torch.cuda.FloatTensor
else:
dtype = torch.FloatTensor
class FeatureExtractor(nn.Module):
def __init__(self, num_features):
super(FeatureExtractor, self).__init__()
for i in self.kernel_range:
self.add_module('conv' + str(i), nn.Conv2d(3, num_features, kernel_size=i, padding=i//2))
def __getitem__(self, i):
return getattr(self, 'conv' + str(i))
@property
def kernel_range(self):
return range(3, 21, 2)
def forward(self, input):
return F.relu(torch.cat([self[i](input) for i in self.kernel_range], 1))
class MultiScaleFeatureExtractor(nn.Module):
def __init__(self, units):
super(MultiScaleFeatureExtractor, self).__init__()
for u in units:
self.add_module('extractor' + str(u), FeatureExtractor(u))
def forward(self, input):
return [child(input) for child in self.children()]
class StyleReproductionLoss(nn.Module):
def __init__(self, target):
super(StyleReproductionLoss, self).__init__()
self.target = [self.gram(t).detach() for t in target]
self.criterion = nn.MSELoss()
def gram(self, f):
C, H, W = f.size(1), f.size(2), f.size(3)
flattened = f.view(C, H * W)
return torch.mm(flattened, flattened.t())
def backward_from(self, input):
total = 0.0
for i, t in zip(input, self.target):
loss = self.criterion(self.gram(i), t)
loss.backward()
total += loss.data
return total
msfe = MultiScaleFeatureExtractor([64])
msfe.type(dtype)
for p in msfe.parameters():
p.requires_grad = False
target_image = PIL.Image.open('images/Sand.128.png')
target_size = target_image.size[::-1]
original = ToTensor()(target_image).add_(-0.5).type(dtype)
target_variable = Variable(original).view(1, -1, *target_size)
target_features = msfe.forward(target_variable)
buffer = torch.FloatTensor(1, 3, *target_size).uniform_().add_(-0.5).type(dtype)
source_image = Variable(buffer, requires_grad=True)
optimizer = optim.Adam([source_image], lr=1e-1)
style_reproduce = StyleReproductionLoss(target_features)
def evaluate():
features = msfe(source_image)
return style_reproduce.backward_from(features)
for i in range(500):
optimizer.zero_grad()
loss = evaluate()
optimizer.step()
buffer.clamp_(min=-0.5, max=+0.5)
if i % 10 == 0:
img = ToPILImage()(buffer.add(0.5).view(3, *target_size).cpu())
img.save('frames/%04d.png' % i)
print('%i %r' % (i, loss[0]))
img = ToPILImage()(buffer.add_(0.5).view(3, *target_size).cpu())
img.save('output.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment