Last active March 31, 2024 02:36
Optimize TensorFlow & Keras models with L-BFGS from TensorFlow Probability
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
# Copyright © 2019 Pi-Yueh Chuang <>
# Distributed under terms of the MIT license.
"""An example of using tfp.optimizer.lbfgs_minimize to optimize a TensorFlow model.
This code shows a naive way to wrap a tf.keras.Model and optimize it with the L-BFGS
optimizer from TensorFlow Probability.
Python interpreter version: 3.6.9
TensorFlow version: 2.0.0
TensorFlow Probability version: 0.8.0
NumPy version: 1.17.2
Matplotlib version: 3.1.1
import numpy
import tensorflow as tf
import tensorflow_probability as tfp
from matplotlib import pyplot
def function_factory(model, loss, train_x, train_y):
"""A factory to create a function required by tfp.optimizer.lbfgs_minimize.
model [in]: an instance of `tf.keras.Model` or its subclasses.
loss [in]: a function with signature loss_value = loss(pred_y, true_y).
train_x [in]: the input part of training data.
train_y [in]: the output part of training data.
A function that has a signature of:
loss_value, gradients = f(model_parameters).
# obtain the shapes of all trainable parameters in the model
shapes = tf.shape_n(model.trainable_variables)
n_tensors = len(shapes)
# we'll use tf.dynamic_stitch and tf.dynamic_partition later, so we need to
# prepare required information first
count = 0
idx = [] # stitch indices
part = [] # partition indices
for i, shape in enumerate(shapes):
n = numpy.product(shape)
idx.append(tf.reshape(tf.range(count, count+n, dtype=tf.int32), shape))
count += n
part = tf.constant(part)
def assign_new_model_parameters(params_1d):
"""A function updating the model's parameters with a 1D tf.Tensor.
params_1d [in]: a 1D tf.Tensor representing the model's trainable parameters.
params = tf.dynamic_partition(params_1d, part, n_tensors)
for i, (shape, param) in enumerate(zip(shapes, params)):
model.trainable_variables[i].assign(tf.reshape(param, shape))
# now create a function that will be returned by this factory
def f(params_1d):
"""A function that can be used by tfp.optimizer.lbfgs_minimize.
This function is created by function_factory.
params_1d [in]: a 1D tf.Tensor.
A scalar loss and the gradients w.r.t. the `params_1d`.
# use GradientTape so that we can calculate the gradient of loss w.r.t. parameters
with tf.GradientTape() as tape:
# update the parameters in the model
# calculate the loss
loss_value = loss(model(train_x, training=True), train_y)
# calculate gradients and convert to 1D tf.Tensor
grads = tape.gradient(loss_value, model.trainable_variables)
grads = tf.dynamic_stitch(idx, grads)
# print out iteration & loss
tf.print("Iter:", f.iter, "loss:", loss_value)
# store loss value so we can retrieve later
tf.py_function(f.history.append, inp=[loss_value], Tout=[])
return loss_value, grads
# store these information as members so we can use them outside the scope
f.iter = tf.Variable(0)
f.idx = idx
f.part = part
f.shapes = shapes
f.assign_new_model_parameters = assign_new_model_parameters
f.history = []
return f
def plot_helper(inputs, outputs, title, fname):
"""Plot helper"""
pyplot.tricontourf(inputs[:, 0], inputs[:, 1], outputs.flatten(), 100)
if __name__ == "__main__":
# use float64 by default
# prepare training data
x_1d = numpy.linspace(-1., 1., 11)
x1, x2 = numpy.meshgrid(x_1d, x_1d)
inps = numpy.stack((x1.flatten(), x2.flatten()), 1)
outs = numpy.reshape(inps[:, 0]**2+inps[:, 1]**2, (x_1d.size**2, 1))
# prepare prediction model, loss function, and the function passed to L-BFGS solver
pred_model = tf.keras.Sequential(
tf.keras.layers.Dense(64, "tanh"),
tf.keras.layers.Dense(64, "tanh"),
tf.keras.layers.Dense(1, None)])
loss_fun = tf.keras.losses.MeanSquaredError()
func = function_factory(pred_model, loss_fun, inps, outs)
# convert initial model parameters to a 1D tf.Tensor
init_params = tf.dynamic_stitch(func.idx, pred_model.trainable_variables)
# train the model with L-BFGS solver
results = tfp.optimizer.lbfgs_minimize(
value_and_gradients_function=func, initial_position=init_params, max_iterations=500)
# after training, the final optimized parameters are still in results.position
# so we have to manually put them back to the model
# do some prediction
pred_outs = pred_model.predict(inps)
err = numpy.abs(pred_outs-outs)
print("L2-error norm: {}".format(numpy.linalg.norm(err)/numpy.sqrt(11)))
# plot figures
plot_helper(inps, outs, "Exact solution", "ext_soln.png")
plot_helper(inps, pred_outs, "Predicted solution", "pred_soln.png")
plot_helper(inps, err, "Absolute error", "abs_err.png")
# print out history
print(*func.history, sep='\n')
Copy link

Thanks for the work, it's a pity I only found out now, but I have a doubt. If I use the model created manually by tf instead of the model created by tf.keras.Sequentia, can I continue to use the above code?

Copy link

piyueh commented Jun 19, 2023

I think so, but I haven't been using TF for more than 3 years, so I can't say for sure.

Copy link

Thank you for your quick reply.

Copy link

I'm not sure if I had a specific problem with my implementation or TF version, but I had to modify the line number 40

shapes = tf.shape_n(model.trainable_variables)

with this

shapes = []
for i in range(len(model.trainable_variables)):

I'm using a tf.keras.Sequential() model.

Thank you very much!!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment