Skip to content

Instantly share code, notes, and snippets.

@yorkerlin yorkerlin/gist:d8acb388d03c6976728e Secret
Last active Aug 29, 2015

Embed
What would you like to do?
see https://github.com/yorkerlin/approxKLVB for detail information
function [alpha, sW, L, nlZ, dnlZ] = approxDiagonalWithLBFGS(hyper, covfunc, lik, x, y)
% Approximation to the posterior Gaussian Process by minimization of the
% KL-divergence. The function takes a specified covariance function (see
% covFunction.m) and likelihood function (see likelihoods.m), and is designed to
% be used with binaryGP.m. See also approximations.m.
%
% Written by Hannes Nickisch, 2007-03-29
% Modified by Wu Lin, 2014
n = size(x,1);
assert (n==length(y))
K = feval(covfunc{:}, hyper.cov, x); % evaluate the covariance matrix
alla_init{1} = [zeros(n,1); log(ones(n,1))]; % stack alpha/lambda together
for alla_id = 1:length(alla_init) % iterate over initial conditions
alla = alla_init{alla_id};
use_pinv=false; check_cond=true;
nlZ_old = Inf; nlZ_new = 1e100; it=0; % make sure the while loop starts
[alla nlZ_new] = lbfgs(alla, K, y, lik, hyper); %using L-BFGS to find the opt alla
% save results
alla_result{alla_id} = alla;
nlZ_result( alla_id) = nlZ_new;
end
alla_id = find(nlZ_result==min(nlZ_result)); alla_id = alla_id(1);
alla = alla_result{alla_id} % extract best result
%display the result
nlZ_new = min(nlZ_result);
alpha = alla(1:n,1)
v = exp(alla(n+1:end,1));
Sigma=diag(v)
% bound on neg log marginal likelihood
nlZ = nlZ_result( alla_id);
sW=[];
L=K\((K\Sigma-eye(n))')
%estimate the hpyer parameter
% do we want derivatives?
if nargout >=4
dnlZ = zeros(size(hyper.cov)); % allocate space for derivatives
v = exp(alla(n+1:end,1));
% parameters after optimization
m = K*alpha
%A = (K\(C*C'))'
%Sigma = A*K
%v=abs(diag(A*K))
invK_V=K\(diag(v));
dnlZ = hyper.cov; % allocate space for derivatives
for j=1:length(hyper.cov) % covariance hypers
dK = feval(covfunc{:},hyper.cov,x,j)
dnlZ(j)=0.5*sum(sum(dK.*(K\((eye(n)- (invK_V+alpha*m'))')),2),1);
end
dnlZ
dnlZ_lik=zeros(size(hyper.lik));
for j=1:length(hyper.lik) % likelihood hypers
lp_dhyp = likKL(v,lik,hyper.lik,y,m,[],[],j);
dnlZ_lik(j) = -sum(lp_dhyp);
end
dnlZ_mean=zeros(size(hyper.mean));
for j=1:length(hyp.mean) % mean hypers
dm_t = feval(mean{:}, hyper.mean, x, j);
dnlZ_mean(j) = -alpha'*dm_t;
end
end
%% evaluation of current negative log marginal likelihood depending on the
% parameters alpha (al) and lambda (la)
function [nlZ,dnlZ] = margLik_diagonal(alla,K,y,lik,hpyer)
% dimensions
n = length(y);
% extract single parameters
alpha = alla(1:n,1);
v = exp(alla(n+1:end,1));
invK_V=K\diag(v);
m = K*alpha;
invK=K\(eye(n));
%trace1=trace(invK_V)
%trace2=sum(diag(invK).*v)
%logdet1=-logdet(invK_V)
%logdet2=logdet(K)-sum(alla(n+1:end,1))
%logdet3=-logdet(invK)-sum(alla(n+1:end,1))
% calculate alpha related terms we need
if nargout==1
[a] = a_related2(m,v,y,lik,hpyer);
else
[a,dm,dV] = a_related2(m,v,y,lik,hpyer);
end
%negative Likelihood
nlZ = -a -logdet(invK_V)/2 -n/2 +(alpha'*K*alpha)/2 +trace(invK_V)/2;
if nargout>1 % gradient of Likelihood
dlZ_alpha = K*(alpha-dm);
dlZ_v = 0.5.*(diag(invK)-1.0./v)-dV;
dlZ_log_v = dlZ_v .* v;
dnlZ = [dlZ_alpha; dlZ_log_v];
end
function [alla2 nlZ] = lbfgs(alla, K, y, lik, hyper,n)
optMinFunc = struct('Display', 'FULL',...
'Method', 'lbfgs',...
'DerivativeCheck', 'off',...
'LS_type', 1,...
'MaxIter', 1000,...
'LS_interp', 1,...
'MaxFunEvals', 1000000,...
'Corr' , 100,...
'optTol', 1e-15,...
'progTol', 1e-15);
[alla2, nlZ] = minFunc(@margLik_diagonal, alla, optMinFunc, K, y, lik,hyper);
%% log(det(A)) for det(A)>0
function y = logdet(A)
% 1) y=det(A); if det(A)<=0, error('det(A)<=0'), end, y=log(y);
% => naive implementation, not numerically stable
% 2) U=chol(A); y=2*sum(log(diag(U)));
% => fast, but works for symmetric p.d. matrices only
% 3) det(A)=det(L)*det(U)=det(L)*prod(diag(U))
% => logdet(A)=log(sum(log(diag(U)))) if det(A)>0
[L,U]=lu(A);
u=diag(U);
if prod(sign(u))~=det(L)
error('det(A)<=0')
end
y=sum(log(abs(u))); % slower, but no symmetry needed
% 4) d=eig(A); if prod(sign(d))<1, error('det(A)<=0'), end
% y=sum(log(d)); y=real(y);
% => slowest
%% compute all terms related to a
% derivatives w.r.t diag(V) and m, 2nd derivatives w.r.t diag(V) and m
% using GPML 3.4 likelihood files
function [a,dm,dV,d2m,d2V,dmdV]=a_related2(m,v,y,lik,hyper)
if nargout<4
[a,dm,d2m,dV] = likKL(v, lik,hyper.lik,y,m);
a = sum(a);
else
[a,dm,d2m,dV,d2V,dmdV] = likKL(v, lik,hyper.lik,y,m)
a = sum(a);
end
%came from GPML 3.4/infKL.m which uses likelihood files to compute terms like a_related
function [ll,df,d2f,dv,d2v,dfdv] = likKL(v, lik, varargin)
N = 20; % number of quadrature points
[t,w] = gauher(N); % location and weights for Gaussian-Hermite quadrature
f = varargin{3}; % obtain location of evaluation
sv = sqrt(v); % smoothing width
ll = 0; df = 0; d2f = 0; dv = 0; d2v = 0; dfdv = 0; % init return arguments
for i=1:N % use Gaussian quadrature
varargin{3} = f + sv*t(i); % coordinate transform of the quadrature points
[lp,dlp,d2lp] = feval(lik{:},varargin{1:3},[],'infLaplace',varargin{6:end});
if nargout>0, ll = ll + w(i)*lp; % value of the integral
if nargout>1, df = df + w(i)*dlp; % derivative w.r.t. mean
if nargout>2, d2f = d2f + w(i)*d2lp; % 2nd derivative w.r.t. mean
if nargout>3 % derivative w.r.t. variance
ai = t(i)./(2*sv+eps); dvi = dlp.*ai; dv = dv + w(i)*dvi; % no 0 div
if nargout>4 % 2nd derivative w.r.t. variance
d2v = d2v + w(i)*(d2lp.*(t(i)^2/2)-dvi)./(v+eps)/2; % no 0 div
if nargout>5 % mixed second derivatives
dfdv = dfdv + w(i)*(ai.*d2lp);
end
end
end
end
end
end
end
function res = low_matrix_to_vector(mat)
assert(size(mat,1)==size(mat,2));
n = size(mat,1);
res = mat(find(tril(ones(n))));
function res = vector_to_low_matrix(vet)
n = floor(sqrt(2*length(vet)));
res = tril(ones(n));
res(res==1) = vet;
function res = convert_dC(dv)
res=dv;
for i=1:(length(dv)-1)
res=[res; dv(1+i:end)];
end
function res = convert_diag(d)
res =[];
for i=1:length(d)
res=[res;[d(i);zeros(length(d)-i,1)]];
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.