Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Last active March 16, 2020 17:31
Show Gist options
  • Save ProGamerGov/a416cc21a9ce454fdc160ad846410237 to your computer and use it in GitHub Desktop.
Save ProGamerGov/a416cc21a9ce454fdc160ad846410237 to your computer and use it in GitHub Desktop.

neural-dream

This is a PyTorch implementation of DeepDream. The code is based on neural-style-pt.

Here we DeepDream a photograph of the Golden Gate Bridge with a variety of settings:

Specific Channel Selection

You can select individual or specific combinations of channels.

Clockwise from upper left: 119, 1, 29, and all channels of the inception_4d_3x3_reduce layer

Clockwise from upper left: 25, 108, 25 & 108, and 25 & 119 from the inception_4d_3x3_reduce layer

Channel Selection Based On Activation Strength

You can select channels automatically based on their activation strength.

Clockwise from upper left: The top 10 weakest channels, the 10 most average channels, the top 10 strongest channels, and all channels of the inception_4e_3x3_reduce layer

Setup:

Dependencies:

Optional dependencies:

  • For CUDA backend:
    • CUDA 7.5 or above
  • For cuDNN backend:
    • cuDNN v6 or above
  • For ROCm backend:
    • ROCm 2.1 or above
  • For MKL backend:
    • MKL 2019 or above
  • For OpenMP backend:
    • OpenMP 5.0 or above

After installing the dependencies, you'll need to run the following script to download the BVLC GoogleNet model:

python models/download_models.py

This will download the original BVLC GoogleNet model.

If you have a smaller memory GPU then using NIN Imagenet model could be an alternative to the BVLC GoogleNet model, though it's DeepDream quality is nowhere near that of the other models. You can get the details on the model from BVLC Caffe ModelZoo. The NIN model is downloaded when you run the download_models.py script with default parameters.

To download most of the compatible models, run the download_models.py script with following parameters:

python models/download_models.py -models all

You can find detailed installation instructions for Ubuntu and Windows in the installation guide.

Usage

Basic usage:

python neural_dream.py -content_image <image.jpg>

cuDNN usage with NIN Model:

python neural_dream.py -content_image examples/inputs/brad_pitt.jpg -output_image pitt_nin_cudnn.png -model_file models/nin_imagenet.pth -gpu 0 -backend cudnn -num_iterations 10 -seed 876 -dream_layers relu0,relu3,relu7,relu12 -dream_weight 10 -image_size 512 -optimizer adam -learning_rate 0.1

cuDNN NIN Model Picasso Brad Pitt

Note that paths to images should not contain the ~ character to represent your home directory; you should instead use a relative path or a full absolute path.

Options:

  • -image_size: Maximum side length (in pixels) of the generated image. Default is 512.
  • -gpu: Zero-indexed ID of the GPU to use; for CPU mode set -gpu to c.

Optimization options:

  • -dream_weight: How much to weight DeepDream. Default is 1e3.
  • -tv_weight: Weight of total-variation (TV) regularization; this helps to smooth the image. Default is set to 0 to disable total-variation (TV) regularization.
  • -l2_weight: Weight of latent state regularization. Default is set to 0 to disable latent state regularization.
  • -num_iterations: Default is 10.
  • -init: Method for generating the generated image; one of random or image. Default is image which initializes with the content image; random uses random noise to initialize the input image.
  • -jitter: Apply jitter to image. Default is 32. Set to 0 to disable jitter.
  • -layer_sigma: Apply gaussian blur to image. Default is set to 0 to disable the gaussian blur layer.
  • -optimizer: The optimization algorithm to use; either lbfgs or adam; default is adam. Adam tends to perform the best for DeepDream. L-BFGS tends to give worse results and it uses more memory; when using L-BFGS you will probably need to play with other parameters to get good results, especially the learning rate.
  • -learning_rate: Learning rate to use with the ADAM and L-BFGS optimizers. Default is 1.5.
  • -normalize_weights: If this flag is present, dream weights will be divided by the number of channels for each layer. Idea from PytorchNeuralStyleTransfer.
  • -loss_mode: The DeepDream loss mode; bce, mse, mean, norm, or l2; default is l2.

Output options:

  • -output_image: Name of the output image. Default is out.png.
  • -print_iter: Print progress every print_iter iterations. Set to 0 to disable printing.
  • -print_octave_iter: Print octave progress every print_octave_iter iterations. Default is set to 0 to disable printing.
  • -save_iter: Save the image every save_iter iterations. Set to 0 to disable saving intermediate results.
  • -save_octave_iter: Save the image every save_octave_iter iterations. Default is set to 0 to disable saving intermediate results.

Layer options:

  • -dream_layers: Comma-separated list of layer names to use for DeepDream reconstruction.

Channel options:

  • -channels: Comma-separated list of channels to use for DeepDream. If -channel_mode is set to a value other than all, only the first value in the list will be used.
  • -channel_mode: The DeepDream channel selection mode; all, strong, avg, weak, or ignore; default is all. The strong option will select the strongest channels, while weak will do the same with the weakest channels. The avg option will select the most average channels instead of the strongest or weakest. The number of channels selected by strong, avg, or weak is based on the first value for the -channels parameter. The ignore option will omit any specified channels.
  • -channel_capture: How often to select channels based on activation strength; either once or iter; default is once. The once option will select channels once at the start, while the iter will select potentially new channels every iteration. This parameter only comes into play if -channel_mode is not set to all.

Octave Options:

  • -num_octaves: Number of octaves per iteration. Default is 4.
  • -octave_scale: Value for resizing the image by. Default is 0.6.
  • -octave_iter: Number of iterations per octave. Default is 50. On other DeepDream projects this parameter is commonly called 'steps'.
  • -octave_mode: The octave size calculation mode; one of normal or advanced. Default is normal.

Laplacian Pyramid Options:

  • -lap_scale: The number of layers in a layer's laplacian pyramid. Default is set to 0 to disable laplacian pyramids.
  • -sigma: The strength of gaussian blur to use in laplacian pyramids. Default is 1. By default, unless a second sigma value is provided with a comma to separate it from the first, the high gaussian layers will use sigma sigma * lap_scale.

Zoom Options:

  • -zoom: The amount to zoom in on the image.
  • -zoom_mode: Whether to read the zoom value as a percentage or pixel value; one of percentage or pixel. Default is percentage.

FFT Options:

  • -use_fft: Whether to enable Fast Fourier transform (FFT) decorrelation.
  • -fft_block: The size of your FFT frequency filtering block. Default is 25.

Help Options:

  • -print_layers: Pass this flag to print the names of all usable layers for the selected model.
  • -print_channels: Pass this flag to print all the selected channels.

Other options:

  • -original_colors: If you set this to 1, then the output image will keep the colors of the content image.
  • -model_file: Path to the .pth file for the VGG Caffe model. Default is the original VGG-19 model; you can also try the original VGG-16 model.
  • -model_type: Whether the model was trained using Caffe or PyTorch preprocessing; caffe, pytorch, or auto; default is auto.
  • -model_mean: A comma separated list of 3 numbers for the model's mean; default is auto.
  • -pooling: The type of pooling layers to use for VGG and NIN models; one of max or avg. Default is max. VGG models seem to create better results with average pooling.
  • -seed: An integer value that you can specify for repeatable results. By default this value is random for each run.
  • -multidevice_strategy: A comma-separated list of layer indices at which to split the network when using multiple devices. See Multi-GPU scaling for more details. Currently this feature only works for VGG and NIN models.
  • -backend: nn, cudnn, openmp, or mkl. Default is nn. mkl requires Intel's MKL backend.
  • -cudnn_autotune: When using the cuDNN backend, pass this flag to use the built-in cuDNN autotuner to select the best convolution algorithms for your architecture. This will make the first iteration a bit slower and can take a bit more memory, but may significantly speed up the cuDNN backend.
  • -clamp: If this flag is enabled, every iteration will clamp the output image clamped so it is within the model's input range.
  • -adjust_contrast: A value between 0 and 100.0 for altering the image's contrast (ex: 99.98). Default is set to 0 to disable contrast adjustments.
  • -label_file: Path to the .txt category list file for classification and channel selection.
  • -random_transforms: Whether to use random transforms on the image; either none, rotate, flip, or all; default is none.
  • -classify: Display what the model thinks an image contains. Integer for the number of choices ranked by how likely each is.

Frequently Asked Questions

Problem: The program runs out of memory and dies

Solution: Try reducing the image size: -image_size 512 (or lower). Note that different image sizes will likely require non-default values for -octave_scale and -num_octaves for optimal results. If you are running on a GPU, you can also try running with -backend cudnn to reduce memory usage.

Problem: -backend cudnn is slower than default NN backend

Solution: Add the flag -cudnn_autotune; this will use the built-in cuDNN autotuner to select the best convolution algorithms.

Problem: Get the following error message:

Missing key(s) in state_dict: "classifier.0.bias", "classifier.0.weight", "classifier.3.bias", "classifier.3.weight". Unexpected key(s) in state_dict: "classifier.1.weight", "classifier.1.bias", "classifier.4.weight", "classifier.4.bias".

Solution: Due to a mix up with layer locations, older models require a fix to be compatible with newer versions of PyTorch. The included donwload_models.py script will automatically perform these fixes after downloading the models.

Problem: Get the following error message:

Given input size: (...). Calculated output size: (...). Output size is too small

Solution: Use a larger -image_size value and/or adjust the octave parameters so that the smallest octave size is larger.

Memory Usage

By default, neural-dream uses the nn backend for convolutions and Adam for optimization. These give good results, but can both use a lot of memory. You can reduce memory usage with the following:

  • Use cuDNN: Add the flag -backend cudnn to use the cuDNN backend. This will only work in GPU mode.
  • Reduce image size: You can reduce the size of the generated image to lower memory usage; pass the flag -image_size 256 to generate an image at half the default size.

With the default settings, neural-dream uses about 1.3 GB of GPU memory on my system; switching to cuDNN reduces the GPU memory footprint to about 1 GB.

Multi-GPU scaling

You can use multiple CPU and GPU devices to process images at higher resolutions; different layers of the network will be computed on different devices. You can control which GPU and CPU devices are used with the -gpu flag, and you can control how to split layers across devices using the -multidevice_strategy flag.

For example in a server with four GPUs, you can give the flag -gpu 0,1,2,3 to process on GPUs 0, 1, 2, and 3 in that order; by also giving the flag -multidevice_strategy 3,6,12 you indicate that the first two layers should be computed on GPU 0, layers 3 to 5 should be computed on GPU 1, layers 6 to 11 should be computed on GPU 2, and the remaining layers should be computed on GPU 3. You will need to tune the -multidevice_strategy for your setup in order to achieve maximal resolution.

We can achieve very high quality results at high resolution by combining multi-GPU processing with multiscale generation as described in the paper Controlling Perceptual Factors in Neural Style Transfer by Leon A. Gatys, Alexander S. Ecker, Matthias Bethge, Aaron Hertzmann and Eli Shechtman.

import torch
import torch.nn as nn
import torchvision
from torchvision import models
from neural_dream.googlenet_models import GoogLeNetPlaces205, GoogLeNetPlaces365, BVLC_GOOGLENET, GoogleNet_SOS, GOOGLENET_CARS, googlenet_layer_names
from neural_dream.inception_models import Inception5h, inception_layer_names
from neural_dream.resnet_models import ResNet_50_1by2_nsfw, resnet_layer_names
class VGG(nn.Module):
def __init__(self, features, num_classes=1000):
super(VGG, self).__init__()
self.features = features
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
class VGG_SOD(nn.Module):
def __init__(self, features, num_classes=100):
super(VGG_SOD, self).__init__()
self.features = features
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 100),
)
class VGG_FCN32S(nn.Module):
def __init__(self, features, num_classes=1000):
super(VGG_FCN32S, self).__init__()
self.features = features
self.classifier = nn.Sequential(
nn.Conv2d(512,4096,(7, 7)),
nn.ReLU(True),
nn.Dropout(0.5),
nn.Conv2d(4096,4096,(1, 1)),
nn.ReLU(True),
nn.Dropout(0.5),
)
class VGG_PRUNED(nn.Module):
def __init__(self, features, num_classes=1000):
super(VGG_PRUNED, self).__init__()
self.features = features
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(0.5),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(0.5),
)
class NIN(nn.Module):
def __init__(self, pooling):
super(NIN, self).__init__()
if pooling == 'max':
pool2d = nn.MaxPool2d((3, 3),(2, 2),(0, 0),ceil_mode=True)
elif pooling == 'avg':
pool2d = nn.AvgPool2d((3, 3),(2, 2),(0, 0),ceil_mode=True)
self.features = nn.Sequential(
nn.Conv2d(3,96,(11, 11),(4, 4)),
nn.ReLU(inplace=True),
nn.Conv2d(96,96,(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(96,96,(1, 1)),
nn.ReLU(inplace=True),
pool2d,
nn.Conv2d(96,256,(5, 5),(1, 1),(2, 2)),
nn.ReLU(inplace=True),
nn.Conv2d(256,256,(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(256,256,(1, 1)),
nn.ReLU(inplace=True),
pool2d,
nn.Conv2d(256,384,(3, 3),(1, 1),(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(384,384,(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(384,384,(1, 1)),
nn.ReLU(inplace=True),
pool2d,
nn.Dropout(0.5),
nn.Conv2d(384,1024,(3, 3),(1, 1),(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(1024,1024,(1, 1)),
nn.ReLU(inplace=True),
nn.Conv2d(1024,1000,(1, 1)),
nn.ReLU(inplace=True),
nn.AvgPool2d((6, 6),(1, 1),(0, 0),ceil_mode=True),
nn.Softmax(),
)
class ModelParallel(nn.Module):
def __init__(self, net, device_ids, device_splits):
super(ModelParallel, self).__init__()
self.device_list = self.name_devices(device_ids.split(','))
self.chunks = self.chunks_to_devices(self.split_net(net, device_splits.split(',')))
def name_devices(self, input_list):
device_list = []
for i, device in enumerate(input_list):
if str(device).lower() != 'c':
device_list.append("cuda:" + str(device))
else:
device_list.append("cpu")
return device_list
def split_net(self, net, device_splits):
chunks, cur_chunk = [], nn.Sequential()
for i, l in enumerate(net):
cur_chunk.add_module(str(i), net[i])
if str(i) in device_splits and device_splits != '':
del device_splits[0]
chunks.append(cur_chunk)
cur_chunk = nn.Sequential()
chunks.append(cur_chunk)
return chunks
def chunks_to_devices(self, chunks):
for i, chunk in enumerate(chunks):
chunk.to(self.device_list[i])
return chunks
def c(self, input, i):
if input.type() == 'torch.FloatTensor' and 'cuda' in self.device_list[i]:
input = input.type('torch.cuda.FloatTensor')
elif input.type() == 'torch.cuda.FloatTensor' and 'cpu' in self.device_list[i]:
input = input.type('torch.FloatTensor')
return input
def forward(self, input):
for i, chunk in enumerate(self.chunks):
if i < len(self.chunks) -1:
input = self.c(chunk(self.c(input, i).to(self.device_list[i])), i+1).to(self.device_list[i+1])
else:
input = chunk(input)
return input
def buildSequential(channel_list, pooling):
layers = []
in_channels = 3
if pooling == 'max':
pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
elif pooling == 'avg':
pool2d = nn.AvgPool2d(kernel_size=2, stride=2)
else:
raise ValueError("Unrecognized pooling parameter")
for c in channel_list:
if c == 'P':
layers += [pool2d]
else:
conv2d = nn.Conv2d(in_channels, c, kernel_size=3, padding=1)
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = c
return nn.Sequential(*layers)
channel_list = {
'VGG-11': [64, 'P', 128, 'P', 256, 256, 'P', 512, 512, 'P', 512, 512, 'P'],
'VGG-13': [64, 64, 'P', 128, 128, 'P', 256, 256, 'P', 512, 512, 'P', 512, 512, 'P'],
'VGG-16p': [24, 22, 'P', 41, 51, 'P', 108, 89, 111, 'P', 184, 276, 228, 'P', 512, 512, 512, 'P'],
'VGG-16': [64, 64, 'P', 128, 128, 'P', 256, 256, 256, 'P', 512, 512, 512, 'P', 512, 512, 512, 'P'],
'VGG-19': [64, 64, 'P', 128, 128, 'P', 256, 256, 256, 256, 'P', 512, 512, 512, 512, 'P', 512, 512, 512, 512, 'P'],
}
nin_dict = {
'C': ['conv1', 'cccp1', 'cccp2', 'conv2', 'cccp3', 'cccp4', 'conv3', 'cccp5', 'cccp6', 'conv4-1024', 'cccp7-1024', 'cccp8-1024'],
'R': ['relu0', 'relu1', 'relu2', 'relu3', 'relu5', 'relu6', 'relu7', 'relu8', 'relu9', 'relu10', 'relu11', 'relu12'],
'P': ['pool1', 'pool2', 'pool3', 'pool4'],
'D': ['drop'],
}
vgg11_dict = {
'C': ['conv1_1', 'conv2_1', 'conv3_1', 'conv3_2', 'conv4_1', 'conv4_2', 'conv5_1', 'conv5_2'],
'R': ['relu1_1', 'relu2_1', 'relu3_1', 'relu3_2', 'relu4_1', 'relu4_2', 'relu5_1', 'relu5_2', 'relu6', 'relu7'],
'P': ['pool1', 'pool2', 'pool3', 'pool4', 'pool5'],
'L': ['fc6', 'fc7', 'fc8'],
'D': ['drop6', 'drop7'],
}
vgg13_dict = {
'C': ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv4_1', 'conv4_2', 'conv5_1', 'conv5_2'],
'R': ['relu1_1', 'relu1_2', 'relu2_1', 'relu2_2', 'relu3_1', 'relu3_2', 'relu4_1', 'relu4_2', 'relu5_1', 'relu5_2', 'relu6', 'relu7'],
'P': ['pool1', 'pool2', 'pool3', 'pool4', 'pool5'],
'L': ['fc6', 'fc7', 'fc8'],
'D': ['drop6', 'drop7'],
}
vgg16_dict = {
'C': ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv4_1', 'conv4_2', 'conv4_3', 'conv5_1', 'conv5_2', 'conv5_3'],
'R': ['relu1_1', 'relu1_2', 'relu2_1', 'relu2_2', 'relu3_1', 'relu3_2', 'relu3_3', 'relu4_1', 'relu4_2', 'relu4_3', 'relu5_1', 'relu5_2', 'relu5_3', 'relu6', 'relu7'],
'P': ['pool1', 'pool2', 'pool3', 'pool4', 'pool5'],
'L': ['fc6', 'fc7', 'fc8'],
'D': ['drop6', 'drop7'],
}
vgg19_dict = {
'C': ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv3_4', 'conv4_1', 'conv4_2', 'conv4_3', 'conv4_4', 'conv5_1', 'conv5_2', 'conv5_3', 'conv5_4'],
'R': ['relu1_1', 'relu1_2', 'relu2_1', 'relu2_2', 'relu3_1', 'relu3_2', 'relu3_3', 'relu3_4', 'relu4_1', 'relu4_2', 'relu4_3', 'relu4_4', 'relu5_1', 'relu5_2', 'relu5_3', 'relu5_4', 'relu6', 'relu7'],
'P': ['pool1', 'pool2', 'pool3', 'pool4', 'pool5'],
'L': ['fc6', 'fc7', 'fc8'],
'D': ['drop6', 'drop7'],
}
name_dict = {
'vgg': ['vgg'],
'vgg11': ['vgg-11', 'vgg11', 'vgg_11'],
'vgg13': ['vgg-13', 'vgg13', 'vgg_13'],
'vgg16': ['vgg-16', 'vgg16', 'vgg_16', 'fcn32s', 'pruning', 'sod'],
'vgg19': ['vgg-19', 'vgg19', 'vgg_19',],
}
ic_dict = {
'inception': ['inception'],
'googlenet': ['googlenet'],
'inceptionv3': ['inception_v3', 'inceptionv3'],
'resnet': ['resnet'],
}
def build_googlenet_list(cnn):
main_layers = ['conv1', 'maxpool1', 'conv2', 'conv3', 'maxpool2', 'inception3a', 'inception3b', 'maxpool3', \
'inception4a', 'inception4b', 'inception4c', 'inception4d', 'inception4e', 'maxpool4', 'inception5a', \
'inception5b', 'aux1', 'aux2', 'avgpool', 'dropout', 'fc']
branch_list = ['branch1', 'branch2', 'branch3', 'branch4']
ax = ['conv', 'fc1', 'fc2']
conv_block =['conv', 'bn']
layer_name_list = []
for i, layer in enumerate(list(cnn.children())):
if 'BasicConv2d' in str(type(layer)):
for bl, block in enumerate(list(layer.children())):
name = main_layers[i] + '/' + conv_block[bl]
layer_name_list.append(name)
elif 'Inception' in str(type(layer)) and 'Aux' not in str(type(layer)):
for br, branch in enumerate(list(layer.children())):
for bl, block in enumerate(list(branch.children())):
name = main_layers[i] + '/' + branch_list[br] + '/' + conv_block[bl]
layer_name_list.append(name)
elif 'Inception' in str(type(layer)) and 'Aux' in str(type(layer)):
for bl, block in enumerate(list(layer.children())):
name = main_layers[i] + '/' + ax[bl]
layer_name_list.append(name)
elif isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AdaptiveAvgPool2d) \
or isinstance(layer, nn.Dropout) or isinstance(layer, nn.Linear):
layer_name_list.append(main_layers[i])
return layer_name_list
def build_inceptionv3_list(cnn):
main_layers = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'Mixed_5b', \
'Mixed_5c', 'Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'AuxLogits', 'Mixed_7a', \
'Mixed_7b', 'Mixed_7c', 'fc']
ba = ['branch1x1', 'branch5x5_1', 'branch5x5_2', 'branch3x3dbl_1', 'branch3x3dbl_2', 'branch3x3dbl_3', 'branch_pool']
bb = ['branch3x3', 'branch3x3dbl_1', 'branch3x3dbl_2', 'branch3x3dbl_3']
bc = ['branch1x1', 'branch7x7_1', 'branch7x7_2', 'branch7x7_3', 'branch7x7dbl_1', 'branch7x7dbl_2', 'branch7x7dbl_3', \
'branch7x7dbl_4', 'branch7x7dbl_5', 'branch_pool']
bd = ['branch3x3_1', 'branch3x3_2', 'branch7x7x3_1', 'branch7x7x3_2', 'branch7x7x3_3', 'branch7x7x3_4']
be = ['branch1x1', 'branch3x3_1', 'branch3x3_2a', 'branch3x3_2b', 'branch3x3dbl_1', 'branch3x3dbl_2', \
'branch3x3dbl_3a', 'branch3x3dbl_3b', 'branch_pool']
ax = ['conv0', 'conv1', 'fc']
conv_block =['conv', 'bn']
layer_name_list = []
for i, layer in enumerate(list(cnn.children())):
if 'BasicConv2d' in str(type(layer)):
for bl, block in enumerate(list(layer.children())):
name = main_layers[i] + '/' + conv_block[bl]
layer_name_list.append(name)
elif 'Inception' in str(type(layer)) and 'Aux' not in str(type(layer)):
if 'InceptionA' in str(type(layer)):
branch_list = ba
elif 'InceptionB' in str(type(layer)):
branch_list = bb
elif 'InceptionC' in str(type(layer)):
branch_list = bc
elif 'InceptionD' in str(type(layer)):
branch_list = bd
elif 'InceptionE' in str(type(layer)):
branch_list = be
for br, branch in enumerate(list(layer.children())):
for bl, block in enumerate(list(branch.children())):
name = main_layers[i] + '/' + branch_list[br] + '/' + conv_block[bl]
layer_name_list.append(name)
elif 'Inception' in str(type(layer)) and 'Aux' in str(type(layer)):
for bl, block in enumerate(list(layer.children())):
name = main_layers[i] + '/' + ax[bl]
layer_name_list.append(name)
elif isinstance(layer, nn.Linear):
layer_name_list.append(main_layers[i])
return layer_name_list
def modelSelector(model_file, pooling):
if any(name in model_file for name in name_dict):
if any(name in model_file for name in name_dict['vgg16']):
print("VGG-16 Architecture Detected")
if "pruning" in model_file:
print("Using The Channel Pruning Model")
cnn, layerList = VGG_PRUNED(buildSequential(channel_list['VGG-16p'], pooling)), vgg16_dict
elif "fcn32s" in model_file:
print("Using the fcn32s-heavy-pascal Model")
cnn, layerList = VGG_FCN32S(buildSequential(channel_list['VGG-16'], pooling)), vgg16_dict
layerList['C'] = layerList['C'] + layerList['L']
elif "sod" in model_file:
print("Using The SOD Fintune Model")
cnn, layerList = VGG_SOD(buildSequential(channel_list['VGG-16'], pooling)), vgg16_dict
elif "16" in model_file:
cnn, layerList = VGG(buildSequential(channel_list['VGG-16'], pooling)), vgg16_dict
elif any(name in model_file for name in name_dict['vgg19']):
print("VGG-19 Architecture Detected")
if "19" in model_file:
cnn, layerList = VGG(buildSequential(channel_list['VGG-19'], pooling)), vgg19_dict
elif any(name in model_file for name in name_dict['vgg13']):
print("VGG-13 Architecture Detected")
cnn, layerList = VGG(buildSequential(channel_list['VGG-13'], pooling)), vgg13_dict
elif any(name in model_file for name in name_dict['vgg11']):
print("VGG-11 Architecture Detected")
cnn, layerList = VGG(buildSequential(channel_list['VGG-11'], pooling)), vgg11_dict
else:
raise ValueError("VGG architecture not recognized.")
elif "googlenet" in model_file:
print("GoogLeNet Architecture Detected")
if '205' in model_file:
cnn, layerList = GoogLeNetPlaces205(), googlenet_layer_names('places')
elif '365' in model_file:
cnn, layerList = GoogLeNetPlaces365(), googlenet_layer_names('places')
elif 'bvlc' in model_file:
cnn, layerList = BVLC_GOOGLENET(), googlenet_layer_names('bvlc')
elif 'cars' in model_file:
cnn, layerList = GOOGLENET_CARS(), googlenet_layer_names('cars')
elif 'sos' in model_file:
cnn, layerList = GoogleNet_SOS(), googlenet_layer_names('sos')
else:
cnn, layerList = models.googlenet(pretrained=False, transform_input=False), ''
elif "inception" in model_file:
print("Inception Architecture Detected")
if 'inception5h' in model_file:
cnn, layerList = Inception5h(), inception_layer_names('5h')
else:
cnn, layerList = models.inception_v3(pretrained=False, transform_input=False), ''
elif "resnet" in model_file:
print("ResNet Architecture Detected")
if 'resnet_50_1by2_nsfw' in model_file:
cnn, layerList = ResNet_50_1by2_nsfw(), resnet_layer_names
else:
raise ValueError("ResNet architecture not recognized.")
elif "nin" in model_file:
print("NIN Architecture Detected")
cnn, layerList = NIN(pooling), nin_dict
else:
raise ValueError("Model architecture not recognized.")
return cnn, layerList
# Print like Torch7/loadcaffe
def print_loadcaffe(cnn, layerList):
c = 0
for l in list(cnn):
if "Conv2d" in str(l) and "Basic" not in str(l):
in_c, out_c, ks = str(l.in_channels), str(l.out_channels), str(l.kernel_size)
print(layerList['C'][c] +": " + (out_c + " " + in_c + " " + ks).replace(")",'').replace("(",'').replace(",",'') )
c+=1
if c == len(layerList['C']):
break
class Flatten(nn.Module):
def forward(self, input):
return torch.flatten(input, 1)
def add_classifier_layers(cnn, pooling='avg'):
new_cnn, cnn_classifier = cnn.features, cnn.classifier
if 'avg' in pooling:
adaptive_pool2d = nn.AdaptiveAvgPool2d((7, 7))
elif 'max' in pooling:
adaptive_pool2d = nn.AdaptiveMaxPool2d((7, 7))
new_cnn.add_module(str(len(new_cnn)), adaptive_pool2d)
if not isinstance(cnn, VGG_FCN32S):
flatten_layer = Flatten()
new_cnn.add_module(str(len(new_cnn)), flatten_layer)
for layer in cnn_classifier:
new_cnn.add_module(str(len(new_cnn)), layer)
return new_cnn
# Load the model, and configure pooling layer type
def loadCaffemodel(model_file, pooling, use_gpu, disable_check, add_classifier=False):
cnn, layerList = modelSelector(str(model_file).lower(), pooling)
cnn.load_state_dict(torch.load(model_file), strict=(not disable_check))
print("Successfully loaded " + str(model_file))
# Maybe convert the model to cuda now, to avoid later issues
if "c" not in str(use_gpu).lower() or "c" not in str(use_gpu[0]).lower():
cnn = cnn.cuda()
if not isinstance(cnn, NIN) and not any(name in model_file.lower() for name in ic_dict) and add_classifier:
cnn, has_inception = add_classifier_layers(cnn, pooling), False
elif any(name in model_file.lower() for name in ic_dict['googlenet']):
if '205' in model_file or '365' in model_file:
has_inception = True
elif 'cars' in model_file.lower() or 'sos' in model_file.lower() or 'bvlc' in model_file.lower():
has_inception = True
else:
layerList, has_inception = build_googlenet_list(cnn), True
elif any(name in model_file.lower() for name in ic_dict['inceptionv3']):
layerList, has_inception = build_inceptionv3_list(cnn), True
elif 'inception5h' in model_file.lower() or 'resnet' in model_file.lower():
has_inception = True
else:
cnn, has_inception = cnn.features, False
if has_inception:
cnn.eval()
cnn.has_inception = True
else:
cnn.has_inception = False
if not any(name in model_file.lower() for name in ic_dict):
print_loadcaffe(cnn, layerList)
if 'resnet' in model_file.lower():
cnn.add_layers()
return cnn, layerList
import torch
import argparse
from os import path
from sys import version_info
from collections import OrderedDict
from torch.utils.model_zoo import load_url
if version_info[0] < 3:
import urllib
else:
import urllib.request
options_list = ['all', 'caffe-vgg16', 'caffe-vgg19', 'caffe-nin', 'caffe-googlenet-places205', 'caffe-googlenet-places365', 'caffe-googlenet-bvlc', 'caffe-googlenet-cars', 'caffe-googlenet-sos', \
'caffe-resnet-opennsfw', 'pytorch-vgg16', 'pytorch-vgg19', 'pytorch-googlenet', 'pytorch-inceptionv3', 'tensorflow-inception5h', 'all-caffe', 'all-caffe-googlenet']
def main():
params = params_list()
if params.models == 'all':
params.models = options_list[1:15]
elif 'all-caffe' in params.models and 'all-caffe-googlenet' not in params.models:
params.models = options_list[1:10] + params.models.split(',')
elif 'all-caffe-googlenet' in params.models:
params.models = options_list[4:9] + params.models.split(',')
else:
params.models = params.models.split(',')
if 'caffe-vgg19' in params.models:
# Download the VGG-19 ILSVRC model and fix the layer names
print("Downloading the VGG-19 ILSVRC model")
sd = load_url("https://web.eecs.umich.edu/~justincj/models/vgg19-d01eb7cb.pth")
map = {'classifier.1.weight':u'classifier.0.weight', 'classifier.1.bias':u'classifier.0.bias', 'classifier.4.weight':u'classifier.3.weight', 'classifier.4.bias':u'classifier.3.bias'}
sd = OrderedDict([(map[k] if k in map else k,v) for k,v in sd.items()])
torch.save(sd, path.join(params.download_path, "vgg19-d01eb7cb.pth"))
if 'caffe-vgg16' in params.models:
# Download the VGG-16 ILSVRC model and fix the layer names
print("Downloading the VGG-16 ILSVRC model")
sd = load_url("https://web.eecs.umich.edu/~justincj/models/vgg16-00b39a1b.pth")
map = {'classifier.1.weight':u'classifier.0.weight', 'classifier.1.bias':u'classifier.0.bias', 'classifier.4.weight':u'classifier.3.weight', 'classifier.4.bias':u'classifier.3.bias'}
sd = OrderedDict([(map[k] if k in map else k,v) for k,v in sd.items()])
torch.save(sd, path.join(params.download_path, "vgg16-00b39a1b.pth"))
if 'caffe-nin' in params.models:
# Download the NIN model
print("Downloading the NIN model")
fileurl = "https://raw.githubusercontent.com/ProGamerGov/pytorch-nin/master/nin_imagenet.pth"
name = "nin_imagenet.pth"
download_file(fileurl, name, params.download_path)
if 'caffe-googlenet-places205' in params.models:
# Download the Caffe GoogeLeNet Places205 model
print("Downloading the Places205 GoogeLeNet model")
fileurl = "https://github.com/ProGamerGov/pytorch-places/raw/master/googlenet_places205.pth"
name = "googlenet_places205.pth"
download_file(fileurl, name, params.download_path)
if 'caffe-googlenet-places365' in params.models:
# Download the Caffe GoogeLeNet Places365 model
print("Downloading the Places365 GoogeLeNet model")
fileurl = "https://github.com/ProGamerGov/pytorch-places/raw/master/googlenet_places365.pth"
name = "googlenet_places365.pth"
download_file(fileurl, name, params.download_path)
if 'caffe-googlenet-bvlc' in params.models:
# Download the Caffe BVLC GoogeLeNet model
print("Downloading the BVLC GoogeLeNet model")
fileurl = "https://github.com/ProGamerGov/pytorch-old-caffemodels/raw/master/bvlc_googlenet.pth"
name = "bvlc_googlenet.pth"
download_file(fileurl, name, params.download_path)
if 'caffe-googlenet-cars' in params.models:
# Download the Caffe GoogeLeNet Cars model
print("Downloading the Cars GoogeLeNet model")
fileurl = "https://github.com/ProGamerGov/pytorch-old-caffemodels/raw/master/googlenet_finetune_web_cars.pth"
name = "googlenet_finetune_web_cars.pth"
download_file(fileurl, name, params.download_path)
if 'caffe-googlenet-sos' in params.models:
# Download the Caffe GoogeLeNet SOS model
print("Downloading the SOD GoogeLeNet model")
fileurl = "https://github.com/ProGamerGov/pytorch-old-caffemodels/raw/master/GoogleNet_SOS.pth"
name = "GoogleNet_SOS.pth"
download_file(fileurl, name, params.download_path)
if 'pytorch-vgg19' in params.models:
# Download the PyTorch VGG19 model
print("Downloading the PyTorch VGG 19 model")
fileurl = "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth"
name = "vgg19-dcbb9e9d.pth"
download_file(fileurl, name, params.download_path)
if 'pytorch-vgg16' in params.models:
# Download the PyTorch VGG16 model
print("Downloading the PyTorch VGG 16 model")
fileurl = "https://download.pytorch.org/models/vgg16-397923af.pth"
name = "vgg16-397923af.pth"
download_file(fileurl, name, params.download_path)
if 'pytorch-googlenet' in params.models:
# Download the PyTorch GoogLeNet model
print("Downloading the PyTorch GoogLeNet model")
fileurl = "https://download.pytorch.org/models/googlenet-1378be20.pth"
name = "googlenet-1378be20.pth"
download_file(fileurl, name, params.download_path)
if 'pytorch-inception' in params.models:
# Download the PyTorch Inception V3 model
print("Downloading the PyTorch Inception V3 model")
fileurl = "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth"
name = "inception_v3_google-1a9a5a14.pth"
download_file(fileurl, name, params.download_path)
if 'tensorflow-inception5h' in params.models:
# Download the Inception5h model
print("Downloading the Tensorflow Inception5h model")
fileurl = "https://github.com/ProGamerGov/pytorch-old-tensorflow-models/raw/master/inception5h.pth"
name = "inception5h.pth"
download_file(fileurl, name, params.download_path)
if 'caffe-resnet-opennsfw' in params.models:
# Download the ResNet Yahoo Open NSFW model
print("Downloading the ResNet Yahoo Open NSFW model")
fileurl = "https://github.com/ProGamerGov/pytorch-old-caffemodels/raw/master/ResNet_50_1by2_nsfw.pth"
name = "ResNet_50_1by2_nsfw.pth"
download_file(fileurl, name, params.download_path)
print("All selected models have been successfully downloaded")
def params_list():
parser = argparse.ArgumentParser()
parser.add_argument("-models", help="Models to download", default='caffe-googlenet-bvlc,caffe-nin', action=MultipleChoice)
parser.add_argument("-download_path", help="Download location for models", default='models')
params = parser.parse_args()
return params
def download_file(fileurl, name, download_path):
if version_info[0] < 3:
urllib.URLopener().retrieve(fileurl, path.join(download_path, name))
else:
urllib.request.urlretrieve(fileurl, path.join(download_path, name))
class MultipleChoice(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
self.options = options_list
e = [o.lower() for o in values.split(',') if o.lower() not in self.options]
if len(e) > 0:
raise argparse.ArgumentError(self, 'invalid choices: ' + ','.join([str(v) for v in e]) +
' (choose from ' + ','.join([ "'"+str(v)+"'" for v in self.options])+')')
setattr(namespace, self.dest, values)
if __name__ == "__main__":
main()
import os
import copy
import math
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from neural_dream.CaffeLoader import loadCaffemodel, ModelParallel, Flatten
import neural_dream.dream_utils as dream_utils
import neural_dream.dream_model as dream_model
from neural_dream.dream_auto import auto_model_mode, auto_mean
import argparse
parser = argparse.ArgumentParser()
# Basic options
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg')
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512)
parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = c", default=0)
# Optimization options
parser.add_argument("-dream_weight", type=float, default=1000)
parser.add_argument("-normalize_weights", action='store_true')
parser.add_argument("-tv_weight", type=float, default=0)
parser.add_argument("-l2_weight", type=float, default=0)
parser.add_argument("-num_iterations", type=int, default=5)
parser.add_argument("-jitter", type=int, default=32)
parser.add_argument("-init", choices=['random', 'image'], default='image')
parser.add_argument("-optimizer", choices=['lbfgs', 'adam'], default='adam')
parser.add_argument("-learning_rate", type=float, default=0.001)
parser.add_argument("-lbfgs_num_correction", type=int, default=100)
parser.add_argument("-loss_mode", choices=['bce', 'mse', 'mean', 'norm', 'l2'], default='mean')
# Output options
parser.add_argument("-print_iter", type=int, default=1)
parser.add_argument("-print_octave_iter", type=int, default=0)
parser.add_argument("-save_iter", type=int, default=1)
parser.add_argument("-save_octave_iter", type=int, default=0)
parser.add_argument("-output_image", default='out.png')
# Octave options
parser.add_argument("-num_octaves", type=int, default=4)
parser.add_argument("-octave_scale", type=float, default=0.25)
parser.add_argument("-octave_iter", type=int, default=100)
parser.add_argument("-octave_mode", choices=['advanced', 'normal'], default='normal')
# Channel options
parser.add_argument("-channels", type=str, help="channels for DeepDream", default='-1')
parser.add_argument("-channel_mode", choices=['all', 'strong', 'avg', 'weak', 'ignore'], default='all')
parser.add_argument("-channel_capture", choices=['once', 'iter'], default='once')
# Guassian Blur Options
parser.add_argument("-layer_sigma", type=float, default=0)
# Laplacian pyramid options
parser.add_argument("-lap_scale", type=int, default=0)
parser.add_argument("-sigma", default='1')
# FFT options
parser.add_argument("-use_fft", action='store_true')
parser.add_argument("-fft_block", type=int, default=25)
# Zoom options
parser.add_argument("-zoom", type=int, default=0)
parser.add_argument("-zoom_mode", choices=['percent', 'pixel'], default='percent')
# Other options
parser.add_argument("-original_colors", type=int, choices=[0, 1], default=0)
parser.add_argument("-pooling", choices=['avg', 'max'], default='max')
parser.add_argument("-model_file", type=str, default='models/vgg19-dcbb9e9d.pth')
parser.add_argument("-model_type", choices=['caffe', 'pytorch', 'auto'], default='auto')
parser.add_argument("-model_mean", default='auto')
parser.add_argument("-label_file", type=str, default='')
parser.add_argument("-disable_check", action='store_true')
parser.add_argument("-backend", choices=['nn', 'cudnn', 'mkl', 'mkldnn', 'openmp', 'mkl,cudnn', 'cudnn,mkl'], default='nn')
parser.add_argument("-cudnn_autotune", action='store_true')
parser.add_argument("-seed", type=int, default=-1)
parser.add_argument("-clamp", action='store_true')
parser.add_argument("-classify", type=int, default=0)
parser.add_argument("-dream_layers", help="layers for DeepDream", default='relu4_2')
parser.add_argument("-multidevice_strategy", default='4,7,29')
# Help Options
parser.add_argument("-print_layers", action='store_true')
# Experimental Params
parser.add_argument("-norm_percent", type=float, default=0)
parser.add_argument("-abs_percent", type=float, default=0)
parser.add_argument("-mean_percent", type=float, default=0)
parser.add_argument("-percent_mode", choices=['slow', 'fast'], default='fast')
parser.add_argument("-random_transforms", choices=['none', 'all', 'flip', 'rotate'], default='none')
parser.add_argument("-adjust_contrast", type=float, help="try 99.98", default=-1)
params = parser.parse_args()
Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images
def main():
dtype, multidevice, backward_device = setup_gpu()
cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu, params.disable_check, True)
has_inception = cnn.has_inception
if params.print_layers:
print_layers(layerList, params.model_file, has_inception)
params.model_type = auto_model_mode(params.model_file) if params.model_type == 'auto' else params.model_type
input_mean = auto_mean(params.model_file, params.model_type) if params.model_mean == 'auto' else params.model_mean
if params.model_mean != 'auto':
input_mean = [float(m) for m in input_mean.split(',')]
content_image = preprocess(params.content_image, params.image_size, params.model_type, input_mean).type(dtype)
clamp_val = 256 if params.model_type == 'caffe' else 1
if params.label_file != '':
labels = load_label_file(params.label_file)
params.channels = channel_ids(labels, params.channels)
if params.classify > 0:
if not has_inception:
params.dream_layers += ',classifier'
if params.label_file == '':
labels = list(range(0, 1000))
dream_layers = params.dream_layers.split(',')
start_params = (dtype, params.random_transforms, params.jitter, params.tv_weight, params.l2_weight, params.layer_sigma)
primary_params = (params.loss_mode, params.dream_weight, params.channels, params.channel_mode)
secondary_params = {'channel_capture': params.channel_capture, 'scale': params.lap_scale, 'sigma': params.sigma, \
'use_fft': (params.use_fft, params.fft_block), 'r': clamp_val, 'p_mode': params.percent_mode, 'norm_p': params.norm_percent, \
'abs_p': params.abs_percent, 'mean_p': params.mean_percent}
# Set up the network, inserting dream loss modules
net_base, dream_losses, tv_losses, l2_losses, lm_layer_names, loss_module_list = dream_model.build_net(cnn, dream_layers, \
has_inception, layerList, params.classify, start_params, primary_params, secondary_params)
if params.classify > 0:
classify_img = dream_utils.Classify(labels, params.classify)
if multidevice and not has_inception:
net_base = setup_multi_device(net_base)
if not has_inception:
print_torch(net_base, multidevice)
# Initialize the image
if params.seed >= 0:
torch.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
torch.backends.cudnn.deterministic=True
random.seed(params.seed)
if params.init == 'random':
base_img = torch.randn_like(content_image).mul(0.001)
elif params.init == 'image':
base_img = content_image.clone()
if params.optimizer == 'lbfgs':
print("Running optimization with L-BFGS")
else:
print("Running optimization with ADAM")
for param in net_base.parameters():
param.requires_grad = False
for i in dream_losses:
i.mode = 'capture'
net_base(base_img.clone())
if params.channels != '-1' or params.channel_mode != 'all' and params.channels != '-1':
print_channels(dream_losses, dream_layers)
if params.classify > 0:
feat = net_base(base_img.clone())
classify_img(feat)
for i in dream_losses:
i.mode = 'None'
current_img = new_img(base_img, -1)
h, w = current_img.size(2), current_img.size(3)
octave_list = ocatve_calc((h,w), params.octave_scale, params.num_octaves, params.octave_mode)
print_octave_sizes(octave_list)
total_dream_losses, total_loss = [], [0]
for iter in range(1, params.num_iterations+1):
for octave, octave_sizes in enumerate(octave_list, 1):
net = copy.deepcopy(net_base) if not has_inception else net_base
for param in net.parameters():
param.requires_grad = False
dream_losses, tv_losses, l2_losses = [], [], []
if not has_inception:
for i, layer in enumerate(net):
if isinstance(layer, dream_utils.TVLoss):
tv_losses.append(layer)
if isinstance(layer, dream_utils.L2Regularizer):
l2_losses.append(layer)
if 'DreamLoss' in str(type(layer)):
dream_losses.append(layer)
elif has_inception:
net, dream_losses, tv_losses, l2_losses = dream_model.renew_net(start_params, net, loss_module_list, lm_layer_names)
img = new_img(current_img.clone(), octave_sizes)
net(img)
for i in dream_losses:
i.mode = 'loss'
for i in dream_losses:
i.mode = 'loss'
# Maybe normalize dream weight
if params.normalize_weights:
normalize_weights(dream_losses)
# Freeze the net_basework in order to prevent
# unnecessary gradient calculations
for param in net.parameters():
param.requires_grad = False
# Function to evaluate loss and gradient. We run the net_base forward and
# backward to get the gradient, and sum up losses from the loss modules.
# optim.lbfgs internally handles iteration and calls this function many
# times, so we manually count the number of iterations to handle printing
# and saving intermediate results.
num_calls = [0]
def feval():
num_calls[0] += 1
optimizer.zero_grad()
net(img)
loss = 0
for mod in dream_losses:
loss += -mod.loss.to(backward_device)
if params.tv_weight > 0:
for mod in tv_losses:
loss += mod.loss.to(backward_device)
if params.l2_weight > 0:
for mod in l2_losses:
loss += mod.loss.to(backward_device)
if params.clamp:
img.clamp(0, clamp_val)
if params.adjust_contrast > -1:
img.data = adjust_contrast(img, r=clamp_val, p=params.adjust_contrast)
total_loss[0] += loss.item()
loss.backward()
maybe_print_octave_iter(num_calls[0], octave, params.octave_iter, dream_losses)
maybe_save_octave(iter, num_calls[0], octave, img, content_image, input_mean)
return loss
optimizer, loopVal = setup_optimizer(img)
while num_calls[0] <= params.octave_iter:
optimizer.step(feval)
if octave == 1:
for mod in dream_losses:
total_dream_losses.append(mod.loss.item())
else:
for d_loss, mod in enumerate(dream_losses):
total_dream_losses[d_loss] += mod.loss.item()
if img.size(2) != h or img.size(3) != w:
current_img = dream_utils.resize_tensor(img.clone(), (h,w))
else:
current_img = img.clone()
maybe_print(iter, total_loss[0], total_dream_losses)
maybe_save(iter, current_img, content_image, input_mean)
total_dream_losses, total_loss = [], [0]
if params.zoom > 0:
current_img = dream_utils.zoom(current_img, params.zoom, params.zoom_mode)
if params.classify > 0:
feat = net_base(base_img.clone())
classify_img(feat)
def save_output(t, save_img, content_image, iter_name, model_mean):
output_filename, file_extension = os.path.splitext(params.output_image)
if t == params.num_iterations:
filename = output_filename + str(file_extension)
else:
filename = str(output_filename) + iter_name + str(file_extension)
disp = deprocess(save_img.clone(), params.model_type, model_mean)
# Maybe perform postprocessing for color-independent style transfer
if params.original_colors == 1:
disp = original_colors(deprocess(content_image.clone(), params.model_type, model_mean), disp)
disp.save(str(filename))
def maybe_save(t, save_img, content_image, input_mean):
should_save = params.save_iter > 0 and t % params.save_iter == 0
should_save = should_save or t == params.num_iterations
if should_save:
save_output(t, save_img, content_image, "_" + str(t), input_mean)
def maybe_save_octave(t, n, o, save_img, content_image, input_mean):
should_save = params.save_octave_iter > 0 and n % params.save_octave_iter == 0
should_save = should_save or params.save_octave_iter > 0 and n == params.octave_iter
if o == params.num_octaves:
should_save = False if params.save_iter > 0 and t % params.save_iter == 0 or t == params.num_iterations else should_save
if should_save:
save_output(t, save_img, content_image, "_" + str(t) + "_" + str(o) + "_" + str(n), input_mean)
def maybe_print(t, loss, dream_losses):
if params.print_iter > 0 and t % params.print_iter == 0:
print("Iteration " + str(t) + " / "+ str(params.num_iterations))
for i, loss_module in enumerate(dream_losses):
print(" DeepDream " + str(i+1) + " loss: " + str(loss_module))
print(" Total loss: " + str(abs(loss)))
def maybe_print_octave_iter(t, n, total, dream_losses):
if params.print_octave_iter > 0 and t % params.print_octave_iter == 0:
print("Octave iter "+str(n) +" iteration " + str(t) + " / "+ str(total))
for i, loss_module in enumerate(dream_losses):
print(" DeepDream " + str(i+1) + " loss: " + str(loss_module.loss.item()))
def print_channels(dream_losses, layers):
print('\nSelected layer channels:')
for i, l in enumerate(dream_losses):
if len(l.dream.channels) > 25:
ch = l.dream.channels[0:25] + ['and ' + str(len(l.dream.channels[25:])) + ' more...']
else:
ch = l.dream.channels
print(' ' + layers[i] + ': ', ch)
# Configure the optimizer
def setup_optimizer(img):
if params.optimizer == 'lbfgs':
optim_state = {
'max_iter': params.num_iterations,
'tolerance_change': -1,
'tolerance_grad': -1,
'lr': params.learning_rate
}
if params.lbfgs_num_correction != 100:
optim_state['history_size'] = params.lbfgs_num_correction
optimizer = optim.LBFGS([img], **optim_state)
loopVal = 1
elif params.optimizer == 'adam':
optimizer = optim.Adam([img], lr = params.learning_rate)
loopVal = params.num_iterations - 1
return optimizer, loopVal
def setup_gpu():
def setup_cuda():
if 'cudnn' in params.backend:
torch.backends.cudnn.enabled = True
if params.cudnn_autotune:
torch.backends.cudnn.benchmark = True
else:
torch.backends.cudnn.enabled = False
def setup_cpu():
if 'mkl' in params.backend and 'mkldnn' not in params.backend:
torch.backends.mkl.enabled = True
elif 'mkldnn' in params.backend:
raise ValueError("MKL-DNN is not supported yet.")
elif 'openmp' in params.backend:
torch.backends.openmp.enabled = True
multidevice = False
if "," in str(params.gpu):
devices = params.gpu.split(',')
multidevice = True
if 'c' in str(devices[0]).lower():
backward_device = "cpu"
setup_cuda(), setup_cpu()
else:
backward_device = "cuda:" + devices[0]
setup_cuda()
dtype = torch.FloatTensor
elif "c" not in str(params.gpu).lower():
setup_cuda()
dtype, backward_device = torch.cuda.FloatTensor, "cuda:" + str(params.gpu)
else:
setup_cpu()
dtype, backward_device = torch.FloatTensor, "cpu"
return dtype, multidevice, backward_device
def setup_multi_device(net_base):
assert len(params.gpu.split(',')) - 1 == len(params.multidevice_strategy.split(',')), \
"The number of -multidevice_strategy layer indices minus 1, must be equal to the number of -gpu devices."
new_net_base = ModelParallel(net_base, params.gpu, params.multidevice_strategy)
return new_net_base
# Preprocess an image before passing it to a model.
# Maybe rescale from [0, 1] to [0, 255], convert from RGB to BGR,
# and subtract the mean pixel.
def preprocess(image_name, image_size, mode='caffe', input_mean=[103.939, 116.779, 123.68]):
image = Image.open(image_name).convert('RGB')
if type(image_size) is not tuple:
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
if mode == 'caffe':
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
Normalize = transforms.Compose([transforms.Normalize(mean=input_mean, std=[1,1,1])])
tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0)
elif mode == 'pytorch':
Normalize = transforms.Compose([transforms.Normalize(mean=input_mean, std=[1,1,1])])
tensor = Normalize(Loader(image)).unsqueeze(0)
return tensor
# Undo the above preprocessing.
def deprocess(output_tensor, mode='caffe', input_mean=[-103.939, -116.779, -123.68]):
input_mean = [n * -1 for n in input_mean]
if mode == 'caffe':
Normalize = transforms.Compose([transforms.Normalize(mean=input_mean, std=[1,1,1])])
bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256
elif mode == 'pytorch':
Normalize = transforms.Compose([transforms.Normalize(mean=input_mean, std=[1,1,1])])
output_tensor = Normalize(output_tensor.squeeze(0).cpu())
output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.cpu())
return image
# Combine the Y channel of the generated image and the UV/CbCr channels of the
# content image to perform color-independent style transfer.
def original_colors(content, generated):
content_channels = list(content.convert('YCbCr').split())
generated_channels = list(generated.convert('YCbCr').split())
content_channels[0] = generated_channels[0]
return Image.merge('YCbCr', content_channels).convert('RGB')
# Print like Lua/Torch7
def print_torch(net_base, multidevice):
if multidevice:
return
simplelist = ""
for i, layer in enumerate(net_base, 1):
simplelist = simplelist + "(" + str(i) + ") -> "
print("nn.Sequential ( \n [input -> " + simplelist + "output]")
def strip(x):
return str(x).replace(", ",',').replace("(",'').replace(")",'') + ", "
def n():
return " (" + str(i) + "): " + "nn." + str(l).split("(", 1)[0]
for i, l in enumerate(net_base, 1):
if "2d" in str(l):
if "AdaptiveAvgPool2d" not in str(l) and "AdaptiveMaxPool2d" not in str(l) and "BasicConv2d" not in str(l):
ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding)
if "BasicConv2d" in str(l):
print(n())
elif "Conv2d" in str(l):
ch = str(l.in_channels) + " -> " + str(l.out_channels)
print(n() + "(" + ch + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')'))
elif "AdaptiveAvgPool2d" in str(l) or "AdaptiveMaxPool2d" in str(l):
print(n())
elif "Pool2d" in str(l):
st = st.replace(" ",' ') + st.replace(", ",')')
print(n() + "(" + ((ks).replace(",",'x' + ks, 1) + st).replace(", ",','))
else:
print(n())
print(")")
# Print planned octave image sizes
def print_octave_sizes(octave_list):
print('\nPerforming ' + str(len(octave_list)) + ' octaves with the following image sizes:')
for o, octave in enumerate(octave_list):
print(' Octave ' + str(o+1) + ' image size: ' + \
str(octave[0]) +'x'+ str(octave[1]))
print()
def ocatve_calc(image_size, octave_scale, num_octaves, mode='advanced'):
print('start', image_size)
octave_list = []
h_size, w_size = image_size[0], image_size[1]
if mode == 'normal':
for o in range(1, num_octaves+1):
h_size *= octave_scale
w_size *= octave_scale
if o < num_octaves:
octave_list.append((int(h_size), int(w_size)))
octave_list.reverse()
octave_list.append((image_size[0], image_size[1]))
elif mode == 'advanced':
for o in range(1, num_octaves+1):
h_size = image_size[0] * (o * octave_scale)
w_size = image_size[1] * (o * octave_scale)
octave_list.append((int(h_size), int(w_size)))
return octave_list
# Divide weights by channel size
def normalize_weights(dream_losses):
for n, i in enumerate(dream_losses):
i.strength = i.strength / max(i.target_size)
# Print all available/usable layer names
def print_layers(layerList, model_name, has_inception):
print()
print("\nUsable Layers For '" + model_name + "':")
if not has_inception:
for l_names in layerList:
if l_names == 'P':
n = ' Pooling Layers:'
if l_names == 'C':
n = ' Conv Layers:'
if l_names == 'R':
n = ' ReLU Layers:'
elif l_names == 'BC':
n = ' BasicConv2d Layers:'
elif l_names == 'L':
n = ' Linear/FC layers:'
if l_names == 'D':
n = ' Dropout Layers:'
elif l_names == 'IC':
n = ' Inception Layers:'
print(n, ', '.join(layerList[l_names]))
elif has_inception:
print(layerList)
quit()
# Load a label file
def load_label_file(filename):
with open(filename, 'r') as f:
x = [l.rstrip('\n') for l in f.readlines()]
return x
# Convert names to channel values
def channel_ids(l, channels):
channels = channels.split(',')
c_vals = ''
for c in channels:
if c.isdigit():
c_vals += ',' + str(c)
elif c.isalpha():
v = ','.join(str(ch) for ch, n in enumerate(l) if c in n)
v = ',' + v + ',' if len(v.split(',')) == 1 else v
c_vals += v
c_vals = '-1' if c_vals == '' else c_vals
c_vals = c_vals.replace(',', '', 1) if c_vals[0] == ',' else c_vals
return c_vals
# Prepare input image
def new_img(input_image, scale_factor, mode='bilinear'):
img = input_image.clone()
if scale_factor != -1:
img = dream_utils.resize_tensor(img, scale_factor, mode)
return nn.Parameter(img)
# Adjust tensor contrast
def adjust_contrast(t, r, p=99.98):
return t * (r / dream_utils.tensor_percentile(t))
if __name__ == "__main__":
main()
pytorch_names = ['vgg11-bbd30ac9.pth', 'vgg13-c768596a.pth', 'vgg16-397923af.pth' , \
'vgg19-dcbb9e9d.pth', 'googlenet-1378be20.pth', 'inception_v3_google-1a9a5a14.pth']
caffe_names = ['vgg16-00b39a1b.pth', 'vgg19-d01eb7cb.pth', 'nin_imagenet.pth', \
'VGG16_SOD_finetune.pth', 'VGG16-Stylized-ImageNet.pth', 'vgg16_places365.pth', \
'vgg16_hybrid1365.pth', 'fcn32s-heavy-pascal.pth', 'nyud-fcn32s-color-heavy.pth', \
'pascalcontext-fcn32s-heavy.pth', 'siftflow-fcn32s-heavy.pth', 'channel_pruning.pth', \
'googlenet_places205.pth', 'googlenet_places365.pth']
# Automatically determine model type
def auto_model_mode(model_name):
if any(name.lower() in model_name.lower() for name in pytorch_names):
input_mode = 'pytorch'
elif any(name.lower() in model_name.lower() for name in caffe_names):
input_mode = 'caffe'
else:
raise ValueError("Model not recognized, please manually specify the model type.")
return input_mode
# Automatically determine preprocessing to use for model
def auto_mean(model_name, model_type):
if any(name.lower() in model_name.lower() for name in pytorch_names) or model_type == 'pytorch':
input_mean = 'pytorch', [0.485, 0.456, 0.406] # PyTorch Imagenet
elif any(name.lower() in model_name.lower() for name in caffe_names) or model_type == 'caffe':
input_mean = [103.939, 116.779, 123.68] # Caffe Imagenet
if 'googlenet_places205.pth' in model_name.lower():
input_mean = [105.417, 113.753, 116.047] # Caffe Places205
elif 'googlenet_places365.pth' in model_name.lower():
input_mean = [104.051, 112.514, 116.676] # Caffe Places365
else:
raise ValueError("Model not recognized, please manually specify the model type or model mean.")
return input_mean
import torch
import torch.nn as nn
# Define a function to partially zero inputs based on channel strength
def percentile_zero(input, p=99.98, mode='norm'):
if 'norm' in mode:
px = input.norm(1)
elif 'sum' in mode:
px = input.sum(1)
elif 'mean' in mode:
px = input.sum(1)
if 'abs' in mode:
px, tp = dream_utils.tensor_percentile(abs(px), p), abs(px)
else:
tp = dream_utils.tensor_percentile(px, p)
th = (0.01*tp)
if 'abs' in mode:
input[abs(input) < abs(th)] = 0
else:
input[input < th] = 0
return input
# Define an nn Module to mask inputs based on channel strength
class ChannelMask(torch.nn.Module):
def __init__(self, mode, channels=-1, rank_mode='norm', channel_percent=-1):
super(ChannelMask, self).__init__()
self.mode = mode
self.channels = channels
self.rank_mode = rank_mode
self.channel_percent = channel_percent
def list_channels(self, input):
if input.is_cuda:
channel_list = torch.zeros(input.size(1), device=input.get_device())
else:
channel_list = torch.zeros(input.size(1))
for i in range(input.size(1)):
y = input.clone().narrow(1,i,1)
if self.rank_mode == 'norm':
y = torch.norm(y)
elif self.rank_mode == 'sum':
y = torch.sum(y)
elif self.rank_mode == 'mean':
y = torch.mean(y)
elif self.rank_mode == 'norm-abs':
y = torch.norm(torch.abs(y))
elif self.rank_mode == 'sum-abs':
y = torch.sum(torch.abs(y))
elif self.rank_mode == 'mean-abs':
y = torch.mean(torch.abs(y))
channel_list[i] = y
return channel_list
def channel_strengths(self, input, num_channels):
channel_list = self.list_channels(input)
channels, idx = torch.sort(channel_list, 0, True)
selected_channels = []
for i in range(num_channels):
if i < input.size(1):
selected_channels.append(idx[i])
return selected_channels
def mask_threshold(self, input, channels, ft=1, fm=0.2, fw=5):
t = torch.ones_like(input.squeeze(0)) * ft
m = torch.ones_like(input.squeeze(0)) * fm
for c in channels:
m[c] = fw
return (t * m).unsqueeze(0)
def average_mask(self, input, channels):
mask = torch.ones_like(input.squeeze(0))
avg = torch.sum(channels)/input.size(1)
for i in range(channels.size(1)):
w = avg/channels[i]
mask[i] = w
return mask.unsqueeze(0)
def average_tensor(self, input):
channel_list = self.list_channels(input)
mask = self.average_mask(input, channel_list)
self.mask = mask
def weak_channels(self, input):
channel_list = self.channel_strengths(input, self.channels)
mask = self.mask_threshold(input, channel_list, ft=1, fm=2, fw=0.2)
self.mask = mask
def strong_channels(self, input):
channel_list = self.channel_strengths(input, self.channels)
mask = self.mask_threshold(input, channel_list, ft=1, fm=0.2, fw=5)
self.mask = mask
def zero_weak(self, input):
channel_list = self.channel_strengths(input, self.channels)
mask = self.mask_threshold(input, channel_list, ft=1, fm=0, fw=1)
self.mask = mask
def mask_input(self, input):
if self.channel_percent > 0:
channels = int((float(self.channel_percent)/100) * float(input.size(1)))
if channels < input.size(1) and channels > 0:
self.channels = channels
else:
self.channels = input.size(1)
if self.mode == 'weak':
input = self.weak_channels(input)
elif self.mode == 'strong':
input = self.strong_channels(input)
elif self.mode == 'average':
input = self.average_tensor(input)
elif self.mode == 'zero_weak':
input = self.zero_weak(input)
def capture(self, input):
self.mask_input(input)
def forward(self, input):
return self.mask * input
# Define a function to partially zero inputs based on channel strength
class ChannelMod(torch.nn.Module):
def __init__(self, p_mode='fast', channels=0, norm_p=0, abs_p=0, mean_p=0):
super(ChannelMod, self).__init__()
self.p_mode = p_mode
self.channels = channels
self.norm_p = norm_p
self.abs_p = abs_p
self.mean_p = mean_p
self.enabled = False
if self.norm_p > 0 and self.p_mode == 'slow':
self.zero_weak_norm = ChannelMask('zero_weak', self.channels, 'norm', channel_percent=self.norm_p)
if self.abs_p > 0 and self.p_mode == 'slow':
self.zero_weak_abs = ChannelMask('zero_weak', self.channels, 'sum', channel_percent=self.abs_p)
if self.mean_p > 0 and self.p_mode == 'slow':
self.zero_weak_mean = ChannelMask('zero_weak', self.channels, 'mean', channel_percent=self.mean_p)
if self.norm_p > 0 or self.abs_p > 0 or self.mean_p > 0:
self.enabled = True
def forward(self, input):
if self.norm_p > 0 and self.p_mode == 'fast':
input = percentile_zero(input, p=self.norm_p, mode='abs-norm')
if self.abs_p > 0 and self.p_mode == 'fast':
input = percentile_zero(input, p=self.abs_p, mode='abs-sum')
if self.mean_p > 0 and self.p_mode == 'fast':
input = percentile_zero(input, p=self.mean_p, mode='abs-mean')
if self.norm_p > 0 and self.p_mode == 'slow':
self.zero_weak_norm.capture(input.clone())
input = self.zero_weak_norm(input)
if self.abs_p > 0 and self.p_mode == 'slow':
self.zero_weak_abs.capture(input.clone())
input = self.zero_weak_abs(input)
if self.mean_p > 0 and self.p_mode == 'slow':
self.zero_weak_mean.capture(input.clone())
input = self.zero_weak_mean(input)
return input
import copy
import torch
import torch.nn as nn
import neural_dream.dream_utils as dream_utils
import dream as dream_main
from neural_dream.CaffeLoader import Flatten
def add_to_incept(net, n, sn, loss_module, capture='after'):
if len(n) == 1:
if capture == 'after':
getattr(net, n[0]).register_forward_hook(loss_module)
elif capture == 'before':
getattr(net, n[0]).register_forward_pre_hook(loss_module)
elif len(n) == 2:
if isinstance(getattr(getattr(net, n[0]), n[1]), nn.Sequential):
if capture == 'after':
getattr(getattr(getattr(net, n[0]), n[1]), str(sn)).register_forward_hook(loss_module)
elif capture == 'before':
getattr(getattr(getattr(net, n[0]), n[1]), str(sn)).register_forward_pre_hook(loss_module)
sn = sn+1 if sn < 1 else 0
else:
if capture == 'after':
getattr(getattr(net, n[0]), n[1]).register_forward_hook(loss_module)
elif capture == 'before':
getattr(getattr(net, n[0]), n[1]).register_forward_pre_hook(loss_module)
elif len(n) == 3:
if isinstance(getattr(getattr(net, n[0]), n[1]), nn.Sequential):
if capture == 'after':
getattr(getattr(getattr(getattr(net, n[0]), n[1]), str(sn)), n[2]).register_forward_hook(loss_module)
elif capture == 'before':
getattr(getattr(getattr(getattr(net, n[0]), n[1]), str(sn)), n[2]).register_forward_pre_hook(loss_module)
sn = sn+1 if sn < 1 else 0
else:
if capture == 'after':
getattr(getattr(getattr(net, n[0]), n[1]), n[2]).register_forward_hook(loss_module)
elif capture == 'before':
getattr(getattr(getattr(net, n[0]), n[1]), n[2]).register_forward_pre_hook(loss_module)
return loss_module, sn
def build_net(cnn, dream_layers, has_inception, layerList, use_classify, start_params, primary_params, secondary_params):
cnn = copy.deepcopy(cnn)
dream_losses, tv_losses, l2_losses = [], [], []
lm_layer_names, loss_module_list = [], []
dtype = start_params[0]
if not has_inception:
next_dream_idx = 1
net_base = nn.Sequential()
c, r, p, l, d = 0, 0, 0, 0, 0
net_base, tv_losses, l2_losses = start_network(*start_params)
for i, layer in enumerate(list(cnn), 1):
if next_dream_idx <= len(dream_layers):
if isinstance(layer, nn.Conv2d):
net_base.add_module(str(len(net_base)), layer)
if layerList['C'][c] in dream_layers:
print("Setting up dream layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = dream_main.DreamLoss(*primary_params, **secondary_params)
net_base.add_module(str(len(net_base)), loss_module)
dream_losses.append(loss_module)
c+=1
if isinstance(layer, nn.ReLU):
net_base.add_module(str(len(net_base)), layer)
if layerList['R'][r] in dream_layers:
print("Setting up dream layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = dream_main.DreamLoss(*primary_params, **secondary_params)
net_base.add_module(str(len(net_base)), loss_module)
dream_losses.append(loss_module)
next_dream_idx += 1
r+=1
if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
net_base.add_module(str(len(net_base)), layer)
if layerList['P'][p] in dream_layers:
print("Setting up dream layer " + str(i) + ": " + str(layerList['P'][p]))
loss_module = dream_main.DreamLoss(*primary_params, **secondary_params)
net_base.add_module(str(len(net_base)), loss_module)
dream_losses.append(loss_module)
next_dream_idx += 1
p+=1
if isinstance(layer, nn.AdaptiveAvgPool2d) or isinstance(layer, nn.AdaptiveMaxPool2d):
net_base.add_module(str(len(net_base)), layer)
if isinstance(layer, Flatten):
flatten_mod = Flatten().type(dtype)
net_base.add_module(str(len(net_base)), flatten_mod)
if isinstance(layer, nn.Linear):
net_base.add_module(str(len(net_base)), layer)
if layerList['L'][l] in dream_layers:
print("Setting up dream layer " + str(i) + ": " + str(layerList['L'][l]))
loss_module = dream_main.DreamLoss(*primary_params, **secondary_params)
net_base.add_module(str(len(net_base)), loss_module)
dream_losses.append(loss_module)
next_dream_idx += 1
l+=1
if isinstance(layer, nn.Dropout):
net_base.add_module(str(len(net_base)), layer)
if layerList['D'][d] in dream_layers:
print("Setting up dream layer " + str(i) + ": " + str(layerList['D'][d]))
loss_module = dream_main.DreamLoss(*primary_params, **secondary_params)
net_base.add_module(str(len(net_base)), loss_module)
dream_losses.append(loss_module)
next_dream_idx += 1
d+=1
if use_classify > 0 and l == len(layerList['L']):
next_dream_idx += 1
elif has_inception:
start_net, tv_losses, l2_losses = start_network(start_params)
lm_layer_names, loss_module_list = [], []
net_base = copy.deepcopy(cnn)
sn=0
for i, n in enumerate(dream_layers):
print("Setting up dream layer " + str(i+1) + ": " + n)
if 'before_' in n:
n = n.split('before_')[1]
loss_module = dream_main.DreamLossPreHook(*primary_params, **secondary_params)
module_loc = 'before'
else:
loss_module = dream_main.DreamLossHook(*primary_params, **secondary_params)
module_loc = 'after'
n = n.split('/')
lm_layer_names.append(n)
loss_module, sn = add_to_incept(net_base, n, sn, loss_module, module_loc)
loss_module_list.append(loss_module)
dream_losses.append(loss_module)
if len(start_net) > 0:
net_base = dream_utils.ModelPlus(start_net, net_base)
return net_base, dream_losses, tv_losses, l2_losses, lm_layer_names, loss_module_list
def start_network(dtype, random_transforms='none', jitter_val=32, tv_weight=0, l2_weight=0, layer_sigma=0):
tv_losses, l2_losses = [], []
start_net = nn.Sequential()
if random_transforms != 'none':
rt_mod = dream_utils.RandomTransform(random_transforms).type(dtype)
start_net.add_module(str(len(start_net)), rt_mod)
if jitter_val > 0:
jitter_mod = dream_utils.Jitter(jitter_val).type(dtype)
start_net.add_module(str(len(start_net)), jitter_mod)
if tv_weight > 0:
tv_mod = dream_utils.TVLoss(tv_weight).type(dtype)
start_net.add_module(str(len(start_net)), tv_mod)
tv_losses.append(tv_mod)
if l2_weight > 0:
l2_mod = dream_utils.L2Regularizer(l2_weight).type(dtype)
start_net.add_module(str(len(start_net)), l2_mod)
l2_losses.append(l2_mod)
if layer_sigma > 0:
gauss_mod = dream_utils.GaussianBlurLayer(5, layer_sigma).type(dtype)
start_net.add_module(str(len(start_net)), gauss_mod)
return start_net, tv_losses, l2_losses
def renew_net(start_params, net, loss_module_list, dream_layers):
start_net, tv_losses, l2_losses = start_network(*start_params)
if isinstance(net, dream_utils.ModelPlus):
net = net.net
new_dream_losses = []
sn=0
for i, layer in enumerate(dream_layers):
n = layer
loss_module = loss_module_list[i]
if str(loss_module).split('(')[0] == 'DreamLossPreHook':
module_loc = 'before'
else:
module_loc = 'after'
loss_module, sn = add_to_incept(net, n, sn, loss_module, module_loc)
new_dream_losses.append(loss_module_list[i])
if len(start_net) > 0:
net = dream_utils.ModelPlus(start_net, net)
return net, new_dream_losses, tv_losses, l2_losses
import math
import random
import torch
import torch.nn as nn
'''
rescale_tensor(tensor, sf, mode='bilinear')
resize_tensor(tensor, size, mode='bilinear')
zoom(input, crop_val, mode='percent')
tensor_percentile(t, p=99.98)
MatchHistogram
GaussianBlur
GaussianBlurLP
GaussianBlurLayer
LaplacianPyramid
RankChannels
FFTTensor
Jitter
RandomTransform
L2Regularizer
Classify
import utils
'''
def start_net(dtype, random_transforms='none', jitter_val=32, tv_weight=0, l2_weight=0, layer_sigma=0):
tv_losses, l2_losses = [], []
s_net = nn.Sequential()
if random_transforms != 'none':
rt_mod = RandomTransform(random_transforms).type(dtype)
s_net.add_module(str(len(s_net)), rt_mod)
if jitter_val > 0:
jitter_mod = Jitter(jitter_val).type(dtype)
s_net.add_module(str(len(s_net)), jitter_mod)
if tv_weight > 0:
tv_mod = TVLoss(tv_weight).type(dtype)
s_net.add_module(str(len(s_net)), tv_mod)
tv_losses.append(tv_mod)
if l2_weight > 0:
l2_mod = L2Regularizer(l2_weight).type(dtype)
s_net.add_module(str(len(s_net)), l2_mod)
l2_losses.append(l2_mod)
if layer_sigma > 0:
gauss_mod = GaussianBlurLayer(5, layer_sigma).type(dtype)
s_net.add_module(str(len(s_net)), gauss_mod)
return s_net, tv_losses, l2_losses
# Rescale tensor
def rescale_tensor(tensor, sf, mode='bilinear'):
if sf is not tuple and sf is not list:
sf = (sf, sf)
return torch.nn.functional.interpolate(tensor.clone(), scale_factor=sf, mode=mode, align_corners=True)
# Resize tensor
def resize_tensor(tensor, size, mode='bilinear'):
return torch.nn.functional.interpolate(tensor.clone(), size=size, mode=mode, align_corners=True)
# Center crop a tensor
def center_crop(input, crop_val, mode='percent'):
h, w = input.size(2), input.size(3)
if mode == 'percent':
h_crop = int((crop_val / 100) * input.size(2))
w_crop = int((crop_val / 100) * input.size(3))
elif mode == 'pixel':
h_crop = input.size(2) - crop_val
w_crop = input.size(3) - crop_val
sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2)
return input[:, :, sh:sh + h_crop, sw:sw + w_crop]
# Center crop and resize a tensor
def zoom(input, crop_val, mode='percent'):
h, w = input.size(2), input.size(3)
input = center_crop(input.clone(), crop_val, mode=mode)
input = resize_tensor(input, (h,w))
return input
# Get tensor percentile
def tensor_percentile(t, p=99.98):
return t.view(-1).kthvalue(1 + round(0.01 * float(p) * (t.numel() - 1))).values.item()
# Define a module to match histograms
class MatchHistogram(nn.Module):
def __init__(self, eps=1e-5, mode='pca'):
super(MatchHistogram, self).__init__()
self.eps = eps or 1e-5
self.mode = mode or 'pca'
self.dim_val = 3
def extract_values(self, tensor):
m = tensor.mean(0).mean(0)
h = (tensor - m).permute(2,0,1).reshape(tensor.size(2),-1)
if h.is_cuda:
ch = torch.mm(h, h.T) / h.size(1) + self.eps * torch.eye(h.size(0), device=h.get_device())
else:
ch = torch.mm(h, h.T) / h.size(1) + self.eps * torch.eye(h.size(0))
return m, h, ch
def permute_tensor(self, tensor):
if tensor.dim() == 4:
tensor = tensor.squeeze(0).permute(2, 1, 0)
self.dim_val = 4
elif tensor.dim() == 3 and self.dim_val != 4:
tensor = tensor.permute(2, 1, 0)
elif tensor.dim() == 3 and self.dim_val == 4:
tensor = tensor.permute(2, 1, 0).unsqueeze(0)
return tensor
def nan2zero(self, tensor):
tensor[tensor != tensor] = 0
return tensor
def chol(self, t, c, s):
return torch.mm(torch.mm(torch.cholesky(s), torch.inverse(torch.cholesky(c))), t)
def sym(self, t, c, s):
p = self.pca(t, c)
psp = torch.mm(torch.mm(p, s), p)
eval_psp, evec_psp = torch.symeig(psp, eigenvectors=True, upper=True)
evec_mm = torch.mm(torch.mm(evec_psp, self.nan2zero(torch.sqrt(torch.diagflat(eval_psp)))), evec_psp.T)
return torch.mm(torch.mm(torch.mm(torch.inverse(p), evec_mm), torch.inverse(p)), t)
def pca(self, t, c):
eval_t, evec_t = torch.symeig(c, eigenvectors=True, upper=True)
return torch.mm(torch.mm(evec_t, self.nan2zero(torch.sqrt(torch.diagflat(eval_t)))), evec_t.T)
def match(self, target_tensor, source_tensor):
source_tensor = self.permute_tensor(source_tensor)
target_tensor = self.permute_tensor(target_tensor)
_, t, ct = self.extract_values(target_tensor)
ms, s, cs = self.extract_values(source_tensor)
if self.mode == 'pca':
mt = torch.mm(torch.mm(self.pca(s, cs), torch.inverse(self.pca(t, ct))), t)
elif self.mode == 'sym':
mt = self.sym(t, ct, cs)
elif self.mode == 'chol':
mt = self.chol(t, ct, cs)
matched_tensor = mt.reshape(*target_tensor.permute(2,0,1).size()).permute(1,2,0) + ms
return self.permute_tensor(matched_tensor)
def forward(self, input, source_tensor):
return self.match(input, source_tensor)
# Define an nn Module to perform guassian blurring
class GaussianBlur(nn.Module):
def __init__(self, k_size, sigma):
super(GaussianBlur, self).__init__()
self.k_size = k_size
self.sigma = sigma
def capture(self, input):
if input.dim() == 4:
d_val = 2
self.groups = input.size(1)
elif input.dim() == 2:
d_val = 1
self.groups = input.size(0)
self.k_size, self.sigma = [self.k_size] * d_val, [self.sigma] * d_val
kernel = 1
meshgrid_tensor = torch.meshgrid([torch.arange(size, dtype=torch.float32, \
device=input.get_device()) for size in self.k_size])
for size, std, mgrid in zip(self.k_size, self.sigma, meshgrid_tensor):
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
torch.exp(-((mgrid - ((size - 1) / 2)) / std) ** 2 / 2)
kernel = (kernel / torch.sum(kernel)).view(1, 1, * kernel.size())
kernel = kernel.repeat(self.groups, * [1] * (kernel.dim() - 1))
self.register_buffer('weight', kernel)
if d_val == 2:
self.conv = torch.nn.functional.conv2d
elif d_val == 1:
self.conv = torch.nn.functional.conv1d
def forward(self, input, pad_mode='reflect'):
d_val = input.dim()
if input.dim() > 2:
input = torch.nn.functional.pad(input, (2, 2, 2, 2), mode=pad_mode)
else:
input = input.view(1, 1, input.size(1))
input = self.conv(input, weight=self.weight, groups=self.groups)
if d_val == 2:
p1d = nn.ConstantPad1d(2, 0)
input = p1d(input)
input = input.view(1, input.size(2))
return input
# Define a Module to create guassian blur for laplacian pyramids
class GaussianBlurLP(GaussianBlur):
def __init__(self, input, k_size=5, sigma=0):
super(GaussianBlur, self).__init__()
self.guass_blur = GaussianBlur(k_size, sigma)
self.guass_blur.capture(input)
def forward(self, input):
return self.guass_blur(input)
# Define an nn Module to apply guassian blur as a layer
class GaussianBlurLayer(nn.Module):
def __init__(self, k_size=5, sigma=0):
super(GaussianBlurLayer, self).__init__()
self.blur = GaussianBlur(k_size, sigma)
self.mode = 'None'
def forward(self, input):
if self.mode == 'loss':
input = self.blur(input)
if self.mode == 'None':
self.mode = 'capture'
if self.mode == 'capture':
self.blur.capture(input.clone())
self.mode = 'loss'
return input
# Define an nn Module to create a laplacian pyramid
class LaplacianPyramid(nn.Module):
def __init__(self, input, scale=4, sigma=1):
super(LaplacianPyramid, self).__init__()
if len(sigma) == 1:
sigma = (float(sigma[0]), float(sigma[0]) * scale)
else:
sigma = [float(s) for s in sigma]
self.gauss_blur = GaussianBlurLP(input, 5, sigma[0])
self.gauss_blur_hi = GaussianBlurLP(input, 5, sigma[1])
self.scale = scale
def split_lap(self, input):
g = self.gauss_blur(input)
gt = self.gauss_blur_hi(input)
return g, input - gt
def pyramid_list(self, input):
pyramid_levels = []
for i in range(self.scale):
input, hi = self.split_lap(input)
pyramid_levels.append(hi)
pyramid_levels.append(input)
return pyramid_levels[::-1]
def lap_merge(self, pyramid_levels):
b = torch.zeros_like(pyramid_levels[0])
for p in pyramid_levels:
b = b + p
return b
def forward(self, input):
return self.lap_merge(self.pyramid_list(input))
# Define an nn Module to rank channels based on activation strength
class RankChannels(torch.nn.Module):
def __init__(self, channels=1, channel_mode='strong'):
super(RankChannels, self).__init__()
self.channels = channels
self.channel_mode = channel_mode
def sort_channels(self, input):
channel_list = []
for i in range(input.size(1)):
channel_list.append(torch.mean(input.clone().squeeze(0).narrow(0,i,1)).item())
return sorted((c,v) for v,c in enumerate(channel_list))
def get_middle(self, sequence):
num = self.channels[0]
m = (len(sequence) - 1)//2 - num//2
return sequence[m:m+num]
def remove_channels(self, cl):
return [c for c in cl if c[1] not in self.channels]
def rank_channel_list(self, input):
top_channels = self.channels[0]
channel_list = self.sort_channels(input)
if 'strong' in self.channel_mode:
channel_list.reverse()
elif 'avg' in self.channel_mode:
channel_list = self.get_middle(channel_list)
elif 'ignore' in self.channel_mode:
channel_list = self.remove_channels(channel_list)
top_channels = len(channel_list)
channels = []
for i in range(top_channels):
channels.append(channel_list[i][1])
return channels
def forward(self, input):
return self.rank_channel_list(input)
class FFTTensor(nn.Module):
def __init__(self, r=1, bl=25):
super(FFTTensor, self).__init__()
self.r = r
self.bl = bl
def block_low(self, input):
if input.dim() == 5:
hh, hw = int(input.size(2)/2), int(input.size(3)/2)
input[:, :, hh-self.bl:hh+self.bl+1, hw-self.bl:hw+self.bl+1, :] = self.r
elif input.dim() == 3:
m = (input.size(1) - 1)//2 - self.bl//2
input[:, m:m+self.bl, :] = self.r
return input
def fft_image(self, input, s=0):
s_dim = 3 if input.dim() == 4 else 1
s_dim = s_dim if s == 0 else s
input = torch.rfft(input, signal_ndim=s_dim, onesided=False)
real, imaginary = torch.unbind(input, -1)
for r_dim in range(1, len(real.size())):
n_shift = real.size(r_dim)//2
if real.size(r_dim) % 2 != 0:
n_shift += 1
real = torch.roll(real, n_shift, dims=r_dim)
imaginary = torch.roll(imaginary, n_shift, dims=r_dim)
return torch.stack((real, imaginary), -1)
def ifft_image(self, input, s=0):
s_dim = 3 if input.dim() == 5 else 1
s_dim = s_dim if s == 0 else s
real, imaginary = torch.unbind(input, -1)
for r_dim in range(len(real.size()) - 1, 0, -1):
real = torch.roll(real, real.size(r_dim)//2, dims=r_dim)
imaginary = torch.roll(imaginary, imaginary.size(r_dim)//2, dims=r_dim)
return torch.irfft(torch.stack((real, imaginary), -1), signal_ndim=s_dim, onesided=False)
def forward(self, input):
input = self.block_low(self.fft_image(input))
return torch.abs(self.ifft_image(input))
# Define an nn Module to apply jitter
class Jitter(torch.nn.Module):
def __init__(self, jitter_val):
super(Jitter, self).__init__()
self.jitter_val = jitter_val
def roll_tensor(self, input):
h_shift = random.randint(-self.jitter_val, self.jitter_val)
w_shift = random.randint(-self.jitter_val, self.jitter_val)
return torch.roll(torch.roll(input, shifts=h_shift, dims=2), shifts=w_shift, dims=3)
def forward(self, input):
return self.roll_tensor(input)
# Define an nn Module to apply random transforms
class RandomTransform(torch.nn.Module):
def __init__(self, t_val):
super(RandomTransform, self).__init__()
self.rotate, self.flip = False, False
if t_val == 'all' or t_val == 'rotate':
self.rotate = True
if t_val == 'all' or t_val == 'flip':
self.flip = True
def rotate_tensor(self, input):
if self.rotate:
k_val = random.randint(0,3)
input = torch.rot90(input, k_val, [2,3])
return input
def flip_tensor(self, input):
if self.flip:
flip_tensor = bool(random.randint(0,1))
if flip_tensor:
input = input.flip([2,3])
return input
def forward(self, input):
return self.flip_tensor(self.rotate_tensor(input))
# Define an nn Module to compute l2 loss
class L2Regularizer(nn.Module):
def __init__(self, strength):
super(L2Regularizer, self).__init__()
self.strength = strength
def forward(self, input):
self.loss = self.strength * (input.clone().norm(3)/2)
return input
# Define an nn Module to compute tv loss
class TVLoss(nn.Module):
def __init__(self, strength):
super(TVLoss, self).__init__()
self.strength = strength
def forward(self, input):
self.x_diff = input[:,:,1:,:] - input[:,:,:-1,:]
self.y_diff = input[:,:,:,1:] - input[:,:,:,:-1]
self.loss = self.strength * (torch.sum(torch.abs(self.x_diff)) + torch.sum(torch.abs(self.y_diff)))
return input
# Define an nn Module to label predicted channels
class Classify(nn.Module):
def __init__(self, labels, k=1):
super(Classify, self).__init__()
self.labels = [str(n) for n in labels]
self.k = k
def forward(self, input):
channel_ids = torch.topk(input, self.k).indices
channel_ids = [n.item() for n in channel_ids[0]]
label_names = ''
for i in channel_ids:
if label_names != '':
label_names += ', ' + self.labels[i]
else:
label_names += self.labels[i]
print(' Predicted labels: ' + label_names)
# Run inception modules with preprocessing layers
class ModelPlus(nn.Module):
def __init__(self, input_net, net):
super(ModelPlus, self).__init__()
self.input_net = input_net
self.net = net
def forward(self, input):
return self.net(self.input_net(input))
class AdditionLayer(nn.Module):
def forward(self, input, input2):
return input + input2
import torch
import torch.nn as nn
import torch.nn.functional as F
class GoogLeNetPlaces205(nn.Module):
def __init__(self):
super(GoogLeNetPlaces205, self).__init__()
self.conv1_7x7_s2 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(7, 7), stride=(2, 2), groups=1, bias=True)
self.conv2_3x3_reduce = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.conv2_3x3 = nn.Conv2d(in_channels=64, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_3a_1x1 = nn.Conv2d(in_channels=192, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_5x5_reduce = nn.Conv2d(in_channels=192, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_3x3_reduce = nn.Conv2d(in_channels=192, out_channels=96, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_pool_proj = nn.Conv2d(in_channels=192, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_5x5 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_3a_3x3 = nn.Conv2d(in_channels=96, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_3b_3x3_reduce = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_1x1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_5x5_reduce = nn.Conv2d(in_channels=256, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_pool_proj = nn.Conv2d(in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_3x3 = nn.Conv2d(in_channels=128, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_3b_5x5 = nn.Conv2d(in_channels=32, out_channels=96, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4a_1x1 = nn.Conv2d(in_channels=480, out_channels=192, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_3x3_reduce = nn.Conv2d(in_channels=480, out_channels=96, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_5x5_reduce = nn.Conv2d(in_channels=480, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_pool_proj = nn.Conv2d(in_channels=480, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_3x3 = nn.Conv2d(in_channels=96, out_channels=208, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_4a_5x5 = nn.Conv2d(in_channels=16, out_channels=48, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4b_5x5_reduce = nn.Conv2d(in_channels=512, out_channels=24, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_1x1 = nn.Conv2d(in_channels=512, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_3x3_reduce = nn.Conv2d(in_channels=512, out_channels=112, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_pool_proj = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.loss1_conv = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_5x5 = nn.Conv2d(in_channels=24, out_channels=64, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4b_3x3 = nn.Conv2d(in_channels=112, out_channels=224, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.loss1_fc_1 = nn.Linear(in_features = 2048, out_features = 1024, bias = True)
self.inception_4c_5x5_reduce = nn.Conv2d(in_channels=512, out_channels=24, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_1x1 = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_3x3_reduce = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_pool_proj = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_5x5 = nn.Conv2d(in_channels=24, out_channels=64, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4c_3x3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.loss1_classifier_1 = nn.Linear(in_features = 1024, out_features = 205, bias = True)
self.inception_4d_3x3_reduce = nn.Conv2d(in_channels=512, out_channels=144, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_1x1 = nn.Conv2d(in_channels=512, out_channels=112, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_5x5_reduce = nn.Conv2d(in_channels=512, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_pool_proj = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_3x3 = nn.Conv2d(in_channels=144, out_channels=288, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_4d_5x5 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4e_1x1 = nn.Conv2d(in_channels=528, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_5x5_reduce = nn.Conv2d(in_channels=528, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_3x3_reduce = nn.Conv2d(in_channels=528, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.loss2_conv = nn.Conv2d(in_channels=528, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_pool_proj = nn.Conv2d(in_channels=528, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_5x5 = nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4e_3x3 = nn.Conv2d(in_channels=160, out_channels=320, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.loss2_fc_1 = nn.Linear(in_features = 2048, out_features = 1024, bias = True)
self.inception_5a_1x1 = nn.Conv2d(in_channels=832, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5a_5x5_reduce = nn.Conv2d(in_channels=832, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5a_3x3_reduce = nn.Conv2d(in_channels=832, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5a_pool_proj = nn.Conv2d(in_channels=832, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.loss2_classifier_1 = nn.Linear(in_features = 1024, out_features = 205, bias = True)
self.inception_5a_5x5 = nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_5a_3x3 = nn.Conv2d(in_channels=160, out_channels=320, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_5b_3x3_reduce = nn.Conv2d(in_channels=832, out_channels=192, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_5x5_reduce = nn.Conv2d(in_channels=832, out_channels=48, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_1x1 = nn.Conv2d(in_channels=832, out_channels=384, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_pool_proj = nn.Conv2d(in_channels=832, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_3x3 = nn.Conv2d(in_channels=192, out_channels=384, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_5b_5x5 = nn.Conv2d(in_channels=48, out_channels=128, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
def forward(self, x):
self.training = False
conv1_7x7_s2_pad = F.pad(x, (3, 3, 3, 3))
conv1_7x7_s2 = self.conv1_7x7_s2(conv1_7x7_s2_pad)
conv1_relu_7x7 = F.relu(conv1_7x7_s2)
pool1_3x3_s2_pad = F.pad(conv1_relu_7x7, (0, 1, 0, 1), value=float('-inf'))
pool1_3x3_s2 = F.max_pool2d(pool1_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
pool1_norm1 = F.local_response_norm(pool1_3x3_s2, size=5, alpha=9.999999747378752e-05, beta=0.75, k=1.0)
conv2_3x3_reduce = self.conv2_3x3_reduce(pool1_norm1)
conv2_relu_3x3_reduce = F.relu(conv2_3x3_reduce)
conv2_3x3_pad = F.pad(conv2_relu_3x3_reduce, (1, 1, 1, 1))
conv2_3x3 = self.conv2_3x3(conv2_3x3_pad)
conv2_relu_3x3 = F.relu(conv2_3x3)
conv2_norm2 = F.local_response_norm(conv2_relu_3x3, size=5, alpha=9.999999747378752e-05, beta=0.75, k=1.0)
pool2_3x3_s2_pad = F.pad(conv2_norm2, (0, 1, 0, 1), value=float('-inf'))
pool2_3x3_s2 = F.max_pool2d(pool2_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
inception_3a_pool_pad = F.pad(pool2_3x3_s2, (1, 1, 1, 1), value=float('-inf'))
inception_3a_pool = F.max_pool2d(inception_3a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_3a_1x1 = self.inception_3a_1x1(pool2_3x3_s2)
inception_3a_5x5_reduce = self.inception_3a_5x5_reduce(pool2_3x3_s2)
inception_3a_3x3_reduce = self.inception_3a_3x3_reduce(pool2_3x3_s2)
inception_3a_pool_proj = self.inception_3a_pool_proj(inception_3a_pool)
inception_3a_relu_1x1 = F.relu(inception_3a_1x1)
inception_3a_relu_5x5_reduce = F.relu(inception_3a_5x5_reduce)
inception_3a_relu_3x3_reduce = F.relu(inception_3a_3x3_reduce)
inception_3a_relu_pool_proj = F.relu(inception_3a_pool_proj)
inception_3a_5x5_pad = F.pad(inception_3a_relu_5x5_reduce, (2, 2, 2, 2))
inception_3a_5x5 = self.inception_3a_5x5(inception_3a_5x5_pad)
inception_3a_3x3_pad = F.pad(inception_3a_relu_3x3_reduce, (1, 1, 1, 1))
inception_3a_3x3 = self.inception_3a_3x3(inception_3a_3x3_pad)
inception_3a_relu_5x5 = F.relu(inception_3a_5x5)
inception_3a_relu_3x3 = F.relu(inception_3a_3x3)
inception_3a_output = torch.cat((inception_3a_relu_1x1, inception_3a_relu_3x3, inception_3a_relu_5x5, inception_3a_relu_pool_proj), 1)
inception_3b_3x3_reduce = self.inception_3b_3x3_reduce(inception_3a_output)
inception_3b_pool_pad = F.pad(inception_3a_output, (1, 1, 1, 1), value=float('-inf'))
inception_3b_pool = F.max_pool2d(inception_3b_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_3b_1x1 = self.inception_3b_1x1(inception_3a_output)
inception_3b_5x5_reduce = self.inception_3b_5x5_reduce(inception_3a_output)
inception_3b_relu_3x3_reduce = F.relu(inception_3b_3x3_reduce)
inception_3b_pool_proj = self.inception_3b_pool_proj(inception_3b_pool)
inception_3b_relu_1x1 = F.relu(inception_3b_1x1)
inception_3b_relu_5x5_reduce = F.relu(inception_3b_5x5_reduce)
inception_3b_3x3_pad = F.pad(inception_3b_relu_3x3_reduce, (1, 1, 1, 1))
inception_3b_3x3 = self.inception_3b_3x3(inception_3b_3x3_pad)
inception_3b_relu_pool_proj = F.relu(inception_3b_pool_proj)
inception_3b_5x5_pad = F.pad(inception_3b_relu_5x5_reduce, (2, 2, 2, 2))
inception_3b_5x5 = self.inception_3b_5x5(inception_3b_5x5_pad)
inception_3b_relu_3x3 = F.relu(inception_3b_3x3)
inception_3b_relu_5x5 = F.relu(inception_3b_5x5)
inception_3b_output = torch.cat((inception_3b_relu_1x1, inception_3b_relu_3x3, inception_3b_relu_5x5, inception_3b_relu_pool_proj), 1)
pool3_3x3_s2_pad = F.pad(inception_3b_output, (0, 1, 0, 1), value=float('-inf'))
pool3_3x3_s2 = F.max_pool2d(pool3_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
inception_4a_1x1 = self.inception_4a_1x1(pool3_3x3_s2)
inception_4a_3x3_reduce = self.inception_4a_3x3_reduce(pool3_3x3_s2)
inception_4a_5x5_reduce = self.inception_4a_5x5_reduce(pool3_3x3_s2)
inception_4a_pool_pad = F.pad(pool3_3x3_s2, (1, 1, 1, 1), value=float('-inf'))
inception_4a_pool = F.max_pool2d(inception_4a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4a_relu_1x1 = F.relu(inception_4a_1x1)
inception_4a_relu_3x3_reduce = F.relu(inception_4a_3x3_reduce)
inception_4a_relu_5x5_reduce = F.relu(inception_4a_5x5_reduce)
inception_4a_pool_proj = self.inception_4a_pool_proj(inception_4a_pool)
inception_4a_3x3_pad = F.pad(inception_4a_relu_3x3_reduce, (1, 1, 1, 1))
inception_4a_3x3 = self.inception_4a_3x3(inception_4a_3x3_pad)
inception_4a_5x5_pad = F.pad(inception_4a_relu_5x5_reduce, (2, 2, 2, 2))
inception_4a_5x5 = self.inception_4a_5x5(inception_4a_5x5_pad)
inception_4a_relu_pool_proj = F.relu(inception_4a_pool_proj)
inception_4a_relu_3x3 = F.relu(inception_4a_3x3)
inception_4a_relu_5x5 = F.relu(inception_4a_5x5)
inception_4a_output = torch.cat((inception_4a_relu_1x1, inception_4a_relu_3x3, inception_4a_relu_5x5, inception_4a_relu_pool_proj), 1)
inception_4b_pool_pad = F.pad(inception_4a_output, (1, 1, 1, 1), value=float('-inf'))
inception_4b_pool = F.max_pool2d(inception_4b_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
#loss1_ave_pool = F.avg_pool2d(inception_4a_output, kernel_size=(5, 5), stride=(3, 3), padding=(0,), ceil_mode=True, count_include_pad=False)
inception_4b_5x5_reduce = self.inception_4b_5x5_reduce(inception_4a_output)
inception_4b_1x1 = self.inception_4b_1x1(inception_4a_output)
inception_4b_3x3_reduce = self.inception_4b_3x3_reduce(inception_4a_output)
inception_4b_pool_proj = self.inception_4b_pool_proj(inception_4b_pool)
#loss1_conv = self.loss1_conv(loss1_ave_pool)
inception_4b_relu_5x5_reduce = F.relu(inception_4b_5x5_reduce)
inception_4b_relu_1x1 = F.relu(inception_4b_1x1)
inception_4b_relu_3x3_reduce = F.relu(inception_4b_3x3_reduce)
inception_4b_relu_pool_proj = F.relu(inception_4b_pool_proj)
#loss1_relu_conv = F.relu(loss1_conv)
inception_4b_5x5_pad = F.pad(inception_4b_relu_5x5_reduce, (2, 2, 2, 2))
inception_4b_5x5 = self.inception_4b_5x5(inception_4b_5x5_pad)
inception_4b_3x3_pad = F.pad(inception_4b_relu_3x3_reduce, (1, 1, 1, 1))
inception_4b_3x3 = self.inception_4b_3x3(inception_4b_3x3_pad)
#loss1_fc_0 = loss1_relu_conv.view(loss1_relu_conv.size(0), -1)
inception_4b_relu_5x5 = F.relu(inception_4b_5x5)
inception_4b_relu_3x3 = F.relu(inception_4b_3x3)
#loss1_fc_1 = self.loss1_fc_1(loss1_fc_0)
inception_4b_output = torch.cat((inception_4b_relu_1x1, inception_4b_relu_3x3, inception_4b_relu_5x5, inception_4b_relu_pool_proj), 1)
#loss1_relu_fc = F.relu(loss1_fc_1)
inception_4c_5x5_reduce = self.inception_4c_5x5_reduce(inception_4b_output)
inception_4c_pool_pad = F.pad(inception_4b_output, (1, 1, 1, 1), value=float('-inf'))
inception_4c_pool = F.max_pool2d(inception_4c_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4c_1x1 = self.inception_4c_1x1(inception_4b_output)
inception_4c_3x3_reduce = self.inception_4c_3x3_reduce(inception_4b_output)
#loss1_drop_fc = F.dropout(input = loss1_relu_fc, p = 0.699999988079071, training = self.training, inplace = True)
inception_4c_relu_5x5_reduce = F.relu(inception_4c_5x5_reduce)
inception_4c_pool_proj = self.inception_4c_pool_proj(inception_4c_pool)
inception_4c_relu_1x1 = F.relu(inception_4c_1x1)
inception_4c_relu_3x3_reduce = F.relu(inception_4c_3x3_reduce)
#loss1_classifier_0 = loss1_drop_fc.view(loss1_drop_fc.size(0), -1)
inception_4c_5x5_pad = F.pad(inception_4c_relu_5x5_reduce, (2, 2, 2, 2))
inception_4c_5x5 = self.inception_4c_5x5(inception_4c_5x5_pad)
inception_4c_relu_pool_proj = F.relu(inception_4c_pool_proj)
inception_4c_3x3_pad = F.pad(inception_4c_relu_3x3_reduce, (1, 1, 1, 1))
inception_4c_3x3 = self.inception_4c_3x3(inception_4c_3x3_pad)
#loss1_classifier_1 = self.loss1_classifier_1(loss1_classifier_0)
inception_4c_relu_5x5 = F.relu(inception_4c_5x5)
inception_4c_relu_3x3 = F.relu(inception_4c_3x3)
inception_4c_output = torch.cat((inception_4c_relu_1x1, inception_4c_relu_3x3, inception_4c_relu_5x5, inception_4c_relu_pool_proj), 1)
inception_4d_pool_pad = F.pad(inception_4c_output, (1, 1, 1, 1), value=float('-inf'))
inception_4d_pool = F.max_pool2d(inception_4d_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4d_3x3_reduce = self.inception_4d_3x3_reduce(inception_4c_output)
inception_4d_1x1 = self.inception_4d_1x1(inception_4c_output)
inception_4d_5x5_reduce = self.inception_4d_5x5_reduce(inception_4c_output)
inception_4d_pool_proj = self.inception_4d_pool_proj(inception_4d_pool)
inception_4d_relu_3x3_reduce = F.relu(inception_4d_3x3_reduce)
inception_4d_relu_1x1 = F.relu(inception_4d_1x1)
inception_4d_relu_5x5_reduce = F.relu(inception_4d_5x5_reduce)
inception_4d_relu_pool_proj = F.relu(inception_4d_pool_proj)
inception_4d_3x3_pad = F.pad(inception_4d_relu_3x3_reduce, (1, 1, 1, 1))
inception_4d_3x3 = self.inception_4d_3x3(inception_4d_3x3_pad)
inception_4d_5x5_pad = F.pad(inception_4d_relu_5x5_reduce, (2, 2, 2, 2))
inception_4d_5x5 = self.inception_4d_5x5(inception_4d_5x5_pad)
inception_4d_relu_3x3 = F.relu(inception_4d_3x3)
inception_4d_relu_5x5 = F.relu(inception_4d_5x5)
inception_4d_output = torch.cat((inception_4d_relu_1x1, inception_4d_relu_3x3, inception_4d_relu_5x5, inception_4d_relu_pool_proj), 1)
inception_4e_1x1 = self.inception_4e_1x1(inception_4d_output)
inception_4e_5x5_reduce = self.inception_4e_5x5_reduce(inception_4d_output)
#loss2_ave_pool = F.avg_pool2d(inception_4d_output, kernel_size=(5, 5), stride=(3, 3), padding=(0,), ceil_mode=True, count_include_pad=False)
inception_4e_3x3_reduce = self.inception_4e_3x3_reduce(inception_4d_output)
inception_4e_pool_pad = F.pad(inception_4d_output, (1, 1, 1, 1), value=float('-inf'))
inception_4e_pool = F.max_pool2d(inception_4e_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4e_relu_1x1 = F.relu(inception_4e_1x1)
inception_4e_relu_5x5_reduce = F.relu(inception_4e_5x5_reduce)
#loss2_conv = self.loss2_conv(loss2_ave_pool)
inception_4e_relu_3x3_reduce = F.relu(inception_4e_3x3_reduce)
inception_4e_pool_proj = self.inception_4e_pool_proj(inception_4e_pool)
inception_4e_5x5_pad = F.pad(inception_4e_relu_5x5_reduce, (2, 2, 2, 2))
inception_4e_5x5 = self.inception_4e_5x5(inception_4e_5x5_pad)
#loss2_relu_conv = F.relu(loss2_conv)
inception_4e_3x3_pad = F.pad(inception_4e_relu_3x3_reduce, (1, 1, 1, 1))
inception_4e_3x3 = self.inception_4e_3x3(inception_4e_3x3_pad)
inception_4e_relu_pool_proj = F.relu(inception_4e_pool_proj)
inception_4e_relu_5x5 = F.relu(inception_4e_5x5)
#loss2_fc_0 = loss2_relu_conv.view(loss2_relu_conv.size(0), -1)
inception_4e_relu_3x3 = F.relu(inception_4e_3x3)
#loss2_fc_1 = self.loss2_fc_1(loss2_fc_0)
inception_4e_output = torch.cat((inception_4e_relu_1x1, inception_4e_relu_3x3, inception_4e_relu_5x5, inception_4e_relu_pool_proj), 1)
#loss2_relu_fc = F.relu(loss2_fc_1)
pool4_3x3_s2_pad = F.pad(inception_4e_output, (0, 1, 0, 1), value=float('-inf'))
pool4_3x3_s2 = F.max_pool2d(pool4_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
#loss2_drop_fc = F.dropout(input = loss2_relu_fc, p = 0.699999988079071, training = self.training, inplace = True)
inception_5a_1x1 = self.inception_5a_1x1(pool4_3x3_s2)
inception_5a_5x5_reduce = self.inception_5a_5x5_reduce(pool4_3x3_s2)
inception_5a_pool_pad = F.pad(pool4_3x3_s2, (1, 1, 1, 1), value=float('-inf'))
inception_5a_pool = F.max_pool2d(inception_5a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_5a_3x3_reduce = self.inception_5a_3x3_reduce(pool4_3x3_s2)
#loss2_classifier_0 = loss2_drop_fc.view(loss2_drop_fc.size(0), -1)
inception_5a_relu_1x1 = F.relu(inception_5a_1x1)
inception_5a_relu_5x5_reduce = F.relu(inception_5a_5x5_reduce)
inception_5a_pool_proj = self.inception_5a_pool_proj(inception_5a_pool)
inception_5a_relu_3x3_reduce = F.relu(inception_5a_3x3_reduce)
#loss2_classifier_1 = self.loss2_classifier_1(loss2_classifier_0)
inception_5a_5x5_pad = F.pad(inception_5a_relu_5x5_reduce, (2, 2, 2, 2))
inception_5a_5x5 = self.inception_5a_5x5(inception_5a_5x5_pad)
inception_5a_relu_pool_proj = F.relu(inception_5a_pool_proj)
inception_5a_3x3_pad = F.pad(inception_5a_relu_3x3_reduce, (1, 1, 1, 1))
inception_5a_3x3 = self.inception_5a_3x3(inception_5a_3x3_pad)
inception_5a_relu_5x5 = F.relu(inception_5a_5x5)
inception_5a_relu_3x3 = F.relu(inception_5a_3x3)
inception_5a_output = torch.cat((inception_5a_relu_1x1, inception_5a_relu_3x3, inception_5a_relu_5x5, inception_5a_relu_pool_proj), 1)
inception_5b_3x3_reduce = self.inception_5b_3x3_reduce(inception_5a_output)
inception_5b_pool_pad = F.pad(inception_5a_output, (1, 1, 1, 1), value=float('-inf'))
inception_5b_pool = F.max_pool2d(inception_5b_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_5b_5x5_reduce = self.inception_5b_5x5_reduce(inception_5a_output)
inception_5b_1x1 = self.inception_5b_1x1(inception_5a_output)
inception_5b_relu_3x3_reduce = F.relu(inception_5b_3x3_reduce)
inception_5b_pool_proj = self.inception_5b_pool_proj(inception_5b_pool)
inception_5b_relu_5x5_reduce = F.relu(inception_5b_5x5_reduce)
inception_5b_relu_1x1 = F.relu(inception_5b_1x1)
inception_5b_3x3_pad = F.pad(inception_5b_relu_3x3_reduce, (1, 1, 1, 1))
inception_5b_3x3 = self.inception_5b_3x3(inception_5b_3x3_pad)
inception_5b_relu_pool_proj = F.relu(inception_5b_pool_proj)
inception_5b_5x5_pad = F.pad(inception_5b_relu_5x5_reduce, (2, 2, 2, 2))
inception_5b_5x5 = self.inception_5b_5x5(inception_5b_5x5_pad)
inception_5b_relu_3x3 = F.relu(inception_5b_3x3)
inception_5b_relu_5x5 = F.relu(inception_5b_5x5)
inception_5b_output = torch.cat((inception_5b_relu_1x1, inception_5b_relu_3x3, inception_5b_relu_5x5, inception_5b_relu_pool_proj), 1)
pool5_7x7_s1 = F.avg_pool2d(inception_5b_output, kernel_size=(7, 7), stride=(1, 1), padding=(0,), ceil_mode=False, count_include_pad=False)
pool5_drop_7x7_s1 = F.dropout(input = pool5_7x7_s1, p = 0.4000000059604645, training = self.training, inplace = True)
return pool5_drop_7x7_s1
#test_avgpool = nn.AdaptiveAvgPool2d((7, 7))
#x = test_avgpool(pool5_drop_7x7_s1)
#x = torch.flatten(x, 1)
#test_fc = nn.Linear(1024 * 7 * 7, 204)
#x = test_fc(x)
#return x
#return pool5_drop_7x7_s1#, loss2_classifier_1, loss1_classifier_1
class GoogLeNetPlaces365(nn.Module):
def __init__(self):
super(GoogLeNetPlaces365, self).__init__()
self.conv1_7x7_s2 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(7, 7), stride=(2, 2), groups=1, bias=True)
self.conv2_3x3_reduce = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.conv2_3x3 = nn.Conv2d(in_channels=64, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_3a_1x1 = nn.Conv2d(in_channels=192, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_5x5_reduce = nn.Conv2d(in_channels=192, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_3x3_reduce = nn.Conv2d(in_channels=192, out_channels=96, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_pool_proj = nn.Conv2d(in_channels=192, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_5x5 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_3a_3x3 = nn.Conv2d(in_channels=96, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_3b_3x3_reduce = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_1x1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_5x5_reduce = nn.Conv2d(in_channels=256, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_pool_proj = nn.Conv2d(in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_3x3 = nn.Conv2d(in_channels=128, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_3b_5x5 = nn.Conv2d(in_channels=32, out_channels=96, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4a_1x1 = nn.Conv2d(in_channels=480, out_channels=192, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_3x3_reduce = nn.Conv2d(in_channels=480, out_channels=96, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_5x5_reduce = nn.Conv2d(in_channels=480, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_pool_proj = nn.Conv2d(in_channels=480, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_3x3 = nn.Conv2d(in_channels=96, out_channels=208, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_4a_5x5 = nn.Conv2d(in_channels=16, out_channels=48, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4b_5x5_reduce = nn.Conv2d(in_channels=512, out_channels=24, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_1x1 = nn.Conv2d(in_channels=512, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_3x3_reduce = nn.Conv2d(in_channels=512, out_channels=112, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_pool_proj = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.loss1_conv = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_5x5 = nn.Conv2d(in_channels=24, out_channels=64, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4b_3x3 = nn.Conv2d(in_channels=112, out_channels=224, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.loss1_fc_1 = nn.Linear(in_features = 2048, out_features = 1024, bias = True)
self.inception_4c_5x5_reduce = nn.Conv2d(in_channels=512, out_channels=24, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_1x1 = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_3x3_reduce = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_pool_proj = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_5x5 = nn.Conv2d(in_channels=24, out_channels=64, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4c_3x3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.loss1_classifier_1 = nn.Linear(in_features = 1024, out_features = 365, bias = True)
self.inception_4d_3x3_reduce = nn.Conv2d(in_channels=512, out_channels=144, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_1x1 = nn.Conv2d(in_channels=512, out_channels=112, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_5x5_reduce = nn.Conv2d(in_channels=512, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_pool_proj = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_3x3 = nn.Conv2d(in_channels=144, out_channels=288, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_4d_5x5 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4e_1x1 = nn.Conv2d(in_channels=528, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_5x5_reduce = nn.Conv2d(in_channels=528, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_3x3_reduce = nn.Conv2d(in_channels=528, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.loss2_conv = nn.Conv2d(in_channels=528, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_pool_proj = nn.Conv2d(in_channels=528, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_5x5 = nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4e_3x3 = nn.Conv2d(in_channels=160, out_channels=320, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.loss2_fc_1 = nn.Linear(in_features = 2048, out_features = 1024, bias = True)
self.inception_5a_1x1 = nn.Conv2d(in_channels=832, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5a_5x5_reduce = nn.Conv2d(in_channels=832, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5a_3x3_reduce = nn.Conv2d(in_channels=832, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5a_pool_proj = nn.Conv2d(in_channels=832, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.loss2_classifier_1 = nn.Linear(in_features = 1024, out_features = 365, bias = True)
self.inception_5a_5x5 = nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_5a_3x3 = nn.Conv2d(in_channels=160, out_channels=320, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_5b_3x3_reduce = nn.Conv2d(in_channels=832, out_channels=192, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_5x5_reduce = nn.Conv2d(in_channels=832, out_channels=48, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_1x1 = nn.Conv2d(in_channels=832, out_channels=384, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_pool_proj = nn.Conv2d(in_channels=832, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_3x3 = nn.Conv2d(in_channels=192, out_channels=384, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_5b_5x5 = nn.Conv2d(in_channels=48, out_channels=128, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
def forward(self, x):
conv1_7x7_s2_pad = F.pad(x, (3, 3, 3, 3))
conv1_7x7_s2 = self.conv1_7x7_s2(conv1_7x7_s2_pad)
conv1_relu_7x7 = F.relu(conv1_7x7_s2)
pool1_3x3_s2_pad = F.pad(conv1_relu_7x7, (0, 1, 0, 1), value=float('-inf'))
pool1_3x3_s2 = F.max_pool2d(pool1_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
pool1_norm1 = F.local_response_norm(pool1_3x3_s2, size=5, alpha=9.999999747378752e-05, beta=0.75, k=1.0)
conv2_3x3_reduce = self.conv2_3x3_reduce(pool1_norm1)
conv2_relu_3x3_reduce = F.relu(conv2_3x3_reduce)
conv2_3x3_pad = F.pad(conv2_relu_3x3_reduce, (1, 1, 1, 1))
conv2_3x3 = self.conv2_3x3(conv2_3x3_pad)
conv2_relu_3x3 = F.relu(conv2_3x3)
conv2_norm2 = F.local_response_norm(conv2_relu_3x3, size=5, alpha=9.999999747378752e-05, beta=0.75, k=1.0)
pool2_3x3_s2_pad = F.pad(conv2_norm2, (0, 1, 0, 1), value=float('-inf'))
pool2_3x3_s2 = F.max_pool2d(pool2_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
inception_3a_pool_pad = F.pad(pool2_3x3_s2, (1, 1, 1, 1), value=float('-inf'))
inception_3a_pool = F.max_pool2d(inception_3a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_3a_1x1 = self.inception_3a_1x1(pool2_3x3_s2)
inception_3a_5x5_reduce = self.inception_3a_5x5_reduce(pool2_3x3_s2)
inception_3a_3x3_reduce = self.inception_3a_3x3_reduce(pool2_3x3_s2)
inception_3a_pool_proj = self.inception_3a_pool_proj(inception_3a_pool)
inception_3a_relu_1x1 = F.relu(inception_3a_1x1)
inception_3a_relu_5x5_reduce = F.relu(inception_3a_5x5_reduce)
inception_3a_relu_3x3_reduce = F.relu(inception_3a_3x3_reduce)
inception_3a_relu_pool_proj = F.relu(inception_3a_pool_proj)
inception_3a_5x5_pad = F.pad(inception_3a_relu_5x5_reduce, (2, 2, 2, 2))
inception_3a_5x5 = self.inception_3a_5x5(inception_3a_5x5_pad)
inception_3a_3x3_pad = F.pad(inception_3a_relu_3x3_reduce, (1, 1, 1, 1))
inception_3a_3x3 = self.inception_3a_3x3(inception_3a_3x3_pad)
inception_3a_relu_5x5 = F.relu(inception_3a_5x5)
inception_3a_relu_3x3 = F.relu(inception_3a_3x3)
inception_3a_output = torch.cat((inception_3a_relu_1x1, inception_3a_relu_3x3, inception_3a_relu_5x5, inception_3a_relu_pool_proj), 1)
inception_3b_3x3_reduce = self.inception_3b_3x3_reduce(inception_3a_output)
inception_3b_pool_pad = F.pad(inception_3a_output, (1, 1, 1, 1), value=float('-inf'))
inception_3b_pool = F.max_pool2d(inception_3b_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_3b_1x1 = self.inception_3b_1x1(inception_3a_output)
inception_3b_5x5_reduce = self.inception_3b_5x5_reduce(inception_3a_output)
inception_3b_relu_3x3_reduce = F.relu(inception_3b_3x3_reduce)
inception_3b_pool_proj = self.inception_3b_pool_proj(inception_3b_pool)
inception_3b_relu_1x1 = F.relu(inception_3b_1x1)
inception_3b_relu_5x5_reduce = F.relu(inception_3b_5x5_reduce)
inception_3b_3x3_pad = F.pad(inception_3b_relu_3x3_reduce, (1, 1, 1, 1))
inception_3b_3x3 = self.inception_3b_3x3(inception_3b_3x3_pad)
inception_3b_relu_pool_proj = F.relu(inception_3b_pool_proj)
inception_3b_5x5_pad = F.pad(inception_3b_relu_5x5_reduce, (2, 2, 2, 2))
inception_3b_5x5 = self.inception_3b_5x5(inception_3b_5x5_pad)
inception_3b_relu_3x3 = F.relu(inception_3b_3x3)
inception_3b_relu_5x5 = F.relu(inception_3b_5x5)
inception_3b_output = torch.cat((inception_3b_relu_1x1, inception_3b_relu_3x3, inception_3b_relu_5x5, inception_3b_relu_pool_proj), 1)
pool3_3x3_s2_pad = F.pad(inception_3b_output, (0, 1, 0, 1), value=float('-inf'))
pool3_3x3_s2 = F.max_pool2d(pool3_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
inception_4a_1x1 = self.inception_4a_1x1(pool3_3x3_s2)
inception_4a_3x3_reduce = self.inception_4a_3x3_reduce(pool3_3x3_s2)
inception_4a_5x5_reduce = self.inception_4a_5x5_reduce(pool3_3x3_s2)
inception_4a_pool_pad = F.pad(pool3_3x3_s2, (1, 1, 1, 1), value=float('-inf'))
inception_4a_pool = F.max_pool2d(inception_4a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4a_relu_1x1 = F.relu(inception_4a_1x1)
inception_4a_relu_3x3_reduce = F.relu(inception_4a_3x3_reduce)
inception_4a_relu_5x5_reduce = F.relu(inception_4a_5x5_reduce)
inception_4a_pool_proj = self.inception_4a_pool_proj(inception_4a_pool)
inception_4a_3x3_pad = F.pad(inception_4a_relu_3x3_reduce, (1, 1, 1, 1))
inception_4a_3x3 = self.inception_4a_3x3(inception_4a_3x3_pad)
inception_4a_5x5_pad = F.pad(inception_4a_relu_5x5_reduce, (2, 2, 2, 2))
inception_4a_5x5 = self.inception_4a_5x5(inception_4a_5x5_pad)
inception_4a_relu_pool_proj = F.relu(inception_4a_pool_proj)
inception_4a_relu_3x3 = F.relu(inception_4a_3x3)
inception_4a_relu_5x5 = F.relu(inception_4a_5x5)
inception_4a_output = torch.cat((inception_4a_relu_1x1, inception_4a_relu_3x3, inception_4a_relu_5x5, inception_4a_relu_pool_proj), 1)
inception_4b_pool_pad = F.pad(inception_4a_output, (1, 1, 1, 1), value=float('-inf'))
inception_4b_pool = F.max_pool2d(inception_4b_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
#loss1_ave_pool = F.avg_pool2d(inception_4a_output, kernel_size=(5, 5), stride=(3, 3), padding=(0,), ceil_mode=True, count_include_pad=False)
inception_4b_5x5_reduce = self.inception_4b_5x5_reduce(inception_4a_output)
inception_4b_1x1 = self.inception_4b_1x1(inception_4a_output)
inception_4b_3x3_reduce = self.inception_4b_3x3_reduce(inception_4a_output)
inception_4b_pool_proj = self.inception_4b_pool_proj(inception_4b_pool)
#loss1_conv = self.loss1_conv(loss1_ave_pool)
inception_4b_relu_5x5_reduce = F.relu(inception_4b_5x5_reduce)
inception_4b_relu_1x1 = F.relu(inception_4b_1x1)
inception_4b_relu_3x3_reduce = F.relu(inception_4b_3x3_reduce)
inception_4b_relu_pool_proj = F.relu(inception_4b_pool_proj)
#loss1_relu_conv = F.relu(loss1_conv)
inception_4b_5x5_pad = F.pad(inception_4b_relu_5x5_reduce, (2, 2, 2, 2))
inception_4b_5x5 = self.inception_4b_5x5(inception_4b_5x5_pad)
inception_4b_3x3_pad = F.pad(inception_4b_relu_3x3_reduce, (1, 1, 1, 1))
inception_4b_3x3 = self.inception_4b_3x3(inception_4b_3x3_pad)
#loss1_fc_0 = loss1_relu_conv.view(loss1_relu_conv.size(0), -1)
inception_4b_relu_5x5 = F.relu(inception_4b_5x5)
inception_4b_relu_3x3 = F.relu(inception_4b_3x3)
#loss1_fc_1 = self.loss1_fc_1(loss1_fc_0)
inception_4b_output = torch.cat((inception_4b_relu_1x1, inception_4b_relu_3x3, inception_4b_relu_5x5, inception_4b_relu_pool_proj), 1)
#loss1_relu_fc = F.relu(loss1_fc_1)
inception_4c_5x5_reduce = self.inception_4c_5x5_reduce(inception_4b_output)
inception_4c_pool_pad = F.pad(inception_4b_output, (1, 1, 1, 1), value=float('-inf'))
inception_4c_pool = F.max_pool2d(inception_4c_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4c_1x1 = self.inception_4c_1x1(inception_4b_output)
inception_4c_3x3_reduce = self.inception_4c_3x3_reduce(inception_4b_output)
#loss1_drop_fc = F.dropout(input = loss1_relu_fc, p = 0.699999988079071, training = self.training, inplace = True)
inception_4c_relu_5x5_reduce = F.relu(inception_4c_5x5_reduce)
inception_4c_pool_proj = self.inception_4c_pool_proj(inception_4c_pool)
inception_4c_relu_1x1 = F.relu(inception_4c_1x1)
inception_4c_relu_3x3_reduce = F.relu(inception_4c_3x3_reduce)
#loss1_classifier_0 = loss1_drop_fc.view(loss1_drop_fc.size(0), -1)
inception_4c_5x5_pad = F.pad(inception_4c_relu_5x5_reduce, (2, 2, 2, 2))
inception_4c_5x5 = self.inception_4c_5x5(inception_4c_5x5_pad)
inception_4c_relu_pool_proj = F.relu(inception_4c_pool_proj)
inception_4c_3x3_pad = F.pad(inception_4c_relu_3x3_reduce, (1, 1, 1, 1))
inception_4c_3x3 = self.inception_4c_3x3(inception_4c_3x3_pad)
#loss1_classifier_1 = self.loss1_classifier_1(loss1_classifier_0)
inception_4c_relu_5x5 = F.relu(inception_4c_5x5)
inception_4c_relu_3x3 = F.relu(inception_4c_3x3)
inception_4c_output = torch.cat((inception_4c_relu_1x1, inception_4c_relu_3x3, inception_4c_relu_5x5, inception_4c_relu_pool_proj), 1)
inception_4d_pool_pad = F.pad(inception_4c_output, (1, 1, 1, 1), value=float('-inf'))
inception_4d_pool = F.max_pool2d(inception_4d_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4d_3x3_reduce = self.inception_4d_3x3_reduce(inception_4c_output)
inception_4d_1x1 = self.inception_4d_1x1(inception_4c_output)
inception_4d_5x5_reduce = self.inception_4d_5x5_reduce(inception_4c_output)
inception_4d_pool_proj = self.inception_4d_pool_proj(inception_4d_pool)
inception_4d_relu_3x3_reduce = F.relu(inception_4d_3x3_reduce)
inception_4d_relu_1x1 = F.relu(inception_4d_1x1)
inception_4d_relu_5x5_reduce = F.relu(inception_4d_5x5_reduce)
inception_4d_relu_pool_proj = F.relu(inception_4d_pool_proj)
inception_4d_3x3_pad = F.pad(inception_4d_relu_3x3_reduce, (1, 1, 1, 1))
inception_4d_3x3 = self.inception_4d_3x3(inception_4d_3x3_pad)
inception_4d_5x5_pad = F.pad(inception_4d_relu_5x5_reduce, (2, 2, 2, 2))
inception_4d_5x5 = self.inception_4d_5x5(inception_4d_5x5_pad)
inception_4d_relu_3x3 = F.relu(inception_4d_3x3)
inception_4d_relu_5x5 = F.relu(inception_4d_5x5)
inception_4d_output = torch.cat((inception_4d_relu_1x1, inception_4d_relu_3x3, inception_4d_relu_5x5, inception_4d_relu_pool_proj), 1)
inception_4e_1x1 = self.inception_4e_1x1(inception_4d_output)
inception_4e_5x5_reduce = self.inception_4e_5x5_reduce(inception_4d_output)
#loss2_ave_pool = F.avg_pool2d(inception_4d_output, kernel_size=(5, 5), stride=(3, 3), padding=(0,), ceil_mode=True, count_include_pad=False)
inception_4e_3x3_reduce = self.inception_4e_3x3_reduce(inception_4d_output)
inception_4e_pool_pad = F.pad(inception_4d_output, (1, 1, 1, 1), value=float('-inf'))
inception_4e_pool = F.max_pool2d(inception_4e_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4e_relu_1x1 = F.relu(inception_4e_1x1)
inception_4e_relu_5x5_reduce = F.relu(inception_4e_5x5_reduce)
#loss2_conv = self.loss2_conv(loss2_ave_pool)
inception_4e_relu_3x3_reduce = F.relu(inception_4e_3x3_reduce)
inception_4e_pool_proj = self.inception_4e_pool_proj(inception_4e_pool)
inception_4e_5x5_pad = F.pad(inception_4e_relu_5x5_reduce, (2, 2, 2, 2))
inception_4e_5x5 = self.inception_4e_5x5(inception_4e_5x5_pad)
#loss2_relu_conv = F.relu(loss2_conv)
inception_4e_3x3_pad = F.pad(inception_4e_relu_3x3_reduce, (1, 1, 1, 1))
inception_4e_3x3 = self.inception_4e_3x3(inception_4e_3x3_pad)
inception_4e_relu_pool_proj = F.relu(inception_4e_pool_proj)
inception_4e_relu_5x5 = F.relu(inception_4e_5x5)
#loss2_fc_0 = loss2_relu_conv.view(loss2_relu_conv.size(0), -1)
inception_4e_relu_3x3 = F.relu(inception_4e_3x3)
#loss2_fc_1 = self.loss2_fc_1(loss2_fc_0)
inception_4e_output = torch.cat((inception_4e_relu_1x1, inception_4e_relu_3x3, inception_4e_relu_5x5, inception_4e_relu_pool_proj), 1)
#loss2_relu_fc = F.relu(loss2_fc_1)
pool4_3x3_s2_pad = F.pad(inception_4e_output, (0, 1, 0, 1), value=float('-inf'))
pool4_3x3_s2 = F.max_pool2d(pool4_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
#loss2_drop_fc = F.dropout(input = loss2_relu_fc, p = 0.699999988079071, training = self.training, inplace = True)
inception_5a_1x1 = self.inception_5a_1x1(pool4_3x3_s2)
inception_5a_5x5_reduce = self.inception_5a_5x5_reduce(pool4_3x3_s2)
inception_5a_pool_pad = F.pad(pool4_3x3_s2, (1, 1, 1, 1), value=float('-inf'))
inception_5a_pool = F.max_pool2d(inception_5a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_5a_3x3_reduce = self.inception_5a_3x3_reduce(pool4_3x3_s2)
#loss2_classifier_0 = loss2_drop_fc.view(loss2_drop_fc.size(0), -1)
inception_5a_relu_1x1 = F.relu(inception_5a_1x1)
inception_5a_relu_5x5_reduce = F.relu(inception_5a_5x5_reduce)
inception_5a_pool_proj = self.inception_5a_pool_proj(inception_5a_pool)
inception_5a_relu_3x3_reduce = F.relu(inception_5a_3x3_reduce)
#loss2_classifier_1 = self.loss2_classifier_1(loss2_classifier_0)
inception_5a_5x5_pad = F.pad(inception_5a_relu_5x5_reduce, (2, 2, 2, 2))
inception_5a_5x5 = self.inception_5a_5x5(inception_5a_5x5_pad)
inception_5a_relu_pool_proj = F.relu(inception_5a_pool_proj)
inception_5a_3x3_pad = F.pad(inception_5a_relu_3x3_reduce, (1, 1, 1, 1))
inception_5a_3x3 = self.inception_5a_3x3(inception_5a_3x3_pad)
inception_5a_relu_5x5 = F.relu(inception_5a_5x5)
inception_5a_relu_3x3 = F.relu(inception_5a_3x3)
inception_5a_output = torch.cat((inception_5a_relu_1x1, inception_5a_relu_3x3, inception_5a_relu_5x5, inception_5a_relu_pool_proj), 1)
inception_5b_3x3_reduce = self.inception_5b_3x3_reduce(inception_5a_output)
inception_5b_pool_pad = F.pad(inception_5a_output, (1, 1, 1, 1), value=float('-inf'))
inception_5b_pool = F.max_pool2d(inception_5b_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_5b_5x5_reduce = self.inception_5b_5x5_reduce(inception_5a_output)
inception_5b_1x1 = self.inception_5b_1x1(inception_5a_output)
inception_5b_relu_3x3_reduce = F.relu(inception_5b_3x3_reduce)
inception_5b_pool_proj = self.inception_5b_pool_proj(inception_5b_pool)
inception_5b_relu_5x5_reduce = F.relu(inception_5b_5x5_reduce)
inception_5b_relu_1x1 = F.relu(inception_5b_1x1)
inception_5b_3x3_pad = F.pad(inception_5b_relu_3x3_reduce, (1, 1, 1, 1))
inception_5b_3x3 = self.inception_5b_3x3(inception_5b_3x3_pad)
inception_5b_relu_pool_proj = F.relu(inception_5b_pool_proj)
inception_5b_5x5_pad = F.pad(inception_5b_relu_5x5_reduce, (2, 2, 2, 2))
inception_5b_5x5 = self.inception_5b_5x5(inception_5b_5x5_pad)
inception_5b_relu_3x3 = F.relu(inception_5b_3x3)
inception_5b_relu_5x5 = F.relu(inception_5b_5x5)
inception_5b_output = torch.cat((inception_5b_relu_1x1, inception_5b_relu_3x3, inception_5b_relu_5x5, inception_5b_relu_pool_proj), 1)
pool5_7x7_s1 = F.avg_pool2d(inception_5b_output, kernel_size=(7, 7), stride=(1, 1), padding=(0,), ceil_mode=False, count_include_pad=False)
pool5_drop_7x7_s1 = F.dropout(input = pool5_7x7_s1, p = 0.4000000059604645, training = self.training, inplace = True)
return pool5_drop_7x7_s1#, loss2_classifier_1, loss1_classifier_1
class BVLC_GOOGLENET(nn.Module):
def __init__(self):
super(BVLC_GOOGLENET, self).__init__()
self.conv1_7x7_s2 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(7, 7), stride=(2, 2), groups=1, bias=True)
self.conv2_3x3_reduce = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.conv2_3x3 = nn.Conv2d(in_channels=64, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_3a_1x1 = nn.Conv2d(in_channels=192, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_5x5_reduce = nn.Conv2d(in_channels=192, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_3x3_reduce = nn.Conv2d(in_channels=192, out_channels=96, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_pool_proj = nn.Conv2d(in_channels=192, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_5x5 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_3a_3x3 = nn.Conv2d(in_channels=96, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_3b_3x3_reduce = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_1x1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_5x5_reduce = nn.Conv2d(in_channels=256, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_pool_proj = nn.Conv2d(in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_3x3 = nn.Conv2d(in_channels=128, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_3b_5x5 = nn.Conv2d(in_channels=32, out_channels=96, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4a_1x1 = nn.Conv2d(in_channels=480, out_channels=192, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_3x3_reduce = nn.Conv2d(in_channels=480, out_channels=96, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_5x5_reduce = nn.Conv2d(in_channels=480, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_pool_proj = nn.Conv2d(in_channels=480, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_3x3 = nn.Conv2d(in_channels=96, out_channels=208, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_4a_5x5 = nn.Conv2d(in_channels=16, out_channels=48, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4b_5x5_reduce = nn.Conv2d(in_channels=512, out_channels=24, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_1x1 = nn.Conv2d(in_channels=512, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_3x3_reduce = nn.Conv2d(in_channels=512, out_channels=112, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_pool_proj = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.loss1_conv = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_5x5 = nn.Conv2d(in_channels=24, out_channels=64, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4b_3x3 = nn.Conv2d(in_channels=112, out_channels=224, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.loss1_fc_1 = nn.Linear(in_features = 2048, out_features = 1024, bias = True)
self.inception_4c_5x5_reduce = nn.Conv2d(in_channels=512, out_channels=24, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_1x1 = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_3x3_reduce = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_pool_proj = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_5x5 = nn.Conv2d(in_channels=24, out_channels=64, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4c_3x3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.loss1_classifier_1 = nn.Linear(in_features = 1024, out_features = 1000, bias = True)
self.inception_4d_3x3_reduce = nn.Conv2d(in_channels=512, out_channels=144, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_1x1 = nn.Conv2d(in_channels=512, out_channels=112, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_5x5_reduce = nn.Conv2d(in_channels=512, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_pool_proj = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_3x3 = nn.Conv2d(in_channels=144, out_channels=288, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_4d_5x5 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4e_1x1 = nn.Conv2d(in_channels=528, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_5x5_reduce = nn.Conv2d(in_channels=528, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_3x3_reduce = nn.Conv2d(in_channels=528, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.loss2_conv = nn.Conv2d(in_channels=528, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_pool_proj = nn.Conv2d(in_channels=528, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_5x5 = nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4e_3x3 = nn.Conv2d(in_channels=160, out_channels=320, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.loss2_fc_1 = nn.Linear(in_features = 2048, out_features = 1024, bias = True)
self.inception_5a_1x1 = nn.Conv2d(in_channels=832, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5a_5x5_reduce = nn.Conv2d(in_channels=832, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5a_3x3_reduce = nn.Conv2d(in_channels=832, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5a_pool_proj = nn.Conv2d(in_channels=832, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.loss2_classifier_1 = nn.Linear(in_features = 1024, out_features = 1000, bias = True)
self.inception_5a_5x5 = nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_5a_3x3 = nn.Conv2d(in_channels=160, out_channels=320, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_5b_3x3_reduce = nn.Conv2d(in_channels=832, out_channels=192, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_5x5_reduce = nn.Conv2d(in_channels=832, out_channels=48, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_1x1 = nn.Conv2d(in_channels=832, out_channels=384, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_pool_proj = nn.Conv2d(in_channels=832, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_3x3 = nn.Conv2d(in_channels=192, out_channels=384, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_5b_5x5 = nn.Conv2d(in_channels=48, out_channels=128, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
def forward(self, x):
conv1_7x7_s2_pad = F.pad(x, (3, 3, 3, 3))
conv1_7x7_s2 = self.conv1_7x7_s2(conv1_7x7_s2_pad)
conv1_relu_7x7 = F.relu(conv1_7x7_s2)
pool1_3x3_s2_pad = F.pad(conv1_relu_7x7, (0, 1, 0, 1), value=float('-inf'))
pool1_3x3_s2 = F.max_pool2d(pool1_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
pool1_norm1 = F.local_response_norm(pool1_3x3_s2, size=5, alpha=9.999999747378752e-05, beta=0.75, k=1.0)
conv2_3x3_reduce = self.conv2_3x3_reduce(pool1_norm1)
conv2_relu_3x3_reduce = F.relu(conv2_3x3_reduce)
conv2_3x3_pad = F.pad(conv2_relu_3x3_reduce, (1, 1, 1, 1))
conv2_3x3 = self.conv2_3x3(conv2_3x3_pad)
conv2_relu_3x3 = F.relu(conv2_3x3)
conv2_norm2 = F.local_response_norm(conv2_relu_3x3, size=5, alpha=9.999999747378752e-05, beta=0.75, k=1.0)
pool2_3x3_s2_pad = F.pad(conv2_norm2, (0, 1, 0, 1), value=float('-inf'))
pool2_3x3_s2 = F.max_pool2d(pool2_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
inception_3a_pool_pad = F.pad(pool2_3x3_s2, (1, 1, 1, 1), value=float('-inf'))
inception_3a_pool = F.max_pool2d(inception_3a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_3a_1x1 = self.inception_3a_1x1(pool2_3x3_s2)
inception_3a_5x5_reduce = self.inception_3a_5x5_reduce(pool2_3x3_s2)
inception_3a_3x3_reduce = self.inception_3a_3x3_reduce(pool2_3x3_s2)
inception_3a_pool_proj = self.inception_3a_pool_proj(inception_3a_pool)
inception_3a_relu_1x1 = F.relu(inception_3a_1x1)
inception_3a_relu_5x5_reduce = F.relu(inception_3a_5x5_reduce)
inception_3a_relu_3x3_reduce = F.relu(inception_3a_3x3_reduce)
inception_3a_relu_pool_proj = F.relu(inception_3a_pool_proj)
inception_3a_5x5_pad = F.pad(inception_3a_relu_5x5_reduce, (2, 2, 2, 2))
inception_3a_5x5 = self.inception_3a_5x5(inception_3a_5x5_pad)
inception_3a_3x3_pad = F.pad(inception_3a_relu_3x3_reduce, (1, 1, 1, 1))
inception_3a_3x3 = self.inception_3a_3x3(inception_3a_3x3_pad)
inception_3a_relu_5x5 = F.relu(inception_3a_5x5)
inception_3a_relu_3x3 = F.relu(inception_3a_3x3)
inception_3a_output = torch.cat((inception_3a_relu_1x1, inception_3a_relu_3x3, inception_3a_relu_5x5, inception_3a_relu_pool_proj), 1)
inception_3b_3x3_reduce = self.inception_3b_3x3_reduce(inception_3a_output)
inception_3b_pool_pad = F.pad(inception_3a_output, (1, 1, 1, 1), value=float('-inf'))
inception_3b_pool = F.max_pool2d(inception_3b_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_3b_1x1 = self.inception_3b_1x1(inception_3a_output)
inception_3b_5x5_reduce = self.inception_3b_5x5_reduce(inception_3a_output)
inception_3b_relu_3x3_reduce = F.relu(inception_3b_3x3_reduce)
inception_3b_pool_proj = self.inception_3b_pool_proj(inception_3b_pool)
inception_3b_relu_1x1 = F.relu(inception_3b_1x1)
inception_3b_relu_5x5_reduce = F.relu(inception_3b_5x5_reduce)
inception_3b_3x3_pad = F.pad(inception_3b_relu_3x3_reduce, (1, 1, 1, 1))
inception_3b_3x3 = self.inception_3b_3x3(inception_3b_3x3_pad)
inception_3b_relu_pool_proj = F.relu(inception_3b_pool_proj)
inception_3b_5x5_pad = F.pad(inception_3b_relu_5x5_reduce, (2, 2, 2, 2))
inception_3b_5x5 = self.inception_3b_5x5(inception_3b_5x5_pad)
inception_3b_relu_3x3 = F.relu(inception_3b_3x3)
inception_3b_relu_5x5 = F.relu(inception_3b_5x5)
inception_3b_output = torch.cat((inception_3b_relu_1x1, inception_3b_relu_3x3, inception_3b_relu_5x5, inception_3b_relu_pool_proj), 1)
pool3_3x3_s2_pad = F.pad(inception_3b_output, (0, 1, 0, 1), value=float('-inf'))
pool3_3x3_s2 = F.max_pool2d(pool3_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
inception_4a_1x1 = self.inception_4a_1x1(pool3_3x3_s2)
inception_4a_3x3_reduce = self.inception_4a_3x3_reduce(pool3_3x3_s2)
inception_4a_5x5_reduce = self.inception_4a_5x5_reduce(pool3_3x3_s2)
inception_4a_pool_pad = F.pad(pool3_3x3_s2, (1, 1, 1, 1), value=float('-inf'))
inception_4a_pool = F.max_pool2d(inception_4a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4a_relu_1x1 = F.relu(inception_4a_1x1)
inception_4a_relu_3x3_reduce = F.relu(inception_4a_3x3_reduce)
inception_4a_relu_5x5_reduce = F.relu(inception_4a_5x5_reduce)
inception_4a_pool_proj = self.inception_4a_pool_proj(inception_4a_pool)
inception_4a_3x3_pad = F.pad(inception_4a_relu_3x3_reduce, (1, 1, 1, 1))
inception_4a_3x3 = self.inception_4a_3x3(inception_4a_3x3_pad)
inception_4a_5x5_pad = F.pad(inception_4a_relu_5x5_reduce, (2, 2, 2, 2))
inception_4a_5x5 = self.inception_4a_5x5(inception_4a_5x5_pad)
inception_4a_relu_pool_proj = F.relu(inception_4a_pool_proj)
inception_4a_relu_3x3 = F.relu(inception_4a_3x3)
inception_4a_relu_5x5 = F.relu(inception_4a_5x5)
inception_4a_output = torch.cat((inception_4a_relu_1x1, inception_4a_relu_3x3, inception_4a_relu_5x5, inception_4a_relu_pool_proj), 1)
inception_4b_pool_pad = F.pad(inception_4a_output, (1, 1, 1, 1), value=float('-inf'))
inception_4b_pool = F.max_pool2d(inception_4b_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
#loss1_ave_pool = F.avg_pool2d(inception_4a_output, kernel_size=(5, 5), stride=(3, 3), padding=(0,), ceil_mode=True, count_include_pad=False)
inception_4b_5x5_reduce = self.inception_4b_5x5_reduce(inception_4a_output)
inception_4b_1x1 = self.inception_4b_1x1(inception_4a_output)
inception_4b_3x3_reduce = self.inception_4b_3x3_reduce(inception_4a_output)
inception_4b_pool_proj = self.inception_4b_pool_proj(inception_4b_pool)
#loss1_conv = self.loss1_conv(loss1_ave_pool)
inception_4b_relu_5x5_reduce = F.relu(inception_4b_5x5_reduce)
inception_4b_relu_1x1 = F.relu(inception_4b_1x1)
inception_4b_relu_3x3_reduce = F.relu(inception_4b_3x3_reduce)
inception_4b_relu_pool_proj = F.relu(inception_4b_pool_proj)
#loss1_relu_conv = F.relu(loss1_conv)
inception_4b_5x5_pad = F.pad(inception_4b_relu_5x5_reduce, (2, 2, 2, 2))
inception_4b_5x5 = self.inception_4b_5x5(inception_4b_5x5_pad)
inception_4b_3x3_pad = F.pad(inception_4b_relu_3x3_reduce, (1, 1, 1, 1))
inception_4b_3x3 = self.inception_4b_3x3(inception_4b_3x3_pad)
#loss1_fc_0 = loss1_relu_conv.view(loss1_relu_conv.size(0), -1)
inception_4b_relu_5x5 = F.relu(inception_4b_5x5)
inception_4b_relu_3x3 = F.relu(inception_4b_3x3)
#loss1_fc_1 = self.loss1_fc_1(loss1_fc_0)
inception_4b_output = torch.cat((inception_4b_relu_1x1, inception_4b_relu_3x3, inception_4b_relu_5x5, inception_4b_relu_pool_proj), 1)
#loss1_relu_fc = F.relu(loss1_fc_1)
inception_4c_5x5_reduce = self.inception_4c_5x5_reduce(inception_4b_output)
inception_4c_pool_pad = F.pad(inception_4b_output, (1, 1, 1, 1), value=float('-inf'))
inception_4c_pool = F.max_pool2d(inception_4c_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4c_1x1 = self.inception_4c_1x1(inception_4b_output)
inception_4c_3x3_reduce = self.inception_4c_3x3_reduce(inception_4b_output)
#loss1_drop_fc = F.dropout(input = loss1_relu_fc, p = 0.699999988079071, training = self.training, inplace = True)
inception_4c_relu_5x5_reduce = F.relu(inception_4c_5x5_reduce)
inception_4c_pool_proj = self.inception_4c_pool_proj(inception_4c_pool)
inception_4c_relu_1x1 = F.relu(inception_4c_1x1)
inception_4c_relu_3x3_reduce = F.relu(inception_4c_3x3_reduce)
#loss1_classifier_0 = loss1_drop_fc.view(loss1_drop_fc.size(0), -1)
inception_4c_5x5_pad = F.pad(inception_4c_relu_5x5_reduce, (2, 2, 2, 2))
inception_4c_5x5 = self.inception_4c_5x5(inception_4c_5x5_pad)
inception_4c_relu_pool_proj = F.relu(inception_4c_pool_proj)
inception_4c_3x3_pad = F.pad(inception_4c_relu_3x3_reduce, (1, 1, 1, 1))
inception_4c_3x3 = self.inception_4c_3x3(inception_4c_3x3_pad)
#loss1_classifier_1 = self.loss1_classifier_1(loss1_classifier_0)
inception_4c_relu_5x5 = F.relu(inception_4c_5x5)
inception_4c_relu_3x3 = F.relu(inception_4c_3x3)
inception_4c_output = torch.cat((inception_4c_relu_1x1, inception_4c_relu_3x3, inception_4c_relu_5x5, inception_4c_relu_pool_proj), 1)
inception_4d_pool_pad = F.pad(inception_4c_output, (1, 1, 1, 1), value=float('-inf'))
inception_4d_pool = F.max_pool2d(inception_4d_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4d_3x3_reduce = self.inception_4d_3x3_reduce(inception_4c_output)
inception_4d_1x1 = self.inception_4d_1x1(inception_4c_output)
inception_4d_5x5_reduce = self.inception_4d_5x5_reduce(inception_4c_output)
inception_4d_pool_proj = self.inception_4d_pool_proj(inception_4d_pool)
inception_4d_relu_3x3_reduce = F.relu(inception_4d_3x3_reduce)
inception_4d_relu_1x1 = F.relu(inception_4d_1x1)
inception_4d_relu_5x5_reduce = F.relu(inception_4d_5x5_reduce)
inception_4d_relu_pool_proj = F.relu(inception_4d_pool_proj)
inception_4d_3x3_pad = F.pad(inception_4d_relu_3x3_reduce, (1, 1, 1, 1))
inception_4d_3x3 = self.inception_4d_3x3(inception_4d_3x3_pad)
inception_4d_5x5_pad = F.pad(inception_4d_relu_5x5_reduce, (2, 2, 2, 2))
inception_4d_5x5 = self.inception_4d_5x5(inception_4d_5x5_pad)
inception_4d_relu_3x3 = F.relu(inception_4d_3x3)
inception_4d_relu_5x5 = F.relu(inception_4d_5x5)
inception_4d_output = torch.cat((inception_4d_relu_1x1, inception_4d_relu_3x3, inception_4d_relu_5x5, inception_4d_relu_pool_proj), 1)
inception_4e_1x1 = self.inception_4e_1x1(inception_4d_output)
inception_4e_5x5_reduce = self.inception_4e_5x5_reduce(inception_4d_output)
#loss2_ave_pool = F.avg_pool2d(inception_4d_output, kernel_size=(5, 5), stride=(3, 3), padding=(0,), ceil_mode=True, count_include_pad=False)
inception_4e_3x3_reduce = self.inception_4e_3x3_reduce(inception_4d_output)
inception_4e_pool_pad = F.pad(inception_4d_output, (1, 1, 1, 1), value=float('-inf'))
inception_4e_pool = F.max_pool2d(inception_4e_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4e_relu_1x1 = F.relu(inception_4e_1x1)
inception_4e_relu_5x5_reduce = F.relu(inception_4e_5x5_reduce)
#loss2_conv = self.loss2_conv(loss2_ave_pool)
inception_4e_relu_3x3_reduce = F.relu(inception_4e_3x3_reduce)
inception_4e_pool_proj = self.inception_4e_pool_proj(inception_4e_pool)
inception_4e_5x5_pad = F.pad(inception_4e_relu_5x5_reduce, (2, 2, 2, 2))
inception_4e_5x5 = self.inception_4e_5x5(inception_4e_5x5_pad)
#loss2_relu_conv = F.relu(loss2_conv)
inception_4e_3x3_pad = F.pad(inception_4e_relu_3x3_reduce, (1, 1, 1, 1))
inception_4e_3x3 = self.inception_4e_3x3(inception_4e_3x3_pad)
inception_4e_relu_pool_proj = F.relu(inception_4e_pool_proj)
inception_4e_relu_5x5 = F.relu(inception_4e_5x5)
#loss2_fc_0 = loss2_relu_conv.view(loss2_relu_conv.size(0), -1)
inception_4e_relu_3x3 = F.relu(inception_4e_3x3)
#loss2_fc_1 = self.loss2_fc_1(loss2_fc_0)
inception_4e_output = torch.cat((inception_4e_relu_1x1, inception_4e_relu_3x3, inception_4e_relu_5x5, inception_4e_relu_pool_proj), 1)
#loss2_relu_fc = F.relu(loss2_fc_1)
pool4_3x3_s2_pad = F.pad(inception_4e_output, (0, 1, 0, 1), value=float('-inf'))
pool4_3x3_s2 = F.max_pool2d(pool4_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
#loss2_drop_fc = F.dropout(input = loss2_relu_fc, p = 0.699999988079071, training = self.training, inplace = True)
inception_5a_1x1 = self.inception_5a_1x1(pool4_3x3_s2)
inception_5a_5x5_reduce = self.inception_5a_5x5_reduce(pool4_3x3_s2)
inception_5a_pool_pad = F.pad(pool4_3x3_s2, (1, 1, 1, 1), value=float('-inf'))
inception_5a_pool = F.max_pool2d(inception_5a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_5a_3x3_reduce = self.inception_5a_3x3_reduce(pool4_3x3_s2)
#loss2_classifier_0 = loss2_drop_fc.view(loss2_drop_fc.size(0), -1)
inception_5a_relu_1x1 = F.relu(inception_5a_1x1)
inception_5a_relu_5x5_reduce = F.relu(inception_5a_5x5_reduce)
inception_5a_pool_proj = self.inception_5a_pool_proj(inception_5a_pool)
inception_5a_relu_3x3_reduce = F.relu(inception_5a_3x3_reduce)
#loss2_classifier_1 = self.loss2_classifier_1(loss2_classifier_0)
inception_5a_5x5_pad = F.pad(inception_5a_relu_5x5_reduce, (2, 2, 2, 2))
inception_5a_5x5 = self.inception_5a_5x5(inception_5a_5x5_pad)
inception_5a_relu_pool_proj = F.relu(inception_5a_pool_proj)
inception_5a_3x3_pad = F.pad(inception_5a_relu_3x3_reduce, (1, 1, 1, 1))
inception_5a_3x3 = self.inception_5a_3x3(inception_5a_3x3_pad)
inception_5a_relu_5x5 = F.relu(inception_5a_5x5)
inception_5a_relu_3x3 = F.relu(inception_5a_3x3)
inception_5a_output = torch.cat((inception_5a_relu_1x1, inception_5a_relu_3x3, inception_5a_relu_5x5, inception_5a_relu_pool_proj), 1)
inception_5b_3x3_reduce = self.inception_5b_3x3_reduce(inception_5a_output)
inception_5b_pool_pad = F.pad(inception_5a_output, (1, 1, 1, 1), value=float('-inf'))
inception_5b_pool = F.max_pool2d(inception_5b_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_5b_5x5_reduce = self.inception_5b_5x5_reduce(inception_5a_output)
inception_5b_1x1 = self.inception_5b_1x1(inception_5a_output)
inception_5b_relu_3x3_reduce = F.relu(inception_5b_3x3_reduce)
inception_5b_pool_proj = self.inception_5b_pool_proj(inception_5b_pool)
inception_5b_relu_5x5_reduce = F.relu(inception_5b_5x5_reduce)
inception_5b_relu_1x1 = F.relu(inception_5b_1x1)
inception_5b_3x3_pad = F.pad(inception_5b_relu_3x3_reduce, (1, 1, 1, 1))
inception_5b_3x3 = self.inception_5b_3x3(inception_5b_3x3_pad)
inception_5b_relu_pool_proj = F.relu(inception_5b_pool_proj)
inception_5b_5x5_pad = F.pad(inception_5b_relu_5x5_reduce, (2, 2, 2, 2))
inception_5b_5x5 = self.inception_5b_5x5(inception_5b_5x5_pad)
inception_5b_relu_3x3 = F.relu(inception_5b_3x3)
inception_5b_relu_5x5 = F.relu(inception_5b_5x5)
inception_5b_output = torch.cat((inception_5b_relu_1x1, inception_5b_relu_3x3, inception_5b_relu_5x5, inception_5b_relu_pool_proj), 1)
pool5_7x7_s1 = F.avg_pool2d(inception_5b_output, kernel_size=(7, 7), stride=(1, 1), padding=(0,), ceil_mode=False, count_include_pad=False)
pool5_drop_7x7_s1 = F.dropout(input = pool5_7x7_s1, p = 0.4000000059604645, training = self.training, inplace = True)
return pool5_drop_7x7_s1#, loss2_classifier_1, loss1_classifier_1
class GoogleNet_SOS(nn.Module):
def __init__(self):
super(GoogleNet_SOS, self).__init__()
self.conv1_7x7_s2 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(7, 7), stride=(2, 2), groups=1, bias=True)
self.conv2_3x3_reduce = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.conv2_3x3 = nn.Conv2d(in_channels=64, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_3a_1x1 = nn.Conv2d(in_channels=192, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_5x5_reduce = nn.Conv2d(in_channels=192, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_3x3_reduce = nn.Conv2d(in_channels=192, out_channels=96, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_pool_proj = nn.Conv2d(in_channels=192, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_5x5 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_3a_3x3 = nn.Conv2d(in_channels=96, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_3b_3x3_reduce = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_1x1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_5x5_reduce = nn.Conv2d(in_channels=256, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_pool_proj = nn.Conv2d(in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_3x3 = nn.Conv2d(in_channels=128, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_3b_5x5 = nn.Conv2d(in_channels=32, out_channels=96, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4a_1x1 = nn.Conv2d(in_channels=480, out_channels=192, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_3x3_reduce = nn.Conv2d(in_channels=480, out_channels=96, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_5x5_reduce = nn.Conv2d(in_channels=480, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_pool_proj = nn.Conv2d(in_channels=480, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_3x3 = nn.Conv2d(in_channels=96, out_channels=208, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_4a_5x5 = nn.Conv2d(in_channels=16, out_channels=48, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4b_5x5_reduce = nn.Conv2d(in_channels=512, out_channels=24, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_1x1 = nn.Conv2d(in_channels=512, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_3x3_reduce = nn.Conv2d(in_channels=512, out_channels=112, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_pool_proj = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_5x5 = nn.Conv2d(in_channels=24, out_channels=64, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4b_3x3 = nn.Conv2d(in_channels=112, out_channels=224, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_4c_5x5_reduce = nn.Conv2d(in_channels=512, out_channels=24, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_1x1 = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_3x3_reduce = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_pool_proj = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_5x5 = nn.Conv2d(in_channels=24, out_channels=64, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4c_3x3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_4d_3x3_reduce = nn.Conv2d(in_channels=512, out_channels=144, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_1x1 = nn.Conv2d(in_channels=512, out_channels=112, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_5x5_reduce = nn.Conv2d(in_channels=512, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_pool_proj = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_3x3 = nn.Conv2d(in_channels=144, out_channels=288, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_4d_5x5 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4e_5x5_reduce = nn.Conv2d(in_channels=528, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_1x1 = nn.Conv2d(in_channels=528, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_3x3_reduce = nn.Conv2d(in_channels=528, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_pool_proj = nn.Conv2d(in_channels=528, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_5x5 = nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4e_3x3 = nn.Conv2d(in_channels=160, out_channels=320, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_5a_1x1 = nn.Conv2d(in_channels=832, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5a_5x5_reduce = nn.Conv2d(in_channels=832, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5a_3x3_reduce = nn.Conv2d(in_channels=832, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5a_pool_proj = nn.Conv2d(in_channels=832, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5a_5x5 = nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_5a_3x3 = nn.Conv2d(in_channels=160, out_channels=320, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_5b_3x3_reduce = nn.Conv2d(in_channels=832, out_channels=192, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_5x5_reduce = nn.Conv2d(in_channels=832, out_channels=48, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_1x1 = nn.Conv2d(in_channels=832, out_channels=384, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_pool_proj = nn.Conv2d(in_channels=832, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_3x3 = nn.Conv2d(in_channels=192, out_channels=384, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_5b_5x5 = nn.Conv2d(in_channels=48, out_channels=128, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
def forward(self, x):
conv1_7x7_s2_pad = F.pad(x, (3, 3, 3, 3))
conv1_7x7_s2 = self.conv1_7x7_s2(conv1_7x7_s2_pad)
conv1_relu_7x7 = F.relu(conv1_7x7_s2)
pool1_3x3_s2_pad = F.pad(conv1_relu_7x7, (0, 1, 0, 1), value=float('-inf'))
pool1_3x3_s2 = F.max_pool2d(pool1_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
pool1_norm1 = F.local_response_norm(pool1_3x3_s2, size=5, alpha=9.999999747378752e-05, beta=0.75, k=1.0)
conv2_3x3_reduce = self.conv2_3x3_reduce(pool1_norm1)
conv2_relu_3x3_reduce = F.relu(conv2_3x3_reduce)
conv2_3x3_pad = F.pad(conv2_relu_3x3_reduce, (1, 1, 1, 1))
conv2_3x3 = self.conv2_3x3(conv2_3x3_pad)
conv2_relu_3x3 = F.relu(conv2_3x3)
conv2_norm2 = F.local_response_norm(conv2_relu_3x3, size=5, alpha=9.999999747378752e-05, beta=0.75, k=1.0)
pool2_3x3_s2_pad = F.pad(conv2_norm2, (0, 1, 0, 1), value=float('-inf'))
pool2_3x3_s2 = F.max_pool2d(pool2_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
inception_3a_pool_pad = F.pad(pool2_3x3_s2, (1, 1, 1, 1), value=float('-inf'))
inception_3a_pool = F.max_pool2d(inception_3a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_3a_1x1 = self.inception_3a_1x1(pool2_3x3_s2)
inception_3a_5x5_reduce = self.inception_3a_5x5_reduce(pool2_3x3_s2)
inception_3a_3x3_reduce = self.inception_3a_3x3_reduce(pool2_3x3_s2)
inception_3a_pool_proj = self.inception_3a_pool_proj(inception_3a_pool)
inception_3a_relu_1x1 = F.relu(inception_3a_1x1)
inception_3a_relu_5x5_reduce = F.relu(inception_3a_5x5_reduce)
inception_3a_relu_3x3_reduce = F.relu(inception_3a_3x3_reduce)
inception_3a_relu_pool_proj = F.relu(inception_3a_pool_proj)
inception_3a_5x5_pad = F.pad(inception_3a_relu_5x5_reduce, (2, 2, 2, 2))
inception_3a_5x5 = self.inception_3a_5x5(inception_3a_5x5_pad)
inception_3a_3x3_pad = F.pad(inception_3a_relu_3x3_reduce, (1, 1, 1, 1))
inception_3a_3x3 = self.inception_3a_3x3(inception_3a_3x3_pad)
inception_3a_relu_5x5 = F.relu(inception_3a_5x5)
inception_3a_relu_3x3 = F.relu(inception_3a_3x3)
inception_3a_output = torch.cat((inception_3a_relu_1x1, inception_3a_relu_3x3, inception_3a_relu_5x5, inception_3a_relu_pool_proj), 1)
inception_3b_3x3_reduce = self.inception_3b_3x3_reduce(inception_3a_output)
inception_3b_pool_pad = F.pad(inception_3a_output, (1, 1, 1, 1), value=float('-inf'))
inception_3b_pool = F.max_pool2d(inception_3b_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_3b_1x1 = self.inception_3b_1x1(inception_3a_output)
inception_3b_5x5_reduce = self.inception_3b_5x5_reduce(inception_3a_output)
inception_3b_relu_3x3_reduce = F.relu(inception_3b_3x3_reduce)
inception_3b_pool_proj = self.inception_3b_pool_proj(inception_3b_pool)
inception_3b_relu_1x1 = F.relu(inception_3b_1x1)
inception_3b_relu_5x5_reduce = F.relu(inception_3b_5x5_reduce)
inception_3b_3x3_pad = F.pad(inception_3b_relu_3x3_reduce, (1, 1, 1, 1))
inception_3b_3x3 = self.inception_3b_3x3(inception_3b_3x3_pad)
inception_3b_relu_pool_proj = F.relu(inception_3b_pool_proj)
inception_3b_5x5_pad = F.pad(inception_3b_relu_5x5_reduce, (2, 2, 2, 2))
inception_3b_5x5 = self.inception_3b_5x5(inception_3b_5x5_pad)
inception_3b_relu_3x3 = F.relu(inception_3b_3x3)
inception_3b_relu_5x5 = F.relu(inception_3b_5x5)
inception_3b_output = torch.cat((inception_3b_relu_1x1, inception_3b_relu_3x3, inception_3b_relu_5x5, inception_3b_relu_pool_proj), 1)
pool3_3x3_s2_pad = F.pad(inception_3b_output, (0, 1, 0, 1), value=float('-inf'))
pool3_3x3_s2 = F.max_pool2d(pool3_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
inception_4a_1x1 = self.inception_4a_1x1(pool3_3x3_s2)
inception_4a_3x3_reduce = self.inception_4a_3x3_reduce(pool3_3x3_s2)
inception_4a_5x5_reduce = self.inception_4a_5x5_reduce(pool3_3x3_s2)
inception_4a_pool_pad = F.pad(pool3_3x3_s2, (1, 1, 1, 1), value=float('-inf'))
inception_4a_pool = F.max_pool2d(inception_4a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4a_relu_1x1 = F.relu(inception_4a_1x1)
inception_4a_relu_3x3_reduce = F.relu(inception_4a_3x3_reduce)
inception_4a_relu_5x5_reduce = F.relu(inception_4a_5x5_reduce)
inception_4a_pool_proj = self.inception_4a_pool_proj(inception_4a_pool)
inception_4a_3x3_pad = F.pad(inception_4a_relu_3x3_reduce, (1, 1, 1, 1))
inception_4a_3x3 = self.inception_4a_3x3(inception_4a_3x3_pad)
inception_4a_5x5_pad = F.pad(inception_4a_relu_5x5_reduce, (2, 2, 2, 2))
inception_4a_5x5 = self.inception_4a_5x5(inception_4a_5x5_pad)
inception_4a_relu_pool_proj = F.relu(inception_4a_pool_proj)
inception_4a_relu_3x3 = F.relu(inception_4a_3x3)
inception_4a_relu_5x5 = F.relu(inception_4a_5x5)
inception_4a_output = torch.cat((inception_4a_relu_1x1, inception_4a_relu_3x3, inception_4a_relu_5x5, inception_4a_relu_pool_proj), 1)
inception_4b_pool_pad = F.pad(inception_4a_output, (1, 1, 1, 1), value=float('-inf'))
inception_4b_pool = F.max_pool2d(inception_4b_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4b_5x5_reduce = self.inception_4b_5x5_reduce(inception_4a_output)
inception_4b_1x1 = self.inception_4b_1x1(inception_4a_output)
inception_4b_3x3_reduce = self.inception_4b_3x3_reduce(inception_4a_output)
inception_4b_pool_proj = self.inception_4b_pool_proj(inception_4b_pool)
inception_4b_relu_5x5_reduce = F.relu(inception_4b_5x5_reduce)
inception_4b_relu_1x1 = F.relu(inception_4b_1x1)
inception_4b_relu_3x3_reduce = F.relu(inception_4b_3x3_reduce)
inception_4b_relu_pool_proj = F.relu(inception_4b_pool_proj)
inception_4b_5x5_pad = F.pad(inception_4b_relu_5x5_reduce, (2, 2, 2, 2))
inception_4b_5x5 = self.inception_4b_5x5(inception_4b_5x5_pad)
inception_4b_3x3_pad = F.pad(inception_4b_relu_3x3_reduce, (1, 1, 1, 1))
inception_4b_3x3 = self.inception_4b_3x3(inception_4b_3x3_pad)
inception_4b_relu_5x5 = F.relu(inception_4b_5x5)
inception_4b_relu_3x3 = F.relu(inception_4b_3x3)
inception_4b_output = torch.cat((inception_4b_relu_1x1, inception_4b_relu_3x3, inception_4b_relu_5x5, inception_4b_relu_pool_proj), 1)
inception_4c_5x5_reduce = self.inception_4c_5x5_reduce(inception_4b_output)
inception_4c_pool_pad = F.pad(inception_4b_output, (1, 1, 1, 1), value=float('-inf'))
inception_4c_pool = F.max_pool2d(inception_4c_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4c_1x1 = self.inception_4c_1x1(inception_4b_output)
inception_4c_3x3_reduce = self.inception_4c_3x3_reduce(inception_4b_output)
inception_4c_relu_5x5_reduce = F.relu(inception_4c_5x5_reduce)
inception_4c_pool_proj = self.inception_4c_pool_proj(inception_4c_pool)
inception_4c_relu_1x1 = F.relu(inception_4c_1x1)
inception_4c_relu_3x3_reduce = F.relu(inception_4c_3x3_reduce)
inception_4c_5x5_pad = F.pad(inception_4c_relu_5x5_reduce, (2, 2, 2, 2))
inception_4c_5x5 = self.inception_4c_5x5(inception_4c_5x5_pad)
inception_4c_relu_pool_proj = F.relu(inception_4c_pool_proj)
inception_4c_3x3_pad = F.pad(inception_4c_relu_3x3_reduce, (1, 1, 1, 1))
inception_4c_3x3 = self.inception_4c_3x3(inception_4c_3x3_pad)
inception_4c_relu_5x5 = F.relu(inception_4c_5x5)
inception_4c_relu_3x3 = F.relu(inception_4c_3x3)
inception_4c_output = torch.cat((inception_4c_relu_1x1, inception_4c_relu_3x3, inception_4c_relu_5x5, inception_4c_relu_pool_proj), 1)
inception_4d_pool_pad = F.pad(inception_4c_output, (1, 1, 1, 1), value=float('-inf'))
inception_4d_pool = F.max_pool2d(inception_4d_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4d_3x3_reduce = self.inception_4d_3x3_reduce(inception_4c_output)
inception_4d_1x1 = self.inception_4d_1x1(inception_4c_output)
inception_4d_5x5_reduce = self.inception_4d_5x5_reduce(inception_4c_output)
inception_4d_pool_proj = self.inception_4d_pool_proj(inception_4d_pool)
inception_4d_relu_3x3_reduce = F.relu(inception_4d_3x3_reduce)
inception_4d_relu_1x1 = F.relu(inception_4d_1x1)
inception_4d_relu_5x5_reduce = F.relu(inception_4d_5x5_reduce)
inception_4d_relu_pool_proj = F.relu(inception_4d_pool_proj)
inception_4d_3x3_pad = F.pad(inception_4d_relu_3x3_reduce, (1, 1, 1, 1))
inception_4d_3x3 = self.inception_4d_3x3(inception_4d_3x3_pad)
inception_4d_5x5_pad = F.pad(inception_4d_relu_5x5_reduce, (2, 2, 2, 2))
inception_4d_5x5 = self.inception_4d_5x5(inception_4d_5x5_pad)
inception_4d_relu_3x3 = F.relu(inception_4d_3x3)
inception_4d_relu_5x5 = F.relu(inception_4d_5x5)
inception_4d_output = torch.cat((inception_4d_relu_1x1, inception_4d_relu_3x3, inception_4d_relu_5x5, inception_4d_relu_pool_proj), 1)
inception_4e_5x5_reduce = self.inception_4e_5x5_reduce(inception_4d_output)
inception_4e_1x1 = self.inception_4e_1x1(inception_4d_output)
inception_4e_3x3_reduce = self.inception_4e_3x3_reduce(inception_4d_output)
inception_4e_pool_pad = F.pad(inception_4d_output, (1, 1, 1, 1), value=float('-inf'))
inception_4e_pool = F.max_pool2d(inception_4e_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4e_relu_5x5_reduce = F.relu(inception_4e_5x5_reduce)
inception_4e_relu_1x1 = F.relu(inception_4e_1x1)
inception_4e_relu_3x3_reduce = F.relu(inception_4e_3x3_reduce)
inception_4e_pool_proj = self.inception_4e_pool_proj(inception_4e_pool)
inception_4e_5x5_pad = F.pad(inception_4e_relu_5x5_reduce, (2, 2, 2, 2))
inception_4e_5x5 = self.inception_4e_5x5(inception_4e_5x5_pad)
inception_4e_3x3_pad = F.pad(inception_4e_relu_3x3_reduce, (1, 1, 1, 1))
inception_4e_3x3 = self.inception_4e_3x3(inception_4e_3x3_pad)
inception_4e_relu_pool_proj = F.relu(inception_4e_pool_proj)
inception_4e_relu_5x5 = F.relu(inception_4e_5x5)
inception_4e_relu_3x3 = F.relu(inception_4e_3x3)
inception_4e_output = torch.cat((inception_4e_relu_1x1, inception_4e_relu_3x3, inception_4e_relu_5x5, inception_4e_relu_pool_proj), 1)
pool4_3x3_s2_pad = F.pad(inception_4e_output, (0, 1, 0, 1), value=float('-inf'))
pool4_3x3_s2 = F.max_pool2d(pool4_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
inception_5a_1x1 = self.inception_5a_1x1(pool4_3x3_s2)
inception_5a_5x5_reduce = self.inception_5a_5x5_reduce(pool4_3x3_s2)
inception_5a_pool_pad = F.pad(pool4_3x3_s2, (1, 1, 1, 1), value=float('-inf'))
inception_5a_pool = F.max_pool2d(inception_5a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_5a_3x3_reduce = self.inception_5a_3x3_reduce(pool4_3x3_s2)
inception_5a_relu_1x1 = F.relu(inception_5a_1x1)
inception_5a_relu_5x5_reduce = F.relu(inception_5a_5x5_reduce)
inception_5a_pool_proj = self.inception_5a_pool_proj(inception_5a_pool)
inception_5a_relu_3x3_reduce = F.relu(inception_5a_3x3_reduce)
inception_5a_5x5_pad = F.pad(inception_5a_relu_5x5_reduce, (2, 2, 2, 2))
inception_5a_5x5 = self.inception_5a_5x5(inception_5a_5x5_pad)
inception_5a_relu_pool_proj = F.relu(inception_5a_pool_proj)
inception_5a_3x3_pad = F.pad(inception_5a_relu_3x3_reduce, (1, 1, 1, 1))
inception_5a_3x3 = self.inception_5a_3x3(inception_5a_3x3_pad)
inception_5a_relu_5x5 = F.relu(inception_5a_5x5)
inception_5a_relu_3x3 = F.relu(inception_5a_3x3)
inception_5a_output = torch.cat((inception_5a_relu_1x1, inception_5a_relu_3x3, inception_5a_relu_5x5, inception_5a_relu_pool_proj), 1)
inception_5b_pool_pad = F.pad(inception_5a_output, (1, 1, 1, 1), value=float('-inf'))
inception_5b_pool = F.max_pool2d(inception_5b_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_5b_3x3_reduce = self.inception_5b_3x3_reduce(inception_5a_output)
inception_5b_5x5_reduce = self.inception_5b_5x5_reduce(inception_5a_output)
inception_5b_1x1 = self.inception_5b_1x1(inception_5a_output)
inception_5b_pool_proj = self.inception_5b_pool_proj(inception_5b_pool)
inception_5b_relu_3x3_reduce = F.relu(inception_5b_3x3_reduce)
inception_5b_relu_5x5_reduce = F.relu(inception_5b_5x5_reduce)
inception_5b_relu_1x1 = F.relu(inception_5b_1x1)
inception_5b_relu_pool_proj = F.relu(inception_5b_pool_proj)
inception_5b_3x3_pad = F.pad(inception_5b_relu_3x3_reduce, (1, 1, 1, 1))
inception_5b_3x3 = self.inception_5b_3x3(inception_5b_3x3_pad)
inception_5b_5x5_pad = F.pad(inception_5b_relu_5x5_reduce, (2, 2, 2, 2))
inception_5b_5x5 = self.inception_5b_5x5(inception_5b_5x5_pad)
inception_5b_relu_3x3 = F.relu(inception_5b_3x3)
inception_5b_relu_5x5 = F.relu(inception_5b_5x5)
inception_5b_output = torch.cat((inception_5b_relu_1x1, inception_5b_relu_3x3, inception_5b_relu_5x5, inception_5b_relu_pool_proj), 1)
pool5_7x7_s1 = F.avg_pool2d(inception_5b_output, kernel_size=(7, 7), stride=(1, 1), padding=(0,), ceil_mode=False, count_include_pad=False)
pool5_drop_7x7_s1 = F.dropout(input = pool5_7x7_s1, p = 0.4000000059604645, training = self.training, inplace = True)
return pool5_drop_7x7_s1
class GOOGLENET_CARS(nn.Module):
def __init__(self):
super(GOOGLENET_CARS, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(7, 7), stride=(2, 2), groups=1, bias=True)
self.conv2_1x1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.conv2_3x3 = nn.Conv2d(in_channels=64, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_3a_5x5_reduce = nn.Conv2d(in_channels=192, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_3x3_reduce = nn.Conv2d(in_channels=192, out_channels=96, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_1x1 = nn.Conv2d(in_channels=192, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_pool_proj = nn.Conv2d(in_channels=192, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3a_5x5 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_3a_3x3 = nn.Conv2d(in_channels=96, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_3b_1x1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_3x3_reduce = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_5x5_reduce = nn.Conv2d(in_channels=256, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_pool_proj = nn.Conv2d(in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_3b_3x3 = nn.Conv2d(in_channels=128, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_3b_5x5 = nn.Conv2d(in_channels=32, out_channels=96, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4a_3x3_reduce = nn.Conv2d(in_channels=480, out_channels=96, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_5x5_reduce = nn.Conv2d(in_channels=480, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_1x1 = nn.Conv2d(in_channels=480, out_channels=192, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_pool_proj = nn.Conv2d(in_channels=480, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4a_3x3 = nn.Conv2d(in_channels=96, out_channels=208, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_4a_5x5 = nn.Conv2d(in_channels=16, out_channels=48, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4b_3x3_reduce = nn.Conv2d(in_channels=512, out_channels=112, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_1x1 = nn.Conv2d(in_channels=512, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_5x5_reduce = nn.Conv2d(in_channels=512, out_channels=24, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.loss1_conv = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_pool_proj = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4b_3x3 = nn.Conv2d(in_channels=112, out_channels=224, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_4b_5x5 = nn.Conv2d(in_channels=24, out_channels=64, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.loss1_fc_1 = nn.Linear(in_features = 2048, out_features = 1024, bias = True)
self.inception_4c_1x1 = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_5x5_reduce = nn.Conv2d(in_channels=512, out_channels=24, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_3x3_reduce = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_pool_proj = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4c_5x5 = nn.Conv2d(in_channels=24, out_channels=64, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4c_3x3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.loss1_classifier_model_1 = nn.Linear(in_features = 1024, out_features = 431, bias = True)
self.inception_4d_1x1 = nn.Conv2d(in_channels=512, out_channels=112, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_3x3_reduce = nn.Conv2d(in_channels=512, out_channels=144, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_5x5_reduce = nn.Conv2d(in_channels=512, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_pool_proj = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4d_3x3 = nn.Conv2d(in_channels=144, out_channels=288, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_4d_5x5 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_4e_1x1 = nn.Conv2d(in_channels=528, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_3x3_reduce = nn.Conv2d(in_channels=528, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_5x5_reduce = nn.Conv2d(in_channels=528, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.loss2_conv = nn.Conv2d(in_channels=528, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_pool_proj = nn.Conv2d(in_channels=528, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_4e_3x3 = nn.Conv2d(in_channels=160, out_channels=320, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_4e_5x5 = nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.loss2_fc_1 = nn.Linear(in_features = 2048, out_features = 1024, bias = True)
self.inception_5a_5x5_reduce = nn.Conv2d(in_channels=832, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5a_3x3_reduce = nn.Conv2d(in_channels=832, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5a_1x1 = nn.Conv2d(in_channels=832, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5a_pool_proj = nn.Conv2d(in_channels=832, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.loss2_classifier_model_1 = nn.Linear(in_features = 1024, out_features = 431, bias = True)
self.inception_5a_5x5 = nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.inception_5a_3x3 = nn.Conv2d(in_channels=160, out_channels=320, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_5b_3x3_reduce = nn.Conv2d(in_channels=832, out_channels=192, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_1x1 = nn.Conv2d(in_channels=832, out_channels=384, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_5x5_reduce = nn.Conv2d(in_channels=832, out_channels=48, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_pool_proj = nn.Conv2d(in_channels=832, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.inception_5b_3x3 = nn.Conv2d(in_channels=192, out_channels=384, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.inception_5b_5x5 = nn.Conv2d(in_channels=48, out_channels=128, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
def forward(self, x):
conv1_pad = F.pad(x, (3, 3, 3, 3))
conv1 = self.conv1(conv1_pad)
relu1 = F.relu(conv1)
pool1_pad = F.pad(relu1, (0, 1, 0, 1), value=float('-inf'))
pool1 = F.max_pool2d(pool1_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
norm1 = F.local_response_norm(pool1, size=5, alpha=9.999999747378752e-05, beta=0.75, k=1.0)
conv2_1x1 = self.conv2_1x1(norm1)
relu_conv2_1x1 = F.relu(conv2_1x1)
conv2_3x3_pad = F.pad(relu_conv2_1x1, (1, 1, 1, 1))
conv2_3x3 = self.conv2_3x3(conv2_3x3_pad)
relu2_3x3 = F.relu(conv2_3x3)
norm2 = F.local_response_norm(relu2_3x3, size=5, alpha=9.999999747378752e-05, beta=0.75, k=1.0)
pool2_pad = F.pad(norm2, (0, 1, 0, 1), value=float('-inf'))
pool2 = F.max_pool2d(pool2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
inception_3a_5x5_reduce = self.inception_3a_5x5_reduce(pool2)
inception_3a_3x3_reduce = self.inception_3a_3x3_reduce(pool2)
inception_3a_1x1 = self.inception_3a_1x1(pool2)
inception_3a_pool_pad = F.pad(pool2, (1, 1, 1, 1), value=float('-inf'))
inception_3a_pool = F.max_pool2d(inception_3a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
relu_inception_3a_5x5_reduce = F.relu(inception_3a_5x5_reduce)
reulu_inception_3a_3x3_reduce = F.relu(inception_3a_3x3_reduce)
relu_inception_3a_1x1 = F.relu(inception_3a_1x1)
inception_3a_pool_proj = self.inception_3a_pool_proj(inception_3a_pool)
inception_3a_5x5_pad = F.pad(relu_inception_3a_5x5_reduce, (2, 2, 2, 2))
inception_3a_5x5 = self.inception_3a_5x5(inception_3a_5x5_pad)
inception_3a_3x3_pad = F.pad(reulu_inception_3a_3x3_reduce, (1, 1, 1, 1))
inception_3a_3x3 = self.inception_3a_3x3(inception_3a_3x3_pad)
relu_inception_3a_pool_proj = F.relu(inception_3a_pool_proj)
relu_inception_3a_5x5 = F.relu(inception_3a_5x5)
relu_inception_3a_3x3 = F.relu(inception_3a_3x3)
inception_3a_output = torch.cat((relu_inception_3a_1x1, relu_inception_3a_3x3, relu_inception_3a_5x5, relu_inception_3a_pool_proj), 1)
inception_3b_1x1 = self.inception_3b_1x1(inception_3a_output)
inception_3b_3x3_reduce = self.inception_3b_3x3_reduce(inception_3a_output)
inception_3b_5x5_reduce = self.inception_3b_5x5_reduce(inception_3a_output)
inception_3b_pool_pad = F.pad(inception_3a_output, (1, 1, 1, 1), value=float('-inf'))
inception_3b_pool = F.max_pool2d(inception_3b_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
relu_inception_3b_1x1 = F.relu(inception_3b_1x1)
relu_inception_3b_3x3_reduce = F.relu(inception_3b_3x3_reduce)
relu_inception_3b_5x5_reduce = F.relu(inception_3b_5x5_reduce)
inception_3b_pool_proj = self.inception_3b_pool_proj(inception_3b_pool)
inception_3b_3x3_pad = F.pad(relu_inception_3b_3x3_reduce, (1, 1, 1, 1))
inception_3b_3x3 = self.inception_3b_3x3(inception_3b_3x3_pad)
inception_3b_5x5_pad = F.pad(relu_inception_3b_5x5_reduce, (2, 2, 2, 2))
inception_3b_5x5 = self.inception_3b_5x5(inception_3b_5x5_pad)
relu_inception_3b_pool_proj = F.relu(inception_3b_pool_proj)
relu_inception_3b_3x3 = F.relu(inception_3b_3x3)
relu_inception_3b_5x5 = F.relu(inception_3b_5x5)
inception_3b_output = torch.cat((relu_inception_3b_1x1, relu_inception_3b_3x3, relu_inception_3b_5x5, relu_inception_3b_pool_proj), 1)
pool3_pad = F.pad(inception_3b_output, (0, 1, 0, 1), value=float('-inf'))
pool3 = F.max_pool2d(pool3_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
inception_4a_pool_pad = F.pad(pool3, (1, 1, 1, 1), value=float('-inf'))
inception_4a_pool = F.max_pool2d(inception_4a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4a_3x3_reduce = self.inception_4a_3x3_reduce(pool3)
inception_4a_5x5_reduce = self.inception_4a_5x5_reduce(pool3)
inception_4a_1x1 = self.inception_4a_1x1(pool3)
inception_4a_pool_proj = self.inception_4a_pool_proj(inception_4a_pool)
relu_inception_4a_3x3_reduce = F.relu(inception_4a_3x3_reduce)
relu_inception_4a_5x5_reduce = F.relu(inception_4a_5x5_reduce)
relu_inception_4a_1x1 = F.relu(inception_4a_1x1)
relu_inception_4a_pool_proj = F.relu(inception_4a_pool_proj)
inception_4a_3x3_pad = F.pad(relu_inception_4a_3x3_reduce, (1, 1, 1, 1))
inception_4a_3x3 = self.inception_4a_3x3(inception_4a_3x3_pad)
inception_4a_5x5_pad = F.pad(relu_inception_4a_5x5_reduce, (2, 2, 2, 2))
inception_4a_5x5 = self.inception_4a_5x5(inception_4a_5x5_pad)
relu_inception_4a_3x3 = F.relu(inception_4a_3x3)
relu_inception_4a_5x5 = F.relu(inception_4a_5x5)
inception_4a_output = torch.cat((relu_inception_4a_1x1, relu_inception_4a_3x3, relu_inception_4a_5x5, relu_inception_4a_pool_proj), 1)
#loss1_ave_pool = F.avg_pool2d(inception_4a_output, kernel_size=(5, 5), stride=(3, 3), padding=(0,), ceil_mode=True, count_include_pad=False)
inception_4b_3x3_reduce = self.inception_4b_3x3_reduce(inception_4a_output)
inception_4b_1x1 = self.inception_4b_1x1(inception_4a_output)
inception_4b_pool_pad = F.pad(inception_4a_output, (1, 1, 1, 1), value=float('-inf'))
inception_4b_pool = F.max_pool2d(inception_4b_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4b_5x5_reduce = self.inception_4b_5x5_reduce(inception_4a_output)
#loss1_conv = self.loss1_conv(loss1_ave_pool)
inception_4b_relu_3x3_reduce = F.relu(inception_4b_3x3_reduce)
inception_4b_relu_1x1 = F.relu(inception_4b_1x1)
inception_4b_pool_proj = self.inception_4b_pool_proj(inception_4b_pool)
inception_4b_relu_5x5_reduce = F.relu(inception_4b_5x5_reduce)
#loss1_relu_conv = F.relu(loss1_conv)
inception_4b_3x3_pad = F.pad(inception_4b_relu_3x3_reduce, (1, 1, 1, 1))
inception_4b_3x3 = self.inception_4b_3x3(inception_4b_3x3_pad)
inception_4b_relu_pool_proj = F.relu(inception_4b_pool_proj)
inception_4b_5x5_pad = F.pad(inception_4b_relu_5x5_reduce, (2, 2, 2, 2))
inception_4b_5x5 = self.inception_4b_5x5(inception_4b_5x5_pad)
#loss1_fc_0 = loss1_relu_conv.view(loss1_relu_conv.size(0), -1)
inception_4b_relu_3x3 = F.relu(inception_4b_3x3)
inception_4b_relu_5x5 = F.relu(inception_4b_5x5)
#loss1_fc_1 = self.loss1_fc_1(loss1_fc_0)
inception_4b_output = torch.cat((inception_4b_relu_1x1, inception_4b_relu_3x3, inception_4b_relu_5x5, inception_4b_relu_pool_proj), 1)
#loss1_relu_fc = F.relu(loss1_fc_1)
inception_4c_pool_pad = F.pad(inception_4b_output, (1, 1, 1, 1), value=float('-inf'))
inception_4c_pool = F.max_pool2d(inception_4c_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4c_1x1 = self.inception_4c_1x1(inception_4b_output)
inception_4c_5x5_reduce = self.inception_4c_5x5_reduce(inception_4b_output)
inception_4c_3x3_reduce = self.inception_4c_3x3_reduce(inception_4b_output)
#loss1_drop_fc = F.dropout(input = loss1_relu_fc, p = 0.699999988079071, training = self.training, inplace = True)
inception_4c_pool_proj = self.inception_4c_pool_proj(inception_4c_pool)
inception_4c_relu_1x1 = F.relu(inception_4c_1x1)
inception_4c_relu_5x5_reduce = F.relu(inception_4c_5x5_reduce)
inception_4c_relu_3x3_reduce = F.relu(inception_4c_3x3_reduce)
#loss1_classifier_model_0 = loss1_drop_fc.view(loss1_drop_fc.size(0), -1)
inception_4c_relu_pool_proj = F.relu(inception_4c_pool_proj)
inception_4c_5x5_pad = F.pad(inception_4c_relu_5x5_reduce, (2, 2, 2, 2))
inception_4c_5x5 = self.inception_4c_5x5(inception_4c_5x5_pad)
inception_4c_3x3_pad = F.pad(inception_4c_relu_3x3_reduce, (1, 1, 1, 1))
inception_4c_3x3 = self.inception_4c_3x3(inception_4c_3x3_pad)
#loss1_classifier_model_1 = self.loss1_classifier_model_1(loss1_classifier_model_0)
inception_4c_relu_5x5 = F.relu(inception_4c_5x5)
inception_4c_relu_3x3 = F.relu(inception_4c_3x3)
inception_4c_output = torch.cat((inception_4c_relu_1x1, inception_4c_relu_3x3, inception_4c_relu_5x5, inception_4c_relu_pool_proj), 1)
inception_4d_1x1 = self.inception_4d_1x1(inception_4c_output)
inception_4d_3x3_reduce = self.inception_4d_3x3_reduce(inception_4c_output)
inception_4d_5x5_reduce = self.inception_4d_5x5_reduce(inception_4c_output)
inception_4d_pool_pad = F.pad(inception_4c_output, (1, 1, 1, 1), value=float('-inf'))
inception_4d_pool = F.max_pool2d(inception_4d_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4d_relu_1x1 = F.relu(inception_4d_1x1)
inception_4d_relu_3x3_reduce = F.relu(inception_4d_3x3_reduce)
inception_4d_relu_5x5_reduce = F.relu(inception_4d_5x5_reduce)
inception_4d_pool_proj = self.inception_4d_pool_proj(inception_4d_pool)
inception_4d_3x3_pad = F.pad(inception_4d_relu_3x3_reduce, (1, 1, 1, 1))
inception_4d_3x3 = self.inception_4d_3x3(inception_4d_3x3_pad)
inception_4d_5x5_pad = F.pad(inception_4d_relu_5x5_reduce, (2, 2, 2, 2))
inception_4d_5x5 = self.inception_4d_5x5(inception_4d_5x5_pad)
inception_4d_relu_pool_proj = F.relu(inception_4d_pool_proj)
inception_4d_relu_3x3 = F.relu(inception_4d_3x3)
inception_4d_relu_5x5 = F.relu(inception_4d_5x5)
inception_4d_output = torch.cat((inception_4d_relu_1x1, inception_4d_relu_3x3, inception_4d_relu_5x5, inception_4d_relu_pool_proj), 1)
#loss2_ave_pool = F.avg_pool2d(inception_4d_output, kernel_size=(5, 5), stride=(3, 3), padding=(0,), ceil_mode=True, count_include_pad=False)
inception_4e_pool_pad = F.pad(inception_4d_output, (1, 1, 1, 1), value=float('-inf'))
inception_4e_pool = F.max_pool2d(inception_4e_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_4e_1x1 = self.inception_4e_1x1(inception_4d_output)
inception_4e_3x3_reduce = self.inception_4e_3x3_reduce(inception_4d_output)
inception_4e_5x5_reduce = self.inception_4e_5x5_reduce(inception_4d_output)
#loss2_conv = self.loss2_conv(loss2_ave_pool)
inception_4e_pool_proj = self.inception_4e_pool_proj(inception_4e_pool)
inception_4e_relu_1x1 = F.relu(inception_4e_1x1)
inception_4e_relu_3x3_reduce = F.relu(inception_4e_3x3_reduce)
inception_4e_relu_5x5_reduce = F.relu(inception_4e_5x5_reduce)
#loss2_relu_conv = F.relu(loss2_conv)
inception_4e_relu_pool_proj = F.relu(inception_4e_pool_proj)
inception_4e_3x3_pad = F.pad(inception_4e_relu_3x3_reduce, (1, 1, 1, 1))
inception_4e_3x3 = self.inception_4e_3x3(inception_4e_3x3_pad)
inception_4e_5x5_pad = F.pad(inception_4e_relu_5x5_reduce, (2, 2, 2, 2))
inception_4e_5x5 = self.inception_4e_5x5(inception_4e_5x5_pad)
#loss2_fc_0 = loss2_relu_conv.view(loss2_relu_conv.size(0), -1)
inception_4e_relu_3x3 = F.relu(inception_4e_3x3)
inception_4e_relu_5x5 = F.relu(inception_4e_5x5)
#loss2_fc_1 = self.loss2_fc_1(loss2_fc_0)
inception_4e_output = torch.cat((inception_4e_relu_1x1, inception_4e_relu_3x3, inception_4e_relu_5x5, inception_4e_relu_pool_proj), 1)
#loss2_relu_fc = F.relu(loss2_fc_1)
pool4_pad = F.pad(inception_4e_output, (0, 1, 0, 1), value=float('-inf'))
pool4 = F.max_pool2d(pool4_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
#loss2_drop_fc = F.dropout(input = loss2_relu_fc, p = 0.699999988079071, training = self.training, inplace = True)
inception_5a_pool_pad = F.pad(pool4, (1, 1, 1, 1), value=float('-inf'))
inception_5a_pool = F.max_pool2d(inception_5a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_5a_5x5_reduce = self.inception_5a_5x5_reduce(pool4)
inception_5a_3x3_reduce = self.inception_5a_3x3_reduce(pool4)
inception_5a_1x1 = self.inception_5a_1x1(pool4)
#loss2_classifier_model_0 = loss2_drop_fc.view(loss2_drop_fc.size(0), -1)
inception_5a_pool_proj = self.inception_5a_pool_proj(inception_5a_pool)
inception_5a_relu_5x5_reduce = F.relu(inception_5a_5x5_reduce)
inception_5a_relu_3x3_reduce = F.relu(inception_5a_3x3_reduce)
inception_5a_relu_1x1 = F.relu(inception_5a_1x1)
#loss2_classifier_model_1 = self.loss2_classifier_model_1(loss2_classifier_model_0)
inception_5a_relu_pool_proj = F.relu(inception_5a_pool_proj)
inception_5a_5x5_pad = F.pad(inception_5a_relu_5x5_reduce, (2, 2, 2, 2))
inception_5a_5x5 = self.inception_5a_5x5(inception_5a_5x5_pad)
inception_5a_3x3_pad = F.pad(inception_5a_relu_3x3_reduce, (1, 1, 1, 1))
inception_5a_3x3 = self.inception_5a_3x3(inception_5a_3x3_pad)
inception_5a_relu_5x5 = F.relu(inception_5a_5x5)
inception_5a_relu_3x3 = F.relu(inception_5a_3x3)
inception_5a_output = torch.cat((inception_5a_relu_1x1, inception_5a_relu_3x3, inception_5a_relu_5x5, inception_5a_relu_pool_proj), 1)
inception_5b_3x3_reduce = self.inception_5b_3x3_reduce(inception_5a_output)
inception_5b_1x1 = self.inception_5b_1x1(inception_5a_output)
inception_5b_5x5_reduce = self.inception_5b_5x5_reduce(inception_5a_output)
inception_5b_pool_pad = F.pad(inception_5a_output, (1, 1, 1, 1), value=float('-inf'))
inception_5b_pool = F.max_pool2d(inception_5b_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
inception_5b_relu_3x3_reduce = F.relu(inception_5b_3x3_reduce)
inception_5b_relu_1x1 = F.relu(inception_5b_1x1)
inception_5b_relu_5x5_reduce = F.relu(inception_5b_5x5_reduce)
inception_5b_pool_proj = self.inception_5b_pool_proj(inception_5b_pool)
inception_5b_3x3_pad = F.pad(inception_5b_relu_3x3_reduce, (1, 1, 1, 1))
inception_5b_3x3 = self.inception_5b_3x3(inception_5b_3x3_pad)
inception_5b_5x5_pad = F.pad(inception_5b_relu_5x5_reduce, (2, 2, 2, 2))
inception_5b_5x5 = self.inception_5b_5x5(inception_5b_5x5_pad)
inception_5b_relu_pool_proj = F.relu(inception_5b_pool_proj)
inception_5b_relu_3x3 = F.relu(inception_5b_3x3)
inception_5b_relu_5x5 = F.relu(inception_5b_5x5)
inception_5b_output = torch.cat((inception_5b_relu_1x1, inception_5b_relu_3x3, inception_5b_relu_5x5, inception_5b_relu_pool_proj), 1)
pool5 = F.avg_pool2d(inception_5b_output, kernel_size=(7, 7), stride=(1, 1), padding=(0,), ceil_mode=False, count_include_pad=False)
pool5_drop = F.dropout(input = pool5, p = 0.4000000059604645, training = self.training, inplace = True)
return pool5_drop # loss1_classifier_model_1, loss2_classifier_model_1,
def googlenet_layer_names(model_name='places'):
if 'places' in model_name:
gnet_layers = ['conv1_7x7_s2', 'conv2_3x3_reduce', 'conv2_3x3', 'inception_3a_1x1', 'inception_3a_5x5_reduce', 'inception_3a_3x3_reduce', 'inception_3a_pool_proj', \
'inception_3a_5x5', 'inception_3a_3x3', 'inception_3b_3x3_reduce', 'inception_3b_1x1', 'inception_3b_5x5_reduce', 'inception_3b_pool_proj', 'inception_3b_3x3', \
'inception_3b_5x5', 'inception_4a_1x1', 'inception_4a_3x3_reduce', 'inception_4a_5x5_reduce', 'inception_4a_pool_proj', 'inception_4a_3x3', 'inception_4a_5x5', \
'inception_4b_5x5_reduce', 'inception_4b_1x1', 'inception_4b_3x3_reduce', 'inception_4b_pool_proj', 'loss1_conv', 'inception_4b_5x5', 'inception_4b_3x3', \
'loss1_fc_1', 'inception_4c_5x5_reduce', 'inception_4c_1x1', 'inception_4c_3x3_reduce', 'inception_4c_pool_proj', 'inception_4c_5x5', 'inception_4c_3x3', \
'loss1_classifier_1', 'inception_4d_3x3_reduce', 'inception_4d_1x1', 'inception_4d_5x5_reduce', 'inception_4d_pool_proj', 'inception_4d_3x3', 'inception_4d_5x5', \
'inception_4e_1x1', 'inception_4e_5x5_reduce', 'inception_4e_3x3_reduce', 'loss2_conv', 'inception_4e_pool_proj', 'inception_4e_5x5', 'inception_4e_3x3', \
'loss2_fc_1', 'inception_5a_1x1', 'inception_5a_5x5_reduce', 'inception_5a_3x3_reduce', 'inception_5a_pool_proj', 'loss2_classifier_1', 'inception_5a_5x5', \
'inception_5a_3x3', 'inception_5b_3x3_reduce', 'inception_5b_5x5_reduce', 'inception_5b_1x1', 'inception_5b_pool_proj', 'inception_5b_3x3', 'inception_5b_5x5']
elif 'bvlc' in model_name:
gnet_layers = ['conv1_7x7_s2', 'conv2_3x3_reduce', 'conv2_3x3', 'inception_3a_1x1', 'inception_3a_5x5_reduce', 'inception_3a_3x3_reduce', 'inception_3a_pool_proj', \
'inception_3a_5x5', 'inception_3a_3x3', 'inception_3b_3x3_reduce', 'inception_3b_1x1', 'inception_3b_5x5_reduce', 'inception_3b_pool_proj', 'inception_3b_3x3', \
'inception_3b_5x5', 'inception_4a_1x1', 'inception_4a_3x3_reduce', 'inception_4a_5x5_reduce', 'inception_4a_pool_proj', 'inception_4a_3x3', 'inception_4a_5x5', \
'inception_4b_5x5_reduce', 'inception_4b_1x1', 'inception_4b_3x3_reduce', 'inception_4b_pool_proj', 'loss1_conv', 'inception_4b_5x5', 'inception_4b_3x3', 'loss1_fc_1', \
'inception_4c_5x5_reduce', 'inception_4c_1x1', 'inception_4c_3x3_reduce', 'inception_4c_pool_proj', 'inception_4c_5x5', 'inception_4c_3x3', 'loss1_classifier_1', \
'inception_4d_3x3_reduce', 'inception_4d_1x1', 'inception_4d_5x5_reduce', 'inception_4d_pool_proj', 'inception_4d_3x3', 'inception_4d_5x5', 'inception_4e_1x1', \
'inception_4e_5x5_reduce', 'inception_4e_3x3_reduce', 'loss2_conv', 'inception_4e_pool_proj', 'inception_4e_5x5', 'inception_4e_3x3', 'loss2_fc_1', 'inception_5a_1x1', \
'inception_5a_5x5_reduce', 'inception_5a_3x3_reduce', 'inception_5a_pool_proj', 'loss2_classifier_1', 'inception_5a_5x5', 'inception_5a_3x3', 'inception_5b_3x3_reduce', \
'inception_5b_5x5_reduce', 'inception_5b_1x1', 'inception_5b_pool_proj', 'inception_5b_3x3', 'inception_5b_5x5']
elif 'sos' in model_name:
gnet_layers = ['conv1_7x7_s2', 'conv2_3x3_reduce', 'conv2_3x3', 'inception_3a_1x1', 'inception_3a_5x5_reduce', 'inception_3a_3x3_reduce', 'inception_3a_pool_proj', \
'inception_3a_5x5', 'inception_3a_3x3', 'inception_3b_3x3_reduce', 'inception_3b_1x1', 'inception_3b_5x5_reduce', 'inception_3b_pool_proj', 'inception_3b_3x3', \
'inception_3b_5x5', 'inception_4a_1x1', 'inception_4a_3x3_reduce', 'inception_4a_5x5_reduce', 'inception_4a_pool_proj', 'inception_4a_3x3', 'inception_4a_5x5', \
'inception_4b_5x5_reduce', 'inception_4b_1x1', 'inception_4b_3x3_reduce', 'inception_4b_pool_proj', 'inception_4b_5x5', 'inception_4b_3x3', 'inception_4c_5x5_reduce', \
'inception_4c_1x1', 'inception_4c_3x3_reduce', 'inception_4c_pool_proj', 'inception_4c_5x5', 'inception_4c_3x3', 'inception_4d_3x3_reduce', 'inception_4d_1x1', \
'inception_4d_5x5_reduce', 'inception_4d_pool_proj', 'inception_4d_3x3', 'inception_4d_5x5', 'inception_4e_5x5_reduce', 'inception_4e_1x1', 'inception_4e_3x3_reduce', \
'inception_4e_pool_proj', 'inception_4e_5x5', 'inception_4e_3x3', 'inception_5a_1x1', 'inception_5a_5x5_reduce', 'inception_5a_3x3_reduce', 'inception_5a_pool_proj', \
'inception_5a_5x5', 'inception_5a_3x3', 'inception_5b_3x3_reduce', 'inception_5b_5x5_reduce', 'inception_5b_1x1', 'inception_5b_pool_proj', 'inception_5b_3x3', 'inception_5b_5x5']
elif 'cars' in model_name:
gnet_layers = ['conv1', 'conv2_1x1', 'conv2_3x3', 'inception_3a_5x5_reduce', 'inception_3a_3x3_reduce', 'inception_3a_1x1', 'inception_3a_pool_proj', 'inception_3a_5x5', \
'inception_3a_3x3', 'inception_3b_1x1', 'inception_3b_3x3_reduce', 'inception_3b_5x5_reduce', 'inception_3b_pool_proj', 'inception_3b_3x3', 'inception_3b_5x5', \
'inception_4a_3x3_reduce', 'inception_4a_5x5_reduce', 'inception_4a_1x1', 'inception_4a_pool_proj', 'inception_4a_3x3', 'inception_4a_5x5', 'inception_4b_3x3_reduce', \
'inception_4b_1x1', 'inception_4b_5x5_reduce', 'loss1_conv', 'inception_4b_pool_proj', 'inception_4b_3x3', 'inception_4b_5x5', 'loss1_fc_1', 'inception_4c_1x1', \
'inception_4c_5x5_reduce', 'inception_4c_3x3_reduce', 'inception_4c_pool_proj', 'inception_4c_5x5', 'inception_4c_3x3', 'loss1_classifier_model_1', 'inception_4d_1x1', \
'inception_4d_3x3_reduce', 'inception_4d_5x5_reduce', 'inception_4d_pool_proj', 'inception_4d_3x3', 'inception_4d_5x5', 'inception_4e_1x1', 'inception_4e_3x3_reduce', \
'inception_4e_5x5_reduce', 'loss2_conv', 'inception_4e_pool_proj', 'inception_4e_3x3', 'inception_4e_5x5', 'loss2_fc_1', 'inception_5a_5x5_reduce', 'inception_5a_3x3_reduce', \
'inception_5a_1x1', 'inception_5a_pool_proj', 'loss2_classifier_model_1', 'inception_5a_5x5', 'inception_5a_3x3', 'inception_5b_3x3_reduce', 'inception_5b_1x1', \
'inception_5b_5x5_reduce', 'inception_5b_pool_proj', 'inception_5b_3x3', 'inception_5b_5x5']
return gnet_layers

Setup

  1. Place all files except for dream.py inside a folder/directory called neural_dream.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Inception5h(nn.Module):
def __init__(self):
super(Inception5h, self).__init__()
self.conv2d0_pre_relu_conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(7, 7), stride=(2, 2), groups=1, bias=True)
self.conv2d1_pre_relu_conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.conv2d2_pre_relu_conv = nn.Conv2d(in_channels=64, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.mixed3a_1x1_pre_relu_conv = nn.Conv2d(in_channels=192, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.mixed3a_3x3_bottleneck_pre_relu_conv = nn.Conv2d(in_channels=192, out_channels=96, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.mixed3a_5x5_bottleneck_pre_relu_conv = nn.Conv2d(in_channels=192, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.mixed3a_pool_reduce_pre_relu_conv = nn.Conv2d(in_channels=192, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.mixed3a_3x3_pre_relu_conv = nn.Conv2d(in_channels=96, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.mixed3a_5x5_pre_relu_conv = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.mixed3b_1x1_pre_relu_conv = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.mixed3b_3x3_bottleneck_pre_relu_conv = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.mixed3b_5x5_bottleneck_pre_relu_conv = nn.Conv2d(in_channels=256, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.mixed3b_pool_reduce_pre_relu_conv = nn.Conv2d(in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.mixed3b_3x3_pre_relu_conv = nn.Conv2d(in_channels=128, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.mixed3b_5x5_pre_relu_conv = nn.Conv2d(in_channels=32, out_channels=96, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.mixed4a_1x1_pre_relu_conv = nn.Conv2d(in_channels=480, out_channels=192, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.mixed4a_3x3_bottleneck_pre_relu_conv = nn.Conv2d(in_channels=480, out_channels=96, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.mixed4a_5x5_bottleneck_pre_relu_conv = nn.Conv2d(in_channels=480, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.mixed4a_pool_reduce_pre_relu_conv = nn.Conv2d(in_channels=480, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.mixed4a_3x3_pre_relu_conv = nn.Conv2d(in_channels=96, out_channels=204, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.mixed4a_5x5_pre_relu_conv = nn.Conv2d(in_channels=16, out_channels=48, kernel_size=(5, 5), stride=(1, 1), groups=1, bias=True)
self.head0_bottleneck_pre_relu_conv = nn.Conv2d(in_channels=508, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.nn0_pre_relu_matmul = nn.Linear(in_features = 2048, out_features = 1024, bias = True)
self.softmax0_pre_activation_matmul = nn.Linear(in_features = 1024, out_features = 1008, bias = True)
def forward(self, x):
conv2d0_pre_relu_conv_pad = F.pad(x, (2, 3, 2, 3))
conv2d0_pre_relu_conv = self.conv2d0_pre_relu_conv(conv2d0_pre_relu_conv_pad)
conv2d0 = F.relu(conv2d0_pre_relu_conv)
maxpool0_pad = F.pad(conv2d0, (0, 1, 0, 1), value=float('-inf'))
maxpool0 = F.max_pool2d(maxpool0_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
localresponsenorm0 = F.local_response_norm(maxpool0, size=9, alpha=9.999999747378752e-05, beta=0.5, k=1)
conv2d1_pre_relu_conv = self.conv2d1_pre_relu_conv(localresponsenorm0)
conv2d1 = F.relu(conv2d1_pre_relu_conv)
conv2d2_pre_relu_conv_pad = F.pad(conv2d1, (1, 1, 1, 1))
conv2d2_pre_relu_conv = self.conv2d2_pre_relu_conv(conv2d2_pre_relu_conv_pad)
conv2d2 = F.relu(conv2d2_pre_relu_conv)
localresponsenorm1 = F.local_response_norm(conv2d2, size=9, alpha=9.999999747378752e-05, beta=0.5, k=1)
maxpool1_pad = F.pad(localresponsenorm1, (0, 1, 0, 1), value=float('-inf'))
maxpool1 = F.max_pool2d(maxpool1_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
mixed3a_1x1_pre_relu_conv = self.mixed3a_1x1_pre_relu_conv(maxpool1)
mixed3a_3x3_bottleneck_pre_relu_conv = self.mixed3a_3x3_bottleneck_pre_relu_conv(maxpool1)
mixed3a_5x5_bottleneck_pre_relu_conv = self.mixed3a_5x5_bottleneck_pre_relu_conv(maxpool1)
mixed3a_pool_pad = F.pad(maxpool1, (1, 1, 1, 1), value=float('-inf'))
mixed3a_pool = F.max_pool2d(mixed3a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
mixed3a_1x1 = F.relu(mixed3a_1x1_pre_relu_conv)
mixed3a_3x3_bottleneck = F.relu(mixed3a_3x3_bottleneck_pre_relu_conv)
mixed3a_5x5_bottleneck = F.relu(mixed3a_5x5_bottleneck_pre_relu_conv)
mixed3a_pool_reduce_pre_relu_conv = self.mixed3a_pool_reduce_pre_relu_conv(mixed3a_pool)
mixed3a_3x3_pre_relu_conv_pad = F.pad(mixed3a_3x3_bottleneck, (1, 1, 1, 1))
mixed3a_3x3_pre_relu_conv = self.mixed3a_3x3_pre_relu_conv(mixed3a_3x3_pre_relu_conv_pad)
mixed3a_5x5_pre_relu_conv_pad = F.pad(mixed3a_5x5_bottleneck, (2, 2, 2, 2))
mixed3a_5x5_pre_relu_conv = self.mixed3a_5x5_pre_relu_conv(mixed3a_5x5_pre_relu_conv_pad)
mixed3a_pool_reduce = F.relu(mixed3a_pool_reduce_pre_relu_conv)
mixed3a_3x3 = F.relu(mixed3a_3x3_pre_relu_conv)
mixed3a_5x5 = F.relu(mixed3a_5x5_pre_relu_conv)
mixed3a = torch.cat((mixed3a_1x1, mixed3a_3x3, mixed3a_5x5, mixed3a_pool_reduce), 1)
mixed3b_1x1_pre_relu_conv = self.mixed3b_1x1_pre_relu_conv(mixed3a)
mixed3b_3x3_bottleneck_pre_relu_conv = self.mixed3b_3x3_bottleneck_pre_relu_conv(mixed3a)
mixed3b_5x5_bottleneck_pre_relu_conv = self.mixed3b_5x5_bottleneck_pre_relu_conv(mixed3a)
mixed3b_pool_pad = F.pad(mixed3a, (1, 1, 1, 1), value=float('-inf'))
mixed3b_pool = F.max_pool2d(mixed3b_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
mixed3b_1x1 = F.relu(mixed3b_1x1_pre_relu_conv)
mixed3b_3x3_bottleneck = F.relu(mixed3b_3x3_bottleneck_pre_relu_conv)
mixed3b_5x5_bottleneck = F.relu(mixed3b_5x5_bottleneck_pre_relu_conv)
mixed3b_pool_reduce_pre_relu_conv = self.mixed3b_pool_reduce_pre_relu_conv(mixed3b_pool)
mixed3b_3x3_pre_relu_conv_pad = F.pad(mixed3b_3x3_bottleneck, (1, 1, 1, 1))
mixed3b_3x3_pre_relu_conv = self.mixed3b_3x3_pre_relu_conv(mixed3b_3x3_pre_relu_conv_pad)
mixed3b_5x5_pre_relu_conv_pad = F.pad(mixed3b_5x5_bottleneck, (2, 2, 2, 2))
mixed3b_5x5_pre_relu_conv = self.mixed3b_5x5_pre_relu_conv(mixed3b_5x5_pre_relu_conv_pad)
mixed3b_pool_reduce = F.relu(mixed3b_pool_reduce_pre_relu_conv)
mixed3b_3x3 = F.relu(mixed3b_3x3_pre_relu_conv)
mixed3b_5x5 = F.relu(mixed3b_5x5_pre_relu_conv)
mixed3b = torch.cat((mixed3b_1x1, mixed3b_3x3, mixed3b_5x5, mixed3b_pool_reduce), 1)
maxpool4_pad = F.pad(mixed3b, (0, 1, 0, 1), value=float('-inf'))
maxpool4 = F.max_pool2d(maxpool4_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
mixed4a_1x1_pre_relu_conv = self.mixed4a_1x1_pre_relu_conv(maxpool4)
mixed4a_3x3_bottleneck_pre_relu_conv = self.mixed4a_3x3_bottleneck_pre_relu_conv(maxpool4)
mixed4a_5x5_bottleneck_pre_relu_conv = self.mixed4a_5x5_bottleneck_pre_relu_conv(maxpool4)
mixed4a_pool_pad = F.pad(maxpool4, (1, 1, 1, 1), value=float('-inf'))
mixed4a_pool = F.max_pool2d(mixed4a_pool_pad, kernel_size=(3, 3), stride=(1, 1), padding=0, ceil_mode=False)
mixed4a_1x1 = F.relu(mixed4a_1x1_pre_relu_conv)
mixed4a_3x3_bottleneck = F.relu(mixed4a_3x3_bottleneck_pre_relu_conv)
mixed4a_5x5_bottleneck = F.relu(mixed4a_5x5_bottleneck_pre_relu_conv)
mixed4a_pool_reduce_pre_relu_conv = self.mixed4a_pool_reduce_pre_relu_conv(mixed4a_pool)
mixed4a_3x3_pre_relu_conv_pad = F.pad(mixed4a_3x3_bottleneck, (1, 1, 1, 1))
mixed4a_3x3_pre_relu_conv = self.mixed4a_3x3_pre_relu_conv(mixed4a_3x3_pre_relu_conv_pad)
mixed4a_5x5_pre_relu_conv_pad = F.pad(mixed4a_5x5_bottleneck, (2, 2, 2, 2))
mixed4a_5x5_pre_relu_conv = self.mixed4a_5x5_pre_relu_conv(mixed4a_5x5_pre_relu_conv_pad)
mixed4a_pool_reduce = F.relu(mixed4a_pool_reduce_pre_relu_conv)
mixed4a_3x3 = F.relu(mixed4a_3x3_pre_relu_conv)
mixed4a_5x5 = F.relu(mixed4a_5x5_pre_relu_conv)
mixed4a = torch.cat((mixed4a_1x1, mixed4a_3x3, mixed4a_5x5, mixed4a_pool_reduce), 1)
head0_pool = F.avg_pool2d(mixed4a, kernel_size=(5, 5), stride=(3, 3), padding=(0,), ceil_mode=False, count_include_pad=False)
head0_bottleneck_pre_relu_conv = self.head0_bottleneck_pre_relu_conv(head0_pool)
head0_bottleneck = F.relu(head0_bottleneck_pre_relu_conv)
avgpool_2d = nn.AdaptiveAvgPool2d((4, 4))
x = avgpool_2d(head0_bottleneck)
x = torch.flatten(x, 1)
nn0_pre_relu_matmul = self.nn0_pre_relu_matmul(x)
nn0 = F.relu(nn0_pre_relu_matmul)
nn0_reshape = torch.reshape(input = nn0, shape = (-1,1024))
softmax0_pre_activation_matmul = self.softmax0_pre_activation_matmul(nn0_reshape)
softmax0 = F.softmax(softmax0_pre_activation_matmul)
return softmax0
def inception_layer_names(model_name='5h'):
if model_name == '5h':
layers = ['conv2d0_pre_relu_conv', 'conv2d1_pre_relu_conv', 'conv2d2_pre_relu_conv', 'mixed3a_1x1_pre_relu_conv', 'mixed3a_3x3_bottleneck_pre_relu_conv', \
'mixed3a_5x5_bottleneck_pre_relu_conv', 'mixed3a_pool_reduce_pre_relu_conv', 'mixed3a_3x3_pre_relu_conv', 'mixed3a_5x5_pre_relu_conv', 'mixed3b_1x1_pre_relu_conv', \
'mixed3b_3x3_bottleneck_pre_relu_conv', 'mixed3b_5x5_bottleneck_pre_relu_conv', 'mixed3b_pool_reduce_pre_relu_conv', 'mixed3b_3x3_pre_relu_conv', \
'mixed3b_5x5_pre_relu_conv', 'mixed4a_1x1_pre_relu_conv', 'mixed4a_3x3_bottleneck_pre_relu_conv', 'mixed4a_5x5_bottleneck_pre_relu_conv', \
'mixed4a_pool_reduce_pre_relu_conv', 'mixed4a_3x3_pre_relu_conv', 'mixed4a_5x5_pre_relu_conv', 'head0_bottleneck_pre_relu_conv', 'nn0_pre_relu_matmul', \
'softmax0_pre_activation_matmul']
return layers

neural-dream Installation

This guide will walk you through multiple ways to setup neural-dream on Ubuntu and Windows. If you wish to install PyTorch and neural-dream on a different operating system like MacOS, installation guides can be found here.

Note that in order to reduce their size, the pre-packaged binary releases (pip, Conda, etc...) have removed support for some older GPUs, and thus you will have to install from source in order to use these GPUs.

Ubuntu:

With A Package Manager:

The pip and Conda packages ship with CUDA and cuDNN already built in, so after you have installed PyTorch with pip or Conda, you can skip to installing neural-dream.

pip:

The neural-dream PyPI page can be found here: https://pypi.org/project/neural-dream/

If you wish to install neural-dream as a pip package, then use the following command:

# in a terminal, run the command
pip install neural-dream

Or:

# in a terminal, run the command
pip3 install neural-dream

Next download the models with:

neural-dream -download_models

By default the models are downloaded to your home directory, but you can specify a download location with:

neural-dream -download_models -download_path <download_path>

To download specific models or specific groups of models, you can use a comma separated list of models like this:

neural-dream -download_models all-caffe-googlenet,caffe-vgg19

To print all the models available for download, run the following command:

neural-dream -download_models print-all

Github and pip:

Following the pip installation instructions here, you can install PyTorch with the following commands:

# in a terminal, run the commands
cd ~/
pip install torch torchvision

Or:

cd ~/
pip3 install torch torchvision

Now continue on to installing neural-dream to install neural-dream.

Conda:

Following the Conda installation instructions here, you can install PyTorch with the following command:

conda install pytorch torchvision -c pytorch

Now continue on to installing neural-dream to install neural-dream.

From Source:

(Optional) Step 1: Install CUDA

If you have a CUDA-capable GPU from NVIDIA then you can speed up neural-dream with CUDA.

Instructions for downloading and installing the latest CUDA version on all supported operating systems, can be found here.

(Optional) Step 2: Install cuDNN

cuDNN is a library from NVIDIA that efficiently implements many of the operations (like convolutions and pooling) that are commonly used in deep learning.

After registering as a developer with NVIDIA, you can download cuDNN here. Make sure that you use the appropriate version of cuDNN for your version of CUDA.

Follow the download instructions on Nvidia's site to install cuDNN correctly.

Note that the cuDNN backend can only be used for GPU mode.

(Optional) Steps 1-3: Install PyTorch with support for AMD GPUs using Radeon Open Compute Stack (ROCm)

It is recommended that if you wish to use PyTorch with an AMD GPU, you install it via the official ROCm dockerfile: https://rocm.github.io/pytorch.html

  • Supported AMD GPUs for the dockerfile are: Vega10 / gfx900 generation discrete graphics cards (Vega56, Vega64, or MI25).

PyTorch does not officially provide support for compilation on the host with AMD GPUs, but a user guide posted here apparently works well.

ROCm utilizes a CUDA porting tool called HIP, which automatically converts CUDA code into HIP code. HIP code can run on both AMD and Nvidia GPUs.

Step 3: Install PyTorch

To install PyTorch from source on Ubuntu (Instructions may be different if you are using a different OS):

cd ~/
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
python setup.py install

cd ~/
git clone --recursive https://github.com/pytorch/vision
cd vision
python setup.py install

To check that your torch installation is working, run the command python or python3 to enter the Python interpreter. Then type import torch and hit enter.

You can then type print(torch.version.cuda) and print(torch.backends.cudnn.version()) to confirm that you are using the desired versions of CUDA and cuDNN.

To quit just type exit() or use Ctrl-D.

Now continue on to installing neural-dream to install neural-dream.

Windows Installation

If you wish to install PyTorch on Windows From Source or via Conda, you can find instructions on the PyTorch website: https://pytorch.org/

Github and pip

First, you will need to download Python 3 and install it: https://www.python.org/downloads/windows/. I recommend using the executable installer for the latest version of Python 3.

Then using https://pytorch.org/, get the correct pip command, paste it into the Command Prompt (CMD) and hit enter:

pip install torch===1.4.0 torchvision===0.5.0 -f https://download.pytorch.org/whl/torch_stable.html

After installing PyTorch, download the neural-dream Github repository and extract/unzip it to the desired location.

Then copy the file path to your neural-dream folder, and paste it into the Command Prompt, with cd in front of it and then hit enter.

In the example below, the neural-dream folder was placed on the desktop:

cd C:\Users\<User_Name>\Desktop\neural-dream-master

You can now continue on to installing neural-dream, skipping the git clone step.

Install neural-dream

First we clone neural-dream from GitHub:

cd ~/
git clone https://github.com/ProGamerGov/neural-dream.git
cd neural-dream

Next we need to download the pretrained neural network models:

python models/download_models.py

You should now be able to run neural-dream in CPU mode like this:

python neural_dream.py -gpu c -print_octave_iter 2

If you installed PyTorch with support for CUDA, then should now be able to run neural-dream in GPU mode like this:

python neural_dream.py -gpu 0 -print_octave_iter 5

If you installed PyTorch with support for cuDNN, then you should now be able to run neural-dream with the cudnn backend like this:

python neural_dream.py -gpu 0 -backend cudnn -print_octave_iter 5

If everything is working properly you should see output like this:

Octave iter 1 iteration 25 / 50
  DeepDream 1 loss: 19534752.0
Octave iter 1 iteration 50 / 50
  DeepDream 1 loss: 23289720.0
Octave iter 2 iteration 25 / 50
  DeepDream 1 loss: 38870436.0
Octave iter 2 iteration 50 / 50
  DeepDream 1 loss: 47514664.0
Iteration 1 / 10
  DeepDream 1 loss: 71727704.0
  Total loss: 2767866014.0
Octave iter 1 iteration 25 / 50
  DeepDream 1 loss: 27209894.0
Octave iter 1 iteration 50 / 50
  DeepDream 1 loss: 31386542.0
Octave iter 2 iteration 25 / 50
  DeepDream 1 loss: 47773244.0
Octave iter 2 iteration 50 / 50
  DeepDream 1 loss: 51204812.0
Iteration 2 / 10
  DeepDream 1 loss: 87182300.0
  Total loss: 3758961954.0
The MIT License (MIT)
Copyright (c) 2020 ProGamerGov
Copyright (c) 2015 Justin Johnson
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
import torch
import torch.nn as nn
import neural_dream.dream_utils as dream_utils
from neural_dream.dream_experimental import ChannelMod
# Define an nn Module to compute DeepDream loss in different ways
class DreamLossMode(torch.nn.Module):
def __init__(self, loss_mode, use_fft, r):
super(DreamLossMode, self).__init__()
self.get_mode(loss_mode)
self.use_fft = use_fft[0]
self.fft_tensor = dream_utils.FFTTensor(r, use_fft[1])
def get_mode(self, loss_mode):
self.loss_mode_string = loss_mode
if loss_mode.lower() == 'norm':
self.get_loss = self.norm_loss
elif loss_mode.lower() == 'mean':
self.get_loss = self.mean_loss
elif loss_mode.lower() == 'l2':
self.get_loss = self.l2_loss
elif loss_mode.lower() == 'mse':
self.crit = torch.nn.MSELoss()
self.get_loss = self.crit_loss
elif loss_mode.lower() == 'bce':
self.crit = torch.nn.BCEWithLogitsLoss()
self.get_loss = self.crit_loss
def norm_loss(self, input):
return input.norm()
def mean_loss(self, input):
return input.mean()
def l2_loss(self, input):
return input.pow(2).sum().sqrt()
def abs_mean(self, input):
return input.abs().mean()
def crit_loss(self, input, target):
return self.crit(input, target)
def forward(self, input):
if self.use_fft:
input = self.fft_tensor(input)
if self.loss_mode_string != 'bce' and self.loss_mode_string != 'mse':
loss = self.get_loss(input)
else:
target = torch.zeros_like(input.detach())
loss = self.crit_loss(input, target)
return loss
# Define an nn Module for DeepDream
class DeepDream(torch.nn.Module):
def __init__(self, loss_mode, channels='-1', channel_mode='strong', \
channel_capture='once', scale=4, sigma=1, use_fft=(True, 25), r=1, p_mode='fast', norm_p=0, abs_p=0, mean_p=0):
super(DeepDream, self).__init__()
self.get_loss = DreamLossMode(loss_mode, use_fft, r)
self.channels = [int(c) for c in channels.split(',')]
self.channel_mode = channel_mode
self.get_channels = dream_utils.RankChannels(self.channels, self.channel_mode)
self.lap_scale = scale
self.sigma = sigma.split(',')
self.channel_capture = channel_capture
self.zero_weak = ChannelMod(p_mode, self.channels[0], norm_p, abs_p, mean_p)
def capture(self, input):
if -1 not in self.channels and 'all' not in self.channel_mode:
self.channels = self.get_channels(input)
elif self.channel_mode == 'all' and -1 not in self.channels:
self.channels = self.channels
if self.lap_scale > 0:
self.lap_pyramid = dream_utils.LaplacianPyramid(input.clone(), self.lap_scale, self.sigma)
def get_channel_loss(self, input):
loss = 0
if 'once' not in self.channel_capture:
self.channels = self.capture(input)
for c in self.channels:
if input.dim() > 0:
if int(c) < input.size(1):
#test = input[:, int(c)].unsqueeze(0)
#input[input != test] = 0
#loss += self.get_loss(input)
loss += self.get_loss(input[:, int(c)])
return loss
def forward(self, input):
if self.lap_scale > 0:
input = self.lap_pyramid(input)
if self.zero_weak.enabled:
input = self.zero_weak(input)
if -1 in self.channels:
loss = self.get_loss(input)
else:
loss = self.get_channel_loss(input)
return loss
# Define an nn Module to collect DeepDream loss
class DreamLoss(torch.nn.Module):
def __init__(self, loss_mode, strength, channels, channel_mode='all', **kwargs):
super(DreamLoss, self).__init__()
self.dream = DeepDream(loss_mode, channels, channel_mode, **kwargs)
self.strength = strength
self.mode = 'None'
def forward(self, input):
if self.mode == 'loss':
self.loss = self.dream(input.clone()) * self.strength
elif self.mode == 'capture':
self.target_size = input.size()
self.dream.capture(input.clone())
return input
# Define a forward hook to collect DeepDream loss
class DreamLossHook(DreamLoss):
def forward(self, module, input, output):
if self.mode == 'loss':
self.loss = self.dream(output.clone()) * self.strength
elif self.mode == 'capture':
self.target_size = output.size()
self.dream.capture(output.clone())
# Define a pre forward hook to collect DeepDream loss
class DreamLossPreHook(DreamLoss):
def forward(self, module, output):
if self.mode == 'loss':
self.loss = self.dream(output[0].clone()) * self.strength
elif self.mode == 'capture':
self.target_size = output[0].size()
self.dream.capture(output[0].clone())
import torch
import torch.nn as nn
import torch.nn.functional as F
from neural_dream.dream_utils import AdditionLayer
class ResNet_50_1by2_nsfw(nn.Module):
def __init__(self):
super(ResNet_50_1by2_nsfw, self).__init__()
self.conv_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(7, 7), stride=(2, 2), groups=1, bias=True)
self.bn_1 = nn.BatchNorm2d(num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage0_block0_branch2a = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.conv_stage0_block0_proj_shortcut = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage0_block0_branch2a = nn.BatchNorm2d(num_features=32, eps=9.999999747378752e-06, momentum=0.0)
self.bn_stage0_block0_proj_shortcut = nn.BatchNorm2d(num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage0_block0_branch2b = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.bn_stage0_block0_branch2b = nn.BatchNorm2d(num_features=32, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage0_block0_branch2c = nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage0_block0_branch2c = nn.BatchNorm2d(num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage0_block1_branch2a = nn.Conv2d(in_channels=128, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage0_block1_branch2a = nn.BatchNorm2d(num_features=32, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage0_block1_branch2b = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.bn_stage0_block1_branch2b = nn.BatchNorm2d(num_features=32, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage0_block1_branch2c = nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage0_block1_branch2c = nn.BatchNorm2d(num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage0_block2_branch2a = nn.Conv2d(in_channels=128, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage0_block2_branch2a = nn.BatchNorm2d(num_features=32, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage0_block2_branch2b = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.bn_stage0_block2_branch2b = nn.BatchNorm2d(num_features=32, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage0_block2_branch2c = nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage0_block2_branch2c = nn.BatchNorm2d(num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage1_block0_proj_shortcut = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(1, 1), stride=(2, 2), groups=1, bias=True)
self.conv_stage1_block0_branch2a = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=(1, 1), stride=(2, 2), groups=1, bias=True)
self.bn_stage1_block0_proj_shortcut = nn.BatchNorm2d(num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.bn_stage1_block0_branch2a = nn.BatchNorm2d(num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage1_block0_branch2b = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.bn_stage1_block0_branch2b = nn.BatchNorm2d(num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage1_block0_branch2c = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage1_block0_branch2c = nn.BatchNorm2d(num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage1_block1_branch2a = nn.Conv2d(in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage1_block1_branch2a = nn.BatchNorm2d(num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage1_block1_branch2b = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.bn_stage1_block1_branch2b = nn.BatchNorm2d(num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage1_block1_branch2c = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage1_block1_branch2c = nn.BatchNorm2d(num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage1_block2_branch2a = nn.Conv2d(in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage1_block2_branch2a = nn.BatchNorm2d(num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage1_block2_branch2b = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.bn_stage1_block2_branch2b = nn.BatchNorm2d(num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage1_block2_branch2c = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage1_block2_branch2c = nn.BatchNorm2d(num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage1_block3_branch2a = nn.Conv2d(in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage1_block3_branch2a = nn.BatchNorm2d(num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage1_block3_branch2b = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.bn_stage1_block3_branch2b = nn.BatchNorm2d(num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage1_block3_branch2c = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage1_block3_branch2c = nn.BatchNorm2d(num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block0_proj_shortcut = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(1, 1), stride=(2, 2), groups=1, bias=True)
self.conv_stage2_block0_branch2a = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(2, 2), groups=1, bias=True)
self.bn_stage2_block0_proj_shortcut = nn.BatchNorm2d(num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.bn_stage2_block0_branch2a = nn.BatchNorm2d(num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block0_branch2b = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block0_branch2b = nn.BatchNorm2d(num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block0_branch2c = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block0_branch2c = nn.BatchNorm2d(num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block1_branch2a = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block1_branch2a = nn.BatchNorm2d(num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block1_branch2b = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block1_branch2b = nn.BatchNorm2d(num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block1_branch2c = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block1_branch2c = nn.BatchNorm2d(num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block2_branch2a = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block2_branch2a = nn.BatchNorm2d(num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block2_branch2b = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block2_branch2b = nn.BatchNorm2d(num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block2_branch2c = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block2_branch2c = nn.BatchNorm2d(num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block3_branch2a = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block3_branch2a = nn.BatchNorm2d(num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block3_branch2b = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block3_branch2b = nn.BatchNorm2d(num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block3_branch2c = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block3_branch2c = nn.BatchNorm2d(num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block4_branch2a = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block4_branch2a = nn.BatchNorm2d(num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block4_branch2b = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block4_branch2b = nn.BatchNorm2d(num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block4_branch2c = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block4_branch2c = nn.BatchNorm2d(num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block5_branch2a = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block5_branch2a = nn.BatchNorm2d(num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block5_branch2b = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block5_branch2b = nn.BatchNorm2d(num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage2_block5_branch2c = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage2_block5_branch2c = nn.BatchNorm2d(num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage3_block0_proj_shortcut = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=(1, 1), stride=(2, 2), groups=1, bias=True)
self.conv_stage3_block0_branch2a = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=(1, 1), stride=(2, 2), groups=1, bias=True)
self.bn_stage3_block0_proj_shortcut = nn.BatchNorm2d(num_features=1024, eps=9.999999747378752e-06, momentum=0.0)
self.bn_stage3_block0_branch2a = nn.BatchNorm2d(num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage3_block0_branch2b = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.bn_stage3_block0_branch2b = nn.BatchNorm2d(num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage3_block0_branch2c = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage3_block0_branch2c = nn.BatchNorm2d(num_features=1024, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage3_block1_branch2a = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage3_block1_branch2a = nn.BatchNorm2d(num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage3_block1_branch2b = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.bn_stage3_block1_branch2b = nn.BatchNorm2d(num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage3_block1_branch2c = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage3_block1_branch2c = nn.BatchNorm2d(num_features=1024, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage3_block2_branch2a = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage3_block2_branch2a = nn.BatchNorm2d(num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage3_block2_branch2b = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
self.bn_stage3_block2_branch2b = nn.BatchNorm2d(num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv_stage3_block2_branch2c = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
self.bn_stage3_block2_branch2c = nn.BatchNorm2d(num_features=1024, eps=9.999999747378752e-06, momentum=0.0)
self.fc_nsfw_1 = nn.Linear(in_features = 1024, out_features = 2, bias = True)
def add_layers(self):
self.eltwise_stage0_block0 = AdditionLayer()
self.eltwise_stage0_block1 = AdditionLayer()
self.eltwise_stage0_block2 = AdditionLayer()
self.eltwise_stage1_block0 = AdditionLayer()
self.eltwise_stage1_block1 = AdditionLayer()
self.eltwise_stage1_block2 = AdditionLayer()
self.eltwise_stage1_block3 = AdditionLayer()
self.eltwise_stage2_block0 = AdditionLayer()
self.eltwise_stage2_block1 = AdditionLayer()
self.eltwise_stage2_block2 = AdditionLayer()
self.eltwise_stage2_block3 = AdditionLayer()
self.eltwise_stage2_block4 = AdditionLayer()
self.eltwise_stage2_block5 = AdditionLayer()
self.eltwise_stage3_block0 = AdditionLayer()
self.eltwise_stage3_block1 = AdditionLayer()
self.eltwise_stage3_block2 = AdditionLayer()
def forward(self, x):
conv_1_pad = F.pad(x, (3, 3, 3, 3))
conv_1 = self.conv_1(conv_1_pad)
bn_1 = self.bn_1(conv_1)
relu_1 = F.relu(bn_1)
pool1_pad = F.pad(relu_1, (0, 1, 0, 1), value=float('-inf'))
pool1 = F.max_pool2d(pool1_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False)
conv_stage0_block0_branch2a = self.conv_stage0_block0_branch2a(pool1)
conv_stage0_block0_proj_shortcut = self.conv_stage0_block0_proj_shortcut(pool1)
bn_stage0_block0_branch2a = self.bn_stage0_block0_branch2a(conv_stage0_block0_branch2a)
bn_stage0_block0_proj_shortcut = self.bn_stage0_block0_proj_shortcut(conv_stage0_block0_proj_shortcut)
relu_stage0_block0_branch2a = F.relu(bn_stage0_block0_branch2a)
conv_stage0_block0_branch2b_pad = F.pad(relu_stage0_block0_branch2a, (1, 1, 1, 1))
conv_stage0_block0_branch2b = self.conv_stage0_block0_branch2b(conv_stage0_block0_branch2b_pad)
bn_stage0_block0_branch2b = self.bn_stage0_block0_branch2b(conv_stage0_block0_branch2b)
relu_stage0_block0_branch2b = F.relu(bn_stage0_block0_branch2b)
conv_stage0_block0_branch2c = self.conv_stage0_block0_branch2c(relu_stage0_block0_branch2b)
bn_stage0_block0_branch2c = self.bn_stage0_block0_branch2c(conv_stage0_block0_branch2c)
eltwise_stage0_block0 = self.eltwise_stage0_block0(bn_stage0_block0_proj_shortcut, bn_stage0_block0_branch2c) #Test
relu_stage0_block0 = F.relu(eltwise_stage0_block0)
conv_stage0_block1_branch2a = self.conv_stage0_block1_branch2a(relu_stage0_block0)
bn_stage0_block1_branch2a = self.bn_stage0_block1_branch2a(conv_stage0_block1_branch2a)
relu_stage0_block1_branch2a = F.relu(bn_stage0_block1_branch2a)
conv_stage0_block1_branch2b_pad = F.pad(relu_stage0_block1_branch2a, (1, 1, 1, 1))
conv_stage0_block1_branch2b = self.conv_stage0_block1_branch2b(conv_stage0_block1_branch2b_pad)
bn_stage0_block1_branch2b = self.bn_stage0_block1_branch2b(conv_stage0_block1_branch2b)
relu_stage0_block1_branch2b = F.relu(bn_stage0_block1_branch2b)
conv_stage0_block1_branch2c = self.conv_stage0_block1_branch2c(relu_stage0_block1_branch2b)
conv_stage0_block1_branch2c = self.conv_stage0_block1_branch2c(relu_stage0_block1_branch2b)
bn_stage0_block1_branch2c = self.bn_stage0_block1_branch2c(conv_stage0_block1_branch2c)
eltwise_stage0_block1 = relu_stage0_block0 + bn_stage0_block1_branch2c
relu_stage0_block1 = F.relu(eltwise_stage0_block1)
conv_stage0_block2_branch2a = self.conv_stage0_block2_branch2a(relu_stage0_block1)
bn_stage0_block2_branch2a = self.bn_stage0_block2_branch2a(conv_stage0_block2_branch2a)
relu_stage0_block2_branch2a = F.relu(bn_stage0_block2_branch2a)
conv_stage0_block2_branch2b_pad = F.pad(relu_stage0_block2_branch2a, (1, 1, 1, 1))
conv_stage0_block2_branch2b = self.conv_stage0_block2_branch2b(conv_stage0_block2_branch2b_pad)
bn_stage0_block2_branch2b = self.bn_stage0_block2_branch2b(conv_stage0_block2_branch2b)
relu_stage0_block2_branch2b = F.relu(bn_stage0_block2_branch2b)
conv_stage0_block2_branch2c = self.conv_stage0_block2_branch2c(relu_stage0_block2_branch2b)
bn_stage0_block2_branch2c = self.bn_stage0_block2_branch2c(conv_stage0_block2_branch2c)
eltwise_stage0_block2 = relu_stage0_block1 + bn_stage0_block2_branch2c
relu_stage0_block2 = F.relu(eltwise_stage0_block2)
conv_stage1_block0_proj_shortcut = self.conv_stage1_block0_proj_shortcut(relu_stage0_block2)
conv_stage1_block0_branch2a = self.conv_stage1_block0_branch2a(relu_stage0_block2)
bn_stage1_block0_proj_shortcut = self.bn_stage1_block0_proj_shortcut(conv_stage1_block0_proj_shortcut)
bn_stage1_block0_branch2a = self.bn_stage1_block0_branch2a(conv_stage1_block0_branch2a)
relu_stage1_block0_branch2a = F.relu(bn_stage1_block0_branch2a)
conv_stage1_block0_branch2b_pad = F.pad(relu_stage1_block0_branch2a, (1, 1, 1, 1))
conv_stage1_block0_branch2b = self.conv_stage1_block0_branch2b(conv_stage1_block0_branch2b_pad)
bn_stage1_block0_branch2b = self.bn_stage1_block0_branch2b(conv_stage1_block0_branch2b)
relu_stage1_block0_branch2b = F.relu(bn_stage1_block0_branch2b)
conv_stage1_block0_branch2c = self.conv_stage1_block0_branch2c(relu_stage1_block0_branch2b)
bn_stage1_block0_branch2c = self.bn_stage1_block0_branch2c(conv_stage1_block0_branch2c)
eltwise_stage1_block0 = bn_stage1_block0_proj_shortcut + bn_stage1_block0_branch2c
relu_stage1_block0 = F.relu(eltwise_stage1_block0)
conv_stage1_block1_branch2a = self.conv_stage1_block1_branch2a(relu_stage1_block0)
bn_stage1_block1_branch2a = self.bn_stage1_block1_branch2a(conv_stage1_block1_branch2a)
relu_stage1_block1_branch2a = F.relu(bn_stage1_block1_branch2a)
conv_stage1_block1_branch2b_pad = F.pad(relu_stage1_block1_branch2a, (1, 1, 1, 1))
conv_stage1_block1_branch2b = self.conv_stage1_block1_branch2b(conv_stage1_block1_branch2b_pad)
bn_stage1_block1_branch2b = self.bn_stage1_block1_branch2b(conv_stage1_block1_branch2b)
relu_stage1_block1_branch2b = F.relu(bn_stage1_block1_branch2b)
conv_stage1_block1_branch2c = self.conv_stage1_block1_branch2c(relu_stage1_block1_branch2b)
bn_stage1_block1_branch2c = self.bn_stage1_block1_branch2c(conv_stage1_block1_branch2c)
eltwise_stage1_block1 = relu_stage1_block0 + bn_stage1_block1_branch2c
relu_stage1_block1 = F.relu(eltwise_stage1_block1)
conv_stage1_block2_branch2a = self.conv_stage1_block2_branch2a(relu_stage1_block1)
bn_stage1_block2_branch2a = self.bn_stage1_block2_branch2a(conv_stage1_block2_branch2a)
relu_stage1_block2_branch2a = F.relu(bn_stage1_block2_branch2a)
conv_stage1_block2_branch2b_pad = F.pad(relu_stage1_block2_branch2a, (1, 1, 1, 1))
conv_stage1_block2_branch2b = self.conv_stage1_block2_branch2b(conv_stage1_block2_branch2b_pad)
bn_stage1_block2_branch2b = self.bn_stage1_block2_branch2b(conv_stage1_block2_branch2b)
relu_stage1_block2_branch2b = F.relu(bn_stage1_block2_branch2b)
conv_stage1_block2_branch2c = self.conv_stage1_block2_branch2c(relu_stage1_block2_branch2b)
bn_stage1_block2_branch2c = self.bn_stage1_block2_branch2c(conv_stage1_block2_branch2c)
eltwise_stage1_block2 = relu_stage1_block1 + bn_stage1_block2_branch2c
relu_stage1_block2 = F.relu(eltwise_stage1_block2)
conv_stage1_block3_branch2a = self.conv_stage1_block3_branch2a(relu_stage1_block2)
bn_stage1_block3_branch2a = self.bn_stage1_block3_branch2a(conv_stage1_block3_branch2a)
relu_stage1_block3_branch2a = F.relu(bn_stage1_block3_branch2a)
conv_stage1_block3_branch2b_pad = F.pad(relu_stage1_block3_branch2a, (1, 1, 1, 1))
conv_stage1_block3_branch2b = self.conv_stage1_block3_branch2b(conv_stage1_block3_branch2b_pad)
bn_stage1_block3_branch2b = self.bn_stage1_block3_branch2b(conv_stage1_block3_branch2b)
relu_stage1_block3_branch2b = F.relu(bn_stage1_block3_branch2b)
conv_stage1_block3_branch2c = self.conv_stage1_block3_branch2c(relu_stage1_block3_branch2b)
bn_stage1_block3_branch2c = self.bn_stage1_block3_branch2c(conv_stage1_block3_branch2c)
eltwise_stage1_block3 = self.eltwise_stage1_block3(relu_stage1_block2, bn_stage1_block3_branch2c)
relu_stage1_block3 = F.relu(eltwise_stage1_block3)
conv_stage2_block0_proj_shortcut = self.conv_stage2_block0_proj_shortcut(relu_stage1_block3)
conv_stage2_block0_branch2a = self.conv_stage2_block0_branch2a(relu_stage1_block3)
bn_stage2_block0_proj_shortcut = self.bn_stage2_block0_proj_shortcut(conv_stage2_block0_proj_shortcut)
bn_stage2_block0_branch2a = self.bn_stage2_block0_branch2a(conv_stage2_block0_branch2a)
relu_stage2_block0_branch2a = F.relu(bn_stage2_block0_branch2a)
conv_stage2_block0_branch2b_pad = F.pad(relu_stage2_block0_branch2a, (1, 1, 1, 1))
conv_stage2_block0_branch2b = self.conv_stage2_block0_branch2b(conv_stage2_block0_branch2b_pad)
bn_stage2_block0_branch2b = self.bn_stage2_block0_branch2b(conv_stage2_block0_branch2b)
relu_stage2_block0_branch2b = F.relu(bn_stage2_block0_branch2b)
conv_stage2_block0_branch2c = self.conv_stage2_block0_branch2c(relu_stage2_block0_branch2b)
bn_stage2_block0_branch2c = self.bn_stage2_block0_branch2c(conv_stage2_block0_branch2c)
eltwise_stage2_block0 = self.eltwise_stage2_block0(bn_stage2_block0_proj_shortcut, bn_stage2_block0_branch2c)
relu_stage2_block0 = F.relu(eltwise_stage2_block0)
conv_stage2_block1_branch2a = self.conv_stage2_block1_branch2a(relu_stage2_block0)
bn_stage2_block1_branch2a = self.bn_stage2_block1_branch2a(conv_stage2_block1_branch2a)
relu_stage2_block1_branch2a = F.relu(bn_stage2_block1_branch2a)
conv_stage2_block1_branch2b_pad = F.pad(relu_stage2_block1_branch2a, (1, 1, 1, 1))
conv_stage2_block1_branch2b = self.conv_stage2_block1_branch2b(conv_stage2_block1_branch2b_pad)
bn_stage2_block1_branch2b = self.bn_stage2_block1_branch2b(conv_stage2_block1_branch2b)
relu_stage2_block1_branch2b = F.relu(bn_stage2_block1_branch2b)
conv_stage2_block1_branch2c = self.conv_stage2_block1_branch2c(relu_stage2_block1_branch2b)
bn_stage2_block1_branch2c = self.bn_stage2_block1_branch2c(conv_stage2_block1_branch2c)
eltwise_stage2_block1 = self.eltwise_stage2_block1(relu_stage2_block0, bn_stage2_block1_branch2c)
relu_stage2_block1 = F.relu(eltwise_stage2_block1)
conv_stage2_block2_branch2a = self.conv_stage2_block2_branch2a(relu_stage2_block1)
bn_stage2_block2_branch2a = self.bn_stage2_block2_branch2a(conv_stage2_block2_branch2a)
relu_stage2_block2_branch2a = F.relu(bn_stage2_block2_branch2a)
conv_stage2_block2_branch2b_pad = F.pad(relu_stage2_block2_branch2a, (1, 1, 1, 1))
conv_stage2_block2_branch2b = self.conv_stage2_block2_branch2b(conv_stage2_block2_branch2b_pad)
bn_stage2_block2_branch2b = self.bn_stage2_block2_branch2b(conv_stage2_block2_branch2b)
relu_stage2_block2_branch2b = F.relu(bn_stage2_block2_branch2b)
conv_stage2_block2_branch2c = self.conv_stage2_block2_branch2c(relu_stage2_block2_branch2b)
bn_stage2_block2_branch2c = self.bn_stage2_block2_branch2c(conv_stage2_block2_branch2c)
eltwise_stage2_block2 = self.eltwise_stage2_block2(relu_stage2_block1, bn_stage2_block2_branch2c)
relu_stage2_block2 = F.relu(eltwise_stage2_block2)
conv_stage2_block3_branch2a = self.conv_stage2_block3_branch2a(relu_stage2_block2)
bn_stage2_block3_branch2a = self.bn_stage2_block3_branch2a(conv_stage2_block3_branch2a)
relu_stage2_block3_branch2a = F.relu(bn_stage2_block3_branch2a)
conv_stage2_block3_branch2b_pad = F.pad(relu_stage2_block3_branch2a, (1, 1, 1, 1))
conv_stage2_block3_branch2b = self.conv_stage2_block3_branch2b(conv_stage2_block3_branch2b_pad)
bn_stage2_block3_branch2b = self.bn_stage2_block3_branch2b(conv_stage2_block3_branch2b)
relu_stage2_block3_branch2b = F.relu(bn_stage2_block3_branch2b)
conv_stage2_block3_branch2c = self.conv_stage2_block3_branch2c(relu_stage2_block3_branch2b)
bn_stage2_block3_branch2c = self.bn_stage2_block3_branch2c(conv_stage2_block3_branch2c)
eltwise_stage2_block3 = self.eltwise_stage2_block3(relu_stage2_block2, bn_stage2_block3_branch2c)
relu_stage2_block3 = F.relu(eltwise_stage2_block3)
conv_stage2_block4_branch2a = self.conv_stage2_block4_branch2a(relu_stage2_block3)
bn_stage2_block4_branch2a = self.bn_stage2_block4_branch2a(conv_stage2_block4_branch2a)
relu_stage2_block4_branch2a = F.relu(bn_stage2_block4_branch2a)
conv_stage2_block4_branch2b_pad = F.pad(relu_stage2_block4_branch2a, (1, 1, 1, 1))
conv_stage2_block4_branch2b = self.conv_stage2_block4_branch2b(conv_stage2_block4_branch2b_pad)
bn_stage2_block4_branch2b = self.bn_stage2_block4_branch2b(conv_stage2_block4_branch2b)
relu_stage2_block4_branch2b = F.relu(bn_stage2_block4_branch2b)
conv_stage2_block4_branch2c = self.conv_stage2_block4_branch2c(relu_stage2_block4_branch2b)
bn_stage2_block4_branch2c = self.bn_stage2_block4_branch2c(conv_stage2_block4_branch2c)
eltwise_stage2_block4 = self.eltwise_stage2_block4(relu_stage2_block3, bn_stage2_block4_branch2c)
relu_stage2_block4 = F.relu(eltwise_stage2_block4)
conv_stage2_block5_branch2a = self.conv_stage2_block5_branch2a(relu_stage2_block4)
bn_stage2_block5_branch2a = self.bn_stage2_block5_branch2a(conv_stage2_block5_branch2a)
relu_stage2_block5_branch2a = F.relu(bn_stage2_block5_branch2a)
conv_stage2_block5_branch2b_pad = F.pad(relu_stage2_block5_branch2a, (1, 1, 1, 1))
conv_stage2_block5_branch2b = self.conv_stage2_block5_branch2b(conv_stage2_block5_branch2b_pad)
bn_stage2_block5_branch2b = self.bn_stage2_block5_branch2b(conv_stage2_block5_branch2b)
relu_stage2_block5_branch2b = F.relu(bn_stage2_block5_branch2b)
conv_stage2_block5_branch2c = self.conv_stage2_block5_branch2c(relu_stage2_block5_branch2b)
bn_stage2_block5_branch2c = self.bn_stage2_block5_branch2c(conv_stage2_block5_branch2c)
eltwise_stage2_block5 = self.eltwise_stage2_block5(relu_stage2_block4, bn_stage2_block5_branch2c) # Test
relu_stage2_block5 = F.relu(eltwise_stage2_block5)
conv_stage3_block0_proj_shortcut = self.conv_stage3_block0_proj_shortcut(relu_stage2_block5)
conv_stage3_block0_branch2a = self.conv_stage3_block0_branch2a(relu_stage2_block5)
bn_stage3_block0_proj_shortcut = self.bn_stage3_block0_proj_shortcut(conv_stage3_block0_proj_shortcut)
bn_stage3_block0_branch2a = self.bn_stage3_block0_branch2a(conv_stage3_block0_branch2a)
relu_stage3_block0_branch2a = F.relu(bn_stage3_block0_branch2a)
conv_stage3_block0_branch2b_pad = F.pad(relu_stage3_block0_branch2a, (1, 1, 1, 1))
conv_stage3_block0_branch2b = self.conv_stage3_block0_branch2b(conv_stage3_block0_branch2b_pad)
bn_stage3_block0_branch2b = self.bn_stage3_block0_branch2b(conv_stage3_block0_branch2b)
relu_stage3_block0_branch2b = F.relu(bn_stage3_block0_branch2b)
conv_stage3_block0_branch2c = self.conv_stage3_block0_branch2c(relu_stage3_block0_branch2b)
bn_stage3_block0_branch2c = self.bn_stage3_block0_branch2c(conv_stage3_block0_branch2c)
eltwise_stage3_block0 = self.eltwise_stage3_block0(bn_stage3_block0_proj_shortcut, bn_stage3_block0_branch2c)
relu_stage3_block0 = F.relu(eltwise_stage3_block0)
conv_stage3_block1_branch2a = self.conv_stage3_block1_branch2a(relu_stage3_block0)
bn_stage3_block1_branch2a = self.bn_stage3_block1_branch2a(conv_stage3_block1_branch2a)
relu_stage3_block1_branch2a = F.relu(bn_stage3_block1_branch2a)
conv_stage3_block1_branch2b_pad = F.pad(relu_stage3_block1_branch2a, (1, 1, 1, 1))
conv_stage3_block1_branch2b = self.conv_stage3_block1_branch2b(conv_stage3_block1_branch2b_pad)
bn_stage3_block1_branch2b = self.bn_stage3_block1_branch2b(conv_stage3_block1_branch2b)
relu_stage3_block1_branch2b = F.relu(bn_stage3_block1_branch2b)
conv_stage3_block1_branch2c = self.conv_stage3_block1_branch2c(relu_stage3_block1_branch2b)
bn_stage3_block1_branch2c = self.bn_stage3_block1_branch2c(conv_stage3_block1_branch2c)
eltwise_stage3_block1 = self.eltwise_stage3_block1(relu_stage3_block0, bn_stage3_block1_branch2c)
relu_stage3_block1 = F.relu(eltwise_stage3_block1)
conv_stage3_block2_branch2a = self.conv_stage3_block2_branch2a(relu_stage3_block1)
bn_stage3_block2_branch2a = self.bn_stage3_block2_branch2a(conv_stage3_block2_branch2a)
relu_stage3_block2_branch2a = F.relu(bn_stage3_block2_branch2a)
conv_stage3_block2_branch2b_pad = F.pad(relu_stage3_block2_branch2a, (1, 1, 1, 1))
conv_stage3_block2_branch2b = self.conv_stage3_block2_branch2b(conv_stage3_block2_branch2b_pad)
bn_stage3_block2_branch2b = self.bn_stage3_block2_branch2b(conv_stage3_block2_branch2b)
relu_stage3_block2_branch2b = F.relu(bn_stage3_block2_branch2b)
conv_stage3_block2_branch2c = self.conv_stage3_block2_branch2c(relu_stage3_block2_branch2b)
bn_stage3_block2_branch2c = self.bn_stage3_block2_branch2c(conv_stage3_block2_branch2c)
eltwise_stage3_block2 = self.eltwise_stage3_block2(relu_stage3_block1, bn_stage3_block2_branch2c)
relu_stage3_block2 = F.relu(eltwise_stage3_block2)
avgpool_2d = nn.AdaptiveAvgPool2d((7, 7))
relu_stage3_block2 = avgpool_2d(relu_stage3_block2)
pool = F.avg_pool2d(relu_stage3_block2, kernel_size=(7, 7), stride=(1, 1), padding=(0,), ceil_mode=False, count_include_pad=False)
fc_nsfw_0 = pool.view(pool.size(0), -1)
fc_nsfw_1 = self.fc_nsfw_1(fc_nsfw_0)
prob = F.softmax(fc_nsfw_1, dim=1)
return prob
def resnet_layer_names(mode):
if mode == '50_1by2_nsfw':
layers = ['conv_1', 'bn_1', 'conv_stage0_block0_branch2a', 'conv_stage0_block0_proj_shortcut', 'bn_stage0_block0_branch2a', 'bn_stage0_block0_proj_shortcut', \
'conv_stage0_block0_branch2b', 'bn_stage0_block0_branch2b', 'conv_stage0_block0_branch2c', 'bn_stage0_block0_branch2c', 'conv_stage0_block1_branch2a', \
'bn_stage0_block1_branch2a', 'conv_stage0_block1_branch2b', 'bn_stage0_block1_branch2b', 'conv_stage0_block1_branch2c', 'bn_stage0_block1_branch2c', \
'conv_stage0_block2_branch2a', 'bn_stage0_block2_branch2a', 'conv_stage0_block2_branch2b', 'bn_stage0_block2_branch2b', 'conv_stage0_block2_branch2c', \
'bn_stage0_block2_branch2c', 'conv_stage1_block0_proj_shortcut', 'conv_stage1_block0_branch2a', 'bn_stage1_block0_proj_shortcut', 'bn_stage1_block0_branch2a', \
'conv_stage1_block0_branch2b', 'bn_stage1_block0_branch2b', 'conv_stage1_block0_branch2c', 'bn_stage1_block0_branch2c', 'conv_stage1_block1_branch2a', \
'bn_stage1_block1_branch2a', 'conv_stage1_block1_branch2b', 'bn_stage1_block1_branch2b', 'conv_stage1_block1_branch2c', 'bn_stage1_block1_branch2c', \
'conv_stage1_block2_branch2a', 'bn_stage1_block2_branch2a', 'conv_stage1_block2_branch2b', 'bn_stage1_block2_branch2b', 'conv_stage1_block2_branch2c', \
'bn_stage1_block2_branch2c', 'conv_stage1_block3_branch2a', 'bn_stage1_block3_branch2a', 'conv_stage1_block3_branch2b', 'bn_stage1_block3_branch2b', \
'conv_stage1_block3_branch2c', 'bn_stage1_block3_branch2c', 'conv_stage2_block0_proj_shortcut', 'conv_stage2_block0_branch2a', 'bn_stage2_block0_proj_shortcut', \
'bn_stage2_block0_branch2a', 'conv_stage2_block0_branch2b', 'bn_stage2_block0_branch2b', 'conv_stage2_block0_branch2c', 'bn_stage2_block0_branch2c', \
'conv_stage2_block1_branch2a', 'bn_stage2_block1_branch2a', 'conv_stage2_block1_branch2b', 'bn_stage2_block1_branch2b', 'conv_stage2_block1_branch2c', \
'bn_stage2_block1_branch2c', 'conv_stage2_block2_branch2a', 'bn_stage2_block2_branch2a', 'conv_stage2_block2_branch2b', 'bn_stage2_block2_branch2b', \
'conv_stage2_block2_branch2c', 'bn_stage2_block2_branch2c', 'conv_stage2_block3_branch2a', 'bn_stage2_block3_branch2a', 'conv_stage2_block3_branch2b', \
'bn_stage2_block3_branch2b', 'conv_stage2_block3_branch2c', 'bn_stage2_block3_branch2c', 'conv_stage2_block4_branch2a', 'bn_stage2_block4_branch2a', \
'conv_stage2_block4_branch2b', 'bn_stage2_block4_branch2b', 'conv_stage2_block4_branch2c', 'bn_stage2_block4_branch2c', 'conv_stage2_block5_branch2a', \
'bn_stage2_block5_branch2a', 'conv_stage2_block5_branch2b', 'bn_stage2_block5_branch2b', 'conv_stage2_block5_branch2c', 'bn_stage2_block5_branch2c', \
'conv_stage3_block0_proj_shortcut', 'conv_stage3_block0_branch2a', 'bn_stage3_block0_proj_shortcut', 'bn_stage3_block0_branch2a', 'conv_stage3_block0_branch2b', \
'bn_stage3_block0_branch2b', 'conv_stage3_block0_branch2c', 'bn_stage3_block0_branch2c', 'conv_stage3_block1_branch2a', 'bn_stage3_block1_branch2a', \
'conv_stage3_block1_branch2b', 'bn_stage3_block1_branch2b', 'conv_stage3_block1_branch2c', 'bn_stage3_block1_branch2c', 'conv_stage3_block2_branch2a', \
'bn_stage3_block2_branch2a', 'conv_stage3_block2_branch2b', 'bn_stage3_block2_branch2b', 'conv_stage3_block2_branch2c', 'bn_stage3_block2_branch2c', 'fc_nsfw_1', \
'eltwise_stage0_block0', 'eltwise_stage0_block1', 'eltwise_stage0_block2', 'eltwise_stage1_block0', 'eltwise_stage1_block1', 'eltwise_stage1_block2', \
'eltwise_stage1_block3', 'eltwise_stage2_block0', 'eltwise_stage2_block1', 'eltwise_stage2_block2', 'eltwise_stage2_block3', 'eltwise_stage2_block4', \
'eltwise_stage2_block5', 'eltwise_stage3_block0', 'eltwise_stage3_block1', 'eltwise_stage3_block2']
return layers
@ProGamerGov
Copy link
Author

ProGamerGov commented Jan 5, 2020

@ProGamerGov
Copy link
Author

You can find a label file for the VGG models here: https://gist.github.com/ProGamerGov/d4dd8dada7fa03e5558a091f3f284353

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment