Skip to content

Instantly share code, notes, and snippets.

@VyBui
Created January 2, 2020 08:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save VyBui/2e300e57b2860f3ca0360d341c39fefc to your computer and use it in GitHub Desktop.
Save VyBui/2e300e57b2860f3ca0360d341c39fefc to your computer and use it in GitHub Desktop.
from __future__ import absolute_import, division, print_function, unicode_literals
import logging.config
import tensorflow as tf
import os
import numpy as np
from tensorflow import ConfigProto
from argument_parse import args
from module.discriminator_vgg19 import Discriminator
from module.generator import Generator
from module.losses import l1_loss
from config import cfg
from data_tools.parse_records_dataset import input_fn
from calculate_average_gradients import get_perturbed_batch, average_gradients
from preprocessing.dataset import pre_processing
from utils import shuffle
if args.mode in ['train', 'test', 'val']:
params = {'batch_size': cfg.train_batch_size,
'tfrecords_path': cfg.tfrecords_path}
train_dataset = input_fn(args.mode, params)
else:
raise ValueError("mode must be via ( train, test or val).")
if not any([isinstance(args.num_gpus, int), isinstance(args.batch_size, int)]):
raise ValueError("num gpus or batch size must be type integer.")
if args.mode in ['train', 'test', 'val']:
params = {'batch_size': cfg.train_batch_size,
'tfrecords_path': cfg.tfrecords_path}
train_dataset = input_fn(args.mode, params)
# get TF logger
# load logging confoguration and create log object
logging.config.fileConfig('logging.conf')
logging.basicConfig(filename='skin_generator.log', level=logging.DEBUG)
log = logging.getLogger('TensorFlow')
log.setLevel(logging.DEBUG)
# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# create file handler which logs even debug messages
fh = logging.FileHandler('module_2.log')
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
log.addHandler(fh)
if __name__ == '__main__':
train_iterator = train_dataset.make_initializable_iterator()
batch_data = train_iterator.get_next()
image_label, body_parts, seg_parts, top_and_bottom = batch_data
train_batch_size_step = args.batch_size // args.num_gpus
# output of D for real images
D_real, D_real_logits = Discriminator(image_label).feed_forward()
# output of D for fake images
gen, end_points = Generator(body_parts, seg_parts, top_and_bottom).feed_forward()
D_fake, D_fake_logits = Discriminator(gen).feed_forward()
label_input_perturbed = get_perturbed_batch = get_perturbed_batch(image_label)
# get loss for discriminator
with tf.name_scope('D_loss'):
d_loss_real = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real)))
d_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros_like(D_fake)))
d_loss = d_loss_real + d_loss_fake
alpha = tf.random_uniform(shape=tf.shape(image_label), minval=0., maxval=1.)
differences = label_input_perturbed - image_label # This is different from WGAN-GP
interpolates = image_label + (alpha * differences)
_, D_inter = Discriminator(interpolates).feed_forward()
gradients = tf.gradients(D_inter, [interpolates])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
lambd = 0.1
d_loss += lambd * gradient_penalty
# get loss for generator
with tf.name_scope('G_loss'):
g_mse_lambda = 100
g_mse_loss = tf.keras.losses.MSE(y_true=image_label, y_pred=gen)
g_mse_loss = g_mse_loss * g_mse_lambda
gen_loss = g_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones_like(D_fake))) + g_mse_loss
# Training: divide trainable variables into a group for D and a group for G
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if "discriminator" in var.name]
g_vars = [var for var in t_vars if "generator" in var.name]
# Optimizers
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
d_train_opt = tf.train.AdamOptimizer(learning_rate=cfg.lr, beta1=0.5).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate=cfg.lr, beta1=0.5).minimize(g_loss, var_list=g_vars)
top = top_and_bottom[:, :, :, 0:3]
# Summary
gen__image_sum = tf.summary.image("fake", gen[:, :, :, ::-1], max_outputs=1)
real_image_sum = tf.summary.image("real", image_label[:, :, :, ::-1], max_outputs=1)
top_sum = tf.summary.image("top", top[:, :, :, ::-1], max_outputs=1)
d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real, family="D_loss")
d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake, family="D_loss")
d_loss_sum = tf.summary.scalar("d_loss", d_loss, family="D_loss")
g_loss_l1_sum = tf.summary.scalar("g_mse_loss", g_mse_loss, family="G_loss")
g_loss_sum = tf.summary.scalar("g_loss", g_loss, family="G_loss")
# final summary operations
g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum, g_loss_l1_sum, gen__image_sum, real_image_sum, top_sum])
d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum, gen__image_sum, real_image_sum])
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
summary_writer = tf.summary.FileWriter(cfg.log_dir, sess.graph)
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
sess.run(train_iterator.initializer)
saver.restore(sess, tf.train.latest_checkpoint(cfg.path_save_model))
print("restore successfully !! " * 100)
try:
for epoch in range(cfg.epoch_size):
# Training
for itr in range(cfg.dataset_size // cfg.train_batch_size):
# noise_label = get_perturbed_batch(image_label)
# Update Dicriminator
d_loss_val, summary_str, opt_d = sess.run([d_loss, d_sum, d_train_opt])
# Update Generator
g_loss_val, summary_str, opt_g = sess.run([g_loss, g_sum, g_train_opt])
if itr % 50 == 0:
print("epoch - {} | iter - {} | d-loss - {}".format(epoch, itr, d_loss_val))
summary_writer.add_summary(summary_str, itr)
print("epoch - {} | iter - {} | g-loss - {}".format(epoch, itr, g_loss_val))
summary_writer.add_summary(summary_str, itr)
# summary_writer.add_summary(clothes_sumarry, itr)
saver.save(sess, cfg.path_save_model)
print("Successful !!!")
except Exception as es:
log.debug(es)
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment