Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Last active July 5, 2018 01:30
Show Gist options
  • Save ProGamerGov/4fbb4a8340ae654a3ae460ccddb7757c to your computer and use it in GitHub Desktop.
Save ProGamerGov/4fbb4a8340ae654a3ae460ccddb7757c to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
class VGG(nn.Module):
def __init__(self, features, num_classes=1000):
super(VGG, self).__init__()
self.features = features
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
class NIN(nn.Module):
def __init__(self, pooling):
super(NIN, self).__init__()
if pooling == 'max':
pool2d = nn.MaxPool2d((3, 3),(2, 2),(0, 0),ceil_mode=True)
elif pooling == 'avg':
pool2d = nn.AvgPool2d((3, 3),(2, 2),(0, 0),ceil_mode=True)
self.features = nn.Sequential(
nn.Conv2d(3,96,(11, 11),(4, 4)),
nn.ReLU(inplace=True),
nn.Conv2d(96,96,(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(96,96,(1, 1)),
nn.ReLU(inplace=True),
pool2d,
nn.Conv2d(96,256,(5, 5),(1, 1),(2, 2)),
nn.ReLU(inplace=True),
nn.Conv2d(256,256,(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(256,256,(1, 1)),
nn.ReLU(inplace=True),
pool2d,
nn.Conv2d(256,384,(3, 3),(1, 1),(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(384,384,(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(384,384,(1, 1)),
nn.ReLU(inplace=True),
pool2d,
nn.Dropout(0.5),
nn.Conv2d(384,1024,(3, 3),(1, 1),(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(1024,1024,(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(1024,1000,(1, 1)),
nn.ReLU(inplace=True),
nn.AvgPool2d((6, 6),(1, 1),(0, 0),ceil_mode=True),
nn.Softmax(),
)
def buildSequential(channel_list, pooling):
layers = []
in_channels = 3
if pooling == 'max':
pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
elif pooling == 'avg':
pool2d = nn.AvgPool2d(kernel_size=2, stride=2)
else:
print("Unrecognized pooling parameter")
quit()
for c in channel_list:
if c == 'P':
layers += [pool2d]
else:
conv2d = nn.Conv2d(in_channels, c, kernel_size=3, padding=1)
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = c
return nn.Sequential(*layers)
channel_list = {
'VGG-16': [64, 64, 'P', 128, 128, 'P', 256, 256, 256, 'P', 512, 512, 512, 'P', 512, 512, 512, 'P'],
'VGG-19': [64, 64, 'P', 128, 128, 'P', 256, 256, 256, 256, 'P', 512, 512, 512, 512, 'P', 512, 512, 512, 512, 'P'],
}
nin_dict = {
'C': ['conv1', 'cccp1', 'cccp2', 'conv2', 'cccp3', 'cccp4', 'conv3', 'cccp5', 'cccp6', 'conv4-1024', 'cccp7-1024', 'cccp8-1024'],
'R': ['relu0', 'relu1', 'relu2', 'relu3', 'relu5', 'relu6', 'relu7', 'relu8', 'relu9', 'relu10', 'relu11', 'relu12'],
'P': ['pool1', 'pool2', 'pool3', 'pool4'],
'D': ['drop'],
}
vgg16_dict = {
'C': ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv4_1', 'conv4_2', 'conv4_3', 'conv5_1', 'conv5_2', 'conv5_3'],
'R': ['relu1_1', 'relu1_2', 'relu2_1', 'relu2_2', 'relu3_1', 'relu3_2', 'relu3_3', 'relu4_1', 'relu4_2', 'relu4_3', 'relu5_1', 'relu5_2', 'relu5_3'],
'P': ['pool1', 'pool2', 'pool3', 'pool4', 'pool5'],
}
vgg19_dict = {
'C': ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv3_4', 'conv4_1', 'conv4_2', 'conv4_3', 'conv4_4', 'conv5_1', 'conv5_2', 'conv5_3', 'conv5_4'],
'R': ['relu1_1', 'relu1_2', 'relu2_1', 'relu2_2', 'relu3_1', 'relu3_2', 'relu3_3', 'relu3_4', 'relu4_1', 'relu4_2', 'relu4_3', 'relu4_4', 'relu5_1', 'relu5_2', 'relu5_3', 'relu5_4'],
'P': ['pool1', 'pool2', 'pool3', 'pool4', 'pool5'],
}
def modelSelector(model_file, pooling):
if "vgg19" in str(model_file):
print("VGG-19 Architecture Detected")
cnn, layerList = VGG(buildSequential(channel_list['VGG-19'], pooling)), vgg19_dict
elif "vgg16" in str(model_file):
print("VGG-16 Architecture Detected")
cnn, layerList = VGG(buildSequential(channel_list['VGG-16'], pooling)), vgg16_dict
elif "nin" in str(model_file):
print("NIN Architecture Detected")
cnn, layerList = NIN(pooling), nin_dict
else:
print("Model Architecture Not Recognized")
raise ValueError("""Model Architecture Not Recognized. Please ensure that the model
name contains either "vgg16", "vgg19", or "nin", in the file name.""")
return cnn, layerList
# Print like Lua/loadcaffe
def print_loadcaffe(cnn, layerList):
c = 0
for l in list(cnn):
if "Conv2d" in str(l):
in_c, out_c, ks = str(l.in_channels), str(l.out_channels), str(l.kernel_size)
print(layerList['C'][c] +": " + (out_c + " " + in_c + " " + ks).replace(")",'').replace("(",'').replace(",",'') )
c+=1
if c == len(layerList['C']):
break
# Get the model class, and configure pooling layer type
def loadCaffemodel(model_file, pooling, use_gpu):
cnn, layerList = modelSelector(model_file, pooling)
cnn.load_state_dict(torch.load(model_file))
print("Successfully loaded " + str(model_file))
# Maybe convert the model to cuda now, to avoid later issues
if use_gpu > -1:
cnn = cnn.cuda()
cnn = cnn.features
print_loadcaffe(cnn, layerList)
return cnn, layerList
import os
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from CaffeLoader import loadCaffemodel
import argparse
parser = argparse.ArgumentParser()
# Basic options
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg')
parser.add_argument("-style_blend_weights", default=None)
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg')
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512)
parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = -1", type=int, default=0)
# Optimization options
parser.add_argument("-content_weight", type=float, default=5e0)
parser.add_argument("-style_weight", type=float, default=1e2)
parser.add_argument("-tv_weight", type=float, default=1e-3)
parser.add_argument("-num_iterations", type=int, default=1000)
parser.add_argument("-init", choices=['random', 'image'], default='random')
parser.add_argument("-init_image", default=None)
parser.add_argument("-optimizer", choices=['lbfgs', 'adam'], default='lbfgs')
parser.add_argument("-learning_rate", type=float, default=1e0)
parser.add_argument("-lbfgs_num_correction", type=int, default=0)
# Output options
parser.add_argument("-print_iter", type=int, default=50)
parser.add_argument("-save_iter", type=int, default=100)
parser.add_argument("-output_image", default='out.png')
# Other options
parser.add_argument("-style_scale", type=float, default=1.0)
parser.add_argument("-pooling", choices=['avg', 'max'], default='max')
parser.add_argument("-model_file", type=str, default='models/vgg19-d01eb7cb.pth')
parser.add_argument("-backend", choices=['nn', 'cudnn', 'mkl'], default='cudnn')
parser.add_argument("-cudnn_autotune", action='store_true')
parser.add_argument("-seed", type=int, default=-1)
parser.add_argument("-content_layers", help="layers for content", default='relu4_2')
parser.add_argument("-style_layers", help="layers for style", default='relu1_1,relu2_1,relu3_1,relu4_1,relu5_1')
params = parser.parse_args()
Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images
def main():
dtype = setup_gpu()
# Build the model definition and setup pooling layers:
cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu)
content_image = preprocess(params.content_image, params.image_size).type(dtype)
style_image_list = params.style_image.split(',')
style_images_caffe = []
for image in style_image_list:
style_size = int(params.image_size * params.style_scale)
img_caffe = preprocess(image, style_size).type(dtype)
style_images_caffe.append(img_caffe)
if params.init_image != None:
image_size = (content_image.size(2), content_image.size(3))
init_image = preprocess(params.init_image, image_size).type(dtype)
# Handle style blending weights for multiple style inputs
style_blend_weights = []
if params.style_blend_weights == None:
# Style blending not specified, so use equal weighting
for i in style_image_list:
style_blend_weights.append(1.0)
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = int(style_blend_weights[i])
else:
style_blend_weights = params.style_blend_weights.split(',')
assert len(style_blend_weights) == len(style_image_list), \
"-style_blend_weights and -style_images must have the same number of elements!"
# Normalize the style blending weights so they sum to 1
style_blend_sum = 0
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = float(style_blend_weights[i])
style_blend_sum = float(style_blend_sum) + style_blend_weights[i]
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum)
content_layers = params.content_layers.split(',')
style_layers = params.style_layers.split(',')
# Set up the network, inserting style and content loss modules
cnn = copy.deepcopy(cnn)
content_losses, style_losses, tv_losses = [], [], []
next_content_idx, next_style_idx = 1, 1
net = nn.Sequential()
c, r = 0, 0
if params.tv_weight > 0:
tv_mod = TVLoss(params.tv_weight).type(dtype)
net.add_module(str(len(net)), tv_mod)
tv_losses.append(tv_mod)
for i, layer in enumerate(list(cnn), 1):
if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers):
if isinstance(layer, nn.Conv2d):
net.add_module(str(len(net)), layer)
if layerList['C'][c] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = ContentLoss(params.content_weight)
net.add_module(str(len(net)), loss_module)
content_losses.append(loss_module)
if layerList['C'][c] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = StyleLoss(params.style_weight)
net.add_module(str(len(net)), loss_module)
style_losses.append(loss_module)
c+=1
if isinstance(layer, nn.ReLU):
net.add_module(str(len(net)), layer)
if layerList['R'][r] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = ContentLoss(params.content_weight)
net.add_module(str(len(net)), loss_module)
content_losses.append(loss_module)
next_content_idx += 1
if layerList['R'][r] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = StyleLoss(params.style_weight)
net.add_module(str(len(net)), loss_module)
style_losses.append(loss_module)
next_style_idx += 1
r+=1
if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
net.add_module(str(len(net)), layer)
# Capture content targets
for i in content_losses:
i.mode = 'capture'
print("Capturing content targets")
print_torch(net)
net(content_image)
# Capture style targets
for i in content_losses:
i.mode = 'None'
for i, image in enumerate(style_images_caffe):
print("Capturing style target " + str(i+1))
for j in style_losses:
j.mode = 'capture'
j.blend_weight = style_blend_weights[i]
net(style_images_caffe[i])
# Set all loss modules to loss mode
for i in content_losses:
i.mode = 'loss'
for i in style_losses:
i.mode = 'loss'
# Freeze the network in order to prevent
# unnecessary gradient calculations
for param in net.parameters():
param.requires_grad = False
# Initialize the image
if params.seed >= 0:
torch.manual_seed(params.seed)
torch.cuda.manual_seed(params.seed)
torch.backends.cudnn.deterministic=True
if params.init == 'random':
B, C, H, W = content_image.size()
img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype)
elif params.init == 'image':
if params.init_image != None:
img = init_image.clone()
else:
img = content_image.clone()
img = nn.Parameter(img.type(dtype))
def maybe_print(t, loss):
if params.print_iter > 0 and t % params.print_iter == 0:
print("Iteration: " + str(t) + " / "+ str(params.num_iterations))
totalLoss = 0
for i, loss_module in enumerate(content_losses):
print(" Content " + str(i+1) + " loss: " + str(loss_module.loss.item()))
totalLoss += loss_module.loss.item()
for i, loss_module in enumerate(style_losses):
print(" Style " + str(i+1) + " loss: " + str(loss_module.loss.item()))
totalLoss += loss_module.loss.item()
print(" C/S Total loss: " + str(totalLoss))
print(" Total loss: " + str(loss.item()))
def maybe_save(t):
should_save = params.save_iter > 0 and t % params.save_iter == 0
should_save = should_save or t == params.num_iterations
if should_save:
output_filename, file_extension = os.path.splitext(params.output_image)
if t == params.num_iterations:
filename = output_filename + str(file_extension)
else:
filename = str(output_filename) + "_" + str(t) + str(file_extension)
deprocess(img.clone(), filename)
# Function to evaluate loss and gradient. We run the net forward and
# backward to get the gradient, and sum up losses from the loss modules.
# optim.lbfgs internally handles iteration and calls this function many
# times, so we manually count the number of iterations to handle printing
# and saving intermediate results.
num_calls = [0]
def feval():
num_calls[0] += 1
optimizer.zero_grad()
net(img)
loss = 0
for mod in content_losses:
loss += mod.loss
for mod in style_losses:
loss += mod.loss
if params.tv_weight > 0:
for mod in tv_losses:
loss += mod.loss
loss.backward()
maybe_save(num_calls[0])
maybe_print(num_calls[0], loss)
return loss
optimizer, loopVal = setup_optimizer(img)
while num_calls[0] <= loopVal:
optimizer.step(feval)
# Configure optimizer and input image
def setup_optimizer(img):
if params.optimizer == 'lbfgs':
print("Running optimization with L-BFGS")
optim_state = {
'max_iter': params.num_iterations,
'tolerance_change': -1,
'tolerance_grad': -1,
}
if params.lbfgs_num_correction > 0:
optim_state['history_size'] = params.lbfgs_num_correction
optimizer = optim.LBFGS([img], **optim_state)
loopVal = 1
elif params.optimizer == 'adam':
print("Running optimization with ADAM")
optimizer = optim.Adam([img], lr = params.learning_rate)
loopVal = params.num_iterations - 1
return optimizer, loopVal
def setup_gpu():
if params.gpu > -1:
if params.backend == 'cudnn':
torch.backends.cudnn.enabled = True
if params.cudnn_autotune:
torch.backends.cudnn.benchmark = True
else:
torch.backends.cudnn.enabled = False
torch.cuda.set_device(params.gpu)
dtype = torch.cuda.FloatTensor
elif params.gpu == -1:
if params.backend =='mkl':
torch.backends.mkl.enabled = True
dtype = torch.FloatTensor
return dtype
# Preprocess an image before passing it to a model.
# We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR,
# and subtract the mean pixel.
def preprocess(image_name, image_size):
image = Image.open(image_name).convert('RGB')
if type(image_size) is not tuple:
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])])
tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0)
return tensor
# Undo the above preprocessing and save the tensor as an image:
def deprocess(output_tensor, output_name):
Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])])
bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256
output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.cpu())
image.save(str(output_name))
# Print like Lua/Torch7
def print_torch(net):
simplelist = ""
for i, layer in enumerate(net, 1):
simplelist = simplelist + "(" + str(i) + ") -> "
print("nn.Sequential ( \n [input -> " + simplelist + "output]")
def strip(x):
return str(x).replace(", ",',').replace("(",'').replace(")",'') + ", "
for i, l in enumerate(net):
is_2d = True if "2d" in str(l) else False
is_conv = True if "Conv2d" in str(l) else False
if is_2d:
ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding)
if is_conv:
in_c, out_c = str(l.in_channels), str(l.out_channels)
print(" (" + str(i+1) + "): " + "nn.Conv2d(" + in_c + " -> " + out_c + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')'))
else:
print(" (" + str(i+1) + "): " + "nn." + str(l).split("(", 1)[0] + "(" + ((ks).replace(",",'x' + ks, 1) + st.replace(" ",' ') + st.replace(", ",')')).replace(", ",',') )
else:
print(" (" + str(i+1) + "): " + "nn." + str(l).split("(", 1)[0])
print(")")
# Define an nn Module to compute content loss in-place
class ContentLoss(nn.Module):
def __init__(self, strength):
super(ContentLoss, self).__init__()
self.strength = strength
self.crit = nn.MSELoss()
self.mode = 'None'
def forward(self, input):
if self.mode == 'loss':
self.loss = self.crit(input, self.target) * self.strength
elif self.mode == 'capture':
self.target = input.detach()
return input
class GramMatrix(nn.Module):
def forward(self, input):
B, C, H, W = input.size()
x_flat = input.view(C, H * W)
return torch.mm(x_flat, x_flat.t())
# Define an nn Module to compute style loss in-place
class StyleLoss(nn.Module):
def __init__(self, strength):
super(StyleLoss, self).__init__()
self.target = torch.Tensor()
self.strength = strength
self.gram = GramMatrix()
self.crit = nn.MSELoss()
self.mode = 'None'
self.blend_weight = None
def forward(self, input):
self.G = self.gram(input)
self.G = self.G.div(input.nelement())
if self.mode == 'capture':
if self.blend_weight == None:
self.target = self.G.detach()
elif self.target.nelement() == 0:
self.target = self.G.detach().mul(self.blend_weight)
else:
self.target = self.target.add(self.blend_weight, self.G.detach())
elif self.mode == 'loss':
self.loss = self.strength * self.crit(self.G, self.target)
return input
class TVLoss(nn.Module):
def __init__(self, strength):
super(TVLoss, self).__init__()
self.strength = strength
self.x_diff = torch.Tensor()
self.y_diff = torch.Tensor()
def forward(self, input):
self.x_diff = input[:,:,1:,:] - input[:,:,:-1,:]
self.y_diff = input[:,:,:,1:] - input[:,:,:,:-1]
self.loss = self.strength * (torch.sum(torch.abs(self.x_diff)) + torch.sum(torch.abs(self.y_diff)))
return input
if __name__ == "__main__":
main()
@ProGamerGov
Copy link
Author

The code from this Gist was continued from here: https://gist.github.com/ProGamerGov/89973941721107f0bf713edfcfb467cf

A more up to date version of this code can be found here: https://gist.github.com/ProGamerGov/089a082c2a000d1e1cc034fc75ff5931

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment