Skip to content

Instantly share code, notes, and snippets.

@manish7294
Created July 17, 2018 17:35
Show Gist options
  • Save manish7294/123598515035fe5a37f0a049143e06ac to your computer and use it in GitHub Desktop.
Save manish7294/123598515035fe5a37f0a049143e06ac to your computer and use it in GitHub Desktop.
% File Type: Matlab
% Author: Junae Kim {junae.kim@gmail.com},
% Chunhua Shen {chhshen@gmail.com}
% Creation Tuesday 26/02/2009 19:56.
% Last Revision: Friday 06/03/2009 10:40.
%
% Input : trn, training data
% [ dim, num ] = size(trn.X),
% trn.y is the labels
% varargin, parameters
%
% Output: metric.X, the learned metric per `pars.save_per_iter' iterations, (X in the paper)
% metric.L, the last projection matrix, the learned metric is
% `X = L * L^T'
% `metric.L ^T * trn.X' projects the data into new space
% info.time, computational time
function [metric, info] = boost_metric(trn, varargin)
fprintf('\nboosted metric (exponential loss) solver\n');
% set parameters
pars.V = 1e-7;
pars.save_per_iter = 10;
pars.k_nearest = 3;
pars.EPS = 1e-5;
pars.alpha_lowerbound = 0;
pars.alpha_upperbound = 10;
pars.p = 1;
% extract parameter
pars = extract_pars(varargin,pars);
pars.maxiter = 1000;
maxiter = pars.maxiter;
V = pars.V;
EPS = pars.EPS;
alpha_lowerbound = pars.alpha_lowerbound;
alpha_upperbound = pars.alpha_upperbound;
opts.disp = 0;
opts.issym = 1;
opts.isreal = 1;
fprintf('triplets generated using %d nearest neighbors\n', pars.k_nearest)
triplets = knn_triplets(trn.X, trn.y, pars.k_nearest);
% d : the number of feature
% n : the number of training sample
% m : the number of triplets
% temp is always 3
[d, n]=size(trn.X);
[m, temp]=size(triplets);
ui = ones(m, 1) / m;
Ar = zeros(d*d, m); % matrix of order d*d, m
X = zeros(d,d);
alpha = zeros(maxiter, 1);
for i=1:m
ai = trn.X(:, triplets(i,1));
aj = trn.X(:, triplets(i,2));
ak = trn.X(:, triplets(i,3));
Ar(:, i) = vec( (ai - ak)*(ai - ak)' - ...
(ai - aj)*(ai - aj)' ); % convert matrix to vector
end
tic;
k = 0;
% update Z and X
for t = 1 : maxiter
temp = reshape(Ar * ui, d, d); %dxd
[v, DB] = eigs((temp+temp')/2, 1, 'la', opts );
if DB < V, break; end
Z = v*v'; % dxd
H = Ar' * vec(Z); % m * 1 inner product
w_up = alpha_upperbound;
w_low = alpha_lowerbound;
while 1
w = (w_low + w_up)/2.0;
tmp_ = H .* exp ( - w * H ) - V; % dx1;
lhs = tmp_' * ui(:);
if lhs > 0,
w_low = w;
else
w_up = w;
end
if w_up - w_low < EPS || abs(lhs) < EPS, break; end
end
alpha(t) = w;
%
% update u
%
% ui = ui(:) .* exp ( - alpha(t) * H(:) );
tmp_ui = - alpha(t) * H(:);
tmp_ui(find(tmp_ui < -700)) = -700;
tmp_ui(find(tmp_ui > 700)) = 700;
ui = ui(:) .* exp ( tmp_ui ); %- alpha(t) * H(:) );
clear tmp_ui;
ui = ui / sum(ui);
if ( pars.p < 1 ) % rubust boosting and its relation to bagging
ui( ui > pars.p/(1-pars.p) ) = 1/(1-pars.p);
ui( ui <= p/(1-pars.p) ) = ui( ui <= pars.p/(1-pars.p) ) / pars.p;
end
ui = ui / sum(ui);
% update X
X = X + alpha(t)*Z;
if mod(t, pars.save_per_iter) == 0
k = k + 1;
metric.X(:, k) = vec(X); metric.iter_saved = t;
end
% update X
X = X + alpha(t)*Z;
if mod(t, pars.save_per_iter) == 0
k = k + 1;
metric.X(:, k) = vec(X);
end
%
% check the process
%
fprintf('.'); if ~mod(t, 100), fprintf('\n'); end;
end
fprintf('training completed.\n');
metric.X(:, k + 1) = vec(X);
info.time = toc;
info.num_iter = t;
% decompose X
[L, dd] = eig( (X+X')/2 );
dd = real( diag(dd) );
%
% reassemble X (ignore negative eigenvalues)
%
dd( dd<1e-6 ) = 0;
[ temp, ii ] = sort( dd, 'descend' );
L = L( :, ii );
dd = dd( ii );
metric.L = ( L*diag( sqrt(dd) ) )';
metric.eigenvalues = dd;
%
%-----------------------------------------------------------------------
%
function pars = extract_pars(vars,default)
if(nargin<2), default=[]; end
pars=default;
if(length(vars)==1)
p=vars{1};
s=fieldnames(p);
for i=1:length(s)
eval(['pars.' s{i} '=p.' s{i} ';']);
end
else
for i=1:2:length(vars)
if(isstr(vars{i}))
if(i+1>length(vars)) error(sprintf('Parameter %s has no value\n',vars{i}));else val=vars{i+1}; end
if(isstr(val))
eval(['pars.' vars{i} '=''' val ''';']);
else
eval(['pars.' vars{i} '=' sprintf('%i',val) ';']);
end
end
end
end
%-----------------------------------------------------------------------
function M = vec(A)
M = A(:);
%
% EoF
%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment