Skip to content

Instantly share code, notes, and snippets.

@hcl14
Last active January 5, 2022 09:09
Show Gist options
  • Save hcl14/0337f63ff7d0546bbde3b296a2451e63 to your computer and use it in GitHub Desktop.
Save hcl14/0337f63ff7d0546bbde3b296a2451e63 to your computer and use it in GitHub Desktop.
Simple example of second-order optimization using Newton's method in Tensorflow
# Newton's method in Tensorflow
# 'Vanilla' N.m. intended to work when loss function to be optimized is convex.
# One-layer linear network without activation is convex.
# If activation function is monotonic, the error surface associated with a single-layer model is convex.
# In other cases, Hessian will have negative eigenvalues in saddle points and other non-convex places of the surface
# To fix that, you can try different methods. One of those approaches is to do eigendecomposition of H and invert negative eigenvalues,
# making H "pushing out" in those directions, as described in this paper: Identifying and attacking the saddle point problem in high-dimensional non-convex optimization (https://papers.nips.cc/paper/5486-identifying-and-attacking-the-saddle-point-problem-in-high-dimensional-non-convex-optimization.pdf)
# This example uses one-layer network with leaky relu activation (activation is just for fun).
# See also this simple script:
# https://gist.github.com/guillaume-chevalier/6b01c4e43a123abf8db69fa97532993f
import os
import tensorflow as tf
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # supress tensorflow verbosity
####
# 1. Training data
####
# Some 3d convex function (elliptic paraboloid)
# This model does not require activation actually
a = 1
b = 2
c = 3
def paraboloid(x2dpoint, a=a, b=b, c=c):
x1, x2 = x2dpoint
return x1**2/a**2 + x2**2/b**2 + c
# sample data randomly
N = 5
# train data
train_datax = np.random.rand(N, 2)
train_datay = np.array([paraboloid(point) for point in train_datax])
# test data (same dimensions for simplicity)
test_datax = np.random.rand(N, 2)
test_datay = np.array([paraboloid(point) for point in test_datax])
####
# 2. Feed data to tensorflow
####
# Input and Output placeholders. No batches, as this is vanilla Newton which should be evaluated on entire training data
x = tf.placeholder(shape=train_datax.shape, dtype=tf.float32, name="x")
y = tf.placeholder(shape=train_datay.shape, dtype=tf.float32, name="y")
# Weights (all in one vector - essential) and bias for linear model
# weight matrix dimensions are input*output, bias: output
shape_w = [train_datax.shape[1], 1] # train_datay.shape[1]]
shape_b = [1] # train_datay.shape[1]
num_variables = shape_w[0]*shape_w[1] + shape_b[0]
model_vars = tf.Variable(tf.ones(shape=[num_variables]), dtype=tf.float32, name="W")
# Delta is an increment to weights vector
stored_delta = tf.Variable(tf.ones(shape=[num_variables]), dtype=tf.float32, name="W")
# function to get matrices from weight vector
def vec_to_variables(vec):
W = tf.reshape(vec[:shape_w[0]*shape_w[1]], shape_w)
# This is for generosity, just to show how the data is packed
b = tf.reshape(vec[shape_w[0]*shape_w[1]:(shape_w[0]*shape_w[1] + shape_b[0])], shape_b)
return W, b
W, b = vec_to_variables(model_vars)
# learning rate
alpha = tf.Variable(1.0, dtype=tf.float32, name="alpha")
####
# 3. Training code
####
# Making a prediction
def predict(W, b):
p = tf.matmul(x,W) + b #linear model
return tf.nn.leaky_relu(p, alpha=0.3)
# and comparing it to the true output
def compute_loss(labels, predicts):
return tf.reduce_mean(0.5 * (labels - predicts)**2)
pred = predict(W,b)
loss = compute_loss(y, pred)
# Model gradients vector
grads = tf.gradients(loss, model_vars)[0]
# Newton 2nd order weights update rule. Learning rate is found using line search
# https://stackoverflow.com/questions/35266370/tensorflow-compute-hessian-matrix-and-higher-order-derivatives#comment78202919_37666032
# Use tf.hessians, which will compute the portion of the Hessian relating to each variable in vars (so long as each variable is a vector). If you want the FULL Hessian (including all pairwise interactions between variables), you'll need to to start with a single super-vector containing every variable you care about, then slice from there
hess = tf.hessians(loss, model_vars)[0]
# print hessian
# hess = tf.Print(hess, [hess], summarize=10)
# compute inverse hessian
inv_hess = tf.matrix_inverse(hess)
# Compute increment which updates weigh vector: W1 = W - alpha*H^(-1)*G
# delta := H^(-1)*G
delta = tf.transpose(tf.matmul(inv_hess, tf.expand_dims(grads,-1)))[0]
# create function which depends on alpha with delta fixed
candidate_vars = model_vars - alpha*delta
W1, b1 = vec_to_variables(candidate_vars)
loss_alpha = compute_loss(y, predict(W1,b1))
# line search: find optimal value for alpha
opt = tf.train.GradientDescentOptimizer(learning_rate=0.1)
line_search = opt.minimize(loss_alpha, var_list=[alpha])
# compute alpha and store delta
compute_delta = [
stored_delta.assign(delta)
]
# update model weights
model_vars_new = model_vars - alpha*stored_delta
model_update = [
model_vars.assign(model_vars_new)
]
####
# 4. Main loop to fit and evaluate the model
####
# Initialize variables
sess = tf.Session()
# Weight and bias update, "training" phase:
with sess.as_default():
sess.run(tf.global_variables_initializer())
# First loss
initial_loss = sess.run(
loss,
feed_dict={
x: train_datax,
y: train_datay
}
)
print("Initial loss:", initial_loss)
for iteration in range(5):
# set learning rate value from which line search is started
sess.run(alpha.assign(1.0))
# I'm cautious and separate those operations,
# as I want hess and alpha to be computed first
hessian,_ = sess.run(
[hess,
compute_delta],
feed_dict={
x: train_datax,
y: train_datay
}
)
# find optimal alpha
for i in range(100):
sess.run(line_search, feed_dict={
x: train_datax,
y: train_datay
})
# then weights updated
sess.run(model_update)
print(hessian)
print("Alpha{}: {}".format(iteration, alpha.eval()))
# and then new loss computed
new_loss = sess.run(
loss,
feed_dict={
x: train_datax,
y: train_datay
}
)
print("Loss after iteration {}: {}".format(iteration, new_loss))
# Results:
print("Test prediction:", sess.run(pred, feed_dict={x: test_datax}).T)
print("Expected:", test_datay)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment