Skip to content

Instantly share code, notes, and snippets.

@neelriyer
Created August 10, 2020 00:25
Show Gist options
  • Save neelriyer/d5164270769c653c38737c550a9a9abb to your computer and use it in GitHub Desktop.
Save neelriyer/d5164270769c653c38737c550a9a9abb to your computer and use it in GitHub Desktop.
create cnn for neural style transfer
import torch
import torch.nn as nn
from torch.nn import ReLU, Conv1d
import torch.optim as optim
import numpy as np
import copy
class CNNModel(nn.Module):
def __init__(self):
super(CNNModel, self).__init__()
self.cnn1 = Conv1d(in_channels=1025, out_channels=4096, kernel_size=3, stride=1, padding=1)
self.relu = ReLU()
self.cnn2 = Conv1d(in_channels=4096, out_channels=4096, kernel_size=3, stride=1, padding=1)
def forward(self, x):
out = self.cnn1(x)
out = self.relu(out)
out = self.cnn2(x)
return out
cnn = CNNModel()
if torch.cuda.is_available():
cnn = cnn.cuda()
style_weight=1000
content_weight = 2
def get_style_model_and_losses(cnn, style_float,\
content_float=content_float,\
style_weight=style_weight):
cnn = copy.deepcopy(cnn)
style_losses = []
content_losses = []
# create model
model = nn.Sequential()
# we need a gram module in order to compute style targets
gram = GramMatrix()
# load onto gpu
if torch.cuda.is_available():
model = model.cuda()
gram = gram.cuda()
# add conv1
model.add_module('conv_1', cnn.cnn1)
# add relu
model.add_module('relu1', cnn.relu)
# add conv2
model.add_module('conv_2', cnn.cnn2)
# add style loss
target_feature = model(style_float).clone()
target_feature_gram = gram(target_feature)
style_loss = StyleLoss(target_feature_gram, style_weight)
model.add_module("style_loss_1", style_loss)
style_losses.append(style_loss)
# add content loss
target = model(content_float).detach()
content_loss = ContentLoss(target, content_weight)
model.add_module("content_loss_1", content_loss)
content_losses.append(content_loss)
return model, style_losses, content_losses
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment