Created
July 30, 2017 04:58
-
-
Save guillaume-chevalier/6b01c4e43a123abf8db69fa97532993f to your computer and use it in GitHub Desktop.
Simple second order optimization with TensorFlow.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Simple second order optimization with TensorFlow.""" | |
import tensorflow as tf | |
#### | |
# 1. Define the problem | |
#### | |
# Here are 3 linear data points we'll want to fit on: | |
data_x = [0., 1., 2.] | |
data_y = [-1., 1., 3.] | |
batch_size = len(data_x) | |
# Input and Output. No batch_size for simplicity. | |
x = tf.placeholder(shape=[batch_size], dtype=tf.float32, name="x") | |
y = tf.placeholder(shape=[batch_size], dtype=tf.float32, name="y") | |
# Weight and bias. | |
# Computing hessians is currently only supported for one-dimensional tensors, | |
# so I did not bothered reshaping some 2D arrays for the parameters. | |
W = tf.Variable(tf.ones(shape=[1]), dtype=tf.float32, name="W") | |
b = tf.Variable(tf.zeros(shape=[1]), dtype=tf.float32, name="b") | |
# Making a prediction and comparing it to the true output | |
pred = x * W + b | |
loss = tf.reduce_mean(0.5 * (y - pred)**2) | |
# Preprocessings to the weight update | |
wrt_variables = [W, b] | |
grads = tf.gradients(loss, wrt_variables) | |
# The way I proceed here is equivalent to only compute the information | |
# contained in the diagonal of a single big hessian, because we isolated | |
# parameters from each others in the "wrt_variables" list. | |
hess = tf.hessians(loss, wrt_variables) | |
inv_hess = [tf.matrix_inverse(h) for h in hess] | |
# 2nd order weights update rule. Learning rate is of 1, because I | |
# trust the second derivatives obtained for such a small problem. | |
update_directions = [ | |
- tf.reduce_sum(h) * g | |
for h, g in zip(inv_hess, grads) | |
] | |
op_apply_updates = [ | |
v.assign_add(up) | |
for v, up in zip(wrt_variables, update_directions) | |
] | |
#### | |
# 2. Proceed to solve the regression | |
#### | |
# Initialize variables | |
sess = tf.Session() | |
sess.run(tf.global_variables_initializer()) | |
# First loss | |
initial_loss = sess.run( | |
loss, | |
feed_dict={ | |
x: data_x, | |
y: data_y | |
} | |
) | |
print("Initial loss:", initial_loss) | |
# Weight and bias update, "training" phase: | |
for iteration in range(25): | |
new_loss, _ = sess.run( | |
[loss, op_apply_updates], | |
feed_dict={ | |
x: data_x, | |
y: data_y | |
} | |
) | |
print("Loss after iteration {}: {}".format(iteration, new_loss)) | |
# Results: | |
print("Prediction:", sess.run(pred, feed_dict={x: data_x})) | |
print("Expected:", data_y) | |
#### | |
# 3. Program output | |
#### | |
# Initial loss: 0.333333 | |
# Loss after iteration 0: 0.3333333432674408 | |
# Loss after iteration 1: 0.19999998807907104 | |
# Loss after iteration 2: 0.11999998241662979 | |
# Loss after iteration 3: 0.07199998944997787 | |
# Loss after iteration 4: 0.04319998249411583 | |
# Loss after iteration 5: 0.025919994339346886 | |
# Loss after iteration 6: 0.01555199921131134 | |
# Loss after iteration 7: 0.009331194683909416 | |
# Loss after iteration 8: 0.00559872156009078 | |
# Loss after iteration 9: 0.003359234193339944 | |
# Loss after iteration 10: 0.0020155382808297873 | |
# Loss after iteration 11: 0.001209322945214808 | |
# Loss after iteration 12: 0.0007255963864736259 | |
# Loss after iteration 13: 0.0004353572439868003 | |
# Loss after iteration 14: 0.00026121424161829054 | |
# Loss after iteration 15: 0.00015672821609769017 | |
# Loss after iteration 16: 9.403712465427816e-05 | |
# Loss after iteration 17: 5.642247560899705e-05 | |
# Loss after iteration 18: 3.3853048080345616e-05 | |
# Loss after iteration 19: 2.0311867046984844e-05 | |
# Loss after iteration 20: 1.2187217180326115e-05 | |
# Loss after iteration 21: 7.312354227906326e-06 | |
# Loss after iteration 22: 4.387426542962203e-06 | |
# Loss after iteration 23: 2.6324989903514506e-06 | |
# Loss after iteration 24: 1.5794303180882707e-06 | |
# Prediction: [-0.99782324 1.0008707 2.99956465] | |
# Expected: [-1.0, 1.0, 3.0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment