Last active
August 10, 2018 08:25
-
-
Save Multihuntr/b8cb68316842ff68cab3062740a2a730 to your computer and use it in GitHub Desktop.
Accumulating gradients to reduce memory requirement per forward pass (using MNIST)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
from tensorflow.examples.tutorials.mnist import input_data | |
def simple_model(input): | |
# This ensures that the model will always be instantiated the same, for comparison. | |
hidden_initializer = tf.constant_initializer(np.random.uniform(-0.025, 0.025, size=[784,100])) | |
hidden = tf.layers.dense(input, 100, kernel_initializer=hidden_initializer) | |
out_initializer = tf.constant_initializer(np.random.uniform(-0.025, 0.025, size=[100,10])) | |
return tf.layers.dense(tf.nn.relu(hidden), 10, kernel_initializer=out_initializer) | |
inp = tf.placeholder(tf.float32, [None,784]) | |
targ = tf.placeholder(tf.float32, [None,10]) | |
# Define our divisor, used to normalise gradients across pseudo_batches | |
divisor = tf.Variable(0, trainable=False) | |
div_fl = tf.to_float(divisor) | |
reset_divisor = divisor.assign(0) | |
inc_divisor = divisor.assign(divisor+1) | |
# Make our model and optimizer and gradients | |
out = simple_model(inp) | |
opt = tf.train.GradientDescentOptimizer(learning_rate=1e-2) | |
loss = tf.losses.mean_squared_error(out, targ) | |
t_vars = tf.trainable_variables() | |
# compute gradients for a batch | |
grads, graph_vars = zip(*opt.compute_gradients(loss, t_vars)) | |
# Accumulation ops and variables | |
# create a copy of all trainable variables with `0` as initial values | |
accum_grads = [tf.Variable(tf.zeros_like(t_var.initialized_value()), trainable=False) for t_var in t_vars] | |
# create a op to initialize all accums vars (and zero the divisor again) | |
with tf.control_dependencies([reset_divisor]): | |
zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_grads] | |
# Create ops for accumulating the gradient (also adds one to the final divisor) | |
with tf.control_dependencies([inc_divisor]): | |
accum_ops = [accum_grad.assign_add(grad) for (accum_grad, grad) in zip(accum_grads, grads)] | |
# Create op that updates the weights (also divides accumulated gradients by the number of steps) | |
normalised_accum_grads = [accum_grad/div_fl for (accum_grad) in accum_grads] | |
train_op = opt.apply_gradients(zip(normalised_accum_grads, graph_vars)) | |
def graph_vars_equivalence(): | |
''' | |
Simply ensures that the graph_vars returned by `opt.compute_gradients` is the full | |
set of trainable variables | |
''' | |
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True, seed=764847) | |
inp_, targ_ = mnist.train.next_batch(1) | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
t_vars_ = sess.run(t_vars) | |
graph_vars_ = sess.run(graph_vars, {inp:inp_, targ: targ_}) | |
for t, g in zip(t_vars_, graph_vars_): | |
assert t.shape == g.shape | |
# Must point to the same memory to pass | |
assert np.all(t == g), 'Graph vars is not the same as t_vars' | |
def initial_weights_same_after_reinit(): | |
''' | |
Ensures that the weights are the same when we re-intialize the graph | |
''' | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
t_vars_1 = sess.run(t_vars) | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
t_vars_2 = sess.run(t_vars) | |
for v1, v2 in zip(t_vars_1, t_vars_2): | |
assert np.all(v1 == v2), 'Weights not initialized the same' | |
def same_seed_gives_same_examples(): | |
''' | |
Ensures that multiple runs of instantiating the dataset returns the same data | |
''' | |
mnist1 = input_data.read_data_sets("MNIST_data/", one_hot=True, seed=764847) | |
for x in range(10): | |
mnist1.train.next_batch(10) | |
inp_1, targ_1 = mnist1.train.next_batch(1) | |
mnist2 = input_data.read_data_sets("MNIST_data/", one_hot=True, seed=764847) | |
for x in range(100): | |
mnist2.train.next_batch(1) | |
inp_2, targ_2 = mnist2.train.next_batch(1) | |
assert np.all(inp_1 == inp_2), 'Batch size counts' | |
def direct_comp(batch_size): | |
''' | |
Directly compares the gradients of a standard forward pass with | |
several elements in a single batch to the accumulated gradients obtained | |
with several forward passes with individual batch elements. | |
If the accumulation method is working, then the accumulated gradients | |
at such a point should be approximately the same value as those calculated | |
from a standard forward pass with all elements at once. | |
''' | |
tf.set_random_seed(147258) | |
np.random.seed(123456) | |
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True, seed=764847) | |
inp_, targ_ = mnist.train.next_batch(batch_size) | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
for i, t in zip(inp_, targ_): | |
sess.run(accum_ops, {inp: [i], targ: [t]}) | |
accum_grads_ = sess.run(normalised_accum_grads) | |
standard_grads = sess.run(grads, {inp:inp_, targ: targ_}) | |
for (i, (acc, sta)) in enumerate(zip(accum_grads_, standard_grads)): | |
diff = np.max(abs(acc - sta)) | |
assert diff < 1e-7, 'Accumulated gradients out by at most {}'.format(diff) | |
def do_train(actual_batch, pseudo_batch, iterations=1000): | |
''' | |
Performs some number of steps of training and does some evaluation. | |
We expect that provided actual_batch*pseudo_batch doesn't change, then | |
neither should the final accuracy or final loss or final loss std | |
deviation. | |
''' | |
tf.set_random_seed(147258) | |
np.random.seed(123456) | |
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True, seed=764847) | |
total_sum = 0 | |
total_sum_2 = 0 | |
losses = [] | |
n_correct = 0 | |
n_incorrect = 0 | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
counter = 0 | |
accumulated_grads = [] | |
# Train | |
for x in range(iterations): | |
# Apparently np.sum isn't compatible with native summing over multiple arrays | |
# So we always pull the same batch size, and split it as needed. | |
inp_, targ_ = mnist.train.next_batch(actual_batch*pseudo_batch) | |
total_sum += np.sum(inp_) | |
# This makes them into a list of examples, each example shaped [actual_batch, 784] | |
inp_ = np.split(inp_, np.arange(actual_batch, actual_batch*pseudo_batch, actual_batch)) | |
targ_ = np.split(targ_, np.arange(actual_batch, actual_batch*pseudo_batch, actual_batch)) | |
iteration_loss = 0 | |
for y in range(pseudo_batch): | |
total_sum_2 += np.sum(inp_[y]) | |
_, loss_ = sess.run((accum_ops, loss), {inp: inp_[y], targ: targ_[y]}) | |
iteration_loss += loss_ | |
sess.run(train_op) | |
sess.run(zero_ops) | |
losses.append(iteration_loss/pseudo_batch) | |
# Evaluate | |
for x in range(10): | |
inp_, targ_ = mnist.test.next_batch(128) | |
pred = sess.run(out, {inp: inp_}) | |
comp = np.argmax(targ_, 1) == np.argmax(pred, 1) | |
c = np.count_nonzero(comp) | |
n_correct += c | |
n_incorrect += 128-c | |
total = n_correct + n_incorrect | |
prop_correct = n_correct/total*100 | |
losses = np.array(losses) | |
print('Accuracy: {:5.3f}%, Loss: mean: {:8.6f}, std: {:8.6f}'.format(prop_correct, np.mean(losses), np.std(losses))) | |
print('Total sum (i.e. simplest hash): {}'.format(total_sum)) | |
print('Total sum 2 (different summing): {}'.format(total_sum_2)) | |
# Initial tests | |
graph_vars_equivalence() | |
initial_weights_same_after_reinit() | |
same_seed_gives_same_examples() | |
direct_comp(1) | |
direct_comp(10) | |
direct_comp(64) | |
print('All direct comparisons passed') | |
num_steps = 50 | |
do_train(64, 1, num_steps) | |
do_train(1, 64, num_steps) | |
# do_train(1, 1, num_steps*64) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I also made a minimal example of it, without all the testing code.
But, you should be aware, it's not going to be quite the same if you use Momentum, Adam, Adagrad, Adadelta or really any other Optimizer. It should be mostly equivalent, but the numbers won't quite match, and the larger the
pseudo_batch
the more difference you would expect.