Skip to content

Instantly share code, notes, and snippets.

@dongkwan-kim
Created May 9, 2019 16:06
Show Gist options
  • Save dongkwan-kim/aa99fbfef1f18c7d38ab707ad2f164e1 to your computer and use it in GitHub Desktop.
Save dongkwan-kim/aa99fbfef1f18c7d38ab707ad2f164e1 to your computer and use it in GitHub Desktop.
from utils import *
params = {}
def create_variable(scope, name, shape, trainable=True, on_cpu=True, **kwargs) -> tf.Variable:
def _create_variable():
with tf.variable_scope(scope):
_w = tf.get_variable(name, shape, trainable=trainable, **kwargs)
params[_w.name] = _w
return _w
if on_cpu:
with tf.device("/cpu:0"):
w = _create_variable()
else:
w = _create_variable()
return w
def get_variable(scope, name, trainable=True) -> tf.Variable:
with tf.variable_scope(scope, reuse=True):
w = tf.get_variable(name, trainable=trainable)
params[w.name] = w
return w
def get_toy_data(n, xd):
xs = np.concatenate([np.random.random((n, xd)) / 2, np.random.random((n, xd)) / 2 + 0.5])
ys = np.concatenate([np.zeros((n,), dtype=np.int), np.ones((n,), dtype=np.int)])
permut = np.random.permutation(len(xs))
xs = xs[permut]
ys = ys[permut]
return xs, np.eye(2)[ys]
def average_gradients(tower_grads):
"""Calculate the average gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over individual gradients. The inner list is over the gradient
calculation for each tower.
Returns:
List of pairs of (gradient, variable) where the gradient has been averaged
across all towers.
"""
average_grads = []
for grad_and_vars in zip(*tower_grads):
# Note that each grad_and_vars looks like the following:
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
grads = []
for g, _ in grad_and_vars:
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
grads.append(expanded_g)
# Average over the 'tower' dimension.
grad = tf.concat(axis=0, values=grads)
grad = tf.reduce_mean(grad, 0)
# Keep in mind that the Variables are redundant because they are shared
# across towers. So .. we will just return the first tower's pointer to
# the Variable.
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
return average_grads
def main():
n = 6000
xd = 14 * 14
hd = 100
xs, ys = get_toy_data(n, xd)
X = tf.placeholder(tf.float32, [None, xd], name="X")
Y = tf.placeholder(tf.float32, [None, 2], name="Y")
w1 = create_variable("layer1", "weight", (xd, hd))
h = tf.nn.relu(tf.matmul(X, w1))
w2 = create_variable("layer2", "weight", (hd, hd))
h = tf.nn.relu(tf.matmul(h, w2))
w3 = create_variable("layer3", "weight", (hd, 2))
h = tf.matmul(h, w3)
hhat = tf.nn.softmax(h)
opt = tf.train.AdamOptimizer(learning_rate=0.001, name="opt")
gpu_names = get_available_gpu_names([1])
batch_size = 300
batch_size_per_gpu = batch_size // len(gpu_names)
grad_list = []
loss_list = []
with tf.variable_scope(tf.get_variable_scope()):
for i, gpu_name in enumerate(gpu_names):
with tf.device(gpu_name):
idx_start = i * batch_size_per_gpu
idx_end = (i + 1) * batch_size_per_gpu
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=h[idx_start:idx_end], labels=Y[idx_start:idx_end],
))
tf.get_variable_scope().reuse_variables()
grad = opt.compute_gradients(loss)
loss_list.append(loss)
grad_list.append(grad)
grads = average_gradients(grad_list)
train_op = opt.apply_gradients(grads)
sess = tf.Session(config=tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=True))
sess.run(tf.global_variables_initializer())
num_batch = n // batch_size
for epoch in range(100):
total_loss = 0
for batch_idx in range(num_batch):
idx_start = batch_idx * batch_size
idx_end = (batch_idx + 1) * batch_size
xs_b = xs[idx_start:idx_end]
ys_b = ys[idx_start:idx_end]
_, loss_value = sess.run([train_op, loss_list], feed_dict={
X: xs_b,
Y: ys_b
})
total_loss += np.mean(loss_value)
print(total_loss)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment