Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Simple second order optimization with TensorFlow.
"""Simple second order optimization with TensorFlow."""
import tensorflow as tf
####
# 1. Define the problem
####
# Here are 3 linear data points we'll want to fit on:
data_x = [0., 1., 2.]
data_y = [-1., 1., 3.]
batch_size = len(data_x)
# Input and Output. No batch_size for simplicity.
x = tf.placeholder(shape=[batch_size], dtype=tf.float32, name="x")
y = tf.placeholder(shape=[batch_size], dtype=tf.float32, name="y")
# Weight and bias.
# Computing hessians is currently only supported for one-dimensional tensors,
# so I did not bothered reshaping some 2D arrays for the parameters.
W = tf.Variable(tf.ones(shape=[1]), dtype=tf.float32, name="W")
b = tf.Variable(tf.zeros(shape=[1]), dtype=tf.float32, name="b")
# Making a prediction and comparing it to the true output
pred = x * W + b
loss = tf.reduce_mean(0.5 * (y - pred)**2)
# Preprocessings to the weight update
wrt_variables = [W, b]
grads = tf.gradients(loss, wrt_variables)
# The way I proceed here is equivalent to only compute the information
# contained in the diagonal of a single big hessian, because we isolated
# parameters from each others in the "wrt_variables" list.
hess = tf.hessians(loss, wrt_variables)
inv_hess = [tf.matrix_inverse(h) for h in hess]
# 2nd order weights update rule. Learning rate is of 1, because I
# trust the second derivatives obtained for such a small problem.
update_directions = [
- tf.reduce_sum(h) * g
for h, g in zip(inv_hess, grads)
]
op_apply_updates = [
v.assign_add(up)
for v, up in zip(wrt_variables, update_directions)
]
####
# 2. Proceed to solve the regression
####
# Initialize variables
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# First loss
initial_loss = sess.run(
loss,
feed_dict={
x: data_x,
y: data_y
}
)
print("Initial loss:", initial_loss)
# Weight and bias update, "training" phase:
for iteration in range(25):
new_loss, _ = sess.run(
[loss, op_apply_updates],
feed_dict={
x: data_x,
y: data_y
}
)
print("Loss after iteration {}: {}".format(iteration, new_loss))
# Results:
print("Prediction:", sess.run(pred, feed_dict={x: data_x}))
print("Expected:", data_y)
####
# 3. Program output
####
# Initial loss: 0.333333
# Loss after iteration 0: 0.3333333432674408
# Loss after iteration 1: 0.19999998807907104
# Loss after iteration 2: 0.11999998241662979
# Loss after iteration 3: 0.07199998944997787
# Loss after iteration 4: 0.04319998249411583
# Loss after iteration 5: 0.025919994339346886
# Loss after iteration 6: 0.01555199921131134
# Loss after iteration 7: 0.009331194683909416
# Loss after iteration 8: 0.00559872156009078
# Loss after iteration 9: 0.003359234193339944
# Loss after iteration 10: 0.0020155382808297873
# Loss after iteration 11: 0.001209322945214808
# Loss after iteration 12: 0.0007255963864736259
# Loss after iteration 13: 0.0004353572439868003
# Loss after iteration 14: 0.00026121424161829054
# Loss after iteration 15: 0.00015672821609769017
# Loss after iteration 16: 9.403712465427816e-05
# Loss after iteration 17: 5.642247560899705e-05
# Loss after iteration 18: 3.3853048080345616e-05
# Loss after iteration 19: 2.0311867046984844e-05
# Loss after iteration 20: 1.2187217180326115e-05
# Loss after iteration 21: 7.312354227906326e-06
# Loss after iteration 22: 4.387426542962203e-06
# Loss after iteration 23: 2.6324989903514506e-06
# Loss after iteration 24: 1.5794303180882707e-06
# Prediction: [-0.99782324 1.0008707 2.99956465]
# Expected: [-1.0, 1.0, 3.0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.