Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Last active July 25, 2018 19:26
Show Gist options
  • Save ProGamerGov/089a082c2a000d1e1cc034fc75ff5931 to your computer and use it in GitHub Desktop.
Save ProGamerGov/089a082c2a000d1e1cc034fc75ff5931 to your computer and use it in GitHub Desktop.
Release Candidate For Neural-Style-PT
*.swp
out*.png
*.png
*.jpg
*.pyc
*.pth
models/
!models/download_models.py

neural-style-pt

This is a PyTorch implementation of the paper A Neural Algorithm of Artistic Style by Leon A. Gatys, Alexander S. Ecker, and Matthias Bethge. The code is based on Justin Johnson's Neural-Style.

The paper presents an algorithm for combining the content of one image with the style of another image using convolutional neural networks. Here's an example that maps the artistic style of The Starry Night onto a night-time photograph of the Stanford campus:

Applying the style of different images to the same content image gives interesting results. Here we reproduce Figure 2 from the paper, which renders a photograph of the Tubingen in Germany in a variety of styles:

Here are the results of applying the style of various pieces of artwork to this photograph of the golden gate bridge:

Content / Style Tradeoff

The algorithm allows the user to trade-off the relative weight of the style and content reconstruction terms, as shown in this example where we port the style of Picasso's 1907 self-portrait onto Brad Pitt:

Style Scale

By resizing the style image before extracting style features, we can control the types of artistic features that are transfered from the style image; you can control this behavior with the -style_scale flag. Below we see three examples of rendering the Golden Gate Bridge in the style of The Starry Night. From left to right, -style_scale is 2.0, 1.0, and 0.5.

Multiple Style Images

You can use more than one style image to blend multiple artistic styles.

Clockwise from upper left: "The Starry Night" + "The Scream", "The Scream" + "Composition VII", "Seated Nude" + "Composition VII", and "Seated Nude" + "The Starry Night"

Style Interpolation

When using multiple style images, you can control the degree to which they are blended:

Transfer style but not color

If you add the flag -original_colors 1 then the output image will retain the colors of the original image.

Setup:

Dependencies:

Optional dependencies:

  • For CUDA backend:
    • CUDA 7.5 or above
  • For cuDNN backend:
    • cuDNN v6 or above

After installing the dependencies, you'll need to run the following script to download the VGG model:

python models/download_models.py

This will download the original VGG-19 model. The original VGG-16 model will also be downloaded. By default the original VGG-19 model is used.

If you have a smaller memory GPU then using NIN Imagenet model will be better and gives slightly worse yet comparable results. You can get the details on the model from BVLC Caffe ModelZoo. The NIN model is downloaded when you run the download_models.py script.

You can find detailed installation instructions for Ubuntu in the installation guide.

Usage

Basic usage:

python neural_style.py -style_image <image.jpg> -content_image <image.jpg>

cuDNN usage with NIN Model:

python neural_style.py -style_image examples/inputs/picasso_selfport1907.jpg -content_image examples/inputs/brad_pitt.jpg -output_image profile.png -model_file models/nin_imagenet.pth -gpu 0 -backend cudnn -num_iterations 1000 -seed 123 -content_layers relu0,relu3,relu7,relu12 -style_layers relu0,relu3,relu7,relu12 -content_weight 10 -style_weight 500 -image_size 512 -optimizer adam

cuDNN NIN Model Picasso Brad Pitt

To use multiple style images, pass a comma-separated list like this:

-style_image starry_night.jpg,the_scream.jpg.

Note that paths to images should not contain the ~ character to represent your home directory; you should instead use a relative path or a full absolute path.

Options:

  • -image_size: Maximum side length (in pixels) of the generated image. Default is 512.
  • -style_blend_weights: The weight for blending the style of multiple style images, as a comma-separated list, such as -style_blend_weights 3,7. By default all style images are equally weighted.
  • -gpu: Zero-indexed ID of the GPU to use; for CPU mode set -gpu to -1.

Optimization options:

  • -content_weight: How much to weight the content reconstruction term. Default is 5e0.
  • -style_weight: How much to weight the style reconstruction term. Default is 1e2.
  • -tv_weight: Weight of total-variation (TV) regularization; this helps to smooth the image. Default is 1e-3. Set to 0 to disable TV regularization.
  • -num_iterations: Default is 1000.
  • -init: Method for generating the generated image; one of random or image. Default is random which uses a noise initialization as in the paper; image initializes with the content image.
  • -optimizer: The optimization algorithm to use; either lbfgs or adam; default is lbfgs. L-BFGS tends to give better results, but uses more memory. Switching to ADAM will reduce memory usage; when using ADAM you will probably need to play with other parameters to get good results, especially the style weight, content weight, and learning rate.
  • -learning_rate: Learning rate to use with the ADAM optimizer. Default is 1e1.

Output options:

  • -output_image: Name of the output image. Default is out.png.
  • -print_iter: Print progress every print_iter iterations. Set to 0 to disable printing.
  • -save_iter: Save the image every save_iter iterations. Set to 0 to disable saving intermediate results.

Layer options:

  • -content_layers: Comma-separated list of layer names to use for content reconstruction. Default is relu4_2.
  • -style_layers: Comma-separated list of layer names to use for style reconstruction. Default is relu1_1,relu2_1,relu3_1,relu4_1,relu5_1.

Other options:

  • -style_scale: Scale at which to extract features from the style image. Default is 1.0.
  • -original_colors: If you set this to 1, then the output image will keep the colors of the content image.
  • -model_file: Path to the .pth file for the VGG Caffe model which was converted to PyTorch. Default is the original VGG-19 model; you can also try the original VGG-16 model.
  • -pooling: The type of pooling layers to use; one of max or avg. Default is max. The VGG-19 models uses max pooling layers, but the paper mentions that replacing these layers with average pooling layers can improve the results. I haven't been able to get good results using average pooling, but the option is here.
  • -backend: nn, cudnn, or mkl. Default is nn. mkl requires Intel's MKL backend.
  • -cudnn_autotune: When using the cuDNN backend, pass this flag to use the built-in cuDNN autotuner to select the best convolution algorithms for your architecture. This will make the first iteration a bit slower and can take a bit more memory, but may significantly speed up the cuDNN backend.

Frequently Asked Questions

Problem: The program runs out of memory and dies

Solution: Try reducing the image size: -image_size 256 (or lower). Note that different image sizes will likely require non-default values for -style_weight and -content_weight for optimal results. If you are running on a GPU, you can also try running with -backend cudnn to reduce memory usage.

Problem: -backend cudnn is slower than default NN backend

Solution: Add the flag -cudnn_autotune; this will use the built-in cuDNN autotuner to select the best convolution algorithms.

Memory Usage

By default, neural-style-pt uses the nn backend for convolutions and L-BFGS for optimization. These give good results, but can both use a lot of memory. You can reduce memory usage with the following:

  • Use cuDNN: Add the flag -backend cudnn to use the cuDNN backend. This will only work in GPU mode.
  • Use ADAM: Add the flag -optimizer adam to use ADAM instead of L-BFGS. This should significantly reduce memory usage, but may require tuning of other parameters for good results; in particular you should play with the learning rate, content weight, and style weight. This should work in both CPU and GPU modes.
  • Reduce image size: If the above tricks are not enough, you can reduce the size of the generated image; pass the flag -image_size 256 to generate an image at half the default size.

With the default settings, neural-style-pt uses about 3.7 GB of GPU memory on my system; switching to ADAM and cuDNN reduces the GPU memory footprint to about 1GB.

Speed

Speed can vary a lot depending on the backend and the optimizer. Here are some times for running 500 iterations with -image_size=512 on a Tesla K80 with different settings:

  • -backend nn -optimizer lbfgs: 117 seconds
  • -backend nn -optimizer adam: 100 seconds
  • -backend cudnn -optimizer lbfgs: 124 seconds
  • -backend cudnn -optimizer adam: 107 seconds
  • -backend cudnn -cudnn_autotune -optimizer lbfgs: 109 seconds
  • -backend cudnn -cudnn_autotune -optimizer adam: 91 seconds

Here are the same benchmarks on a GTX 1080:

  • -backend nn -optimizer lbfgs: 56 seconds
  • -backend nn -optimizer adam: 38 seconds
  • -backend cudnn -optimizer lbfgs: 40 seconds
  • -backend cudnn -optimizer adam: 40 seconds
  • -backend cudnn -cudnn_autotune -optimizer lbfgs: 23 seconds
  • -backend cudnn -cudnn_autotune -optimizer adam: 24 seconds

Implementation details

Images are initialized with white noise and optimized using L-BFGS.

We perform style reconstructions using the conv1_1, conv2_1, conv3_1, conv4_1, and conv5_1 layers and content reconstructions using the conv4_2 layer. As in the paper, the five style reconstruction losses have equal weights.

Citation

If you find this code useful for your research, please cite:

@misc{ProGamerGov2018,
  author = {ProGamerGov},
  title = {neural-style-pt},
  year = {2018},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/ProGamerGov/neural-style-pt}},
}
import torch
import torch.nn as nn
class VGG(nn.Module):
def __init__(self, features, num_classes=1000):
super(VGG, self).__init__()
self.features = features
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
class NIN(nn.Module):
def __init__(self, pooling):
super(NIN, self).__init__()
if pooling == 'max':
pool2d = nn.MaxPool2d((3, 3),(2, 2),(0, 0),ceil_mode=True)
elif pooling == 'avg':
pool2d = nn.AvgPool2d((3, 3),(2, 2),(0, 0),ceil_mode=True)
self.features = nn.Sequential(
nn.Conv2d(3,96,(11, 11),(4, 4)),
nn.ReLU(inplace=True),
nn.Conv2d(96,96,(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(96,96,(1, 1)),
nn.ReLU(inplace=True),
pool2d,
nn.Conv2d(96,256,(5, 5),(1, 1),(2, 2)),
nn.ReLU(inplace=True),
nn.Conv2d(256,256,(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(256,256,(1, 1)),
nn.ReLU(inplace=True),
pool2d,
nn.Conv2d(256,384,(3, 3),(1, 1),(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(384,384,(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(384,384,(1, 1)),
nn.ReLU(inplace=True),
pool2d,
nn.Dropout(0.5),
nn.Conv2d(384,1024,(3, 3),(1, 1),(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(1024,1024,(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(1024,1000,(1, 1)),
nn.ReLU(inplace=True),
nn.AvgPool2d((6, 6),(1, 1),(0, 0),ceil_mode=True),
nn.Softmax(),
)
def buildSequential(channel_list, pooling):
layers = []
in_channels = 3
if pooling == 'max':
pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
elif pooling == 'avg':
pool2d = nn.AvgPool2d(kernel_size=2, stride=2)
else:
raise ValueError("Unrecognized pooling parameter")
for c in channel_list:
if c == 'P':
layers += [pool2d]
else:
conv2d = nn.Conv2d(in_channels, c, kernel_size=3, padding=1)
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = c
return nn.Sequential(*layers)
channel_list = {
'VGG-16': [64, 64, 'P', 128, 128, 'P', 256, 256, 256, 'P', 512, 512, 512, 'P', 512, 512, 512, 'P'],
'VGG-19': [64, 64, 'P', 128, 128, 'P', 256, 256, 256, 256, 'P', 512, 512, 512, 512, 'P', 512, 512, 512, 512, 'P'],
}
nin_dict = {
'C': ['conv1', 'cccp1', 'cccp2', 'conv2', 'cccp3', 'cccp4', 'conv3', 'cccp5', 'cccp6', 'conv4-1024', 'cccp7-1024', 'cccp8-1024'],
'R': ['relu0', 'relu1', 'relu2', 'relu3', 'relu5', 'relu6', 'relu7', 'relu8', 'relu9', 'relu10', 'relu11', 'relu12'],
'P': ['pool1', 'pool2', 'pool3', 'pool4'],
'D': ['drop'],
}
vgg16_dict = {
'C': ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv4_1', 'conv4_2', 'conv4_3', 'conv5_1', 'conv5_2', 'conv5_3'],
'R': ['relu1_1', 'relu1_2', 'relu2_1', 'relu2_2', 'relu3_1', 'relu3_2', 'relu3_3', 'relu4_1', 'relu4_2', 'relu4_3', 'relu5_1', 'relu5_2', 'relu5_3'],
'P': ['pool1', 'pool2', 'pool3', 'pool4', 'pool5'],
}
vgg19_dict = {
'C': ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv3_4', 'conv4_1', 'conv4_2', 'conv4_3', 'conv4_4', 'conv5_1', 'conv5_2', 'conv5_3', 'conv5_4'],
'R': ['relu1_1', 'relu1_2', 'relu2_1', 'relu2_2', 'relu3_1', 'relu3_2', 'relu3_3', 'relu3_4', 'relu4_1', 'relu4_2', 'relu4_3', 'relu4_4', 'relu5_1', 'relu5_2', 'relu5_3', 'relu5_4'],
'P': ['pool1', 'pool2', 'pool3', 'pool4', 'pool5'],
}
def modelSelector(model_file, pooling):
if "vgg" in model_file:
if "19" in model_file:
print("VGG-19 Architecture Detected")
cnn, layerList = VGG(buildSequential(channel_list['VGG-19'], pooling)), vgg19_dict
elif "16" in model_file:
print("VGG-16 Architecture Detected")
cnn, layerList = VGG(buildSequential(channel_list['VGG-16'], pooling)), vgg16_dict
else:
raise ValueError("VGG architecture not recognized.")
elif "nin" in model_file:
print("NIN Architecture Detected")
cnn, layerList = NIN(pooling), nin_dict
else:
raise ValueError("Model architecture not recognized.")
return cnn, layerList
# Print like Torch7/loadcaffe
def print_loadcaffe(cnn, layerList):
c = 0
for l in list(cnn):
if "Conv2d" in str(l):
in_c, out_c, ks = str(l.in_channels), str(l.out_channels), str(l.kernel_size)
print(layerList['C'][c] +": " + (out_c + " " + in_c + " " + ks).replace(")",'').replace("(",'').replace(",",'') )
c+=1
if c == len(layerList['C']):
break
# Load the model, and configure pooling layer type
def loadCaffemodel(model_file, pooling, use_gpu):
cnn, layerList = modelSelector(str(model_file).lower(), pooling)
cnn.load_state_dict(torch.load(model_file))
print("Successfully loaded " + str(model_file))
# Maybe convert the model to cuda now, to avoid later issues
if use_gpu > -1:
cnn = cnn.cuda()
cnn = cnn.features
print_loadcaffe(cnn, layerList)
return cnn, layerList
import torch
from sys import version_info
from collections import OrderedDict
from torch.utils.model_zoo import load_url
# Download the VGG-19 model and fix the layer names
sd = load_url("https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg19-d01eb7cb.pth")
map = {'classifier.1.weight':u'classifier.0.weight', 'classifier.1.bias':u'classifier.0.bias', 'classifier.4.weight':u'classifier.3.weight', 'classifier.4.bias':u'classifier.3.bias'}
sd = OrderedDict([(map[k] if k in map else k,v) for k,v in sd.items()])
torch.save(sd, "models/vgg19-d01eb7cb.pth")
# Download the VGG-16 model and fix the layer names
sd = load_url("https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg16-00b39a1b.pth")
map = {'classifier.1.weight':u'classifier.0.weight', 'classifier.1.bias':u'classifier.0.bias', 'classifier.4.weight':u'classifier.3.weight', 'classifier.4.bias':u'classifier.3.bias'}
sd = OrderedDict([(map[k] if k in map else k,v) for k,v in sd.items()])
torch.save(sd, "models/vgg16-00b39a1b.pth")
# Download the NIN model
if version_info[0] < 3:
import urllib
urllib.URLopener().retrieve("https://raw.githubusercontent.com/ProGamerGov/pytorch-nin/master/nin_imagenet.pth", "models/nin_imagenet.pth")
else:
import urllib.request
urllib.request.urlretrieve("https://raw.githubusercontent.com/ProGamerGov/pytorch-nin/master/nin_imagenet.pth", "models/nin_imagenet.pth")

neural-style-pt Installation

This guide will walk you through the setup for neural-style-pt on Ubuntu. If you wish to install PyTorch on a different operating system like Windows or MacOS, installation guides can be found here.

Note that in order to reduce their size, the pre-packaged binary releases (pip, Conda, etc...) have removed support for some older GPUs, and thus you will have to install from source in order to use these GPUs.

With A Package Manager:

The pip and Conda packages ship with CUDA and cuDNN already built in, so after you have installed PyTorch with pip or Conda, you can skip to installing neural-style-pt.

pip:

Following the pip installation instructions here, you can install PyTorch with the following commands:

# in a terminal, run the commands
cd ~/
pip install torch 
pip install torchvision 

Or:

cd ~/
pip3 install torch 
pip3 install torchvision 

Conda:

Following the Conda installation instructions here, you can install PyTorch with the following command:

conda install pytorch torchvision -c pytorch

From Source:

(Optional) Step 1: Install CUDA

If you have a CUDA-capable GPU from NVIDIA then you can speed up neural-style-pt with CUDA.

First download and unpack the local CUDA installer from NVIDIA; note that there are different installers for each recent version of Ubuntu:

# For Ubuntu 17.0.4
sudo dpkg -i cuda-repo-ubuntu1704-9-1-local_9.1.85-1_amd64.deb
sudo apt-key add /var/cuda-repo-<version>/7fa2af80.pub
# For Ubuntu 16.0.4
sudo dpkg -i cuda-repo-ubuntu1604-9-1-local_9.1.85-1_amd64.deb
sudo apt-key add /var/cuda-repo-<version>/7fa2af80.pub

Instructions for downloading and installing the latest CUDA version on all supported operating systems, can be found here.

Now update the repository cache and install CUDA. Note that this will also install a graphics driver from NVIDIA.

sudo apt-get update
sudo apt-get install cuda

At this point you may need to reboot your machine to load the new graphics driver. After rebooting, you should be able to see the status of your graphics card(s) by running the command nvidia-smi; it should give output that looks something like this:

Wed Apr 11 21:54:49 2018
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.90                 Driver Version: 384.90                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla K80           Off  | 00000000:00:1E.0 Off |                    0 |
| N/A   62C    P0    68W / 149W |      0MiB / 11439MiB |     94%      Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

(Optional) Step 2: Install cuDNN

cuDNN is a library from NVIDIA that efficiently implements many of the operations (like convolutions and pooling) that are commonly used in deep learning.

After registering as a developer with NVIDIA, you can download cuDNN here. Make sure that you use the approprite version of cuDNN for your version of CUDA.

After dowloading, you can unpack and install cuDNN like this:

tar -xvzf cudnn-9.1-linux-x64-v7.1.tgz
sudo cp cuda/lib64/libcudnn* /usr/local/cuda-9.1/lib64/
sudo cp cuda/include/cudnn.h /usr/local/cuda-9.1/include/

Note that the cuDNN backend can only be used for GPU mode.

Step 3: Install PyTorch

To install PyTorch from source on Ubuntu (Instructions may be different if you are using a different OS):

cd ~/
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
python setup.py install

cd ~/
git clone --recursive https://github.com/pytorch/vision
cd vision
python setup.py install

To check that your torch installation is working, run the command python or python3 to enter the Python interpreter. Then type import torch and hit enter.

You can then type print(torch.version.cuda) and print(torch.backends.cudnn.version()) to confirm that you are using the desired versions of CUDA and cuDNN.

To quit just type exit() or use Ctrl-D.

Install neural-style-pt

First we clone neural-style-pt from GitHub:

cd ~/
git clone https://github.com/ProGamerGov/neural-style-pt.git
cd neural-style-pt

Next we need to download the pretrained neural network models:

python models/download_models.py

You should now be able to run neural-style-pt in CPU mode like this:

python neural_style.py -gpu -1 -print_iter 1

If you installed PyTorch with support for CUDA, then should now be able to run neural-style-pt in GPU mode like this:

python neural_style.py -gpu 0 -print_iter 1

If you installed PyTorch with support for cuDNN, then you should now be able to run neural-style-pt with the cudnn backend like this:

python neural_style.py -gpu 0 -backend cudnn -print_iter 1

If everything is working properly you should see output like this:

Iteration 1 / 1000
  Content 1 loss: 1616196.125
  Style 1 loss: 29890.9980469
  Style 2 loss: 658038.625
  Style 3 loss: 145283.671875
  Style 4 loss: 11347409.0
  Style 5 loss: 563.368896484
  Total loss: 13797382.0
Iteration 2 / 1000
  Content 1 loss: 1616195.625
  Style 1 loss: 29890.9980469
  Style 2 loss: 658038.625
  Style 3 loss: 145283.671875
  Style 4 loss: 11347409.0
  Style 5 loss: 563.368896484
  Total loss: 13797382.0
Iteration 3 / 1000
  Content 1 loss: 1579918.25
  Style 1 loss: 29881.3164062
  Style 2 loss: 654351.75
  Style 3 loss: 144214.640625
  Style 4 loss: 11301945.0
  Style 5 loss: 562.733032227
  Total loss: 13711628.0
Iteration 4 / 1000
  Content 1 loss: 1460443.0
  Style 1 loss: 29849.7226562
  Style 2 loss: 643799.1875
  Style 3 loss: 140405.015625
  Style 4 loss: 10940431.0
  Style 5 loss: 553.507446289
  Total loss: 13217080.0
Iteration 5 / 1000
  Content 1 loss: 1298983.625
  Style 1 loss: 29734.8964844
  Style 2 loss: 604133.8125
  Style 3 loss: 125455.945312
  Style 4 loss: 8850759.0
  Style 5 loss: 526.118591309
  Total loss: 10912633.0
The MIT License (MIT)
Copyright (c) 2018 ProGamerGov
Copyright (c) 2015 Justin Johnson
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
import os
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from CaffeLoader import loadCaffemodel
import argparse
parser = argparse.ArgumentParser()
# Basic options
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg')
parser.add_argument("-style_blend_weights", default=None)
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg')
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512)
parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = -1", type=int, default=0)
# Optimization options
parser.add_argument("-content_weight", type=float, default=5e0)
parser.add_argument("-style_weight", type=float, default=1e2)
parser.add_argument("-tv_weight", type=float, default=1e-3)
parser.add_argument("-num_iterations", type=int, default=1000)
parser.add_argument("-init", choices=['random', 'image'], default='random')
parser.add_argument("-init_image", default=None)
parser.add_argument("-optimizer", choices=['lbfgs', 'adam'], default='lbfgs')
parser.add_argument("-learning_rate", type=float, default=1e0)
parser.add_argument("-lbfgs_num_correction", type=int, default=0)
# Output options
parser.add_argument("-print_iter", type=int, default=50)
parser.add_argument("-save_iter", type=int, default=100)
parser.add_argument("-output_image", default='out.png')
# Other options
parser.add_argument("-style_scale", type=float, default=1.0)
parser.add_argument("-original_colors", type=int, choices=[0, 1], default=0)
parser.add_argument("-pooling", choices=['avg', 'max'], default='max')
parser.add_argument("-model_file", type=str, default='models/vgg19-d01eb7cb.pth')
parser.add_argument("-backend", choices=['nn', 'cudnn', 'mkl'], default='nn')
parser.add_argument("-cudnn_autotune", action='store_true')
parser.add_argument("-seed", type=int, default=-1)
parser.add_argument("-content_layers", help="layers for content", default='relu4_2')
parser.add_argument("-style_layers", help="layers for style", default='relu1_1,relu2_1,relu3_1,relu4_1,relu5_1')
params = parser.parse_args()
Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images
def main():
dtype = setup_gpu()
cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu)
content_image = preprocess(params.content_image, params.image_size).type(dtype)
style_image_list = params.style_image.split(',')
style_images_caffe = []
for image in style_image_list:
style_size = int(params.image_size * params.style_scale)
img_caffe = preprocess(image, style_size).type(dtype)
style_images_caffe.append(img_caffe)
if params.init_image != None:
image_size = (content_image.size(2), content_image.size(3))
init_image = preprocess(params.init_image, image_size).type(dtype)
# Handle style blending weights for multiple style inputs
style_blend_weights = []
if params.style_blend_weights == None:
# Style blending not specified, so use equal weighting
for i in style_image_list:
style_blend_weights.append(1.0)
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = int(style_blend_weights[i])
else:
style_blend_weights = params.style_blend_weights.split(',')
assert len(style_blend_weights) == len(style_image_list), \
"-style_blend_weights and -style_images must have the same number of elements!"
# Normalize the style blending weights so they sum to 1
style_blend_sum = 0
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = float(style_blend_weights[i])
style_blend_sum = float(style_blend_sum) + style_blend_weights[i]
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum)
content_layers = params.content_layers.split(',')
style_layers = params.style_layers.split(',')
# Set up the network, inserting style and content loss modules
cnn = copy.deepcopy(cnn)
content_losses, style_losses, tv_losses = [], [], []
next_content_idx, next_style_idx = 1, 1
net = nn.Sequential()
c, r = 0, 0
if params.tv_weight > 0:
tv_mod = TVLoss(params.tv_weight).type(dtype)
net.add_module(str(len(net)), tv_mod)
tv_losses.append(tv_mod)
for i, layer in enumerate(list(cnn), 1):
if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers):
if isinstance(layer, nn.Conv2d):
net.add_module(str(len(net)), layer)
if layerList['C'][c] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = ContentLoss(params.content_weight)
net.add_module(str(len(net)), loss_module)
content_losses.append(loss_module)
if layerList['C'][c] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = StyleLoss(params.style_weight)
net.add_module(str(len(net)), loss_module)
style_losses.append(loss_module)
c+=1
if isinstance(layer, nn.ReLU):
net.add_module(str(len(net)), layer)
if layerList['R'][r] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = ContentLoss(params.content_weight)
net.add_module(str(len(net)), loss_module)
content_losses.append(loss_module)
next_content_idx += 1
if layerList['R'][r] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = StyleLoss(params.style_weight)
net.add_module(str(len(net)), loss_module)
style_losses.append(loss_module)
next_style_idx += 1
r+=1
if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
net.add_module(str(len(net)), layer)
# Capture content targets
for i in content_losses:
i.mode = 'capture'
print("Capturing content targets")
print_torch(net)
net(content_image)
# Capture style targets
for i in content_losses:
i.mode = 'None'
for i, image in enumerate(style_images_caffe):
print("Capturing style target " + str(i+1))
for j in style_losses:
j.mode = 'capture'
j.blend_weight = style_blend_weights[i]
net(style_images_caffe[i])
# Set all loss modules to loss mode
for i in content_losses:
i.mode = 'loss'
for i in style_losses:
i.mode = 'loss'
# Freeze the network in order to prevent
# unnecessary gradient calculations
for param in net.parameters():
param.requires_grad = False
# Initialize the image
if params.seed >= 0:
torch.manual_seed(params.seed)
torch.cuda.manual_seed(params.seed)
torch.backends.cudnn.deterministic=True
if params.init == 'random':
B, C, H, W = content_image.size()
img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype)
elif params.init == 'image':
if params.init_image != None:
img = init_image.clone()
else:
img = content_image.clone()
img = nn.Parameter(img.type(dtype))
def maybe_print(t, loss):
if params.print_iter > 0 and t % params.print_iter == 0:
print("Iteration " + str(t) + " / "+ str(params.num_iterations))
for i, loss_module in enumerate(content_losses):
print(" Content " + str(i+1) + " loss: " + str(loss_module.loss.item()))
for i, loss_module in enumerate(style_losses):
print(" Style " + str(i+1) + " loss: " + str(loss_module.loss.item()))
print(" Total loss: " + str(loss.item()))
def maybe_save(t):
should_save = params.save_iter > 0 and t % params.save_iter == 0
should_save = should_save or t == params.num_iterations
if should_save:
output_filename, file_extension = os.path.splitext(params.output_image)
if t == params.num_iterations:
filename = output_filename + str(file_extension)
else:
filename = str(output_filename) + "_" + str(t) + str(file_extension)
disp = deprocess(img.clone())
# Maybe perform postprocessing for color-independent style transfer
if params.original_colors == 1:
disp = original_colors(deprocess(content_image.clone()), disp)
disp.save(str(filename))
# Function to evaluate loss and gradient. We run the net forward and
# backward to get the gradient, and sum up losses from the loss modules.
# optim.lbfgs internally handles iteration and calls this function many
# times, so we manually count the number of iterations to handle printing
# and saving intermediate results.
num_calls = [0]
def feval():
num_calls[0] += 1
optimizer.zero_grad()
net(img)
loss = 0
for mod in content_losses:
loss += mod.loss
for mod in style_losses:
loss += mod.loss
if params.tv_weight > 0:
for mod in tv_losses:
loss += mod.loss
loss.backward()
maybe_save(num_calls[0])
maybe_print(num_calls[0], loss)
return loss
optimizer, loopVal = setup_optimizer(img)
while num_calls[0] <= loopVal:
optimizer.step(feval)
# Configure the optimizer
def setup_optimizer(img):
if params.optimizer == 'lbfgs':
print("Running optimization with L-BFGS")
optim_state = {
'max_iter': params.num_iterations,
'tolerance_change': -1,
'tolerance_grad': -1,
}
if params.lbfgs_num_correction > 0:
optim_state['history_size'] = params.lbfgs_num_correction
optimizer = optim.LBFGS([img], **optim_state)
loopVal = 1
elif params.optimizer == 'adam':
print("Running optimization with ADAM")
optimizer = optim.Adam([img], lr = params.learning_rate)
loopVal = params.num_iterations - 1
return optimizer, loopVal
def setup_gpu():
if params.gpu > -1:
if params.backend == 'cudnn':
torch.backends.cudnn.enabled = True
if params.cudnn_autotune:
torch.backends.cudnn.benchmark = True
else:
torch.backends.cudnn.enabled = False
torch.cuda.set_device(params.gpu)
dtype = torch.cuda.FloatTensor
elif params.gpu == -1:
if params.backend =='mkl':
torch.backends.mkl.enabled = True
dtype = torch.FloatTensor
return dtype
# Preprocess an image before passing it to a model.
# We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR,
# and subtract the mean pixel.
def preprocess(image_name, image_size):
image = Image.open(image_name).convert('RGB')
if type(image_size) is not tuple:
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])])
tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0)
return tensor
# Undo the above preprocessing.
def deprocess(output_tensor):
Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])])
bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256
output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.cpu())
return image
# Combine the Y channel of the generated image and the UV/CbCr channels of the
# content image to perform color-independent style transfer.
def original_colors(content, generated):
content_channels = list(content.convert('YCbCr').split())
generated_channels = list(generated.convert('YCbCr').split())
content_channels[0] = generated_channels[0]
return Image.merge('YCbCr', content_channels).convert('RGB')
# Print like Lua/Torch7
def print_torch(net):
simplelist = ""
for i, layer in enumerate(net, 1):
simplelist = simplelist + "(" + str(i) + ") -> "
print("nn.Sequential ( \n [input -> " + simplelist + "output]")
def strip(x):
return str(x).replace(", ",',').replace("(",'').replace(")",'') + ", "
def n():
return " (" + str(i) + "): " + "nn." + str(l).split("(", 1)[0]
for i, l in enumerate(net, 1):
if "2d" in str(l):
ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding)
if "Conv2d" in str(l):
ch = str(l.in_channels) + " -> " + str(l.out_channels)
print(n() + "(" + ch + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')'))
elif "Pool2d" in str(l):
st = st.replace(" ",' ') + st.replace(", ",')')
print(n() + "(" + ((ks).replace(",",'x' + ks, 1) + st).replace(", ",','))
else:
print(n())
print(")")
# Define an nn Module to compute content loss
class ContentLoss(nn.Module):
def __init__(self, strength):
super(ContentLoss, self).__init__()
self.strength = strength
self.crit = nn.MSELoss()
self.mode = 'None'
def forward(self, input):
if self.mode == 'loss':
self.loss = self.crit(input, self.target) * self.strength
elif self.mode == 'capture':
self.target = input.detach()
return input
class GramMatrix(nn.Module):
def forward(self, input):
B, C, H, W = input.size()
x_flat = input.view(C, H * W)
return torch.mm(x_flat, x_flat.t())
# Define an nn Module to compute style loss
class StyleLoss(nn.Module):
def __init__(self, strength):
super(StyleLoss, self).__init__()
self.target = torch.Tensor()
self.strength = strength
self.gram = GramMatrix()
self.crit = nn.MSELoss()
self.mode = 'None'
self.blend_weight = None
def forward(self, input):
self.G = self.gram(input)
self.G = self.G.div(input.nelement())
if self.mode == 'capture':
if self.blend_weight == None:
self.target = self.G.detach()
elif self.target.nelement() == 0:
self.target = self.G.detach().mul(self.blend_weight)
else:
self.target = self.target.add(self.blend_weight, self.G.detach())
elif self.mode == 'loss':
self.loss = self.strength * self.crit(self.G, self.target)
return input
class TVLoss(nn.Module):
def __init__(self, strength):
super(TVLoss, self).__init__()
self.strength = strength
def forward(self, input):
self.x_diff = input[:,:,1:,:] - input[:,:,:-1,:]
self.y_diff = input[:,:,:,1:] - input[:,:,:,:-1]
self.loss = self.strength * (torch.sum(torch.abs(self.x_diff)) + torch.sum(torch.abs(self.y_diff)))
return input
if __name__ == "__main__":
main()
# To run this script you'll need to download the ultra-high res
# scan of Starry Night from the Google Art Project, using this command:
# wget -c https://upload.wikimedia.org/wikipedia/commons/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg -O starry_night_gigapixel.jpg
# Or you can manually download the image from here: https://commons.wikimedia.org/wiki/File:Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg
STYLE_IMAGE=starry_night_gigapixel.jpg
CONTENT_IMAGE=examples/inputs/hoovertowernight.jpg
STYLE_WEIGHT=5e2
STYLE_SCALE=1.0
PYTHON=python
GPU=0
$PYTHON neural_style.py \
-content_image $CONTENT_IMAGE \
-style_image $STYLE_IMAGE \
-style_scale $STYLE_SCALE \
-print_iter 1 \
-style_weight $STYLE_WEIGHT \
-image_size 256 \
-output_image out1.png \
-tv_weight 0 \
-gpu $GPU \
-backend cudnn -cudnn_autotune
$PYTHON neural_style.py \
-content_image $CONTENT_IMAGE \
-style_image $STYLE_IMAGE \
-init image -init_image out1.png \
-style_scale $STYLE_SCALE \
-print_iter 1 \
-style_weight $STYLE_WEIGHT \
-image_size 512 \
-num_iterations 500 \
-output_image out2.png \
-tv_weight 0 \
-gpu $GPU \
-backend cudnn -cudnn_autotune
$PYTHON neural_style.py \
-content_image $CONTENT_IMAGE \
-style_image $STYLE_IMAGE \
-init image -init_image out2.png \
-style_scale $STYLE_SCALE \
-print_iter 1 \
-style_weight $STYLE_WEIGHT \
-image_size 1024 \
-num_iterations 200 \
-output_image out3.png \
-tv_weight 0 \
-gpu $GPU \
-backend cudnn -cudnn_autotune
STYLE_WEIGHT=2500
$PYTHON neural_style.py \
-content_image $CONTENT_IMAGE \
-style_image $STYLE_IMAGE \
-init image -init_image out3.png \
-style_scale $STYLE_SCALE \
-print_iter 1 \
-style_weight $STYLE_WEIGHT \
-image_size 2048 \
-num_iterations 200 \
-output_image out4.png \
-tv_weight 0 \
-gpu $GPU \
-backend cudnn
$PYTHON neural_style.py \
-content_image $CONTENT_IMAGE \
-style_image $STYLE_IMAGE \
-init image -init_image out4.png \
-style_scale $STYLE_SCALE \
-print_iter 1 \
-style_weight $STYLE_WEIGHT \
-image_size 2350 \
-num_iterations 200 \
-output_image out5.png \
-tv_weight 0 \
-gpu $GPU \
-backend cudnn -optimizer adam
@ProGamerGov
Copy link
Author

Neural-Style Control:

Command Initial Memory Spike Memory Stabilizes At
LBFGS nn No Spike 3753 MiB
LBFGS cudnn No Spike 1755 MiB
LBFGS cudnn autotune 10257 MiB 1774 MiB
Adam nn No Spike 3753 MiB
Adam cudnn No Spike 1755 MiB
Adam cudnn autotune 2880 MiB 1774 MiB

@ProGamerGov
Copy link
Author

ProGamerGov commented May 9, 2018

Neural-Style-PT:

Command Initial Memory Spike Memory Stabilizes At
LBFGS nn No Spike 3930 MiB
LBFGS cudnn No Spike 1483 MiB
LBFGS cudnn autotune 9682 MiB 1541 MiB
Adam nn No Spike 3930 MiB
Adam cudnn No Spike 1194 MiB
Adam cudnn autotune 4291 MiB 965 MiB

Neural-Style Control:

Command Initial Memory Spike Memory Stabilizes At
LBFGS nn No Spike 3753 MiB
LBFGS cudnn No Spike 1755 MiB
LBFGS cudnn autotune 10257 MiB 1774 MiB
Adam nn No Spike 3753 MiB
Adam cudnn No Spike 1755 MiB
Adam cudnn autotune 2880 MiB 1774 MiB

Speed:

Command Neural-Style Neural-Style-PT
LBFGS nn 131.874060512 seconds 117.297629714 seconds
LBFGS cudnn 135.883114755 seconds 125.011222303 seconds
LBFGS cudnn autotune 118.96576196 seconds 110.40396905 seconds
Adam nn 119.635071182 seconds 100.37450546 seconds
Adam cudnn 123.493975282 seconds 107.182621241 seconds
Adam cudnn autotune 106.208788693 seconds 91.7877862453 seconds

@ProGamerGov
Copy link
Author

ProGamerGov commented May 11, 2018

Old Test:

Command Initial Memory Spike Memory Stabilizes At Time
LBFGS nn No Spike 3930 MiB 117 seconds
LBFGS cudnn No Spike 1483 MiB 125 seconds
LBFGS cudnn autotune 9682 MiB 1541 MiB 110 seconds
Adam nn No Spike 3930 MiB 100 seconds
Adam cudnn No Spike 1194 MiB 107 seconds
Adam cudnn autotune 4291 MiB 965 MiB 92 seconds

New Optimization Test Version:

Command Initial Memory Spike Memory Stabilizes At Time
LBFGS nn No Spike 3743 MiB 120 seconds
LBFGS cudnn No Spike 1260 MiB 127 seconds
LBFGS cudnn autotune 4351 MiB 1206 MiB 111 seconds
Adam nn No Spike 3744 MiB 103 seconds
Adam cudnn No Spike 1008 MiB 110 seconds
Adam cudnn autotune 9266 MiB 791 MiB 93 seconds

Neural-Style Control:

Command Initial Memory Spike Memory Stabilizes At Time
LBFGS nn No Spike 3753 MiB 132 seconds
LBFGS cudnn No Spike 1755 MiB 136 seconds
LBFGS cudnn autotune 10257 MiB 1774 MiB 119 seconds
Adam nn No Spike 3753 MiB 120 seconds
Adam cudnn No Spike 1755 MiB 123 seconds
Adam cudnn autotune 2880 MiB 1774 MiB 106 seconds

@ProGamerGov
Copy link
Author

ProGamerGov commented May 12, 2018

Command Initial Memory Spike Memory Stabilizes At Time
LBFGS nn No Spike 3743 MiB 116.584144433 seconds
LBFGS cudnn No Spike 1259-1260 MiB 123.626243353 seconds
LBFGS cudnn autotune 4351 MiB 1206 MiB 109.293155988 seconds
Adam nn No Spike 3744 MiB 99.8951559067 seconds
Adam cudnn No Spike 1008 MiB 106.525854747 seconds
Adam cudnn autotune 9266 MiB 791 MiB 91.1590566635 seconds
Command Initial Memory Spike Memory Stabilizes At Time
LBFGS nn No Spike 3744 MiB 116.580465356 seconds
LBFGS cudnn No Spike 1259-1260 MiB 123.617907127 seconds
LBFGS cudnn autotune 2520-9473 MiB 1207 MiB 109.26023519 seconds
Adam nn No Spike 3744 MiB 99.9323001305 seconds
Adam cudnn No Spike 1007-1008 MiB 106.575520198 seconds
Adam cudnn autotune 2520-9266 MiB 791 MiB 91.1483420531 seconds
Command Initial Memory Spike Memory Stabilizes At Time
LBFGS nn No Spike 3744 MiB 117 seconds
LBFGS cudnn No Spike 1259-1260 MiB 124 seconds
LBFGS cudnn autotune 2520-9473 MiB 1207 MiB 109 seconds
Adam nn No Spike 3744 MiB 100 seconds
Adam cudnn No Spike 1007-1008 MiB 107 seconds
Adam cudnn autotune 2520-9266 MiB 791 MiB 91 seconds

@ProGamerGov
Copy link
Author

With the default settings, neural-style uses about 3.5GB of GPU memory on my system; switching to ADAM and cuDNN reduces the GPU memory footprint to about 1GB.

With the default settings, neural-style-pt uses about 3.7 GB of GPU memory on my system; switching to ADAM and cuDNN reduces the GPU memory footprint to about 1GB.

@ProGamerGov
Copy link
Author

Speed

Speed can vary a lot depending on the backend and the optimizer.
Here are some times for running 500 iterations with -image_size=512 on a Tesla K80 with different settings:

  • -backend nn -optimizer lbfgs: 117 seconds
  • -backend nn -optimizer adam: 100 seconds
  • -backend cudnn -optimizer lbfgs: 124 seconds
  • -backend cudnn -optimizer adam: 107 seconds
  • -backend cudnn -cudnn_autotune -optimizer lbfgs: 109 seconds
  • -backend cudnn -cudnn_autotune -optimizer adam: 91 seconds

@ProGamerGov
Copy link
Author

ProGamerGov commented May 14, 2018

3 of: " (" + str(i+1) + "): " +

2 of: + "nn." +

2 of .split("(", 1)[0] ---> str(l).split("(", 1)[0]

2 of: .replace(", ",')'


.lower()

'' and ""


def n(i):
 return "  (" + str(i+1) + "): " 


    for i, l in enumerate(net): 
         is_2d = True if "2d" in str(l) else False
         is_conv = True if "Conv2d" in str(l) else False
         if is_2d:
             ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding).replace(", ",')')
             if is_conv:
                 in_c, out_c = str(l.in_channels), str(l.out_channels)
                 print( n(i) + "nn.Conv2d(" + in_c + " -> " + out_c + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')')) 
             else: 
                 print(n(i) + "nn." + str(l).split("(", 1)[0] + "(" + ((ks).replace(",",'x' + ks, 1) + st.replace("  ",' ') + st.replace(", ",')')).replace(", ",',') )
         else:
             print(n(i) + "nn." + str(l).split("(", 1)[0]) 

    def strip(x):
        return str(x).replace(", ",',').replace("(",'').replace(")",'') + ", "


    for i, l in enumerate(net): 
         is_2d = True if "2d" in str(l) else False
         is_conv = True if "Conv2d" in str(l) else False
         if is_2d:
             ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding).replace(", ",')')
             if is_conv:
                 in_c, out_c = str(l.in_channels), str(l.out_channels)
                 print( n(i) + "nn.Conv2d(" + in_c + " -> " + out_c + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')')) 
             else: 
                 print(n(i) + "nn." + str(l).split("(", 1)[0] + "(" + ((ks).replace(",",'x' + ks, 1) + st.replace("  ",' ') + st.replace(", ",')')).replace(", ",',') )
         else:
             print(n(i) + "nn." + str(l).split("(", 1)[0]) 

May also want to add "Is Pool2d to:

         is_2d = True if "2d" in str(l) else False
         is_conv = True if "Conv2d" in str(l) else False




May or may not need .lower() in this function:

def modelSelector(model_file, pooling):
    if "vgg19" in str(model_file).lower():
        print("VGG-19 Architecture Detected")
        cnn, layerList = VGG(buildSequential(channel_list['VGG-19'], pooling)), vgg19_dict
    elif "vgg16" in str(model_file).lower():
        print("VGG-16 Architecture Detected")
        cnn, layerList = VGG(buildSequential(channel_list['VGG-16'], pooling)), vgg16_dict
    elif "nin" in str(model_file).lower():
        print("NIN Architecture Detected")
        cnn, layerList = NIN(pooling), nin_dict
    else:
        print("Model Architecture Not Recognized")
        raise ValueError("""Model Architecture Not Recognized. Please ensure that the model
        name contains either "vgg16", "vgg19", or "nin", in the file name.""")       
    return cnn, layerList

@ProGamerGov
Copy link
Author

The code in this Gist was continued from here: https://gist.github.com/ProGamerGov/4fbb4a8340ae654a3ae460ccddb7757c

@ProGamerGov
Copy link
Author

The latest version of this Gist can be found here: https://github.com/ProGamerGov/neural-style-pt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment