Skip to content

Instantly share code, notes, and snippets.

@Multihuntr
Last active August 10, 2018 08:25
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save Multihuntr/b8cb68316842ff68cab3062740a2a730 to your computer and use it in GitHub Desktop.
Save Multihuntr/b8cb68316842ff68cab3062740a2a730 to your computer and use it in GitHub Desktop.
Accumulating gradients to reduce memory requirement per forward pass (using MNIST)
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
def simple_model(input):
# This ensures that the model will always be instantiated the same, for comparison.
hidden_initializer = tf.constant_initializer(np.random.uniform(-0.025, 0.025, size=[784,100]))
hidden = tf.layers.dense(input, 100, kernel_initializer=hidden_initializer)
out_initializer = tf.constant_initializer(np.random.uniform(-0.025, 0.025, size=[100,10]))
return tf.layers.dense(tf.nn.relu(hidden), 10, kernel_initializer=out_initializer)
inp = tf.placeholder(tf.float32, [None,784])
targ = tf.placeholder(tf.float32, [None,10])
# Define our divisor, used to normalise gradients across pseudo_batches
divisor = tf.Variable(0, trainable=False)
div_fl = tf.to_float(divisor)
reset_divisor = divisor.assign(0)
inc_divisor = divisor.assign(divisor+1)
# Make our model and optimizer and gradients
out = simple_model(inp)
opt = tf.train.GradientDescentOptimizer(learning_rate=1e-2)
loss = tf.losses.mean_squared_error(out, targ)
t_vars = tf.trainable_variables()
# compute gradients for a batch
grads, graph_vars = zip(*opt.compute_gradients(loss, t_vars))
# Accumulation ops and variables
# create a copy of all trainable variables with `0` as initial values
accum_grads = [tf.Variable(tf.zeros_like(t_var.initialized_value()), trainable=False) for t_var in t_vars]
# create a op to initialize all accums vars (and zero the divisor again)
with tf.control_dependencies([reset_divisor]):
zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_grads]
# Create ops for accumulating the gradient (also adds one to the final divisor)
with tf.control_dependencies([inc_divisor]):
accum_ops = [accum_grad.assign_add(grad) for (accum_grad, grad) in zip(accum_grads, grads)]
# Create op that updates the weights (also divides accumulated gradients by the number of steps)
normalised_accum_grads = [accum_grad/div_fl for (accum_grad) in accum_grads]
train_op = opt.apply_gradients(zip(normalised_accum_grads, graph_vars))
def graph_vars_equivalence():
'''
Simply ensures that the graph_vars returned by `opt.compute_gradients` is the full
set of trainable variables
'''
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True, seed=764847)
inp_, targ_ = mnist.train.next_batch(1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
t_vars_ = sess.run(t_vars)
graph_vars_ = sess.run(graph_vars, {inp:inp_, targ: targ_})
for t, g in zip(t_vars_, graph_vars_):
assert t.shape == g.shape
# Must point to the same memory to pass
assert np.all(t == g), 'Graph vars is not the same as t_vars'
def initial_weights_same_after_reinit():
'''
Ensures that the weights are the same when we re-intialize the graph
'''
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
t_vars_1 = sess.run(t_vars)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
t_vars_2 = sess.run(t_vars)
for v1, v2 in zip(t_vars_1, t_vars_2):
assert np.all(v1 == v2), 'Weights not initialized the same'
def same_seed_gives_same_examples():
'''
Ensures that multiple runs of instantiating the dataset returns the same data
'''
mnist1 = input_data.read_data_sets("MNIST_data/", one_hot=True, seed=764847)
for x in range(10):
mnist1.train.next_batch(10)
inp_1, targ_1 = mnist1.train.next_batch(1)
mnist2 = input_data.read_data_sets("MNIST_data/", one_hot=True, seed=764847)
for x in range(100):
mnist2.train.next_batch(1)
inp_2, targ_2 = mnist2.train.next_batch(1)
assert np.all(inp_1 == inp_2), 'Batch size counts'
def direct_comp(batch_size):
'''
Directly compares the gradients of a standard forward pass with
several elements in a single batch to the accumulated gradients obtained
with several forward passes with individual batch elements.
If the accumulation method is working, then the accumulated gradients
at such a point should be approximately the same value as those calculated
from a standard forward pass with all elements at once.
'''
tf.set_random_seed(147258)
np.random.seed(123456)
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True, seed=764847)
inp_, targ_ = mnist.train.next_batch(batch_size)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i, t in zip(inp_, targ_):
sess.run(accum_ops, {inp: [i], targ: [t]})
accum_grads_ = sess.run(normalised_accum_grads)
standard_grads = sess.run(grads, {inp:inp_, targ: targ_})
for (i, (acc, sta)) in enumerate(zip(accum_grads_, standard_grads)):
diff = np.max(abs(acc - sta))
assert diff < 1e-7, 'Accumulated gradients out by at most {}'.format(diff)
def do_train(actual_batch, pseudo_batch, iterations=1000):
'''
Performs some number of steps of training and does some evaluation.
We expect that provided actual_batch*pseudo_batch doesn't change, then
neither should the final accuracy or final loss or final loss std
deviation.
'''
tf.set_random_seed(147258)
np.random.seed(123456)
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True, seed=764847)
total_sum = 0
total_sum_2 = 0
losses = []
n_correct = 0
n_incorrect = 0
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
counter = 0
accumulated_grads = []
# Train
for x in range(iterations):
# Apparently np.sum isn't compatible with native summing over multiple arrays
# So we always pull the same batch size, and split it as needed.
inp_, targ_ = mnist.train.next_batch(actual_batch*pseudo_batch)
total_sum += np.sum(inp_)
# This makes them into a list of examples, each example shaped [actual_batch, 784]
inp_ = np.split(inp_, np.arange(actual_batch, actual_batch*pseudo_batch, actual_batch))
targ_ = np.split(targ_, np.arange(actual_batch, actual_batch*pseudo_batch, actual_batch))
iteration_loss = 0
for y in range(pseudo_batch):
total_sum_2 += np.sum(inp_[y])
_, loss_ = sess.run((accum_ops, loss), {inp: inp_[y], targ: targ_[y]})
iteration_loss += loss_
sess.run(train_op)
sess.run(zero_ops)
losses.append(iteration_loss/pseudo_batch)
# Evaluate
for x in range(10):
inp_, targ_ = mnist.test.next_batch(128)
pred = sess.run(out, {inp: inp_})
comp = np.argmax(targ_, 1) == np.argmax(pred, 1)
c = np.count_nonzero(comp)
n_correct += c
n_incorrect += 128-c
total = n_correct + n_incorrect
prop_correct = n_correct/total*100
losses = np.array(losses)
print('Accuracy: {:5.3f}%, Loss: mean: {:8.6f}, std: {:8.6f}'.format(prop_correct, np.mean(losses), np.std(losses)))
print('Total sum (i.e. simplest hash): {}'.format(total_sum))
print('Total sum 2 (different summing): {}'.format(total_sum_2))
# Initial tests
graph_vars_equivalence()
initial_weights_same_after_reinit()
same_seed_gives_same_examples()
direct_comp(1)
direct_comp(10)
direct_comp(64)
print('All direct comparisons passed')
num_steps = 50
do_train(64, 1, num_steps)
do_train(1, 64, num_steps)
# do_train(1, 1, num_steps*64)
@Multihuntr
Copy link
Author

Multihuntr commented Jan 23, 2018

I also made a minimal example of it, without all the testing code.

But, you should be aware, it's not going to be quite the same if you use Momentum, Adam, Adagrad, Adadelta or really any other Optimizer. It should be mostly equivalent, but the numbers won't quite match, and the larger the pseudo_batch the more difference you would expect.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment