Skip to content

Instantly share code, notes, and snippets.

# guillaume-chevalier/tensorflow-simple-second-order-optimization.py Created Jul 30, 2017

Simple second order optimization with TensorFlow.
 """Simple second order optimization with TensorFlow.""" import tensorflow as tf #### # 1. Define the problem #### # Here are 3 linear data points we'll want to fit on: data_x = [0., 1., 2.] data_y = [-1., 1., 3.] batch_size = len(data_x) # Input and Output. No batch_size for simplicity. x = tf.placeholder(shape=[batch_size], dtype=tf.float32, name="x") y = tf.placeholder(shape=[batch_size], dtype=tf.float32, name="y") # Weight and bias. # Computing hessians is currently only supported for one-dimensional tensors, # so I did not bothered reshaping some 2D arrays for the parameters. W = tf.Variable(tf.ones(shape=), dtype=tf.float32, name="W") b = tf.Variable(tf.zeros(shape=), dtype=tf.float32, name="b") # Making a prediction and comparing it to the true output pred = x * W + b loss = tf.reduce_mean(0.5 * (y - pred)**2) # Preprocessings to the weight update wrt_variables = [W, b] grads = tf.gradients(loss, wrt_variables) # The way I proceed here is equivalent to only compute the information # contained in the diagonal of a single big hessian, because we isolated # parameters from each others in the "wrt_variables" list. hess = tf.hessians(loss, wrt_variables) inv_hess = [tf.matrix_inverse(h) for h in hess] # 2nd order weights update rule. Learning rate is of 1, because I # trust the second derivatives obtained for such a small problem. update_directions = [ - tf.reduce_sum(h) * g for h, g in zip(inv_hess, grads) ] op_apply_updates = [ v.assign_add(up) for v, up in zip(wrt_variables, update_directions) ] #### # 2. Proceed to solve the regression #### # Initialize variables sess = tf.Session() sess.run(tf.global_variables_initializer()) # First loss initial_loss = sess.run( loss, feed_dict={ x: data_x, y: data_y } ) print("Initial loss:", initial_loss) # Weight and bias update, "training" phase: for iteration in range(25): new_loss, _ = sess.run( [loss, op_apply_updates], feed_dict={ x: data_x, y: data_y } ) print("Loss after iteration {}: {}".format(iteration, new_loss)) # Results: print("Prediction:", sess.run(pred, feed_dict={x: data_x})) print("Expected:", data_y) #### # 3. Program output #### # Initial loss: 0.333333 # Loss after iteration 0: 0.3333333432674408 # Loss after iteration 1: 0.19999998807907104 # Loss after iteration 2: 0.11999998241662979 # Loss after iteration 3: 0.07199998944997787 # Loss after iteration 4: 0.04319998249411583 # Loss after iteration 5: 0.025919994339346886 # Loss after iteration 6: 0.01555199921131134 # Loss after iteration 7: 0.009331194683909416 # Loss after iteration 8: 0.00559872156009078 # Loss after iteration 9: 0.003359234193339944 # Loss after iteration 10: 0.0020155382808297873 # Loss after iteration 11: 0.001209322945214808 # Loss after iteration 12: 0.0007255963864736259 # Loss after iteration 13: 0.0004353572439868003 # Loss after iteration 14: 0.00026121424161829054 # Loss after iteration 15: 0.00015672821609769017 # Loss after iteration 16: 9.403712465427816e-05 # Loss after iteration 17: 5.642247560899705e-05 # Loss after iteration 18: 3.3853048080345616e-05 # Loss after iteration 19: 2.0311867046984844e-05 # Loss after iteration 20: 1.2187217180326115e-05 # Loss after iteration 21: 7.312354227906326e-06 # Loss after iteration 22: 4.387426542962203e-06 # Loss after iteration 23: 2.6324989903514506e-06 # Loss after iteration 24: 1.5794303180882707e-06 # Prediction: [-0.99782324 1.0008707 2.99956465] # Expected: [-1.0, 1.0, 3.0]
to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.