-
-
Save ProGamerGov/ad11076ad743677e8ab993cc1930af9b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import torch | |
import torch.nn as nn | |
from CaffeLoader import loadCaffemodel | |
from PIL import Image | |
import torchvision.transforms as transforms | |
import torchvision.models as models | |
torch.backends.cudnn.enabled = True | |
torch.backends.cudnn.benchmark = True | |
# Define an nn Module to compute DeepDream loss | |
class DeepDreamBCELoss(torch.nn.Module): | |
def __init__(self, strength, label=-1): | |
super(DeepDreamBCELoss, self).__init__() | |
self.strength = strength | |
self.label = label | |
self.crit = torch.nn.BCEWithLogitsLoss() | |
self.mode = 'None' | |
def forward(self, input): | |
if self.mode == 'loss': | |
self.loss = -self.crit(input, target) * self.strength | |
elif self.mode == 'capture': | |
self.target = torch.zeros(input.size()) | |
if self.label !=-1: | |
self.target[0, self.label] = 100 | |
return input | |
# Preprocess an image before passing it to a model. | |
# We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR, | |
# and subtract the mean pixel. | |
def preprocess(image_name, image_size): | |
image = Image.open(image_name).convert('RGB') | |
if type(image_size) is not tuple: | |
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)]) | |
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()]) | |
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])]) | |
Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])]) | |
tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0) | |
return tensor | |
# Undo the above preprocessing. | |
def deprocess(output_tensor, name): | |
Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])]) | |
bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])]) | |
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256 | |
output_tensor.clamp_(0, 1) | |
Image2PIL = transforms.ToPILImage() | |
image = Image2PIL(output_tensor.cpu()) | |
image.save(name+'.png') | |
def simple_preprocess(image_name, image_size): | |
image = Image.open(image_name).convert('RGB') | |
if type(image_size) is not tuple: | |
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)]) | |
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()]) | |
tensor = Loader(image).unsqueeze(0) | |
return tensor | |
# Undo the above preprocessing. | |
def simple_deprocess(output_tensor,name): | |
output_tensor = output_tensor.detach().squeeze(0).cpu() | |
output_tensor.clamp_(0, 1) | |
Image2PIL = transforms.ToPILImage() | |
image = Image2PIL(output_tensor.cpu()) | |
image.save(name+'.png') | |
# Rescale tensor with scale factor | |
def rescale_tensor(tensor, scale_factor): | |
tensor = torch.nn.functional.interpolate(tensor, scale_factor=scale_factor) | |
return tensor | |
# Resize tensor | |
def resize_tensor(tensor, size): | |
if type(size) is not tuple: | |
size = (size, size) | |
if tensor.dim() == 3: | |
tensor, dim_val = tensor.unsqueeze(0), 3 | |
else: | |
dim_val = 4 | |
tensor = torch.nn.functional.interpolate(tensor, size=size) | |
if dim_val == 3: | |
tensor = tensor.squeeze(0) | |
return tensor | |
def setup_net(cnn, layerList, dream_layers, dream_weight, l_mode, channels, tv_weight=0): | |
# Set up the network, inserting style and dream loss modules | |
cnn = copy.deepcopy(cnn.features) | |
dtype=torch.cuda.FloatTensor | |
dream_layers = dream_layers.split(',') | |
dream_losses, tv_losses = [], [] | |
next_dream_idx = 1 | |
net = nn.Sequential() | |
c, r = 0, 0 | |
if tv_weight > 0: | |
tv_mod = TVLoss(tv_weight).type(dtype) | |
net.add_module(str(len(net)), tv_mod) | |
tv_losses.append(tv_mod) | |
for i, layer in enumerate(list(cnn), 1): | |
if next_dream_idx <= len(dream_layers): | |
if isinstance(layer, nn.Conv2d): | |
net.add_module(str(len(net)), layer) | |
if layerList['C'][c] in dream_layers: | |
print("Setting up dream layer " + str(i) + ": " + str(layerList['C'][c])) | |
loss_module = DreamLoss(loss_mode=l_mode, strength=dream_weight, channels=channels) | |
net.add_module(str(len(net)), loss_module) | |
dream_losses.append(loss_module) | |
c+=1 | |
if isinstance(layer, nn.ReLU): | |
net.add_module(str(len(net)), layer) | |
if layerList['R'][r] in dream_layers: | |
print("Setting up dream layer " + str(i) + ": " + str(layerList['R'][r])) | |
loss_module = DreamLoss(loss_mode=l_mode, strength=dream_weight, channels=channels) | |
net.add_module(str(len(net)), loss_module) | |
dream_losses.append(loss_module) | |
next_dream_idx += 1 | |
r+=1 | |
if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d): | |
net.add_module(str(len(net)), layer) | |
return net, dream_losses, tv_losses | |
def zero_tensor(tensor): | |
return (tensor.squeeze(0) * 0).unsqueeze(0) | |
class DreamLossType(torch.nn.Module): | |
def __init__(self, loss_mode, channels): | |
super(DreamLossType, self).__init__() | |
self.get_mode(loss_mode) | |
self.channels = channels | |
def get_mode(self, loss_mode): | |
if loss_mode.lower() == 'norm': | |
self.loss_mode = self.norm_loss | |
elif loss_mode.lower() == 'mean': | |
self.loss_mode = self.mean_loss | |
elif loss_mode.lower() == 'mse': | |
self.crit = torch.nn.MSELoss() | |
self.loss_mode = self.crit_loss | |
elif loss_mode.lower() == 'bce': | |
self.crit = torch.nn.BCEWithLogitsLoss() | |
self.loss_mode = self.crit_loss | |
def norm_loss(self, input): | |
return self.ch(input).norm() | |
def mean_loss(self, input): | |
return self.ch(input).norm() | |
def crit_loss(self, input): | |
target = zero_tensor(self.ch(input.detach())) | |
loss = self.crit(input, target) | |
return loss | |
def ch(self, input): | |
if '-1' not in self.channels: | |
for c in self.channels: | |
#input[0, int(c)] = 100 | |
input = input[0, int(c)] | |
return input | |
def forward(self, input): | |
loss = self.loss_mode(input) | |
return loss | |
class DreamLoss(torch.nn.Module): | |
def __init__(self, loss_mode, strength, channels): | |
super(DreamLoss, self).__init__() | |
self.dream = DreamLossType(loss_mode, channels.split(',')) | |
self.strength = strength | |
self.mode = 'None' | |
def forward(self, input): | |
if self.mode == 'loss': | |
self.loss = self.dream(input) * self.strength | |
return input | |
cnn, layerList = loadCaffemodel("vgg19-dcbb9e9d.pth", 'max', 0, True, False) | |
print('Model loaded') | |
net, dream_losses, tv_losses = setup_net(cnn, layerList, 'relu1_1,relu2_1,relu4_2,relu5_1', \ | |
dream_weight=4000, l_mode='bce', channels='-1', tv_weight=0) | |
print('DreamCNN') | |
img = preprocess('golden_gate.jpg', 512).cuda() | |
#img = preprocess('dream_test_input3.png', 512).cuda() | |
for i in dream_losses: | |
i.mode = 'loss' | |
for param in net.parameters(): | |
param.requires_grad = False | |
img = torch.nn.Parameter(img) | |
learning_rate=0.1 | |
strength = 1000 | |
iter = 1000 | |
optimizer = torch.optim.Adam([img], lr=learning_rate) | |
print(img.sum()) | |
print('Running Closure') | |
num_calls = [0] | |
def closure(): | |
tv_weight = 0 | |
num_calls[0] += 1 | |
optimizer.zero_grad() | |
loss = 0 | |
net(img) | |
for mod in dream_losses: | |
loss += -mod.loss | |
if tv_weight > 0: | |
for mod in tv_losses: | |
loss += mod.loss | |
loss.backward() | |
print('Iter: ', num_calls[0], 'Total loss: ', loss.item()) | |
while num_calls[0] <= iter: | |
optimizer.step(closure) | |
print(img.sum()) | |
deprocess(img, 'dream_test') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment