Last active
May 22, 2018 09:13
-
-
Save codescv/7973f516d1d23da1f5ad72efe9171c35 to your computer and use it in GitHub Desktop.
Distributed tensorflow code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def build_model(filename): | |
# ... same as before | |
global_step = tf.train.get_or_create_global_step() | |
train_op = optimizer.minimize(loss, global_step=global_step) | |
summary = tf.summary.merge_all() | |
global_vars = tf.global_variables() | |
uninitialized = tf.report_uninitialized_variables(tf.global_variables()) | |
global_init = tf.global_variables_initializer() | |
local_init = [tf.local_variables_initializer(), tf.tables_initializer()] | |
return { | |
'train': { | |
'train_op': train_op, | |
'loss': loss, | |
}, | |
'init': { | |
'global': global_init, | |
'local': local_init | |
}, | |
'global_variables': global_vars, | |
'uninitialized': uninitialized, | |
'cols_to_vars': cols_to_vars, | |
'summary': summary, | |
'global_step': global_step, | |
} | |
def main(): | |
task_type = tf_config['task']['type'] | |
task_index = tf_config['task']['index'] | |
# start server | |
cluster = tf.train.ClusterSpec(tf_config['cluster']) | |
server = tf.train.Server( | |
cluster, job_name=task_type, task_index=task_index) | |
if task_type == 'ps': | |
server.join() | |
return | |
is_chief = task_type == 'master' | |
# build graph | |
with tf.device(tf.train.replica_device_setter( | |
worker_device=f"/job:{task_type}/task:{task_index}", | |
ps_device="/job:ps", | |
cluster=tf_config['cluster'])): | |
model = build_model(filename='census_data/adult.data') | |
writer = None | |
if is_chief: | |
writer = tf.summary.FileWriter(logdir='tmp/model/lr-dist', graph=tf.get_default_graph()) | |
# create session | |
config = tf.ConfigProto(log_device_placement=True) | |
with tf.Session(target=server.target, config=config) as sess: | |
tf.get_default_graph().finalize() | |
if is_chief: | |
sess.run(model['init']['global']) | |
sess.run(model['init']['local']) | |
step = 0 | |
while step < 2000: | |
step = sess.run(model['global_step']) | |
logging.info('global step = %s', step) | |
writer.add_summary(sess.run(model['summary']), global_step=step) | |
time.sleep(1) | |
else: | |
ready = False | |
while not ready: | |
uninitialized = sess.run(model['uninitialized']) | |
ready = len(uninitialized) == 0 | |
if not ready: | |
logging.info('still waiting for variables to initialize: %s', uninitialized) | |
time.sleep(5) | |
sess.run(model['init']['local']) | |
for step in range(1000): | |
result = sess.run(model['train']) | |
logging.info('step = %s, loss= %s, global step = %s', step, result['loss'], sess.run(model['global_step'])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment