Skip to content

Instantly share code, notes, and snippets.

@warneracw21
Created August 24, 2018 18:31
Show Gist options
  • Save warneracw21/811d881afc19fbe79dbdf59318f61b35 to your computer and use it in GitHub Desktop.
Save warneracw21/811d881afc19fbe79dbdf59318f61b35 to your computer and use it in GitHub Desktop.
Alternating Training Schedule for an RNN regression and convolutional LSGAN model
""" Methods for Training on a distributed Network """
"""
The session is set up to be run in a distributed local environment.
The workers should be assigned to GPUs because they do most of the
heavy lifting in the compuation. Assign a different domain:port double
for each GPU available on the local host.
The parameter servers should simply be assigned to CPUs. See the
__main__ conditional to see how the GPUs and CPUs are set for both
the worker and the parameter server. If the job_name is a worker, the
visible GPU is set to the task index of that worker.
USAGE:
Run a Worker:
python task.py --job_name=worker --task_index=0 (GPU set to 0)
python task.py --job_name=worker --task_index=1 (GPU set to 1)
Run a PS:
python task.py --job_name=ps --task_index=0 (GPU set to NULL)
python task.py --job_name=ps --task_index=1 (GPU set to NULL)
AUTHOR: Andrew Warner
DATE: 7 August 2018
"""
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import sys
import os
# Import Libraries
import tensorflow as tf
import argparse
# Import Tensorflow Libraries
from tensorflow.contrib.training import HParams
from tensorflow.python.training import saver
from tensorflow.python import debug as tf_debug
# Import Methods
from tensorbose import tfutils, fileio
from functools import partial
# Import Modules
from trainer import data_provider
from trainer import preprocessing
from trainer import simon_ops
from trainer import hookdefs
from trainer import networks
from trainer import losses
# Constants
from trainer.params import get_hparams
HPARAMS = get_hparams()
# Set Environment Variables
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS = None
def main(argv=[]):
"""
*** Graph Definition and Session Execution ***
This function defines both the regression and GAN model for the
alternating training routine executed by the provided session.
"""
# Define the parameters and the cluster specs
params = get_hparams()
cluster = tf.train.ClusterSpec({'worker': ['localhost:2222', 'localhost:2223'],
'ps': ['localhost:2221', 'localhost:2224']})
# The server is specific to the job_name and task_index given in the process call
server = tf.train.Server(
cluster,
job_name=FLAGS.job_name,
task_index=FLAGS.task_index)
# Join the process similar to multi-threading
if FLAGS.job_name == 'ps':
server.join()
elif FLAGS.job_name == 'worker':
# This is taken directly from Tensorflow's Distributed Example
load_fn = tf.contrib.training.byte_size_load_fn
startegy = tf.contrib.training.GreedyLoadBalancingStrategy(2, load_fn)
replica_setter = tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)
with tf.device(replica_setter):
"""
The input is synchronous. Thus, the clean speech and noisy speech
should hold the same speech content. See data_provider.py for more
"""
with tf.variable_scope('input'):
noisy_speech_data, clean_speech_signal = data_provider.input_fn(mode='TRAIN')
###################################################################
""" Regression Model """
###################################################################
"""
The regression model uses the generator to produce sequences in the
frequency domain, upon which it runs a mean squared error loss against
the stft frequency domain of the clean speech. Moreover, the
cleaned frequency domain sequence from the generator is inverted
back into a time domain signal, from which another mean squared error
is calculated against the clean signal
"""
# Data Preperation for Generator and Loss Functions
with tf.variable_scope('preprocess'):
# Preprocess the Noisy Signal for Processing by the Generator
noisy_speech_signal, noisy_sequence_length = noisy_speech_data
noisy_frequency_sequence = preprocessing.gen_preprocess_data(
data=noisy_speech_signal,
data_type='features')
noisy_speech_data = (noisy_frequency_sequence, noisy_sequence_length)
# Preprocess the Clean Signal for the STFT Loss Function
clean_frequency_sequence = preprocessing.gen_preprocess_data(
data=clean_speech_signal,
data_type='labels')
# Preprocess the Clean Signal for MSE Signal Loss Function
clean_speech_signal = preprocessing._pad_to_max(
data=clean_speech_signal,
pad_size=params.max_signal_length,
params=params)
clean_speech_signal = tf.reshape(
tensor=clean_speech_signal,
shape=[-1, params.max_signal_length, 1])
# Generator and Postprocess Cleaning
with tf.variable_scope('generator') as gen_scope:
generator_fn = networks.generator_fn
cleaned_frequency_sequence= generator_fn(noisy_speech_data)
with tf.variable_scope('postprocess'):
cleaned_speech_signal = simon_ops.inverse_stft(
sequence=cleaned_frequency_sequence,
params=params)
cleaned_speech_signal = tf.reshape(
tensor=cleaned_speech_signal,
shape=[-1, params.max_signal_length, 1])
# Loss Functions
with tf.variable_scope('stft_loss'):
stft_loss = losses.stft_loss(cleaned_frequency_sequence,
clean_frequency_sequence)
with tf.variable_scope('signal_loss'):
mse_loss = losses.signal_loss(cleaned_speech_signal,
clean_speech_signal)
# Train Ops
"""
NOTE:
The loss for the mse_signal_loss is significantly less
than the loss for the stft_loss, thus, should be weighted
accordingly if the mse_loss is deamed necessary
"""
global_step = tf.train.get_or_create_global_step()
gen_lr = params.learning_rate_decay_fn(0.001, global_step)
with tf.variable_scope('regression_train_op'):
optimizer = tf.train.AdamOptimizer(learning_rate=gen_lr)
with tf.variable_scope('compute_gradients'):
stft_gradients = losses.compute_gradients(
optimizer=optimizer,
loss_fn=stft_loss)
mse_gradients = losses.compute_gradients(
optimizer=optimizer,
loss_fn=mse_loss)
with tf.variable_scope('apply_gradients'):
stft_train_op = optimizer.apply_gradients(stft_gradients)
mse_train_op = optimizer.apply_gradients(mse_gradients)
with tf.name_scope('update_global_step'):
global_step = tf.train.get_or_create_global_step()
global_step_update = global_step.assign_add(1)
regression_train_op = (stft_train_op, mse_train_op, global_step_update)
###################################################################
""" GAN Model Definition and Operations """
###################################################################
"""
The GAN model uses the same generator as the regression model and
a new discriminator. The discriminator uses the VGGish model
defined by Google and is initialized with frozen, trained, variables
from that model. The generator takes a sequence in the frequency domain
and produces a sequence in the frequency domain. This sequence is
transformed into a signal and passed into another preprocessing routine.
This other preprocessing routine performs another stft on the signal,
but with a much larger window, thus, capturing much more frequency
information. This sequence in then transormed into the mel log space
and passed into the VGGish CNN. The only trainable variable in the
discriminator is the dense output layer that transforms the
embedding produced by the CNN into a scalar logit that is then passed
into the GAN loss functions.
The GAN loss functions used for this model are the least squared
gan loss functions defined in the tf.contrib.gan module.
"""
tfgan = tf.contrib.gan
# Prepare the Signals for the Discriminator Network (Check Shapes in Params)
with tf.variable_scope('preprocess_dis'):
clean_frequency_sequence = preprocessing.dis_preprocess_data(
data=clean_speech_signal)
cleaned_frequency_sequence = preprocessing.dis_preprocess_data(
data=cleaned_speech_signal)
# Define both the Discriminators with shared variables
discriminator_fn = networks.discriminator_fn
with tf.variable_scope('discriminator') as dis_scope:
logits_from_gen = discriminator_fn(data=cleaned_frequency_sequence)
with tf.variable_scope(dis_scope, reuse=True):
logits_from_real = discriminator_fn(data=clean_frequency_sequence)
generator_variables = tf.trainable_variables(scope='generator')
discriminator_variables = tf.trainable_variables(scope='discriminator')
gan_model = tfgan.GANModel(
# Generator Scope
generator_inputs=noisy_speech_data,
generated_data=cleaned_frequency_sequence,
generator_variables=generator_variables,
generator_scope=gen_scope,
generator_fn=generator_fn,
# Discriminator Scope
real_data=clean_speech_signal,
discriminator_real_outputs=logits_from_real,
discriminator_gen_outputs=logits_from_gen,
discriminator_variables=discriminator_variables,
discriminator_scope=dis_scope,
discriminator_fn=discriminator_fn)
with tf.variable_scope('gan_loss'):
gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan.losses.least_squares_generator_loss,
discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss,
add_summaries=False)
with tf.variable_scope('gan_train_op'):
gen_optimizer = tf.train.AdamOptimizer(gen_lr, 0.9)
dis_optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
with tf.name_scope('update_global_step'):
global_step = tf.train.get_or_create_global_step()
global_step_update = global_step.assign_add(1)
with tf.variable_scope('generator'):
with tf.variable_scope('compute_gradients'):
gen_gan_loss = gan_loss.generator_loss
gen_gan_gradients = losses.compute_gradients(
optimizer=optimizer,
loss_fn=gen_gan_loss)
with tf.variable_scope('apply_gradients'):
gen_gan_train_op = (gen_optimizer.apply_gradients(gen_gan_gradients),
global_step_update)
with tf.variable_scope('discriminator'):
with tf.variable_scope('compute_gradients'):
dis_gan_loss = gan_loss.discriminator_loss
dis_gan_gradients = losses.compute_gradients(
optimizer=optimizer,
loss_fn=dis_gan_loss)
with tf.variable_scope('apply_gradients'):
dis_gan_train_op = (gen_optimizer.apply_gradients(dis_gan_gradients),
global_step_update)
###################################################################
""" Summaries """
###################################################################
with tf.variable_scope('summaries'):
with tf.variable_scope('audio_summaries'):
clean_signal_summary = simon_ops.add_audio_summary(
audio_data=clean_speech_signal,
name='clean_signal')
noisy_signal_summary = simon_ops.add_audio_summary(
audio_data=noisy_speech_signal,
name='noisy_signal')
cleaned_signal_summary = simon_ops.add_audio_summary(
audio_data=cleaned_speech_signal,
name='cleaned_signal')
audio_summaries = tf.summary.merge([clean_signal_summary,
noisy_signal_summary,
cleaned_signal_summary])
with tf.variable_scope('scalar_summaries'):
stft_loss_summary = tf.summary.scalar(
name='stft_loss',
tensor=stft_loss)
mse_loss_summary = tf.summary.scalar(
name='mse_loss',
tensor=mse_loss)
mse_stft_scalar_summaries = tf.summary.merge([stft_loss_summary,
mse_loss_summary])
gen_gan_loss_summary = tf.summary.scalar(
name='gen_gan',
tensor=gen_gan_loss)
dis_gan_loss_summary = tf.summary.scalar(
name='dis_gan',
tensor=dis_gan_loss)
#######################################################################
""" Training """
#######################################################################
"""
The training session needs to be handled manually to handle the
complex routine. The regression model and GAN model alternate between
training. The regression model is trained for 100 steps, then the GAN
is trained for 50 steps.
The variable for the GAN Discriminator need to be preloaded from
the VGGish model trained by Google. Moreover, it is necessary
to allow GPU growth as to avoid Out Of Memory Errors
"""
model_dir = 'models/LSGAN_clipping_larger_rnn/'
# Set up the hook for transferring the VGG Variables
variables_to_restore = tf.contrib.framework.get_model_variables()
init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
'vars/vggish_variables/dis_vggish_model.ckpt',
variables_to_restore)
# Variables for the Monitored Training Session
is_chief = FLAGS.task_index == 0
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
saver = tf.train.Saver(max_to_keep=2)
hooks=[tf.train.StopAtStepHook(last_step=int(1e4)), hookdefs.RestoreHook(init_fn)]
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=is_chief,
checkpoint_dir=model_dir,
config=config,
hooks=hooks) as sess:
# This is so dumb (This is how you get the real session)
real_sess = sess._sess._sess._sess._sess
while not sess.should_stop():
print("Training Model")
# We only want to write summaries from the chief
if is_chief:
writer = tf.summary.FileWriter(logdir=model_dir, graph=real_sess.graph)
j = sess.run(tf.train.get_or_create_global_step())
while True:
try:
# MSE and STFT Training Routine
print('\nTraining Regression Train Ops for 100 Steps')
j_start = j
k = 0
while (j < j_start + 100):
# Save first time around
if k == 0 and is_chief:
save_path = saver.save(
sess=real_sess,
save_path=model_dir + 'model',
global_step=j,
latest_filename='regression')
print('Saved Model to %s' % save_path)
run = [stft_loss, mse_loss, audio_summaries,
mse_stft_scalar_summaries, *(regression_train_op)]
stft_loss_, mse_loss_, audio_, scalars_, _, _, j = sess.run(run)
if k % 5 == 0:
print('STFT Loss at Step %d: %f' % (j, stft_loss_))
print('MSE Loss at Step %d: %f' % (j, mse_loss_))
if is_chief:
writer.add_summary(audio_, j)
writer.add_summary(scalars_, j)
k += 1
# GAN Training Routine
print('\nRunning GAN Train Ops for 50 Steps')
j_start = j
k = 0
while (j < j_start + 50):
# Save First Time Around
if k == 0 and is_chief:
save_path = saver.save(
sess=real_sess,
save_path=model_dir + 'model',
global_step=j,
latest_filename='gan')
print('Saved Model to %s' % save_path)
audio_ = None
if k % 2 == 0:
run = [gen_gan_loss, gen_gan_loss_summary,
audio_summaries, *gen_gan_train_op]
run_type = 'GEN'
else:
run = [dis_gan_loss, dis_gan_loss_summary,
audio_summaries, *dis_gan_train_op]
run_type = 'DIS'
loss_, scalars_, audio_, _, j = sess.run(run)
if k % 2 == 0:
print(run_type + " Loss at Step %d: %f" % (j, loss_))
if is_chief:
writer.add_summary(scalars_, j)
writer.add_summary(audio_, j)
k += 1
except tf.errors.OutOfRangeError:
writer.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--job_name",
type=str,
default="",
help="One of 'ps', 'worker'"
)
# Flags for defining the tf.train.Server
parser.add_argument(
"--task_index",
type=int,
default=0,
help="Index of task within the job"
)
FLAGS, unparsed = parser.parse_known_args()
if FLAGS.job_name == 'ps':
os.environ["CUDA_VISIBLE_DEVICES"] = ""
elif FLAGS.job_name == 'worker':
os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.task_index)
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment