Skip to content

Instantly share code, notes, and snippets.

@d02k01
Last active December 13, 2015 12:12
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save d02k01/c3686d15f8894558e2b3 to your computer and use it in GitHub Desktop.
Save d02k01/c3686d15f8894558e2b3 to your computer and use it in GitHub Desktop.
deep-goggle-chainer (port https://github.com/aravindhm/deep-goggle to python)
#!/usr/bin/env python
import numpy as np
from PIL import Image
import chainer
from chainer import cuda
from chainer import optimizers
import chainer.functions as F
from chainer.links import caffe
gpu = 0
if gpu >= 0:
cuda.check_cuda_available()
xp = cuda.cupy if gpu >= 0 else np
# Take a few minutes
func = caffe.CaffeFunction('bvlc_reference_caffenet.caffemodel')
if gpu >= 0:
cuda.get_device(gpu).use()
func.to_gpu()
image_size = 256
in_size = 227
def extract_feature(x, layer):
y, = func(inputs={'data': x}, outputs=[layer], train=False)
return y
mean_image = np.load('ilsvrc_2012_mean.npy').astype('f')
cropwidth = image_size - in_size
start = cropwidth // 2
stop = start + in_size
mean_image = mean_image[:, start:stop, start:stop].copy()
def preprocess(pil_image, resize=True):
if resize:
pil_image = pil_image.resize((in_size, in_size))
pil_image = pil_image.convert('RGB')
in_ = np.asarray(pil_image, dtype='f')
in_ = in_.transpose(2, 0, 1)
in_ = in_[::-1]
in_ -= mean_image
return in_
def deprocess(in_):
in_ = in_.copy()
in_ += mean_image
in_ = in_[::-1]
in_ = in_.transpose(1, 2, 0)
pil_image = Image.fromarray(np.clip(in_, 0, 255).astype(np.uint8))
return pil_image
image_path = 'dataset/ILSVRC2012_val_00000013.JPEG'
# image_path = 'dataset/ILSVRC2012_val_00002012.JPEG'
in_ = preprocess(Image.open(image_path))
# x0_sigma: average Euclidean norm of natural images in a training set.
# You can compute this as follows:
# norms = [np.linalg.norm(preprocess(pil_image)) for pil_image in IMAGES]
# x0_sigma = np.mean(norms).astype('f')
# In this case, use the value distributed by the author:
# https://github.com/aravindhm/deep-goggle/blob/master/experiments/x0_sigma.mat
x0_sigma = 27098.11571533
# Initialize parameters
x0_data = xp.asarray(in_[np.newaxis].copy())
x0 = chainer.Variable(x0_data, volatile=True)
x_data = np.random.randn(*x0_data.shape).astype('f')
x_data = x_data / np.linalg.norm(x_data) * x0_sigma
x_data = xp.asarray(x_data)
# Extract feature from target image
layer = 'conv3'
y0 = extract_feature(x0, layer)
y0.volatile = False
y0_sigma = np.linalg.norm(cuda.to_cpu(y0.data))
# Define hyper-parameters
learning_rate = 0.004 * np.array([1] * 100
+ [0.1] * 200
+ [0.01] * 100, dtype='f')
momentum = 0.9
lambda_euc = x0_sigma**2 / y0_sigma**2
beta = 2
lambda_tv = 0.5
if layer in ['conv1', 'pool1', 'norm1', 'conv2']:
lambda_tv *= 1
elif layer in ['pool2', 'norm2', 'conv3', 'conv4']:
lambda_tv *= 10
elif layer in ['conv5', 'pool5', 'fc6', 'fc7', 'fc8']:
lambda_tv *= 100
else:
print('The layer does not exist.')
# Total variation can be expressed as a combination of the two filters
Wh_data = xp.array([[[[1],[-1]]]], dtype='f')
Ww_data = xp.array([[[[1, -1]]]], dtype='f')
Wh = chainer.Variable(Wh_data)
Ww = chainer.Variable(Ww_data)
tvh = lambda x: F.convolution_2d(x, W=Wh)
tvw = lambda x: F.convolution_2d(x, W=Ww)
def tv_norm(x, beta=2):
diffh = tvh(F.reshape(x, (3, 1, in_size, in_size)))
diffw = tvw(F.reshape(x, (3, 1, in_size, in_size)))
tv = (F.sum(diffh**2) + F.sum(diffw**2))**(beta / 2.)
return tv
p = 6
lambda_lp = 4e-10
def lossfun(x, layer):
y = extract_feature(x, layer)
loss = (lambda_euc * float(y0.data.size) * F.mean_squared_error(y0, y)
+ lambda_tv * tv_norm(x, beta)
+ lambda_lp * F.sum(x**p))
return loss
opt = optimizers.MomentumSGD(momentum=momentum)
lr_prev = 0
x_link = chainer.links.Parameter(x_data)
opt.setup(x_link)
for lr in learning_rate:
if lr != lr_prev:
opt.lr = lr
lr_prev = lr
opt.momentum = 0
else:
opt.momentum = 0.9
x = x_link.W
opt.update(lossfun, x, layer)
result = deprocess(cuda.to_cpu(x.data)[0])
result.save('result.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment