Skip to content

Instantly share code, notes, and snippets.

@AruniRC
Created September 29, 2015 21:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AruniRC/dcdc81959e0658faeb02 to your computer and use it in GitHub Desktop.
Save AruniRC/dcdc81959e0658faeb02 to your computer and use it in GitHub Desktop.
Training SVMs on CNN encodings for 1:N and 1:1 identification tasks
% -------------------------------------------------------------------------
function info = traintest(opts, imdb, psi)
% -------------------------------------------------------------------------
% Train using verification or not
verificationTask = isfield(imdb, 'pairs');
if verificationTask,
% Verification using SVMs (Jeff's idea).
% For each of the templates in IJB-A, we can train a one-vs-rest
% linear SVM. The positive encoded features are derived from the
% media in the template, and negative encoded features for the media
% in a subject-disjoint training set (IJB-A Train set, here).
%
% For verification (1:1), given two templates (P,Q), encode P and
% evaluate the SVM for Q, then encode Q and evaluate the SVM for P
% and take the mean. This is the similarity score. This is a
% "probe classifier" and requires training an SVM at test time. The
% only assumption about identity that is needed is that there exists
% a large negative training set that does not contain either P or Q.
% TODO - print verification dataset info
scores = zeros(1, length(imdb.pairs.tid)) ;
negatives = ismember(imdb.images.set, [1 2]) ; % Train set (remains constant)
% numWorkers = matlabpool('size') ;
parfor i = 1:length(imdb.pairs.tid),
fprintf('\n-------------------------------------- ');
fprintf('\nVerification: pair: %d\n', i) ;
% Train an SVM for each verification pair
C = 1 ;
w = {} ;
b = {} ;
% Training classifier
for c = 1:2
% negative data - Train set
y = -negatives ;
% positive data - pair template
tid = imdb.pairs.tid(i,c) ;
positives = (imdb.images.template == tid) ;
y(positives) = 1 ;
train = positives | negatives ;
np = sum(y(train) > 0) ;
nn = sum(y(train) < 0) ;
n = np + nn ;
fprintf('OVA-classifier: class: %d\n', c) ;
[w{c},b{c}] = vl_svmtrain(psi(:,train & y ~= 0), y(train & y ~= 0), 1/(n* C), ...
'epsilon', 0.001, 'verbose', 'biasMultiplier', 1, ...
'maxNumIterations', n * 200) ;
end
% Test pair
tid1 = imdb.pairs.tid(i,1) ;
tid2 = imdb.pairs.tid(i,2) ;
% template pooling - average features
pred1 = w{2}'* mean(psi(:, imdb.images.template == tid1),2) + b{2} ;
pred2 = w{1}'* mean(psi(:, imdb.images.template == tid2),2) + b{1} ;
% mean of the two unnormalized classifier scores
scores(i) = 0.5*(pred1 + pred2) ;
end
% !!! NOT SAVING SVM FOR EACH VERIFICATION PAIR NOW !!!!!
% info.w = cat(2,w{:}) ;
% info.b = cat(2,b{:}) ;
info.scores = scores ; % saving verification scores
save(opts.resultPath, '-struct', 'info') ;
% ROC curve
[tpr, tnr, rocinfo] = vl_roc(imdb.pairs.label, scores) ;
info.tpr = tpr ; info.tnr = tnr ;
disp(rocinfo.auc) ; info.auc = rocinfo.auc ;
disp(1-rocinfo.eer) ; info.eer = rocinfo.eer ; info.accu = 1-rocinfo.eer ;
% TAR at FAR of 0.1 and 0.01
fpr = 1 - tnr ;
tprUnique = nonduplicate(tpr) ;
fprUnique = nonduplicate(fpr) ;
fprResampled = [0.01, 0.1] ;
tprInterp = interp1(fprUnique, tprUnique, fprResampled);
info.tar1 = tprInterp(2);
info.tar01 = tprInterp(1);
% vl_roc(imdb.pairs.label, info.scores, 'plot', 'fptp') ;
info.test = info ;
return ;
else
% classification task
multiLabel = (size(imdb.images.label,1) > 1) ; % e.g. PASCAL VOC cls
train = ismember(imdb.images.set, [1 2]) ; % images in train+val set (Gallery)
test = ismember(imdb.images.set, 3) ; % images in test set (Probe)
info.classes = find(imdb.meta.inUse) ;
% Train classifiers
C = 1 ;
w = {} ;
b = {} ;
% IJB-A open-set
galleryClasses = unique(imdb.images.label(train)) ;
probeClasses = unique(imdb.images.label(test)) ;
distractorClasses = setdiff(probeClasses, galleryClasses) ;
for c=1:numel(galleryClasses)
fprintf('\n-------------------------------------- ');
fprintf('OVA-classifier: class: %d\n', c) ;
if ~multiLabel
y = 2*(imdb.images.label == galleryClasses(c)) - 1 ;
else
y = imdb.images.label(c,:) ;
end
np = sum(y(train) > 0) ;
nn = sum(y(train) < 0) ;
n = np + nn ;
[w{c},b{c}] = vl_svmtrain(psi(:,train & y ~= 0), y(train & y ~= 0), 1/(n* C), ...
'epsilon', 0.001, 'verbose', 'biasMultiplier', 1, ...
'maxNumIterations', n * 200) ;
pred = w{c}'*psi + b{c} ;
% try cheap calibration
mp = median(pred(train & y > 0)) ;
mn = median(pred(train & y < 0)) ;
b{c} = (b{c} - mn) / (mp - mn) ;
w{c} = w{c} / (mp - mn) ;
pred = w{c}'*psi + b{c} ;
scores{c} = pred ;
[~,~,i]= vl_pr(y(train), pred(train)) ; ap(c) = i.ap ; ap11(c) = i.ap_interp_11 ;
[~,~,i]= vl_pr(y(test), pred(test)) ; tap(c) = i.ap ; tap11(c) = i.ap_interp_11 ;
[~,~,i]= vl_pr(y(train), pred(train), 'normalizeprior', 0.01) ; nap(c) = i.ap ;
[~,~,i]= vl_pr(y(test), pred(test), 'normalizeprior', 0.01) ; tnap(c) = i.ap ;
end
% Book keeping
info.w = cat(2,w{:}) ;
info.b = cat(2,b{:}) ;
info.scores = cat(1, scores{:}) ;
info.train.ap = ap ;
info.train.ap11 = ap11 ;
info.train.nap = nap ;
info.train.map = mean(ap) ;
info.train.map11 = mean(ap11) ;
info.train.mnap = mean(nap) ;
info.test.ap = tap ;
info.test.ap11 = tap11 ;
info.test.nap = tnap ;
info.test.map = mean(tap) ;
info.test.map11 = mean(tap11) ;
info.test.mnap = mean(tnap) ;
clear ap nap tap tnap scores pred ;
fprintf('mAP train: %.1f, test: %.1f\n', ...
mean(info.train.ap)*100, ...
mean(info.test.ap)*100);
% Compute predictions, confusion and accuracy
[~,preds] = max(info.scores,[],1) ;
preds = galleryClasses(preds); % changed this for open-set
[~,gts] = ismember(imdb.images.label, info.classes) ;
usedClassIndex = (gts ~= 0) ; % remove class not in use
gts = gts(usedClassIndex) ;
preds = preds(usedClassIndex) ;
train = train(usedClassIndex) ;
test = test(usedClassIndex) ;
[info.train.confusion, info.train.acc] = compute_confusion(numel(info.classes), gts(train), preds(train)) ;
[info.test.confusion, info.test.acc] = compute_confusion(numel(info.classes), gts(test), preds(test)) ;
fprintf('Accuracy train: %.1f, test: %.1f\n', ...
mean(info.train.acc)*100, ...
mean(info.test.acc)*100);
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment