Created
July 17, 2018 17:35
-
-
Save manish7294/123598515035fe5a37f0a049143e06ac to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
% File Type: Matlab | |
% Author: Junae Kim {junae.kim@gmail.com}, | |
% Chunhua Shen {chhshen@gmail.com} | |
% Creation Tuesday 26/02/2009 19:56. | |
% Last Revision: Friday 06/03/2009 10:40. | |
% | |
% Input : trn, training data | |
% [ dim, num ] = size(trn.X), | |
% trn.y is the labels | |
% varargin, parameters | |
% | |
% Output: metric.X, the learned metric per `pars.save_per_iter' iterations, (X in the paper) | |
% metric.L, the last projection matrix, the learned metric is | |
% `X = L * L^T' | |
% `metric.L ^T * trn.X' projects the data into new space | |
% info.time, computational time | |
function [metric, info] = boost_metric(trn, varargin) | |
fprintf('\nboosted metric (exponential loss) solver\n'); | |
% set parameters | |
pars.V = 1e-7; | |
pars.save_per_iter = 10; | |
pars.k_nearest = 3; | |
pars.EPS = 1e-5; | |
pars.alpha_lowerbound = 0; | |
pars.alpha_upperbound = 10; | |
pars.p = 1; | |
% extract parameter | |
pars = extract_pars(varargin,pars); | |
pars.maxiter = 1000; | |
maxiter = pars.maxiter; | |
V = pars.V; | |
EPS = pars.EPS; | |
alpha_lowerbound = pars.alpha_lowerbound; | |
alpha_upperbound = pars.alpha_upperbound; | |
opts.disp = 0; | |
opts.issym = 1; | |
opts.isreal = 1; | |
fprintf('triplets generated using %d nearest neighbors\n', pars.k_nearest) | |
triplets = knn_triplets(trn.X, trn.y, pars.k_nearest); | |
% d : the number of feature | |
% n : the number of training sample | |
% m : the number of triplets | |
% temp is always 3 | |
[d, n]=size(trn.X); | |
[m, temp]=size(triplets); | |
ui = ones(m, 1) / m; | |
Ar = zeros(d*d, m); % matrix of order d*d, m | |
X = zeros(d,d); | |
alpha = zeros(maxiter, 1); | |
for i=1:m | |
ai = trn.X(:, triplets(i,1)); | |
aj = trn.X(:, triplets(i,2)); | |
ak = trn.X(:, triplets(i,3)); | |
Ar(:, i) = vec( (ai - ak)*(ai - ak)' - ... | |
(ai - aj)*(ai - aj)' ); % convert matrix to vector | |
end | |
tic; | |
k = 0; | |
% update Z and X | |
for t = 1 : maxiter | |
temp = reshape(Ar * ui, d, d); %dxd | |
[v, DB] = eigs((temp+temp')/2, 1, 'la', opts ); | |
if DB < V, break; end | |
Z = v*v'; % dxd | |
H = Ar' * vec(Z); % m * 1 inner product | |
w_up = alpha_upperbound; | |
w_low = alpha_lowerbound; | |
while 1 | |
w = (w_low + w_up)/2.0; | |
tmp_ = H .* exp ( - w * H ) - V; % dx1; | |
lhs = tmp_' * ui(:); | |
if lhs > 0, | |
w_low = w; | |
else | |
w_up = w; | |
end | |
if w_up - w_low < EPS || abs(lhs) < EPS, break; end | |
end | |
alpha(t) = w; | |
% | |
% update u | |
% | |
% ui = ui(:) .* exp ( - alpha(t) * H(:) ); | |
tmp_ui = - alpha(t) * H(:); | |
tmp_ui(find(tmp_ui < -700)) = -700; | |
tmp_ui(find(tmp_ui > 700)) = 700; | |
ui = ui(:) .* exp ( tmp_ui ); %- alpha(t) * H(:) ); | |
clear tmp_ui; | |
ui = ui / sum(ui); | |
if ( pars.p < 1 ) % rubust boosting and its relation to bagging | |
ui( ui > pars.p/(1-pars.p) ) = 1/(1-pars.p); | |
ui( ui <= p/(1-pars.p) ) = ui( ui <= pars.p/(1-pars.p) ) / pars.p; | |
end | |
ui = ui / sum(ui); | |
% update X | |
X = X + alpha(t)*Z; | |
if mod(t, pars.save_per_iter) == 0 | |
k = k + 1; | |
metric.X(:, k) = vec(X); metric.iter_saved = t; | |
end | |
% update X | |
X = X + alpha(t)*Z; | |
if mod(t, pars.save_per_iter) == 0 | |
k = k + 1; | |
metric.X(:, k) = vec(X); | |
end | |
% | |
% check the process | |
% | |
fprintf('.'); if ~mod(t, 100), fprintf('\n'); end; | |
end | |
fprintf('training completed.\n'); | |
metric.X(:, k + 1) = vec(X); | |
info.time = toc; | |
info.num_iter = t; | |
% decompose X | |
[L, dd] = eig( (X+X')/2 ); | |
dd = real( diag(dd) ); | |
% | |
% reassemble X (ignore negative eigenvalues) | |
% | |
dd( dd<1e-6 ) = 0; | |
[ temp, ii ] = sort( dd, 'descend' ); | |
L = L( :, ii ); | |
dd = dd( ii ); | |
metric.L = ( L*diag( sqrt(dd) ) )'; | |
metric.eigenvalues = dd; | |
% | |
%----------------------------------------------------------------------- | |
% | |
function pars = extract_pars(vars,default) | |
if(nargin<2), default=[]; end | |
pars=default; | |
if(length(vars)==1) | |
p=vars{1}; | |
s=fieldnames(p); | |
for i=1:length(s) | |
eval(['pars.' s{i} '=p.' s{i} ';']); | |
end | |
else | |
for i=1:2:length(vars) | |
if(isstr(vars{i})) | |
if(i+1>length(vars)) error(sprintf('Parameter %s has no value\n',vars{i}));else val=vars{i+1}; end | |
if(isstr(val)) | |
eval(['pars.' vars{i} '=''' val ''';']); | |
else | |
eval(['pars.' vars{i} '=' sprintf('%i',val) ';']); | |
end | |
end | |
end | |
end | |
%----------------------------------------------------------------------- | |
function M = vec(A) | |
M = A(:); | |
% | |
% EoF | |
% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment