-
-
Save ProGamerGov/4fbb4a8340ae654a3ae460ccddb7757c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
class VGG(nn.Module): | |
def __init__(self, features, num_classes=1000): | |
super(VGG, self).__init__() | |
self.features = features | |
self.classifier = nn.Sequential( | |
nn.Linear(512 * 7 * 7, 4096), | |
nn.ReLU(True), | |
nn.Dropout(), | |
nn.Linear(4096, 4096), | |
nn.ReLU(True), | |
nn.Dropout(), | |
nn.Linear(4096, num_classes), | |
) | |
class NIN(nn.Module): | |
def __init__(self, pooling): | |
super(NIN, self).__init__() | |
if pooling == 'max': | |
pool2d = nn.MaxPool2d((3, 3),(2, 2),(0, 0),ceil_mode=True) | |
elif pooling == 'avg': | |
pool2d = nn.AvgPool2d((3, 3),(2, 2),(0, 0),ceil_mode=True) | |
self.features = nn.Sequential( | |
nn.Conv2d(3,96,(11, 11),(4, 4)), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(96,96,(1, 1)), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(96,96,(1, 1)), | |
nn.ReLU(inplace=True), | |
pool2d, | |
nn.Conv2d(96,256,(5, 5),(1, 1),(2, 2)), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(256,256,(1, 1)), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(256,256,(1, 1)), | |
nn.ReLU(inplace=True), | |
pool2d, | |
nn.Conv2d(256,384,(3, 3),(1, 1),(1, 1)), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(384,384,(1, 1)), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(384,384,(1, 1)), | |
nn.ReLU(inplace=True), | |
pool2d, | |
nn.Dropout(0.5), | |
nn.Conv2d(384,1024,(3, 3),(1, 1),(1, 1)), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(1024,1024,(1, 1)), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(1024,1000,(1, 1)), | |
nn.ReLU(inplace=True), | |
nn.AvgPool2d((6, 6),(1, 1),(0, 0),ceil_mode=True), | |
nn.Softmax(), | |
) | |
def buildSequential(channel_list, pooling): | |
layers = [] | |
in_channels = 3 | |
if pooling == 'max': | |
pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |
elif pooling == 'avg': | |
pool2d = nn.AvgPool2d(kernel_size=2, stride=2) | |
else: | |
print("Unrecognized pooling parameter") | |
quit() | |
for c in channel_list: | |
if c == 'P': | |
layers += [pool2d] | |
else: | |
conv2d = nn.Conv2d(in_channels, c, kernel_size=3, padding=1) | |
layers += [conv2d, nn.ReLU(inplace=True)] | |
in_channels = c | |
return nn.Sequential(*layers) | |
channel_list = { | |
'VGG-16': [64, 64, 'P', 128, 128, 'P', 256, 256, 256, 'P', 512, 512, 512, 'P', 512, 512, 512, 'P'], | |
'VGG-19': [64, 64, 'P', 128, 128, 'P', 256, 256, 256, 256, 'P', 512, 512, 512, 512, 'P', 512, 512, 512, 512, 'P'], | |
} | |
nin_dict = { | |
'C': ['conv1', 'cccp1', 'cccp2', 'conv2', 'cccp3', 'cccp4', 'conv3', 'cccp5', 'cccp6', 'conv4-1024', 'cccp7-1024', 'cccp8-1024'], | |
'R': ['relu0', 'relu1', 'relu2', 'relu3', 'relu5', 'relu6', 'relu7', 'relu8', 'relu9', 'relu10', 'relu11', 'relu12'], | |
'P': ['pool1', 'pool2', 'pool3', 'pool4'], | |
'D': ['drop'], | |
} | |
vgg16_dict = { | |
'C': ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv4_1', 'conv4_2', 'conv4_3', 'conv5_1', 'conv5_2', 'conv5_3'], | |
'R': ['relu1_1', 'relu1_2', 'relu2_1', 'relu2_2', 'relu3_1', 'relu3_2', 'relu3_3', 'relu4_1', 'relu4_2', 'relu4_3', 'relu5_1', 'relu5_2', 'relu5_3'], | |
'P': ['pool1', 'pool2', 'pool3', 'pool4', 'pool5'], | |
} | |
vgg19_dict = { | |
'C': ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv3_4', 'conv4_1', 'conv4_2', 'conv4_3', 'conv4_4', 'conv5_1', 'conv5_2', 'conv5_3', 'conv5_4'], | |
'R': ['relu1_1', 'relu1_2', 'relu2_1', 'relu2_2', 'relu3_1', 'relu3_2', 'relu3_3', 'relu3_4', 'relu4_1', 'relu4_2', 'relu4_3', 'relu4_4', 'relu5_1', 'relu5_2', 'relu5_3', 'relu5_4'], | |
'P': ['pool1', 'pool2', 'pool3', 'pool4', 'pool5'], | |
} | |
def modelSelector(model_file, pooling): | |
if "vgg19" in str(model_file): | |
print("VGG-19 Architecture Detected") | |
cnn, layerList = VGG(buildSequential(channel_list['VGG-19'], pooling)), vgg19_dict | |
elif "vgg16" in str(model_file): | |
print("VGG-16 Architecture Detected") | |
cnn, layerList = VGG(buildSequential(channel_list['VGG-16'], pooling)), vgg16_dict | |
elif "nin" in str(model_file): | |
print("NIN Architecture Detected") | |
cnn, layerList = NIN(pooling), nin_dict | |
else: | |
print("Model Architecture Not Recognized") | |
raise ValueError("""Model Architecture Not Recognized. Please ensure that the model | |
name contains either "vgg16", "vgg19", or "nin", in the file name.""") | |
return cnn, layerList | |
# Print like Lua/loadcaffe | |
def print_loadcaffe(cnn, layerList): | |
c = 0 | |
for l in list(cnn): | |
if "Conv2d" in str(l): | |
in_c, out_c, ks = str(l.in_channels), str(l.out_channels), str(l.kernel_size) | |
print(layerList['C'][c] +": " + (out_c + " " + in_c + " " + ks).replace(")",'').replace("(",'').replace(",",'') ) | |
c+=1 | |
if c == len(layerList['C']): | |
break | |
# Get the model class, and configure pooling layer type | |
def loadCaffemodel(model_file, pooling, use_gpu): | |
cnn, layerList = modelSelector(model_file, pooling) | |
cnn.load_state_dict(torch.load(model_file)) | |
print("Successfully loaded " + str(model_file)) | |
# Maybe convert the model to cuda now, to avoid later issues | |
if use_gpu > -1: | |
cnn = cnn.cuda() | |
cnn = cnn.features | |
print_loadcaffe(cnn, layerList) | |
return cnn, layerList |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import copy | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from CaffeLoader import loadCaffemodel | |
import argparse | |
parser = argparse.ArgumentParser() | |
# Basic options | |
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg') | |
parser.add_argument("-style_blend_weights", default=None) | |
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg') | |
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512) | |
parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = -1", type=int, default=0) | |
# Optimization options | |
parser.add_argument("-content_weight", type=float, default=5e0) | |
parser.add_argument("-style_weight", type=float, default=1e2) | |
parser.add_argument("-tv_weight", type=float, default=1e-3) | |
parser.add_argument("-num_iterations", type=int, default=1000) | |
parser.add_argument("-init", choices=['random', 'image'], default='random') | |
parser.add_argument("-init_image", default=None) | |
parser.add_argument("-optimizer", choices=['lbfgs', 'adam'], default='lbfgs') | |
parser.add_argument("-learning_rate", type=float, default=1e0) | |
parser.add_argument("-lbfgs_num_correction", type=int, default=0) | |
# Output options | |
parser.add_argument("-print_iter", type=int, default=50) | |
parser.add_argument("-save_iter", type=int, default=100) | |
parser.add_argument("-output_image", default='out.png') | |
# Other options | |
parser.add_argument("-style_scale", type=float, default=1.0) | |
parser.add_argument("-pooling", choices=['avg', 'max'], default='max') | |
parser.add_argument("-model_file", type=str, default='models/vgg19-d01eb7cb.pth') | |
parser.add_argument("-backend", choices=['nn', 'cudnn', 'mkl'], default='cudnn') | |
parser.add_argument("-cudnn_autotune", action='store_true') | |
parser.add_argument("-seed", type=int, default=-1) | |
parser.add_argument("-content_layers", help="layers for content", default='relu4_2') | |
parser.add_argument("-style_layers", help="layers for style", default='relu1_1,relu2_1,relu3_1,relu4_1,relu5_1') | |
params = parser.parse_args() | |
Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images | |
def main(): | |
dtype = setup_gpu() | |
# Build the model definition and setup pooling layers: | |
cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu) | |
content_image = preprocess(params.content_image, params.image_size).type(dtype) | |
style_image_list = params.style_image.split(',') | |
style_images_caffe = [] | |
for image in style_image_list: | |
style_size = int(params.image_size * params.style_scale) | |
img_caffe = preprocess(image, style_size).type(dtype) | |
style_images_caffe.append(img_caffe) | |
if params.init_image != None: | |
image_size = (content_image.size(2), content_image.size(3)) | |
init_image = preprocess(params.init_image, image_size).type(dtype) | |
# Handle style blending weights for multiple style inputs | |
style_blend_weights = [] | |
if params.style_blend_weights == None: | |
# Style blending not specified, so use equal weighting | |
for i in style_image_list: | |
style_blend_weights.append(1.0) | |
for i, blend_weights in enumerate(style_blend_weights): | |
style_blend_weights[i] = int(style_blend_weights[i]) | |
else: | |
style_blend_weights = params.style_blend_weights.split(',') | |
assert len(style_blend_weights) == len(style_image_list), \ | |
"-style_blend_weights and -style_images must have the same number of elements!" | |
# Normalize the style blending weights so they sum to 1 | |
style_blend_sum = 0 | |
for i, blend_weights in enumerate(style_blend_weights): | |
style_blend_weights[i] = float(style_blend_weights[i]) | |
style_blend_sum = float(style_blend_sum) + style_blend_weights[i] | |
for i, blend_weights in enumerate(style_blend_weights): | |
style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum) | |
content_layers = params.content_layers.split(',') | |
style_layers = params.style_layers.split(',') | |
# Set up the network, inserting style and content loss modules | |
cnn = copy.deepcopy(cnn) | |
content_losses, style_losses, tv_losses = [], [], [] | |
next_content_idx, next_style_idx = 1, 1 | |
net = nn.Sequential() | |
c, r = 0, 0 | |
if params.tv_weight > 0: | |
tv_mod = TVLoss(params.tv_weight).type(dtype) | |
net.add_module(str(len(net)), tv_mod) | |
tv_losses.append(tv_mod) | |
for i, layer in enumerate(list(cnn), 1): | |
if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers): | |
if isinstance(layer, nn.Conv2d): | |
net.add_module(str(len(net)), layer) | |
if layerList['C'][c] in content_layers: | |
print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c])) | |
loss_module = ContentLoss(params.content_weight) | |
net.add_module(str(len(net)), loss_module) | |
content_losses.append(loss_module) | |
if layerList['C'][c] in style_layers: | |
print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c])) | |
loss_module = StyleLoss(params.style_weight) | |
net.add_module(str(len(net)), loss_module) | |
style_losses.append(loss_module) | |
c+=1 | |
if isinstance(layer, nn.ReLU): | |
net.add_module(str(len(net)), layer) | |
if layerList['R'][r] in content_layers: | |
print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r])) | |
loss_module = ContentLoss(params.content_weight) | |
net.add_module(str(len(net)), loss_module) | |
content_losses.append(loss_module) | |
next_content_idx += 1 | |
if layerList['R'][r] in style_layers: | |
print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r])) | |
loss_module = StyleLoss(params.style_weight) | |
net.add_module(str(len(net)), loss_module) | |
style_losses.append(loss_module) | |
next_style_idx += 1 | |
r+=1 | |
if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d): | |
net.add_module(str(len(net)), layer) | |
# Capture content targets | |
for i in content_losses: | |
i.mode = 'capture' | |
print("Capturing content targets") | |
print_torch(net) | |
net(content_image) | |
# Capture style targets | |
for i in content_losses: | |
i.mode = 'None' | |
for i, image in enumerate(style_images_caffe): | |
print("Capturing style target " + str(i+1)) | |
for j in style_losses: | |
j.mode = 'capture' | |
j.blend_weight = style_blend_weights[i] | |
net(style_images_caffe[i]) | |
# Set all loss modules to loss mode | |
for i in content_losses: | |
i.mode = 'loss' | |
for i in style_losses: | |
i.mode = 'loss' | |
# Freeze the network in order to prevent | |
# unnecessary gradient calculations | |
for param in net.parameters(): | |
param.requires_grad = False | |
# Initialize the image | |
if params.seed >= 0: | |
torch.manual_seed(params.seed) | |
torch.cuda.manual_seed(params.seed) | |
torch.backends.cudnn.deterministic=True | |
if params.init == 'random': | |
B, C, H, W = content_image.size() | |
img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype) | |
elif params.init == 'image': | |
if params.init_image != None: | |
img = init_image.clone() | |
else: | |
img = content_image.clone() | |
img = nn.Parameter(img.type(dtype)) | |
def maybe_print(t, loss): | |
if params.print_iter > 0 and t % params.print_iter == 0: | |
print("Iteration: " + str(t) + " / "+ str(params.num_iterations)) | |
totalLoss = 0 | |
for i, loss_module in enumerate(content_losses): | |
print(" Content " + str(i+1) + " loss: " + str(loss_module.loss.item())) | |
totalLoss += loss_module.loss.item() | |
for i, loss_module in enumerate(style_losses): | |
print(" Style " + str(i+1) + " loss: " + str(loss_module.loss.item())) | |
totalLoss += loss_module.loss.item() | |
print(" C/S Total loss: " + str(totalLoss)) | |
print(" Total loss: " + str(loss.item())) | |
def maybe_save(t): | |
should_save = params.save_iter > 0 and t % params.save_iter == 0 | |
should_save = should_save or t == params.num_iterations | |
if should_save: | |
output_filename, file_extension = os.path.splitext(params.output_image) | |
if t == params.num_iterations: | |
filename = output_filename + str(file_extension) | |
else: | |
filename = str(output_filename) + "_" + str(t) + str(file_extension) | |
deprocess(img.clone(), filename) | |
# Function to evaluate loss and gradient. We run the net forward and | |
# backward to get the gradient, and sum up losses from the loss modules. | |
# optim.lbfgs internally handles iteration and calls this function many | |
# times, so we manually count the number of iterations to handle printing | |
# and saving intermediate results. | |
num_calls = [0] | |
def feval(): | |
num_calls[0] += 1 | |
optimizer.zero_grad() | |
net(img) | |
loss = 0 | |
for mod in content_losses: | |
loss += mod.loss | |
for mod in style_losses: | |
loss += mod.loss | |
if params.tv_weight > 0: | |
for mod in tv_losses: | |
loss += mod.loss | |
loss.backward() | |
maybe_save(num_calls[0]) | |
maybe_print(num_calls[0], loss) | |
return loss | |
optimizer, loopVal = setup_optimizer(img) | |
while num_calls[0] <= loopVal: | |
optimizer.step(feval) | |
# Configure optimizer and input image | |
def setup_optimizer(img): | |
if params.optimizer == 'lbfgs': | |
print("Running optimization with L-BFGS") | |
optim_state = { | |
'max_iter': params.num_iterations, | |
'tolerance_change': -1, | |
'tolerance_grad': -1, | |
} | |
if params.lbfgs_num_correction > 0: | |
optim_state['history_size'] = params.lbfgs_num_correction | |
optimizer = optim.LBFGS([img], **optim_state) | |
loopVal = 1 | |
elif params.optimizer == 'adam': | |
print("Running optimization with ADAM") | |
optimizer = optim.Adam([img], lr = params.learning_rate) | |
loopVal = params.num_iterations - 1 | |
return optimizer, loopVal | |
def setup_gpu(): | |
if params.gpu > -1: | |
if params.backend == 'cudnn': | |
torch.backends.cudnn.enabled = True | |
if params.cudnn_autotune: | |
torch.backends.cudnn.benchmark = True | |
else: | |
torch.backends.cudnn.enabled = False | |
torch.cuda.set_device(params.gpu) | |
dtype = torch.cuda.FloatTensor | |
elif params.gpu == -1: | |
if params.backend =='mkl': | |
torch.backends.mkl.enabled = True | |
dtype = torch.FloatTensor | |
return dtype | |
# Preprocess an image before passing it to a model. | |
# We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR, | |
# and subtract the mean pixel. | |
def preprocess(image_name, image_size): | |
image = Image.open(image_name).convert('RGB') | |
if type(image_size) is not tuple: | |
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)]) | |
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()]) | |
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])]) | |
Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])]) | |
tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0) | |
return tensor | |
# Undo the above preprocessing and save the tensor as an image: | |
def deprocess(output_tensor, output_name): | |
Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])]) | |
bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])]) | |
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256 | |
output_tensor.clamp_(0, 1) | |
Image2PIL = transforms.ToPILImage() | |
image = Image2PIL(output_tensor.cpu()) | |
image.save(str(output_name)) | |
# Print like Lua/Torch7 | |
def print_torch(net): | |
simplelist = "" | |
for i, layer in enumerate(net, 1): | |
simplelist = simplelist + "(" + str(i) + ") -> " | |
print("nn.Sequential ( \n [input -> " + simplelist + "output]") | |
def strip(x): | |
return str(x).replace(", ",',').replace("(",'').replace(")",'') + ", " | |
for i, l in enumerate(net): | |
is_2d = True if "2d" in str(l) else False | |
is_conv = True if "Conv2d" in str(l) else False | |
if is_2d: | |
ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding) | |
if is_conv: | |
in_c, out_c = str(l.in_channels), str(l.out_channels) | |
print(" (" + str(i+1) + "): " + "nn.Conv2d(" + in_c + " -> " + out_c + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')')) | |
else: | |
print(" (" + str(i+1) + "): " + "nn." + str(l).split("(", 1)[0] + "(" + ((ks).replace(",",'x' + ks, 1) + st.replace(" ",' ') + st.replace(", ",')')).replace(", ",',') ) | |
else: | |
print(" (" + str(i+1) + "): " + "nn." + str(l).split("(", 1)[0]) | |
print(")") | |
# Define an nn Module to compute content loss in-place | |
class ContentLoss(nn.Module): | |
def __init__(self, strength): | |
super(ContentLoss, self).__init__() | |
self.strength = strength | |
self.crit = nn.MSELoss() | |
self.mode = 'None' | |
def forward(self, input): | |
if self.mode == 'loss': | |
self.loss = self.crit(input, self.target) * self.strength | |
elif self.mode == 'capture': | |
self.target = input.detach() | |
return input | |
class GramMatrix(nn.Module): | |
def forward(self, input): | |
B, C, H, W = input.size() | |
x_flat = input.view(C, H * W) | |
return torch.mm(x_flat, x_flat.t()) | |
# Define an nn Module to compute style loss in-place | |
class StyleLoss(nn.Module): | |
def __init__(self, strength): | |
super(StyleLoss, self).__init__() | |
self.target = torch.Tensor() | |
self.strength = strength | |
self.gram = GramMatrix() | |
self.crit = nn.MSELoss() | |
self.mode = 'None' | |
self.blend_weight = None | |
def forward(self, input): | |
self.G = self.gram(input) | |
self.G = self.G.div(input.nelement()) | |
if self.mode == 'capture': | |
if self.blend_weight == None: | |
self.target = self.G.detach() | |
elif self.target.nelement() == 0: | |
self.target = self.G.detach().mul(self.blend_weight) | |
else: | |
self.target = self.target.add(self.blend_weight, self.G.detach()) | |
elif self.mode == 'loss': | |
self.loss = self.strength * self.crit(self.G, self.target) | |
return input | |
class TVLoss(nn.Module): | |
def __init__(self, strength): | |
super(TVLoss, self).__init__() | |
self.strength = strength | |
self.x_diff = torch.Tensor() | |
self.y_diff = torch.Tensor() | |
def forward(self, input): | |
self.x_diff = input[:,:,1:,:] - input[:,:,:-1,:] | |
self.y_diff = input[:,:,:,1:] - input[:,:,:,:-1] | |
self.loss = self.strength * (torch.sum(torch.abs(self.x_diff)) + torch.sum(torch.abs(self.y_diff))) | |
return input | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The code from this Gist was continued from here: https://gist.github.com/ProGamerGov/89973941721107f0bf713edfcfb467cf
A more up to date version of this code can be found here: https://gist.github.com/ProGamerGov/089a082c2a000d1e1cc034fc75ff5931