Skip to content

Instantly share code, notes, and snippets.

@spitis
Created April 24, 2017 20:12
Show Gist options
  • Save spitis/b5b49b5c8714e7b6b32865da3c302420 to your computer and use it in GitHub Desktop.
Save spitis/b5b49b5c8714e7b6b32865da3c302420 to your computer and use it in GitHub Desktop.
ewc
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import tensorflow as tf
class Model:
def __init__(self):
self.x = x = tf.placeholder(tf.float32, [None, 784])
self.y = y = tf.placeholder(tf.int64, [None])
self.ewc_loss_coef = tf.placeholder_with_default(0., [])
self.beta = tf.placeholder_with_default(0.95, [])
self.dropout = tf.placeholder_with_default(1., [])
x = tf.cond(tf.less(self.dropout, 1.), lambda: tf.nn.dropout(x, 0.8), lambda: x)
hx = tf.contrib.layers.fully_connected(
inputs=x, num_outputs=1500, activation_fn=tf.nn.relu)
hx = tf.nn.dropout(hx, self.dropout)
hx = tf.contrib.layers.fully_connected(
inputs=hx, num_outputs=1500, activation_fn=tf.nn.relu)
hx = tf.nn.dropout(hx, self.dropout)
hx = tf.contrib.layers.fully_connected(
inputs=hx, num_outputs=1500, activation_fn=tf.nn.relu)
hx = tf.nn.dropout(hx, self.dropout)
self.logits = logits = tf.contrib.layers.fully_connected(
inputs=hx, num_outputs=10)
self.softmax = tf.nn.softmax(logits)
self.var_list = tvs = tf.trainable_variables()
self.cross_entropy = cross_entropy = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=y, logits=self.logits))
opt = tf.train.GradientDescentOptimizer(0.1)
self.grads = grads = tf.gradients(cross_entropy, tvs)
# create gradient variance accumulators and update ops
grad_variances, fisher, sticky_weights = [], [], []
update_grad_variances, update_fisher, replace_fisher, update_sticky_weights, restore_sticky_weights =\
[], [], [], [], []
ewc_losses = []
for i, (g, v) in enumerate(zip(grads, tvs)):
print(g, v)
with tf.variable_scope("grad_variance"):
grad_variances.append(
tf.get_variable(
"gv_{}".format(v.name.replace(":", "_")),
v.get_shape().as_list(),
dtype=tf.float32,
trainable=False,
initializer=tf.zeros_initializer()))
fisher.append(
tf.get_variable(
"fisher_{}".format(v.name.replace(":", "_")),
v.get_shape().as_list(),
dtype=tf.float32,
trainable=False,
initializer=tf.zeros_initializer()))
with tf.variable_scope("sticky_weights"):
sticky_weights.append(
tf.get_variable(
"sticky_{}".format(v.name.replace(":", "_")),
v.get_shape().as_list(),
dtype=tf.float32,
trainable=False,
initializer=tf.zeros_initializer()))
update_grad_variances.append(
tf.assign(grad_variances[i], self.beta * grad_variances[i] + (
1 - self.beta) * g * g * tf.to_float(tf.shape(x)[0])))
update_fisher.append(tf.assign(fisher[i], fisher[i] + grad_variances[i]))
replace_fisher.append(tf.assign(fisher[i], grad_variances[i]))
update_sticky_weights.append(tf.assign(sticky_weights[i], v))
restore_sticky_weights.append(tf.assign(v, sticky_weights[i]))
ewc_losses.append(
tf.reduce_sum(tf.square(v - sticky_weights[i]) * fisher[i]))
ewc_loss = cross_entropy + self.ewc_loss_coef * .5 * tf.add_n(ewc_losses)
grads_ewc = tf.gradients(ewc_loss, tvs)
self.sticky_weights = sticky_weights
self.grad_variances = grad_variances
with tf.control_dependencies(update_grad_variances):
self.update_grad_variances = tf.no_op('update_grad_variances')
with tf.control_dependencies(update_grad_variances):
self.ts = tf.cond(
tf.equal(self.ewc_loss_coef, tf.constant(0.)),
lambda: opt.apply_gradients(zip(grads, tvs)),
lambda: opt.apply_gradients(zip(grads_ewc, tvs)))
with tf.control_dependencies(update_fisher):
self.update_fisher = tf.no_op('update_fisher')
with tf.control_dependencies(replace_fisher):
self.replace_fisher = tf.no_op('replace_fisher')
with tf.control_dependencies(update_sticky_weights):
self.update_sticky_weights = tf.no_op('update_sticky_weights')
with tf.control_dependencies(restore_sticky_weights):
self.restore_sticky_weights = tf.no_op('restore_sticky_weights')
self.acc = tf.reduce_mean(
tf.cast(tf.equal(y, tf.argmax(logits, 1)), tf.float32))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment