-
-
Save fchollet/2c9b029f505d94e6b8cd7f8a5e244a4e to your computer and use it in GitHub Desktop.
import tensorflow as tf | |
import keras | |
# Flags for defining the tf.train.ClusterSpec | |
tf.app.flags.DEFINE_string("ps_hosts", "", | |
"Comma-separated list of hostname:port pairs") | |
tf.app.flags.DEFINE_string("worker_hosts", "", | |
"Comma-separated list of hostname:port pairs") | |
# Flags for defining the tf.train.Server | |
tf.app.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'") | |
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job") | |
FLAGS = tf.app.flags.FLAGS | |
def main(_): | |
ps_hosts = FLAGS.ps_hosts.split(",") | |
worker_hosts = FLAGS.worker_hosts.split(",") | |
# Create a cluster from the parameter server and worker hosts. | |
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) | |
# Create and start a server for the local task. | |
server = tf.train.Server(cluster, | |
job_name=FLAGS.job_name, | |
task_index=FLAGS.task_index) | |
if FLAGS.job_name == "ps": | |
server.join() | |
elif FLAGS.job_name == "worker": | |
# Assigns ops to the local worker by default. | |
with tf.device(tf.train.replica_device_setter( | |
worker_device="/job:worker/task:%d" % FLAGS.task_index, | |
cluster=cluster)): | |
# set Keras learning phase to train | |
keras.backend.set_learning_phase(1) | |
# do not initialize variables on the fly | |
keras.backend.manual_variable_initialization(True) | |
# Build Keras model | |
model = ... | |
# keras model predictions | |
preds = model.output | |
# placeholder for training targets | |
targets = tf.placeholder(...) | |
# our categorical crossentropy loss | |
xent_loss = tf.reduce_mean( | |
keras.objectives.categorical_crossentropy(targets, preds)) | |
# we create a global_step tensor for distributed training | |
# (a counter of iterations) | |
global_step = tf.Variable(0, name='global_step', trainable=False) | |
# apply regularizers if any | |
if model.regularizers: | |
total_loss = xent_loss * 1. # copy tensor | |
for regularizer in model.regularizers: | |
total_loss = regularizer(total_loss) | |
else: | |
total_loss = xent_loss | |
# set up TF optimizer | |
optimizer = tf.train.RMSPropOptimizer( | |
learning_rate, decay=0.9, momentum=FLAGS.momentum, epsilon=1e-8) | |
# Set up model update ops (batch norm ops). | |
# The gradients should only be computed after updating the moving average | |
# of the batch normalization parameters, in order to prevent a data race | |
# between the parameter updates and moving average computations. | |
with tf.control_dependencies(model.updates): | |
barrier = tf.no_op(name='update_barrier') | |
# define gradient updates | |
with tf.control_dependencies([barrier]): | |
grads = optimizer.compute_gradients( | |
total_loss, | |
model.trainable_weights, | |
gate_gradients=tf.Optimizer.GATE_OP, | |
aggregation_method=None, | |
colocate_gradients_with_ops=False) | |
# define train tensor | |
train_tensor = tf.with_dependencies([grad_updates], | |
total_loss, | |
name='train') | |
# blah blah | |
saver = tf.train.Saver() | |
summary_op = tf.merge_all_summaries() | |
init_op = tf.initialize_all_variables() | |
# Create a "supervisor", which oversees the training process. | |
sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), | |
logdir="/tmp/train_logs", | |
init_op=init_op, | |
summary_op=summary_op, | |
saver=saver, | |
global_step=global_step, | |
save_model_secs=600) | |
# The supervisor takes care of session initialization, restoring from | |
# a checkpoint, and closing when done or an error occurs. | |
with sv.managed_session(server.target) as sess: | |
# Loop until the supervisor shuts down or 1000000 steps have completed. | |
step = 0 | |
while not sv.should_stop() and step < 1000000: | |
# Run a training step asynchronously. | |
# See `tf.train.SyncReplicasOptimizer` for additional details on how to | |
# perform *synchronous* training. | |
# feed_dict must contain the model inputs (the tensors listed in model.inputs) | |
# and the "targets" placeholder we created ealier | |
# it's a dictionary mapping tensors to batches of Numpy data | |
# like: feed_dict={model.inputs[0]: np_train_data_batch, targets: np_train_labels_batch} | |
loss_value, step_value = sess.run([train_op, global_step], feed_dict={...}) | |
# Ask for all the services to stop. | |
sv.stop() | |
if __name__ == "__main__": | |
tf.app.run() |
Is this still the "official" way for how to do Data Parallelization in keras on a single machine?
@smhoang from tensorflow.python.ops.control_flow_ops import with_dependencies
@THUfl12 there is an issue with the above script when the first layer is an Embedding layer.
Since it is training in mini batch, the batch might not contain all indices in the dictionary (e.g. input_dim = 100 but the mini batch has only 60). In this case, TF throws an exception because the matrix size of the RHS and the LHS don't match.
Any hint to fix this?
@jerrypaytm,Embedding layer's parameters
model = Sequential()
model.add(Embedding(input_dim=1000, output_dim=64, input_length=10))
the model will take as input an integer matrix of size (batch, input_length).
the largest integer (i.e. word index) in the input should be no larger than 999 (vocabulary size).
now model.output_shape == (None, 10, 64), where None is the batch dimension
for this example, your input data can be something like:
data_example = numpy.array([[1,2,3,4,5,6,7,8,9,10], [1,2,3,4,5,6,7,8,9,10]]), the shape of data_example is (2,10)(in this example, the batch is 2) and the largest element of input data_example is 10 which is smaller than 1000, so it satisfies.
@THUfl12 Thanks for your help. I was fooled by the model checkpoint. Once I deleted the checkpoint, it refreshed some dimensions of the embedding to match the expected dictionary size.
Confused if this is the official way for data parallelism as well. According to this blog (https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html) it's as simple as:
`server = tf.train.Server.create_local_server()
sess = tf.Session(server.target)
from keras import backend as K
K.set_session(sess)`
Can anyone confirm this method from the blog works?
@vagarwal87 the below code needs to be liked to the cluster. In short keras_distributed.py is the official way to do this.
`server = tf.train.Server.create_local_server()
sess = tf.Session(server.target)
from keras import backend as K
K.set_session(sess)`
The code is not updated. train_tensor = tf.with_dependencies([grad_updates], total_loss, name='train') is not working anymore. Can any one fix it ? Thanks,