Skip to content

Instantly share code, notes, and snippets.

@alphaville
Last active December 7, 2018 19:12
Show Gist options
  • Save alphaville/deacf56bff72937cb9750a7ce474cbea to your computer and use it in GitHub Desktop.
Save alphaville/deacf56bff72937cb9750a7ce474cbea to your computer and use it in GitHub Desktop.
function [r,a,rho,beta,q] = lbfgs(H0, v, Y, S)
%LBFGS performs a simple L-BFGS update
%
% Input arguments:
% - H0: Initial Hessian estimate (matrix); typical choice: H0 = (s'y)/(y'y)
% where (s,y) is the most recent pair of s and y
% - v: the vector on which the L-BFGS Hessian estimate should be applied;
% the function will return r = Hk*v, for given v
% - Y and S: buffers: Y = [y(k-1), ..., y(0)], where y(k-1) = g(k) - g(k-1)
% and Sk = [s(k-1), ..., s(0)], where s(k-1) = x(k) - x(k-1)
%
% Output arguments:
% - r = Hk*v
%
q = v;
sz = size(Y,2);
rho = zeros(sz, 1);
a = zeros(sz, 1);
for i=1:sz
si = S(:, i);
yi = Y(:, i);
rho(i) = 1./(si'*yi);
a(i) = rho(i) * si'*q;
q = q - a(i)*yi;
end
r = H0*q;
for i=sz:-1:1
yi = Y(:, i);
si = S(:, i);
beta = rho(i) * yi' * r;
r = r + si*(a(i) - beta);
end
fn assert_ae(x: f64, y: f64, tol: f64, msg: &str) {
if (x - y).abs() > tol {
panic!("({}) {} != {} [log(tol) = {}]", msg, x, y, tol.log10());
}
}
fn assert_array_ae(x: &[f64], y: &[f64], tol: f64, msg: &str) {
x.iter()
.zip(y.iter())
.for_each(|(&xi, &yi)| assert_ae(xi, yi, tol, msg));
}
#[test]
fn correctneess_buff_empty() {
let mut e = Estimator::new(3, 3);
let mut g = [-3.1, 1.5, 2.1];
e.update_hessian(&vec![0.0, 0.0, 0.0], &vec![0.0, 0.0, 0.0], 0.0, 0.0);
e.apply_hessian(&mut g);
let correct_dir = [-3.1, 1.5, 2.1];
assert_array_ae(&correct_dir, &g, 1e-10, "direction");
}
#[test]
fn correctneess_buff_1() {
let mut e = Estimator::new(3, 3);
let mut g = [-3.1, 1.5, 2.1];
e.update_hessian(&vec![0.0, 0.0, 0.0], &vec![0.0, 0.0, 0.0], 0.0, 0.0);
e.update_hessian(&[-0.5, 0.6, -1.2], &[0.1, 0.2, -0.3], 0., 0.);
e.apply_hessian(&mut g);
println!("{:#.3?}", e);
let correct_dir = [-1.100601247872944, -0.086568349404424, 0.948633011911515];
let alpha_correct = -1.488372093023256;
let rho_correct = 2.325581395348837;
assert_ae(alpha_correct, e.alpha[0], 1e-10, "alpha");
assert_ae(rho_correct, e.rho[0], 1e-10, "rho");
assert_array_ae(&correct_dir, &g, 1e-10, "direction");
}
#[test]
fn correctneess_buff_2() {
let mut e = Estimator::new(3, 3);
let mut g = [-3.1, 1.5, 2.1];
e.update_hessian(&vec![0.0, 0.0, 0.0], &vec![0.0, 0.0, 0.0], 0.0, 0.0);
e.update_hessian(&[-0.5, 0.6, -1.2], &[0.1, 0.2, -0.3], 0., 0.);
e.update_hessian(&[-0.75, 0.9, -1.9], &[0.19, 0.19, -0.44], 0.0, 0.0);
e.apply_hessian(&mut g);
let correct_dir = [-1.814749861477524, 0.895232314736337, 1.871795942557546];
assert_array_ae(&correct_dir, &g, 1e-10, "direction");
}
#[test]
fn correctneess_buff_overfull() {
let mut e = Estimator::new(3, 3);
let mut g = [-2.0, 0.2, -0.3];
e.update_hessian(&vec![0.0, 0.0, 0.0], &vec![0.0, 0.0, 0.0], 0.0, 0.0);
e.update_hessian(&[-0.5, 0.6, -1.2], &[0.1, 0.2, -0.3], 0., 0.);
e.update_hessian(&[-0.75, 0.9, -1.9], &[0.19, 0.19, -0.44], 0.0, 0.0);
e.update_hessian(&[-2.25, 3.5, -3.1], &[0.39, 0.39, -0.84], 0.0, 0.0);
e.update_hessian(&[-3.75, 6.3, -4.3], &[0.49, 0.59, -1.24], 0.0, 0.0);
e.apply_hessian(&mut g);
println!("{:#.3?}", e);
let gamma_correct = 0.077189939288812;
let alpha_correct = [-0.044943820224719, -0.295345104333868, -1.899418829910887];
let rho_correct = [1.123595505617978, 1.428571428571429, 13.793103448275861];
let dir_correct = [-0.933604237447365, -0.078865807539102, 1.016318412551302];
assert_ae(gamma_correct, e.gamma, 1e-10, "gamma");
assert_array_ae(&alpha_correct, &e.alpha, 1e-10, "alpha");
assert_array_ae(&rho_correct, &e.rho, 1e-10, "rho");
assert_array_ae(&dir_correct, &g, 1e-10, "direction");
}
% --- First
S = [0.1, 0.2, -0.3]';
Y = [-0.5, 0.6, -1.2]';
va = [-3.1, 1.5, 2.1]';
gamma = (Y(:,1)'*S(:,1))/(Y(:,1)'*Y(:,1));
dir_a_correct = [-1.100601247872944, -0.086568349404424, 0.948633011911515]';
[d,a,rho,beta,q] = lbfgs(gamma*eye(3), va, Y, S);
assert ( norm(dir_a_correct - d) < 1e-10 )
% --- Second
S = [[0.09, -0.01, -0.14]', S];
Y = [[-0.25, 0.30, -0.70]', Y];
vb = [-3.1, 1.5, 2.1]';
gamma = (Y(:,1)'*S(:,1))/(Y(:,1)'*Y(:,1));
dir_b_correct = [-1.814749861477524, 0.895232314736337, 1.871795942557546]';
d = lbfgs(gamma*eye(3), va, Y, S);
assert ( norm(dir_b_correct - d) < 1e-10 )
% --- Third
S = [[0.2, 0.2, -0.4]', S];
Y = [[-1.5, 2.6, -1.2]', Y];
vc = [1.1, 0.2, -0.3]';
gamma = (Y(:,1)'*S(:,1))/(Y(:,1)'*Y(:,1));
dir_c_correct = [1.025214973501680, 0.070767249318312, -1.444856343354091]';
d = lbfgs(gamma*eye(3), vc, Y, S);
assert ( norm(dir_c_correct - d) < 1e-10 )
% --- Fourth
S = [[0.1, 0.2, -0.4]', S]; S(:, end) = [];
Y = [[-1.5, 2.8, -1.2]', Y]; Y(:, end) = [];
vd = [-2.0, 0.2, -0.3]';
gamma = (Y(:,1)'*S(:,1))/(Y(:,1)'*Y(:,1));
dir_d_correct = [-0.933604237447365, -0.078865807539102, 1.016318412551302]';
[dd,ad,rhod,betad,qd] = lbfgs(gamma*eye(3), vd, Y, S);
assert ( norm(dir_d_correct - dd) < 1e-10 )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment