Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Created March 16, 2018 04:59
Show Gist options
  • Save ProGamerGov/f94d1ba5defad7725781537de6812c99 to your computer and use it in GitHub Desktop.
Save ProGamerGov/f94d1ba5defad7725781537de6812c99 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
#import torch.legacy.nn as lnn
import torchvision
from torch.autograd import Variable
class ContentLoss(nn.Module):
def __init__(self, strength, normalize):
super(ContentLoss, self).__init__()
# we 'detach' the target content from the tree used
#self.target = target.detach() * weight
# to dynamically compute the gradient: this is a stated value,
# not a variable. Otherwise the forward method of the criterion
# will throw an error.
self.target = torch.Tensor()
self.strength = strength
self.criterion = nn.MSELoss()
self.mode = None
self.normalize = 'false'
def forward(self, input):
if self.mode == 'loss':
self.targetP = nn.Parameter(self.target,requires_grad=False)
self.loss = self.criterion(input.cuda(), self.targetP.cuda()) * self.strength
elif self.mode == 'capture':
self.target.resize_as_(input.cpu().data).copy_(input.cpu().data)
self.output = input
return self.output
def backward(self, input, gradOutput, retain_graph=True):
if self.mode == 'loss':
if input.nelement() == self.target.nelement():
self.loss.backward(retain_graph=retain_graph)
if self.normalize == 'True':
self.gradInput.div(torch.norm(self.gradInput, 1) + 1e-8) # Normalize Gradients
self.loss.mul(self.strength)
self.loss.add(gradOutput)
else:
self.target.resize_as_(gradOutput).copy_(gradOutput)
return self.loss
class GramMatrix(nn.Module):
def forward(self, input):
a, b, c, d = input.size() # a=batch size(=1)
# b=number of feature maps
# (c,d)=dimensions of a f. map (N=c*d)
features = input.view(a * b, c * d) # resise F_XL into \hat F_XL
G = torch.mm(features, features.t()) # compute the gram product
# we 'normalize' the values of the gram matrix
# by dividing by the number of element in each feature maps.
return G.div(a * b * c * d)
class StyleLoss(nn.Module):
def __init__(self, strength, normalize):
super(StyleLoss, self).__init__()
#self.target = target.detach() * weight
self.target = torch.Tensor()
self.strength = strength
self.gram = GramMatrix()
self.criterion = nn.MSELoss()
self.mode = None
self.blend_weight = None
self.G = None
self.normalize = 'false'
def forward(self, input):
self.output = input.clone()
self.G = self.gram(input)
self.G.div(input.nelement())
if self.mode == 'capture':
if self.blend_weight == None:
self.target.resize_as_(self.G.cpu().data).copy_(self.G.cpu().data)
elif self.target.nelement() == 0:
self.target.resize_as_(self.G.cpu().data).copy_(self.G.cpu().data).mul_(self.blend_weight)
else:
self.target.add(self.blend_weight, self.G.data)
elif self.mode == 'loss':
self.targetP = nn.Parameter(self.target,requires_grad=False)
self.loss = self.strength * self.criterion(self.G.cuda(), self.targetP.cuda())
return self.output
def backward(self, input, gradOutput, retain_graph=True):
if self.mode == 'loss':
self.loss.backward(retain_graph=retain_graph)
self.loss.div(input.nelement())
if self.normalize == 'True':
self.gradInput.div(torch.norm(self.gradInput, 1) + 1e-8) # Normalize Gradients
self.loss.mul(self.strength)
self.loss.add(gradOutput)
else:
self.loss = gradOutput
return self.loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment