Last active
April 20, 2020 21:32
-
-
Save ppwwyyxx/e1900b0e49bfb6771af22cc882b57bb5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# File: CycleGAN-replaybuffer.py | |
import os, sys | |
import argparse | |
import glob | |
from six.moves import map, zip, range | |
from collections import deque | |
import numpy as np | |
import random | |
from tensorpack import * | |
from tensorpack.utils.viz import * | |
import tensorpack.tfutils.symbolic_functions as symbf | |
from tensorpack.tfutils.summary import add_moving_summary | |
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope | |
import tensorflow as tf | |
from GAN import GANTrainer, GANModelDesc | |
""" | |
1. Download the dataset following the original project: https://github.com/junyanz/CycleGAN#train | |
2. ./CycleGAN.py --data /path/to/datasets/horse2zebra | |
Training and testing visuliazations will be in tensorboard. | |
""" | |
SHAPE = 256 | |
BATCH = 1 | |
TEST_BATCH = 32 | |
NF = 64 # channel size | |
def INReLU(x, name=None): | |
x = InstanceNorm('inorm', x) | |
return tf.nn.relu(x, name=name) | |
def INLReLU(x, name=None): | |
x = InstanceNorm('inorm', x) | |
return LeakyReLU(x, name=name) | |
class Model(GANModelDesc): | |
def _get_inputs(self): | |
return [InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'inputA'), | |
InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'inputB'), | |
InputDesc(tf.float32, (None, 3, SHAPE, SHAPE), 'fakeinputA'), | |
InputDesc(tf.float32, (None, 3, SHAPE, SHAPE), 'fakeinputB') | |
] | |
@staticmethod | |
def build_res_block(x, name, chan, first=False): | |
with tf.variable_scope(name): | |
input = x | |
return (LinearWrap(x) | |
.tf.pad([[0, 0], [0, 0], [1, 1], [1, 1]], mode='SYMMETRIC') | |
.Conv2D('conv0', chan, padding='VALID') | |
.tf.pad([[0, 0], [0, 0], [1, 1], [1, 1]], mode='SYMMETRIC') | |
.Conv2D('conv1', chan, padding='VALID', nl=tf.identity) | |
.InstanceNorm('inorm')()) + input | |
@auto_reuse_variable_scope | |
def generator(self, img): | |
assert img is not None | |
with argscope([Conv2D, Deconv2D], nl=INReLU, kernel_shape=3): | |
l = (LinearWrap(img) | |
.tf.pad([[0, 0], [0, 0], [3, 3], [3, 3]], mode='SYMMETRIC') | |
.Conv2D('conv0', NF, kernel_shape=7, padding='VALID') | |
.Conv2D('conv1', NF * 2, stride=2) | |
.Conv2D('conv2', NF * 4, stride=2)()) | |
for k in range(9): | |
l = Model.build_res_block(l, 'res{}'.format(k), NF * 4, first=(k == 0)) | |
l = (LinearWrap(l) | |
.Deconv2D('deconv0', NF * 2, stride=2) | |
.Deconv2D('deconv1', NF * 1, stride=2) | |
.tf.pad([[0, 0], [0, 0], [3, 3], [3, 3]], mode='SYMMETRIC') | |
.Conv2D('convlast', 3, kernel_shape=7, padding='VALID', nl=tf.tanh, use_bias=True)()) | |
return l | |
@auto_reuse_variable_scope | |
def discriminator(self, img): | |
with argscope(Conv2D, nl=INLReLU, kernel_shape=4, stride=2): | |
l = (LinearWrap(img) | |
.Conv2D('conv0', NF, nl=LeakyReLU) | |
.Conv2D('conv1', NF * 2) | |
.Conv2D('conv2', NF * 4) | |
.Conv2D('conv3', NF * 8, stride=1) | |
.Conv2D('conv4', 1, stride=1, nl=tf.identity, use_bias=True)()) | |
return l | |
def _build_graph(self, inputs): | |
A, B, fakeA, fakeB = inputs | |
A = tf.transpose(A / 128.0 - 1.0, [0, 3, 1, 2]) | |
B = tf.transpose(B / 128.0 - 1.0, [0, 3, 1, 2]) | |
def viz3(name, a, b, c): | |
im = tf.concat([a, b, c], axis=3) | |
im = tf.transpose(im, [0, 2, 3, 1]) | |
im = (im + 1.0) * 128 | |
im = tf.clip_by_value(im, 0, 255) | |
im = tf.cast(im, tf.uint8, name='viz_' + name) | |
tf.summary.image(name, im, max_outputs=50) | |
# use the initializers from torch | |
with argscope([Conv2D, Deconv2D], use_bias=False, | |
W_init=tf.random_normal_initializer(stddev=0.02)), \ | |
argscope([Conv2D, Deconv2D, InstanceNorm], data_format='NCHW'), \ | |
argscope(LeakyReLU, alpha=0.2): | |
with tf.variable_scope('gen'): | |
with tf.variable_scope('B'): | |
AB = self.generator(A) | |
with tf.variable_scope('A'): | |
BA = self.generator(B) | |
ABA = self.generator(AB) | |
with tf.variable_scope('B'): | |
BAB = self.generator(BA) | |
viz3('A_recon', A, AB, ABA) | |
viz3('B_recon', B, BA, BAB) | |
def batch_discriminator(*values): | |
c = tf.concat(values, axis=0) | |
outputs = self.discriminator(c) | |
return tf.split(outputs, len(values), axis=0) | |
with tf.variable_scope('discrim'): | |
with tf.variable_scope('A'): | |
A_dis_real, A_dis_buffer = batch_discriminator(A, fakeA) | |
A_dis_fake = self.discriminator(BA) | |
with tf.variable_scope('B'): | |
B_dis_real, B_dis_buffer = batch_discriminator(B, fakeB) | |
B_dis_fake = self.discriminator(AB) | |
def LSGAN_losses(real, fake, fake_buffer): | |
with tf.name_scope('LSGAN_losses'): | |
d_real = tf.reduce_mean(tf.squared_difference(real, 0.9), name='d_real') | |
choice = tf.less(tf.random_uniform([BATCH]), 0.5) | |
d_fake = tf.reduce_mean(tf.square(tf.where(choice, fake, fake_buffer)), name='d_fake') | |
d_loss = tf.add(d_real, d_fake, name='d_loss') | |
g_loss = tf.reduce_mean(tf.squared_difference(fake, 0.9), name='g_loss') | |
add_moving_summary(g_loss, d_loss) | |
return g_loss, d_loss | |
with tf.name_scope('LossA'): | |
# reconstruction loss | |
recon_loss_A = tf.reduce_mean(tf.abs(A - ABA), name='recon_loss') | |
# gan loss | |
G_loss_A, D_loss_A = LSGAN_losses(A_dis_real, A_dis_fake, A_dis_buffer) | |
with tf.name_scope('LossB'): | |
recon_loss_B = tf.reduce_mean(tf.abs(B - BAB), name='recon_loss') | |
G_loss_B, D_loss_B = LSGAN_losses(B_dis_real, B_dis_fake, B_dis_buffer) | |
LAMBDA = 10.0 | |
self.g_loss = tf.add((G_loss_A + G_loss_B), | |
(recon_loss_A + recon_loss_B) * LAMBDA, name='G_loss_total') | |
self.d_loss = tf.add(D_loss_A, D_loss_B, name='D_loss_total') | |
self.collect_variables('gen', 'discrim') | |
add_moving_summary(recon_loss_A, recon_loss_B, self.g_loss, self.d_loss) | |
def _get_optimizer(self): | |
lr = symbolic_functions.get_scalar_var('learning_rate', 2e-4, summary=True) | |
return tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3) | |
def get_data(datadir, isTrain=True): | |
if isTrain: | |
augs = [ | |
imgaug.Resize(int(SHAPE * 1.12)), | |
imgaug.RandomCrop(SHAPE), | |
imgaug.Flip(horiz=True) | |
] | |
else: | |
augs = [imgaug.Resize(SHAPE)] | |
def get_image_pairs(dir1, dir2): | |
def get_df(dir): | |
files = sorted(glob.glob(os.path.join(dir, '*.jpg'))) | |
df = ImageFromFile(files, channel=3, shuffle=isTrain) | |
return AugmentImageComponent(df, augs) | |
return JoinData([get_df(dir1), get_df(dir2)]) | |
names = ['trainA', 'trainB'] if isTrain else ['testA', 'testB'] | |
df = get_image_pairs(*[os.path.join(datadir, n) for n in names]) | |
df = BatchData(df, BATCH if isTrain else TEST_BATCH) | |
df = PrefetchDataZMQ(df, 2 if isTrain else 1) | |
return df | |
class VisualizeTestSet(Callback): | |
def _setup_graph(self): | |
self.pred = self.trainer.get_predictor(['inputA', 'inputB'], ['viz_A_recon', 'viz_B_recon']) | |
def _before_train(self): | |
global args | |
self.val_ds = get_data(args.data, isTrain=False) | |
def _trigger(self): | |
idx = 0 | |
for iA, iB in self.val_ds.get_data(): | |
vizA, vizB = self.pred(iA, iB) | |
self.trainer.monitors.put_image('testA-{}'.format(idx), vizA) | |
self.trainer.monitors.put_image('testB-{}'.format(idx), vizB) | |
idx += 1 | |
class FakeBuffer(ProxyDataFlow, Callback): | |
""" A buffer to hold previously-generated fake images. """ | |
def __init__(self, df, size): | |
ProxyDataFlow.__init__(self, df) | |
self._sz = size | |
def get_data(self): | |
for dp in super(FakeBuffer, self).get_data(): | |
fA = random.sample(list(self._bufferA), BATCH) | |
fB = random.sample(list(self._bufferB), BATCH) | |
dp.extend([fA, fB]) | |
yield dp | |
def _setup_graph(self): | |
G = tf.get_default_graph() | |
self.genA = G.get_tensor_by_name('gen/A/convlast/output:0') | |
self.genB = G.get_tensor_by_name('gen/B/convlast/output:0') | |
self._bufferA = self._create_queue() | |
self._bufferB = self._create_queue() | |
def _create_queue(self): | |
queue = deque(maxlen=self._sz) | |
shp = self.genA.shape.as_list()[1:] | |
for k in range(5): | |
queue.append(np.random.rand(*shp) * 2. - 1.) | |
return queue | |
def _before_run(self, _): | |
return tf.train.SessionRunArgs(fetches=[self.genA, self.genB]) | |
def _after_run(self, _, rv): | |
gA, gB = rv.results | |
self._bufferA.extend(gA) | |
self._bufferB.extend(gB) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--data', required=True, | |
help='the image directory. should contain trainA/trainB/testA/testB') | |
parser.add_argument('--load', help='load model') | |
args = parser.parse_args() | |
logger.auto_set_dir() | |
data = get_data(args.data) | |
data = PrintData(data) | |
data = FakeBuffer(data, 50) | |
config = TrainConfig( | |
model=Model(), | |
dataflow=data, | |
callbacks=[ | |
ModelSaver(), | |
ScheduledHyperParamSetter( | |
'learning_rate', | |
[(100, 2e-4), (200, 0)], interp='linear'), | |
PeriodicTrigger(VisualizeTestSet(), every_k_epochs=3), | |
data, | |
], | |
max_epoch=195, | |
session_init=SaverRestore(args.load) if args.load else None | |
) | |
GANTrainer(config).train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment