Skip to content

Instantly share code, notes, and snippets.

@hcl14
Created September 20, 2018 10:50
Show Gist options
  • Save hcl14/5a3f06e33cca71b0e7c8ca96761f4ea0 to your computer and use it in GitHub Desktop.
Save hcl14/5a3f06e33cca71b0e7c8ca96761f4ea0 to your computer and use it in GitHub Desktop.
Simple example of second-order optimization via Newton's method in Tensorflow on Iris dataset
# Newton's method in Tensorflow
# 'Vanilla' N.m. intended to work when loss function to be optimized is convex.
# One-layer linear network without activation is convex.
# If activation function is monotonic, the error surface associated with a single-layer model is convex.
# In other cases, Hessian will have negative eigenvalues in saddle points and other non-convex places of the surface
# To fix that, you can try different methods. One of those approaches is to do eigendecomposition of H and invert negative eigenvalues,
# making H "pushing out" in those directions, as described in this paper: Identifying and attacking the saddle point problem in high-dimensional non-convex optimization (https://papers.nips.cc/paper/5486-identifying-and-attacking-the-saddle-point-problem-in-high-dimensional-non-convex-optimization.pdf)
# This example uses one-layer network with leaky relu activation (activation is just for fun).
# Notice, that for one-layer network Hessian matrix has multidiagonal structure,
# so in theory you can reduce computations by not computing derivatives which does not exist in single-layer NN
# See also this simple script:
# https://gist.github.com/guillaume-chevalier/6b01c4e43a123abf8db69fa97532993f
import os
import tensorflow as tf
import numpy as np
from sklearn import datasets
import random
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # supress tensorflow verbosity
####
# 1. Training data, iris dataset
####
def data():
# import some data to play with
iris = datasets.load_iris()
x = iris.data
y = iris.target
# shuffle data
ntrain = x.shape[0]
arrayindices = list(range(ntrain))
random.shuffle(arrayindices)
x = x[arrayindices]
y = y[arrayindices]
# one-hot encoder
target_vector = y
n_labels = np.max(y) # 3
y = np.equal.outer(target_vector, np.arange(n_labels+1)).astype(np.float)
# take 20 samples as test
trainx = x[:-20]
train_labels = y[:-20]
testx = x[-20:]
test_labels = y[-20:]
return trainx, train_labels, testx, test_labels
# train and test data
train_datax, train_datay, test_datax, test_datay = data()
print("train data shape:{}, Train labels shape: {}".format(train_datax.shape,train_datay.shape))
####
# 2. Feed data to tensorflow
####
# Input and Output placeholders. No batches, as this is vanilla Newton which should be evaluated on entire training data
x = tf.placeholder(shape=None, dtype=tf.float32, name="x")
y = tf.placeholder(shape=None, dtype=tf.float32, name="y")
# Weights (all in one vector - essential) and bias for linear model
# weight matrix dimensions are input*output, bias: output
shape_w = [train_datax.shape[1], train_datay.shape[1]]
shape_b = [train_datay.shape[1]]
num_variables = shape_w[0]*shape_w[1] + shape_b[0]
model_vars = tf.Variable(tf.ones(shape=[num_variables]), dtype=tf.float32, name="W")
# Delta is an increment to weights vector
stored_delta = tf.Variable(tf.ones(shape=[num_variables]), dtype=tf.float32, name="W")
# function to get matrices from weight vector
def vec_to_variables(vec):
W = tf.reshape(vec[:shape_w[0]*shape_w[1]], shape_w)
# This is for generosity, just to show how the data is packed
b = tf.reshape(vec[shape_w[0]*shape_w[1]:(shape_w[0]*shape_w[1] + shape_b[0])], shape_b)
return W, b
W, b = vec_to_variables(model_vars)
# learning rate
alpha = tf.Variable(1.0, dtype=tf.float32, name="alpha")
####
# 3. Training code
####
# Making a prediction
def predict(W, b):
p = tf.matmul(x,W) + b #linear model
return tf.nn.leaky_relu(p, alpha=0.3)
# and comparing it to the true output
def compute_loss(labels, predicts):
return tf.reduce_mean(0.5 * (labels - predicts)**2)
pred = predict(W,b)
loss = compute_loss(y, pred)
# Model gradients vector
grads = tf.gradients(loss, model_vars)[0]
# Newton 2nd order weights update rule. Learning rate is found using line search
# https://stackoverflow.com/questions/35266370/tensorflow-compute-hessian-matrix-and-higher-order-derivatives#comment78202919_37666032
# Use tf.hessians, which will compute the portion of the Hessian relating to each variable in vars (so long as each variable is a vector). If you want the FULL Hessian (including all pairwise interactions between variables), you'll need to to start with a single super-vector containing every variable you care about, then slice from there
hess = tf.hessians(loss, model_vars)[0]
# print hessian
# hess = tf.Print(hess, [hess], summarize=10)
# compute inverse hessian
inv_hess = tf.matrix_inverse(hess)
# Compute increment which updates weigh vector: W1 = W - alpha*H^(-1)*G
# delta := H^(-1)*G
delta = tf.transpose(tf.matmul(inv_hess, tf.expand_dims(grads,-1)))[0]
# create function which depends on alpha with delta fixed
candidate_vars = model_vars - alpha*delta
W1, b1 = vec_to_variables(candidate_vars)
loss_alpha = compute_loss(y, predict(W1,b1))
# line search: find optimal value for alpha
opt = tf.train.AdamOptimizer(learning_rate=1.0) #SGD explodes somewhy
line_search = opt.minimize(loss_alpha, var_list=[alpha])
# compute alpha and store delta
compute_delta = [
stored_delta.assign(delta)
]
# update model weights
model_vars_new = model_vars - alpha*stored_delta
model_update = [
model_vars.assign(model_vars_new)
]
# evaluate accuracy
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
####
# 4. Main loop to fit and evaluate the model
####
# Initialize variables
sess = tf.Session()
# Weight and bias update, "training" phase:
with sess.as_default():
sess.run(tf.global_variables_initializer())
# First loss
initial_loss = sess.run(
loss,
feed_dict={
x: train_datax,
y: train_datay
}
)
print("Initial loss:", initial_loss)
for iteration in range(5):
# set learning rate value from which line search is started
sess.run(alpha.assign(1.0))
# I'm cautious and separate those operations,
# as I want hess and alpha to be computed first
hessian,_ = sess.run(
[hess,
compute_delta],
feed_dict={
x: train_datax,
y: train_datay
}
)
# find optimal alpha
for i in range(100):
sess.run(line_search, feed_dict={
x: train_datax,
y: train_datay
})
# then weights updated
sess.run(model_update)
#print(hessian)
print("Alpha{}: {}".format(iteration, alpha.eval()))
# and then new loss computed
new_loss = sess.run(
loss,
feed_dict={
x: train_datax,
y: train_datay
}
)
print("Loss after iteration {}: {}".format(iteration, new_loss))
# Results:
print("Train accuracy: %.2f %% " % sess.run(accuracy*100, feed_dict={x: train_datax, y:train_datay}))
print("Test accuracy: %.2f %% " % sess.run(accuracy*100, feed_dict={x: test_datax, y:test_datay}))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment