Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save guillaume-chevalier/6b01c4e43a123abf8db69fa97532993f to your computer and use it in GitHub Desktop.
Save guillaume-chevalier/6b01c4e43a123abf8db69fa97532993f to your computer and use it in GitHub Desktop.
Simple second order optimization with TensorFlow.
"""Simple second order optimization with TensorFlow."""
import tensorflow as tf
####
# 1. Define the problem
####
# Here are 3 linear data points we'll want to fit on:
data_x = [0., 1., 2.]
data_y = [-1., 1., 3.]
batch_size = len(data_x)
# Input and Output. No batch_size for simplicity.
x = tf.placeholder(shape=[batch_size], dtype=tf.float32, name="x")
y = tf.placeholder(shape=[batch_size], dtype=tf.float32, name="y")
# Weight and bias.
# Computing hessians is currently only supported for one-dimensional tensors,
# so I did not bothered reshaping some 2D arrays for the parameters.
W = tf.Variable(tf.ones(shape=[1]), dtype=tf.float32, name="W")
b = tf.Variable(tf.zeros(shape=[1]), dtype=tf.float32, name="b")
# Making a prediction and comparing it to the true output
pred = x * W + b
loss = tf.reduce_mean(0.5 * (y - pred)**2)
# Preprocessings to the weight update
wrt_variables = [W, b]
grads = tf.gradients(loss, wrt_variables)
# The way I proceed here is equivalent to only compute the information
# contained in the diagonal of a single big hessian, because we isolated
# parameters from each others in the "wrt_variables" list.
hess = tf.hessians(loss, wrt_variables)
inv_hess = [tf.matrix_inverse(h) for h in hess]
# 2nd order weights update rule. Learning rate is of 1, because I
# trust the second derivatives obtained for such a small problem.
update_directions = [
- tf.reduce_sum(h) * g
for h, g in zip(inv_hess, grads)
]
op_apply_updates = [
v.assign_add(up)
for v, up in zip(wrt_variables, update_directions)
]
####
# 2. Proceed to solve the regression
####
# Initialize variables
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# First loss
initial_loss = sess.run(
loss,
feed_dict={
x: data_x,
y: data_y
}
)
print("Initial loss:", initial_loss)
# Weight and bias update, "training" phase:
for iteration in range(25):
new_loss, _ = sess.run(
[loss, op_apply_updates],
feed_dict={
x: data_x,
y: data_y
}
)
print("Loss after iteration {}: {}".format(iteration, new_loss))
# Results:
print("Prediction:", sess.run(pred, feed_dict={x: data_x}))
print("Expected:", data_y)
####
# 3. Program output
####
# Initial loss: 0.333333
# Loss after iteration 0: 0.3333333432674408
# Loss after iteration 1: 0.19999998807907104
# Loss after iteration 2: 0.11999998241662979
# Loss after iteration 3: 0.07199998944997787
# Loss after iteration 4: 0.04319998249411583
# Loss after iteration 5: 0.025919994339346886
# Loss after iteration 6: 0.01555199921131134
# Loss after iteration 7: 0.009331194683909416
# Loss after iteration 8: 0.00559872156009078
# Loss after iteration 9: 0.003359234193339944
# Loss after iteration 10: 0.0020155382808297873
# Loss after iteration 11: 0.001209322945214808
# Loss after iteration 12: 0.0007255963864736259
# Loss after iteration 13: 0.0004353572439868003
# Loss after iteration 14: 0.00026121424161829054
# Loss after iteration 15: 0.00015672821609769017
# Loss after iteration 16: 9.403712465427816e-05
# Loss after iteration 17: 5.642247560899705e-05
# Loss after iteration 18: 3.3853048080345616e-05
# Loss after iteration 19: 2.0311867046984844e-05
# Loss after iteration 20: 1.2187217180326115e-05
# Loss after iteration 21: 7.312354227906326e-06
# Loss after iteration 22: 4.387426542962203e-06
# Loss after iteration 23: 2.6324989903514506e-06
# Loss after iteration 24: 1.5794303180882707e-06
# Prediction: [-0.99782324 1.0008707 2.99956465]
# Expected: [-1.0, 1.0, 3.0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment