Skip to content

Instantly share code, notes, and snippets.

@hanzhanggit
Created December 15, 2017 02:00
Show Gist options
  • Save hanzhanggit/478769db513d5b0e6c6d186c295477a3 to your computer and use it in GitHub Desktop.
Save hanzhanggit/478769db513d5b0e6c6d186c295477a3 to your computer and use it in GitHub Desktop.
GAN with normalization
import os
import time
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from depot import inits
from depot.utils import find_trainable_variables, find_variables, iter_data, shuffle
from depot.vis import color_grid
from utils.inception import get_inception_score
import sys
# from depot.load import cifar10_with_valid_set
desc = 'dog_128_normalize'
t = time.time()
trX = np.load('/home/hanzhang/Data/imagenet/doggonet_128px_imgs.npy')
print('%.3f seconds to load'%(time.time()-t))
ntrain = len(trX)
print(trX.shape)
# X = tf.placeholder(tf.float32, [128, 64, 64, 3])
X = tf.placeholder(tf.float32, [128, 128, 128, 3])
# X = tf.placeholder(tf.float32, [128, 32, 32, 3])
Z = tf.placeholder(tf.float32, [128, 100])
test_batches = 5000 // (128) + 1
bn_updates = []
file_name = './' + desc + '_log.txt'
file = open(file_name, 'w+')
def lrelu(x, leak=0.2):
f1 = 0.5 * (1 + leak)
f2 = 0.5 * (1 - leak)
return f1 * x + f2 * tf.abs(x)
def glu(x):
dim = len(x.get_shape())-1
a, b = tf.split(x, 2, dim)
return a*tf.nn.sigmoid(b)
def _bn(x, g, b, e=1e-5, axes=[1], ema=None):
shape = [s.value for s in x.get_shape()]
for axis in axes:
shape[axis] = 1
uv = tf.get_variable("u", shape, initializer=inits.constant_init(0.0), trainable=False)
sv = tf.get_variable("s", shape, initializer=inits.constant_init(1.0), trainable=False)
if ema is not None:
u = ema.average(uv)
s = ema.average(sv)
else:
u, s = tf.nn.moments(x, axes=axes, keep_dims=True)
bn_updates.append(uv.assign(u))
bn_updates.append(sv.assign(s))
x = (x-u)/tf.sqrt(s+e)
x = x*g+b
return x
def conv(x, scope, rf, nf, act, stride=1, pad='SAME', winit=inits.ortho_init(1.0), binit=inits.constant_init(0.0), ema=None):
with tf.variable_scope(scope):
nin = x.get_shape()[-1].value
w = tf.get_variable("w", [rf, rf, nin, nf], initializer=winit)
b = tf.get_variable("b", [nf], initializer=binit)
if ema is not None:
w = ema.average(w)
b = ema.average(b)
z = tf.nn.conv2d(x, w, [1, stride, stride, 1], padding=pad)
z = z+b
h = act(z)
return h
def bnconv(x, scope, rf, nf, act, stride=1, pad='SAME', winit=inits.ortho_init(1.0), ema=None):
with tf.variable_scope(scope):
nin = x.get_shape()[-1].value
w = tf.get_variable("w", [rf, rf, nin, nf], initializer=winit)
g = tf.get_variable("g", [nf], initializer=inits.constant_init(1.0))
b = tf.get_variable("b", [nf], initializer=inits.constant_init(0.0))
if ema is not None:
w = ema.average(w)
g = ema.average(g)
b = ema.average(b)
z = tf.nn.conv2d(x, w, [1, stride, stride, 1], padding=pad)
z = _bn(z, g, b, axes=[0, 1, 2], ema=ema)
h = act(z)
return h
def deconv(x, scope, shape, rf, nf, act, stride=2, pad='SAME', winit=inits.ortho_init(1.0), binit=inits.constant_init(0.0), ema=None):
with tf.variable_scope(scope):
nin = x.get_shape()[-1].value
w = tf.get_variable("w", [rf, rf, nf, nin], initializer=winit)
b = tf.get_variable("b", [nf], initializer=binit)
if ema is not None:
w = ema.average(w)
b = ema.average(b)
z = tf.nn.conv2d_transpose(x, w, shape, [1, stride, stride, 1], padding=pad)
z = z+b
h = act(z)
return h
def bndeconv(x, scope, shape, rf, nf, act, stride=2, pad='SAME', winit=inits.ortho_init(1.0), ema=None):
with tf.variable_scope(scope):
nin = x.get_shape()[-1].value
w = tf.get_variable("w", [rf, rf, nf, nin], initializer=winit)
g = tf.get_variable("g", [nf], initializer=inits.constant_init(1.0))
b = tf.get_variable("b", [nf], initializer=inits.constant_init(0.0))
if ema is not None:
w = ema.average(w)
g = ema.average(g)
b = ema.average(b)
z = tf.nn.conv2d_transpose(x, w, shape, [1, stride, stride, 1], padding=pad)
z = _bn(z, g, b, axes=[0, 1, 2], ema=ema)
h = act(z)
return h
def bnfc(x, scope, nh, act, ema=None):
with tf.variable_scope(scope):
nin = x.get_shape()[1].value
w = tf.get_variable("w", [nin, nh], initializer=inits.ortho_init(1.0))
g = tf.get_variable("g", [nh], initializer=inits.constant_init(1.0))
b = tf.get_variable("b", [nh], initializer=inits.constant_init(0.0))
if ema is not None:
w = ema.average(w)
g = ema.average(g)
b = ema.average(b)
z = tf.matmul(x, w)
z = _bn(z, g, b, axes=[0], ema=ema)
h = act(z)
return h
def fc(x, scope, nh, act, winit=inits.ortho_init(1.0), binit=inits.constant_init(0.0), ema=None):
with tf.variable_scope(scope):
nin = x.get_shape()[1].value
w = tf.get_variable("w", [nin, nh], initializer=winit)
b = tf.get_variable("b", [nh], initializer=binit)
if ema is not None:
w = ema.average(w)
b = ema.average(b)
z = tf.matmul(x, w)
z = z+b
h = act(z)
return h
def generator(Z, reuse=False, ema=None):
with tf.variable_scope('generator', reuse=reuse):
h = bnfc(Z, scope='h', nh=4*4*1024, act=glu, ema=ema)
h = tf.reshape(h, [128, 4, 4, 512])
h2 = bndeconv(h, scope='h2', shape=[128, 8, 8, 512], rf=5, nf=512, act=glu, ema=ema)
h3 = bndeconv(h2, scope='h3', shape=[128, 16, 16, 256], rf=5, nf=256, act=glu, ema=ema)
h4 = bndeconv(h3, scope='h4', shape=[128, 32, 32, 128], rf=5, nf=128, act=glu, ema=ema)
h5 = bndeconv(h4, scope='h5', shape=[128, 64, 64, 64], rf=5, nf=64, act=glu, ema=ema)
h6 = deconv(h5, scope='h6', shape=[128, 128, 128, 3], rf=5, nf=3, act=tf.nn.tanh, ema=ema)
return h6
def discriminator(X, reuse=False):
with tf.variable_scope('discriminator', reuse=reuse):
h = conv(X, scope='h', rf=5, nf=32, act=lrelu, stride=2)
h2 = bnconv(h, scope='h2', rf=5, nf=64, act=lrelu, stride=2)
h3 = bnconv(h2, scope='h3', rf=5, nf=128, act=lrelu, stride=2)
h4 = bnconv(h3, scope='h4', rf=5, nf=256, act=lrelu, stride=2)
h5 = bnconv(h4, scope='h5', rf=5, nf=512, act=lrelu, stride=2)
h5 = tf.reshape(h5, [128, -1])
logits = fc(h5, scope='out', nh=1, act=lambda x:x, winit=inits.ortho_init(1.0))
return logits
gz = generator(Z)
dx = discriminator(X)
dgz = discriminator(gz, reuse=True)
ema_params = find_variables('generator')
for p in ema_params:
print(p)
ema = tf.train.ExponentialMovingAverage(decay=0.999)
avg_params = ema.apply(ema_params)
ema_params = [ema.average(p) for p in ema_params]
gz_ema = generator(Z, reuse=True, ema=ema)
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=dgz, labels=tf.ones((128, 1))))
dx_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=dx, labels=tf.ones((128, 1))))
dgz_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=dgz, labels=tf.zeros((128, 1))))
d_loss = dx_loss*0.5 + dgz_loss*0.5
d_params = find_trainable_variables('discriminator')
g_params = find_trainable_variables('generator')
for p in g_params:
print(p.name)
for p in d_params:
print(p.name)
d_grads = tf.gradients(d_loss, d_params)
g_grads = tf.gradients(g_loss, g_params)
for i in range(len(d_grads)):
d_grads[i] = d_grads[i] / tf.norm(d_grads[i])
for i in range(len(g_grads)):
g_grads[i] = g_grads[i] / tf.norm(g_grads[i])
d_trainer = tf.train.AdamOptimizer(learning_rate=0.0005, beta1=0.5)
g_trainer = tf.train.AdamOptimizer(learning_rate=0.0005, beta1=0.5)
d_train = d_trainer.apply_gradients(list(zip(d_grads, d_params)))
g_train = g_trainer.apply_gradients(list(zip(g_grads, g_params)))
bn_updates = [bn_update for bn_update in bn_updates if 'generator' in bn_update.name]
bn_updates = tf.group(*bn_updates)
sample_zmb = np.random.randn(128, 100).astype(np.float32)
nepochs = 0
nupdates = 0
nseconds = 0
config = tf.ConfigProto(allow_soft_placement=True,
intra_op_parallelism_threads=4,
inter_op_parallelism_threads=4)
with tf.Session(config=config) as sess:
tf.global_variables_initializer().run()
samples = sess.run(gz, {Z:sample_zmb})
print(samples.mean(), samples.std(), samples.min(), samples.max())
img = color_grid((samples+1)/2., path='vis/%s/init.png'%desc)
tstart = time.time()
for i in range(1000):
for xmb in tqdm(iter_data(*shuffle(trX), size=128), total=ntrain//128, leave=False, ncols=80):
if len(xmb) == 128:
zmb = np.random.randn(128, 100).astype(np.float32)
sess.run(d_train, {X:xmb/127.5-1., Z:zmb})
zmb = np.random.randn(128, 100).astype(np.float32)
sess.run([g_train, avg_params, bn_updates], {Z:zmb})
nupdates += 1
nseconds = (time.time()-tstart)
zmb = np.random.randn(128, 100).astype(np.float32)
xmb = trX[:128]
print(i)
samples = sess.run(gz, {Z:sample_zmb})
img = color_grid((samples+1)/2., path='vis/%s/cur/%d.png'%(desc, i))
samples = sess.run(gz_ema, {Z:sample_zmb})
img = color_grid((samples+1)/2., path='vis/%s/ema/%d.png'%(desc, i))
test_sample = []
test_sample_ema = []
if (i+1) % 30 == 0:
for t in range(test_batches):
test_zmb = np.random.randn(128, 100).astype(np.float32)
samples = sess.run(gz, {Z: test_zmb})
samples_ema = sess.run(gz_ema, {Z:test_zmb})
test_sample.append(samples)
test_sample_ema.append(samples_ema)
test_sample = np.concatenate(test_sample)
test_sample_ema = np.concatenate(test_sample_ema)
test_sample = [127.5*(test_sample[i]+1.) for i in range(test_sample.shape[0])]
test_sample_ema = [127.5 * (test_sample_ema[i] + 1.) for i in range(test_sample_ema.shape[0])]
inception_score = get_inception_score(test_sample, splits=1)
file.write('epoch %d inception score was %.6f \n' % (i, inception_score[0]))
inception_score = get_inception_score(test_sample_ema, splits=1)
file.write('epoch %d EMA inception score was %.6f\n\n' % (i, inception_score[0]))
file.flush()
file.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment