Created
August 24, 2018 18:31
-
-
Save warneracw21/811d881afc19fbe79dbdf59318f61b35 to your computer and use it in GitHub Desktop.
Alternating Training Schedule for an RNN regression and convolutional LSGAN model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Methods for Training on a distributed Network """ | |
""" | |
The session is set up to be run in a distributed local environment. | |
The workers should be assigned to GPUs because they do most of the | |
heavy lifting in the compuation. Assign a different domain:port double | |
for each GPU available on the local host. | |
The parameter servers should simply be assigned to CPUs. See the | |
__main__ conditional to see how the GPUs and CPUs are set for both | |
the worker and the parameter server. If the job_name is a worker, the | |
visible GPU is set to the task index of that worker. | |
USAGE: | |
Run a Worker: | |
python task.py --job_name=worker --task_index=0 (GPU set to 0) | |
python task.py --job_name=worker --task_index=1 (GPU set to 1) | |
Run a PS: | |
python task.py --job_name=ps --task_index=0 (GPU set to NULL) | |
python task.py --job_name=ps --task_index=1 (GPU set to NULL) | |
AUTHOR: Andrew Warner | |
DATE: 7 August 2018 | |
""" | |
from __future__ import absolute_import | |
from __future__ import print_function | |
from __future__ import division | |
import sys | |
import os | |
# Import Libraries | |
import tensorflow as tf | |
import argparse | |
# Import Tensorflow Libraries | |
from tensorflow.contrib.training import HParams | |
from tensorflow.python.training import saver | |
from tensorflow.python import debug as tf_debug | |
# Import Methods | |
from tensorbose import tfutils, fileio | |
from functools import partial | |
# Import Modules | |
from trainer import data_provider | |
from trainer import preprocessing | |
from trainer import simon_ops | |
from trainer import hookdefs | |
from trainer import networks | |
from trainer import losses | |
# Constants | |
from trainer.params import get_hparams | |
HPARAMS = get_hparams() | |
# Set Environment Variables | |
tf.logging.set_verbosity(tf.logging.INFO) | |
FLAGS = None | |
def main(argv=[]): | |
""" | |
*** Graph Definition and Session Execution *** | |
This function defines both the regression and GAN model for the | |
alternating training routine executed by the provided session. | |
""" | |
# Define the parameters and the cluster specs | |
params = get_hparams() | |
cluster = tf.train.ClusterSpec({'worker': ['localhost:2222', 'localhost:2223'], | |
'ps': ['localhost:2221', 'localhost:2224']}) | |
# The server is specific to the job_name and task_index given in the process call | |
server = tf.train.Server( | |
cluster, | |
job_name=FLAGS.job_name, | |
task_index=FLAGS.task_index) | |
# Join the process similar to multi-threading | |
if FLAGS.job_name == 'ps': | |
server.join() | |
elif FLAGS.job_name == 'worker': | |
# This is taken directly from Tensorflow's Distributed Example | |
load_fn = tf.contrib.training.byte_size_load_fn | |
startegy = tf.contrib.training.GreedyLoadBalancingStrategy(2, load_fn) | |
replica_setter = tf.train.replica_device_setter( | |
worker_device="/job:worker/task:%d" % FLAGS.task_index, | |
cluster=cluster) | |
with tf.device(replica_setter): | |
""" | |
The input is synchronous. Thus, the clean speech and noisy speech | |
should hold the same speech content. See data_provider.py for more | |
""" | |
with tf.variable_scope('input'): | |
noisy_speech_data, clean_speech_signal = data_provider.input_fn(mode='TRAIN') | |
################################################################### | |
""" Regression Model """ | |
################################################################### | |
""" | |
The regression model uses the generator to produce sequences in the | |
frequency domain, upon which it runs a mean squared error loss against | |
the stft frequency domain of the clean speech. Moreover, the | |
cleaned frequency domain sequence from the generator is inverted | |
back into a time domain signal, from which another mean squared error | |
is calculated against the clean signal | |
""" | |
# Data Preperation for Generator and Loss Functions | |
with tf.variable_scope('preprocess'): | |
# Preprocess the Noisy Signal for Processing by the Generator | |
noisy_speech_signal, noisy_sequence_length = noisy_speech_data | |
noisy_frequency_sequence = preprocessing.gen_preprocess_data( | |
data=noisy_speech_signal, | |
data_type='features') | |
noisy_speech_data = (noisy_frequency_sequence, noisy_sequence_length) | |
# Preprocess the Clean Signal for the STFT Loss Function | |
clean_frequency_sequence = preprocessing.gen_preprocess_data( | |
data=clean_speech_signal, | |
data_type='labels') | |
# Preprocess the Clean Signal for MSE Signal Loss Function | |
clean_speech_signal = preprocessing._pad_to_max( | |
data=clean_speech_signal, | |
pad_size=params.max_signal_length, | |
params=params) | |
clean_speech_signal = tf.reshape( | |
tensor=clean_speech_signal, | |
shape=[-1, params.max_signal_length, 1]) | |
# Generator and Postprocess Cleaning | |
with tf.variable_scope('generator') as gen_scope: | |
generator_fn = networks.generator_fn | |
cleaned_frequency_sequence= generator_fn(noisy_speech_data) | |
with tf.variable_scope('postprocess'): | |
cleaned_speech_signal = simon_ops.inverse_stft( | |
sequence=cleaned_frequency_sequence, | |
params=params) | |
cleaned_speech_signal = tf.reshape( | |
tensor=cleaned_speech_signal, | |
shape=[-1, params.max_signal_length, 1]) | |
# Loss Functions | |
with tf.variable_scope('stft_loss'): | |
stft_loss = losses.stft_loss(cleaned_frequency_sequence, | |
clean_frequency_sequence) | |
with tf.variable_scope('signal_loss'): | |
mse_loss = losses.signal_loss(cleaned_speech_signal, | |
clean_speech_signal) | |
# Train Ops | |
""" | |
NOTE: | |
The loss for the mse_signal_loss is significantly less | |
than the loss for the stft_loss, thus, should be weighted | |
accordingly if the mse_loss is deamed necessary | |
""" | |
global_step = tf.train.get_or_create_global_step() | |
gen_lr = params.learning_rate_decay_fn(0.001, global_step) | |
with tf.variable_scope('regression_train_op'): | |
optimizer = tf.train.AdamOptimizer(learning_rate=gen_lr) | |
with tf.variable_scope('compute_gradients'): | |
stft_gradients = losses.compute_gradients( | |
optimizer=optimizer, | |
loss_fn=stft_loss) | |
mse_gradients = losses.compute_gradients( | |
optimizer=optimizer, | |
loss_fn=mse_loss) | |
with tf.variable_scope('apply_gradients'): | |
stft_train_op = optimizer.apply_gradients(stft_gradients) | |
mse_train_op = optimizer.apply_gradients(mse_gradients) | |
with tf.name_scope('update_global_step'): | |
global_step = tf.train.get_or_create_global_step() | |
global_step_update = global_step.assign_add(1) | |
regression_train_op = (stft_train_op, mse_train_op, global_step_update) | |
################################################################### | |
""" GAN Model Definition and Operations """ | |
################################################################### | |
""" | |
The GAN model uses the same generator as the regression model and | |
a new discriminator. The discriminator uses the VGGish model | |
defined by Google and is initialized with frozen, trained, variables | |
from that model. The generator takes a sequence in the frequency domain | |
and produces a sequence in the frequency domain. This sequence is | |
transformed into a signal and passed into another preprocessing routine. | |
This other preprocessing routine performs another stft on the signal, | |
but with a much larger window, thus, capturing much more frequency | |
information. This sequence in then transormed into the mel log space | |
and passed into the VGGish CNN. The only trainable variable in the | |
discriminator is the dense output layer that transforms the | |
embedding produced by the CNN into a scalar logit that is then passed | |
into the GAN loss functions. | |
The GAN loss functions used for this model are the least squared | |
gan loss functions defined in the tf.contrib.gan module. | |
""" | |
tfgan = tf.contrib.gan | |
# Prepare the Signals for the Discriminator Network (Check Shapes in Params) | |
with tf.variable_scope('preprocess_dis'): | |
clean_frequency_sequence = preprocessing.dis_preprocess_data( | |
data=clean_speech_signal) | |
cleaned_frequency_sequence = preprocessing.dis_preprocess_data( | |
data=cleaned_speech_signal) | |
# Define both the Discriminators with shared variables | |
discriminator_fn = networks.discriminator_fn | |
with tf.variable_scope('discriminator') as dis_scope: | |
logits_from_gen = discriminator_fn(data=cleaned_frequency_sequence) | |
with tf.variable_scope(dis_scope, reuse=True): | |
logits_from_real = discriminator_fn(data=clean_frequency_sequence) | |
generator_variables = tf.trainable_variables(scope='generator') | |
discriminator_variables = tf.trainable_variables(scope='discriminator') | |
gan_model = tfgan.GANModel( | |
# Generator Scope | |
generator_inputs=noisy_speech_data, | |
generated_data=cleaned_frequency_sequence, | |
generator_variables=generator_variables, | |
generator_scope=gen_scope, | |
generator_fn=generator_fn, | |
# Discriminator Scope | |
real_data=clean_speech_signal, | |
discriminator_real_outputs=logits_from_real, | |
discriminator_gen_outputs=logits_from_gen, | |
discriminator_variables=discriminator_variables, | |
discriminator_scope=dis_scope, | |
discriminator_fn=discriminator_fn) | |
with tf.variable_scope('gan_loss'): | |
gan_loss = tfgan.gan_loss( | |
gan_model, | |
generator_loss_fn=tfgan.losses.least_squares_generator_loss, | |
discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss, | |
add_summaries=False) | |
with tf.variable_scope('gan_train_op'): | |
gen_optimizer = tf.train.AdamOptimizer(gen_lr, 0.9) | |
dis_optimizer = tf.train.AdamOptimizer(learning_rate=0.001) | |
with tf.name_scope('update_global_step'): | |
global_step = tf.train.get_or_create_global_step() | |
global_step_update = global_step.assign_add(1) | |
with tf.variable_scope('generator'): | |
with tf.variable_scope('compute_gradients'): | |
gen_gan_loss = gan_loss.generator_loss | |
gen_gan_gradients = losses.compute_gradients( | |
optimizer=optimizer, | |
loss_fn=gen_gan_loss) | |
with tf.variable_scope('apply_gradients'): | |
gen_gan_train_op = (gen_optimizer.apply_gradients(gen_gan_gradients), | |
global_step_update) | |
with tf.variable_scope('discriminator'): | |
with tf.variable_scope('compute_gradients'): | |
dis_gan_loss = gan_loss.discriminator_loss | |
dis_gan_gradients = losses.compute_gradients( | |
optimizer=optimizer, | |
loss_fn=dis_gan_loss) | |
with tf.variable_scope('apply_gradients'): | |
dis_gan_train_op = (gen_optimizer.apply_gradients(dis_gan_gradients), | |
global_step_update) | |
################################################################### | |
""" Summaries """ | |
################################################################### | |
with tf.variable_scope('summaries'): | |
with tf.variable_scope('audio_summaries'): | |
clean_signal_summary = simon_ops.add_audio_summary( | |
audio_data=clean_speech_signal, | |
name='clean_signal') | |
noisy_signal_summary = simon_ops.add_audio_summary( | |
audio_data=noisy_speech_signal, | |
name='noisy_signal') | |
cleaned_signal_summary = simon_ops.add_audio_summary( | |
audio_data=cleaned_speech_signal, | |
name='cleaned_signal') | |
audio_summaries = tf.summary.merge([clean_signal_summary, | |
noisy_signal_summary, | |
cleaned_signal_summary]) | |
with tf.variable_scope('scalar_summaries'): | |
stft_loss_summary = tf.summary.scalar( | |
name='stft_loss', | |
tensor=stft_loss) | |
mse_loss_summary = tf.summary.scalar( | |
name='mse_loss', | |
tensor=mse_loss) | |
mse_stft_scalar_summaries = tf.summary.merge([stft_loss_summary, | |
mse_loss_summary]) | |
gen_gan_loss_summary = tf.summary.scalar( | |
name='gen_gan', | |
tensor=gen_gan_loss) | |
dis_gan_loss_summary = tf.summary.scalar( | |
name='dis_gan', | |
tensor=dis_gan_loss) | |
####################################################################### | |
""" Training """ | |
####################################################################### | |
""" | |
The training session needs to be handled manually to handle the | |
complex routine. The regression model and GAN model alternate between | |
training. The regression model is trained for 100 steps, then the GAN | |
is trained for 50 steps. | |
The variable for the GAN Discriminator need to be preloaded from | |
the VGGish model trained by Google. Moreover, it is necessary | |
to allow GPU growth as to avoid Out Of Memory Errors | |
""" | |
model_dir = 'models/LSGAN_clipping_larger_rnn/' | |
# Set up the hook for transferring the VGG Variables | |
variables_to_restore = tf.contrib.framework.get_model_variables() | |
init_fn = tf.contrib.framework.assign_from_checkpoint_fn( | |
'vars/vggish_variables/dis_vggish_model.ckpt', | |
variables_to_restore) | |
# Variables for the Monitored Training Session | |
is_chief = FLAGS.task_index == 0 | |
config = tf.ConfigProto() | |
config.gpu_options.allow_growth = True | |
saver = tf.train.Saver(max_to_keep=2) | |
hooks=[tf.train.StopAtStepHook(last_step=int(1e4)), hookdefs.RestoreHook(init_fn)] | |
with tf.train.MonitoredTrainingSession(master=server.target, | |
is_chief=is_chief, | |
checkpoint_dir=model_dir, | |
config=config, | |
hooks=hooks) as sess: | |
# This is so dumb (This is how you get the real session) | |
real_sess = sess._sess._sess._sess._sess | |
while not sess.should_stop(): | |
print("Training Model") | |
# We only want to write summaries from the chief | |
if is_chief: | |
writer = tf.summary.FileWriter(logdir=model_dir, graph=real_sess.graph) | |
j = sess.run(tf.train.get_or_create_global_step()) | |
while True: | |
try: | |
# MSE and STFT Training Routine | |
print('\nTraining Regression Train Ops for 100 Steps') | |
j_start = j | |
k = 0 | |
while (j < j_start + 100): | |
# Save first time around | |
if k == 0 and is_chief: | |
save_path = saver.save( | |
sess=real_sess, | |
save_path=model_dir + 'model', | |
global_step=j, | |
latest_filename='regression') | |
print('Saved Model to %s' % save_path) | |
run = [stft_loss, mse_loss, audio_summaries, | |
mse_stft_scalar_summaries, *(regression_train_op)] | |
stft_loss_, mse_loss_, audio_, scalars_, _, _, j = sess.run(run) | |
if k % 5 == 0: | |
print('STFT Loss at Step %d: %f' % (j, stft_loss_)) | |
print('MSE Loss at Step %d: %f' % (j, mse_loss_)) | |
if is_chief: | |
writer.add_summary(audio_, j) | |
writer.add_summary(scalars_, j) | |
k += 1 | |
# GAN Training Routine | |
print('\nRunning GAN Train Ops for 50 Steps') | |
j_start = j | |
k = 0 | |
while (j < j_start + 50): | |
# Save First Time Around | |
if k == 0 and is_chief: | |
save_path = saver.save( | |
sess=real_sess, | |
save_path=model_dir + 'model', | |
global_step=j, | |
latest_filename='gan') | |
print('Saved Model to %s' % save_path) | |
audio_ = None | |
if k % 2 == 0: | |
run = [gen_gan_loss, gen_gan_loss_summary, | |
audio_summaries, *gen_gan_train_op] | |
run_type = 'GEN' | |
else: | |
run = [dis_gan_loss, dis_gan_loss_summary, | |
audio_summaries, *dis_gan_train_op] | |
run_type = 'DIS' | |
loss_, scalars_, audio_, _, j = sess.run(run) | |
if k % 2 == 0: | |
print(run_type + " Loss at Step %d: %f" % (j, loss_)) | |
if is_chief: | |
writer.add_summary(scalars_, j) | |
writer.add_summary(audio_, j) | |
k += 1 | |
except tf.errors.OutOfRangeError: | |
writer.close() | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.register("type", "bool", lambda v: v.lower() == "true") | |
parser.add_argument( | |
"--job_name", | |
type=str, | |
default="", | |
help="One of 'ps', 'worker'" | |
) | |
# Flags for defining the tf.train.Server | |
parser.add_argument( | |
"--task_index", | |
type=int, | |
default=0, | |
help="Index of task within the job" | |
) | |
FLAGS, unparsed = parser.parse_known_args() | |
if FLAGS.job_name == 'ps': | |
os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
elif FLAGS.job_name == 'worker': | |
os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.task_index) | |
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment