Skip to content

Instantly share code, notes, and snippets.

@poutyface
Created April 4, 2016 23:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save poutyface/c1818e43e96d550deea0ceef48091eff to your computer and use it in GitHub Desktop.
Save poutyface/c1818e43e96d550deea0ceef48091eff to your computer and use it in GitHub Desktop.
m2
#!/usr/bin/env python
import os
import numpy as np
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import cuda
from chainer import Chain
from chainer import optimizers
from chainer import serializers
import cv2
import dataset
print "VAE p(x|y,z)"
data_dir = "./dataset/"
mnist = dataset.load_mnist_data(data_dir)
all_x = np.array(mnist['data'], dtype=np.float32) / 255.0
all_y_tmp = np.array(mnist['target'], dtype=np.float32)
all_y = np.zeros((all_x.shape[0], (np.max(all_y_tmp) + 1.0)), dtype=np.float32)
for i in range(all_y_tmp.shape[0]):
all_y[i][all_y_tmp[i]] = 1.0
train_x = all_x[:50000]
train_y = all_y[:50000]
valid_x = all_x[50000:60000]
valid_y = all_y[50000:60000]
test_x = all_x[60000:]
test_y = all_y[60000:]
image_size = 28
nx = image_size * image_size
nbatch = 10
nz = 300
ny = 10
class VAE(Chain):
def __init__(self):
super(VAE, self).__init__(
recog_x1 = L.Linear(nx, 500, nobias=True),
recog_x2 = L.Linear(500, 500),
recog_y1 = L.Linear(ny, 500),
recog_y2 = L.Linear(500, 500),
recog_x_y = L.Linear(500, 500),
recog_mean = L.Linear(500, nz),
recog_log_sigma = L.Linear(500, nz),
gen_y1= L.Linear(ny, 500),
gen_y2 = L.Linear(500, 500),
gen_z1 = L.Linear(nz, 500, nobias=True),
gen_z2 = L.Linear(500, 500),
gen_z_y = L.Linear(500, 500),
gen = L.Linear(500, nx),
#bn1 = L.BatchNormalization(500),
#bn2 = L.BatchNormalization(500),
#bn3 = L.BatchNormalization(ndf*8),
#gen_log_sigma = L.Liner(500, n_input)
)
def generate_z(self, x, y):
# q(z|x,y)
hx = F.relu(self.recog_x1(x))
hx = F.relu(self.recog_x2(hx))
hy = F.relu(self.recog_y1(y))
hy = F.relu(self.recog_y2(hy))
hq = F.relu(self.recog_x_y(hx + hy))
recog_mean = self.recog_mean(hq)
recog_log_sigma = 0.5 * self.recog_log_sigma(hq)
eps = np.random.normal(0, 1, (x.data.shape[0], nz)).astype(np.float32)
eps = chainer.Variable(eps)
# z = mu + sigma + epsilon
#z = recog_mean + F.exp(0.5 * recog_log_sigma) * eps
#z = recog_mean + F.exp(recog_log_sigma) * eps
z = recog_mean + F.exp(recog_log_sigma) * eps
return z, recog_mean, recog_log_sigma
def generate_x(self, z, y):
# q(x|y,z)
hy = F.relu(self.gen_y1(y))
hy = F.relu(self.gen_y2(hy))
hz = F.relu(self.gen_z1(z))
hz = F.relu(self.gen_z2(hz))
hp = F.relu(self.gen_z_y(hy + hz))
output = self.gen(hp)
output = F.sigmoid(output)
return output
def generate(self, x, y):
z, _, _ = self.generate_z(x, y)
output = np.zeros((y.data.shape[1], x.data.shape[1]), dtype=np.float32)
for label in range(y.data.shape[1]):
sample_y = np.zeros((1, y.data.shape[1]), dtype=np.float32)
sample_y[0][label] = 1.0
out = self.generate_x(z, chainer.Variable(sample_y))
output[label] = out.data
return output
def __call__(self, x, y):
z, recog_mean, recog_log_sigma = self.generate_z(x, y)
output = self.generate_x(z, y)
loss = F.mean_squared_error(output, x)
kld = -0.5 * F.sum(1 + recog_log_sigma - recog_mean**2 - F.exp(recog_log_sigma)) / (x.data.shape[0] * x.data.shape[1])
return loss, kld, output
class Disc(Chain):
def __init__(self):
super(Disc, self).__init__(
bn1 = L.BatchNormalization(ndf*2),
bn2 = L.BatchNormalization(ndf*4),
bn3 = L.BatchNormalization(ndf*8),
c1 = L.Convolution2D(nc, ndf, ksize=4, stride=2, pad=1),
c2 = L.Convolution2D(ndf, ndf*2, ksize=4, stride=2, pad=1),
c3 = L.Convolution2D(ndf*2, ndf*4, ksize=4, stride=2, pad=1),
c4 = L.Convolution2D(ndf*4, ndf*8, ksize=4, stride=2, pad=1),
l1 = L.Linear(ndf*8*6*6, 1)
)
def __call__(self, x, test=False):
h1 = F.leaky_relu(self.c1(x))
h2 = F.leaky_relu(self.bn1(self.c2(h1), test=test))
h3 = F.leaky_relu(self.bn2(self.c3(h2), test=test))
h4 = F.leaky_relu(self.bn3(self.c4(h3), test=test))
#h2 = F.leaky_relu(self.c2(h1))
#h3 = F.leaky_relu(self.c3(h2))
#h4 = F.leaky_relu(self.c4(h3))
#h5 = F.average_pooling_2d(h4, 4)
#h5 = self.l1(h4)
h5 = self.l1(h4)
print x.data.shape
print h1.data.shape
print h2.data.shape
print h3.data.shape
print h4.data.shape
print h5.data.shape
#print h6.data.shape
return h5
#image_path = "./lfwcrop_grey/faces"
#fs = os.listdir(image_path)
#print len(fs)
#dataset = []
#for fn in fs:
# read as grey
# img = cv2.imread('%s/%s'%(image_path, fn), 0)
# img = cv2.resize(img, (image_size,image_size))
# img = img.astype(np.float32)
# img = img / 255
# img = img.reshape(image_size*image_size)
# dataset.append(img)
vae = VAE()
opt = optimizers.Adam(alpha=0.0002, beta1=0.5)
opt.setup(vae)
indexes = np.random.permutation(train_x.shape[0])
for epoch in xrange(500000):
print "epoch:", epoch
for i in xrange(0, train_x.shape[0], nbatch):
x_batch = train_x[indexes[i:i+nbatch]]
y_batch = train_y[indexes[i:i+nbatch]]
# VAE
recog_loss, kld_loss, output = vae(chainer.Variable(x_batch), chainer.Variable(y_batch))
loss = recog_loss + kld_loss
print loss.data, recog_loss.data, kld_loss.data
vae.zerograds()
loss.backward()
opt.update()
x_batch = np.zeros((1, train_x.shape[1]), dtype=np.float32)
y_batch = np.zeros((1, train_y.shape[1]), dtype=np.float32)
x_batch[0] = train_x[1]
y_batch[0] = train_y[1]
output = vae.generate(chainer.Variable(x_batch), chainer.Variable(y_batch))
img = train_x[1]
img = img * 255
img = img.reshape(image_size, image_size)
img = img.astype(np.uint8)
cv2.imshow("input", img)
for i in range(0, output.shape[0]):
img = output[i]
img = img * 255
img = img.reshape(image_size, image_size)
img = img.astype(np.uint8)
cv2.imshow("%d"%i, img)
cv2.waitKey(1)
## for j in range(0, 3):
# img = output.data[j]
# img = img * 255
# img = img.reshape(image_size, image_size)
# img = img.astype(np.uint8)
# cv2.imshow("%d"%j, img)
# cv2.waitKey(1)
# if epoch % 1000 == 0:
# for j in range(0, nbatch):
# img = output.data[j]
# img = img * 255
# img = img.reshape(image_size, image_size)
# img = img.astype(np.uint8)
# cv2.imwrite("out_images/%d_%d.jpg"%(epoch, j), img)
# serializers.save_hdf5("out_models/model_%d.h5"%(epoch), vae)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment