Skip to content

Instantly share code, notes, and snippets.

@ppwwyyxx
Last active April 20, 2020 21:32
Show Gist options
  • Save ppwwyyxx/e1900b0e49bfb6771af22cc882b57bb5 to your computer and use it in GitHub Desktop.
Save ppwwyyxx/e1900b0e49bfb6771af22cc882b57bb5 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: CycleGAN-replaybuffer.py
import os, sys
import argparse
import glob
from six.moves import map, zip, range
from collections import deque
import numpy as np
import random
from tensorpack import *
from tensorpack.utils.viz import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf
from GAN import GANTrainer, GANModelDesc
"""
1. Download the dataset following the original project: https://github.com/junyanz/CycleGAN#train
2. ./CycleGAN.py --data /path/to/datasets/horse2zebra
Training and testing visuliazations will be in tensorboard.
"""
SHAPE = 256
BATCH = 1
TEST_BATCH = 32
NF = 64 # channel size
def INReLU(x, name=None):
x = InstanceNorm('inorm', x)
return tf.nn.relu(x, name=name)
def INLReLU(x, name=None):
x = InstanceNorm('inorm', x)
return LeakyReLU(x, name=name)
class Model(GANModelDesc):
def _get_inputs(self):
return [InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'inputA'),
InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'inputB'),
InputDesc(tf.float32, (None, 3, SHAPE, SHAPE), 'fakeinputA'),
InputDesc(tf.float32, (None, 3, SHAPE, SHAPE), 'fakeinputB')
]
@staticmethod
def build_res_block(x, name, chan, first=False):
with tf.variable_scope(name):
input = x
return (LinearWrap(x)
.tf.pad([[0, 0], [0, 0], [1, 1], [1, 1]], mode='SYMMETRIC')
.Conv2D('conv0', chan, padding='VALID')
.tf.pad([[0, 0], [0, 0], [1, 1], [1, 1]], mode='SYMMETRIC')
.Conv2D('conv1', chan, padding='VALID', nl=tf.identity)
.InstanceNorm('inorm')()) + input
@auto_reuse_variable_scope
def generator(self, img):
assert img is not None
with argscope([Conv2D, Deconv2D], nl=INReLU, kernel_shape=3):
l = (LinearWrap(img)
.tf.pad([[0, 0], [0, 0], [3, 3], [3, 3]], mode='SYMMETRIC')
.Conv2D('conv0', NF, kernel_shape=7, padding='VALID')
.Conv2D('conv1', NF * 2, stride=2)
.Conv2D('conv2', NF * 4, stride=2)())
for k in range(9):
l = Model.build_res_block(l, 'res{}'.format(k), NF * 4, first=(k == 0))
l = (LinearWrap(l)
.Deconv2D('deconv0', NF * 2, stride=2)
.Deconv2D('deconv1', NF * 1, stride=2)
.tf.pad([[0, 0], [0, 0], [3, 3], [3, 3]], mode='SYMMETRIC')
.Conv2D('convlast', 3, kernel_shape=7, padding='VALID', nl=tf.tanh, use_bias=True)())
return l
@auto_reuse_variable_scope
def discriminator(self, img):
with argscope(Conv2D, nl=INLReLU, kernel_shape=4, stride=2):
l = (LinearWrap(img)
.Conv2D('conv0', NF, nl=LeakyReLU)
.Conv2D('conv1', NF * 2)
.Conv2D('conv2', NF * 4)
.Conv2D('conv3', NF * 8, stride=1)
.Conv2D('conv4', 1, stride=1, nl=tf.identity, use_bias=True)())
return l
def _build_graph(self, inputs):
A, B, fakeA, fakeB = inputs
A = tf.transpose(A / 128.0 - 1.0, [0, 3, 1, 2])
B = tf.transpose(B / 128.0 - 1.0, [0, 3, 1, 2])
def viz3(name, a, b, c):
im = tf.concat([a, b, c], axis=3)
im = tf.transpose(im, [0, 2, 3, 1])
im = (im + 1.0) * 128
im = tf.clip_by_value(im, 0, 255)
im = tf.cast(im, tf.uint8, name='viz_' + name)
tf.summary.image(name, im, max_outputs=50)
# use the initializers from torch
with argscope([Conv2D, Deconv2D], use_bias=False,
W_init=tf.random_normal_initializer(stddev=0.02)), \
argscope([Conv2D, Deconv2D, InstanceNorm], data_format='NCHW'), \
argscope(LeakyReLU, alpha=0.2):
with tf.variable_scope('gen'):
with tf.variable_scope('B'):
AB = self.generator(A)
with tf.variable_scope('A'):
BA = self.generator(B)
ABA = self.generator(AB)
with tf.variable_scope('B'):
BAB = self.generator(BA)
viz3('A_recon', A, AB, ABA)
viz3('B_recon', B, BA, BAB)
def batch_discriminator(*values):
c = tf.concat(values, axis=0)
outputs = self.discriminator(c)
return tf.split(outputs, len(values), axis=0)
with tf.variable_scope('discrim'):
with tf.variable_scope('A'):
A_dis_real, A_dis_buffer = batch_discriminator(A, fakeA)
A_dis_fake = self.discriminator(BA)
with tf.variable_scope('B'):
B_dis_real, B_dis_buffer = batch_discriminator(B, fakeB)
B_dis_fake = self.discriminator(AB)
def LSGAN_losses(real, fake, fake_buffer):
with tf.name_scope('LSGAN_losses'):
d_real = tf.reduce_mean(tf.squared_difference(real, 0.9), name='d_real')
choice = tf.less(tf.random_uniform([BATCH]), 0.5)
d_fake = tf.reduce_mean(tf.square(tf.where(choice, fake, fake_buffer)), name='d_fake')
d_loss = tf.add(d_real, d_fake, name='d_loss')
g_loss = tf.reduce_mean(tf.squared_difference(fake, 0.9), name='g_loss')
add_moving_summary(g_loss, d_loss)
return g_loss, d_loss
with tf.name_scope('LossA'):
# reconstruction loss
recon_loss_A = tf.reduce_mean(tf.abs(A - ABA), name='recon_loss')
# gan loss
G_loss_A, D_loss_A = LSGAN_losses(A_dis_real, A_dis_fake, A_dis_buffer)
with tf.name_scope('LossB'):
recon_loss_B = tf.reduce_mean(tf.abs(B - BAB), name='recon_loss')
G_loss_B, D_loss_B = LSGAN_losses(B_dis_real, B_dis_fake, B_dis_buffer)
LAMBDA = 10.0
self.g_loss = tf.add((G_loss_A + G_loss_B),
(recon_loss_A + recon_loss_B) * LAMBDA, name='G_loss_total')
self.d_loss = tf.add(D_loss_A, D_loss_B, name='D_loss_total')
self.collect_variables('gen', 'discrim')
add_moving_summary(recon_loss_A, recon_loss_B, self.g_loss, self.d_loss)
def _get_optimizer(self):
lr = symbolic_functions.get_scalar_var('learning_rate', 2e-4, summary=True)
return tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3)
def get_data(datadir, isTrain=True):
if isTrain:
augs = [
imgaug.Resize(int(SHAPE * 1.12)),
imgaug.RandomCrop(SHAPE),
imgaug.Flip(horiz=True)
]
else:
augs = [imgaug.Resize(SHAPE)]
def get_image_pairs(dir1, dir2):
def get_df(dir):
files = sorted(glob.glob(os.path.join(dir, '*.jpg')))
df = ImageFromFile(files, channel=3, shuffle=isTrain)
return AugmentImageComponent(df, augs)
return JoinData([get_df(dir1), get_df(dir2)])
names = ['trainA', 'trainB'] if isTrain else ['testA', 'testB']
df = get_image_pairs(*[os.path.join(datadir, n) for n in names])
df = BatchData(df, BATCH if isTrain else TEST_BATCH)
df = PrefetchDataZMQ(df, 2 if isTrain else 1)
return df
class VisualizeTestSet(Callback):
def _setup_graph(self):
self.pred = self.trainer.get_predictor(['inputA', 'inputB'], ['viz_A_recon', 'viz_B_recon'])
def _before_train(self):
global args
self.val_ds = get_data(args.data, isTrain=False)
def _trigger(self):
idx = 0
for iA, iB in self.val_ds.get_data():
vizA, vizB = self.pred(iA, iB)
self.trainer.monitors.put_image('testA-{}'.format(idx), vizA)
self.trainer.monitors.put_image('testB-{}'.format(idx), vizB)
idx += 1
class FakeBuffer(ProxyDataFlow, Callback):
""" A buffer to hold previously-generated fake images. """
def __init__(self, df, size):
ProxyDataFlow.__init__(self, df)
self._sz = size
def get_data(self):
for dp in super(FakeBuffer, self).get_data():
fA = random.sample(list(self._bufferA), BATCH)
fB = random.sample(list(self._bufferB), BATCH)
dp.extend([fA, fB])
yield dp
def _setup_graph(self):
G = tf.get_default_graph()
self.genA = G.get_tensor_by_name('gen/A/convlast/output:0')
self.genB = G.get_tensor_by_name('gen/B/convlast/output:0')
self._bufferA = self._create_queue()
self._bufferB = self._create_queue()
def _create_queue(self):
queue = deque(maxlen=self._sz)
shp = self.genA.shape.as_list()[1:]
for k in range(5):
queue.append(np.random.rand(*shp) * 2. - 1.)
return queue
def _before_run(self, _):
return tf.train.SessionRunArgs(fetches=[self.genA, self.genB])
def _after_run(self, _, rv):
gA, gB = rv.results
self._bufferA.extend(gA)
self._bufferB.extend(gB)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data', required=True,
help='the image directory. should contain trainA/trainB/testA/testB')
parser.add_argument('--load', help='load model')
args = parser.parse_args()
logger.auto_set_dir()
data = get_data(args.data)
data = PrintData(data)
data = FakeBuffer(data, 50)
config = TrainConfig(
model=Model(),
dataflow=data,
callbacks=[
ModelSaver(),
ScheduledHyperParamSetter(
'learning_rate',
[(100, 2e-4), (200, 0)], interp='linear'),
PeriodicTrigger(VisualizeTestSet(), every_k_epochs=3),
data,
],
max_epoch=195,
session_init=SaverRestore(args.load) if args.load else None
)
GANTrainer(config).train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment