Created
September 29, 2015 21:06
-
-
Save AruniRC/dcdc81959e0658faeb02 to your computer and use it in GitHub Desktop.
Training SVMs on CNN encodings for 1:N and 1:1 identification tasks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
% ------------------------------------------------------------------------- | |
function info = traintest(opts, imdb, psi) | |
% ------------------------------------------------------------------------- | |
% Train using verification or not | |
verificationTask = isfield(imdb, 'pairs'); | |
if verificationTask, | |
% Verification using SVMs (Jeff's idea). | |
% For each of the templates in IJB-A, we can train a one-vs-rest | |
% linear SVM. The positive encoded features are derived from the | |
% media in the template, and negative encoded features for the media | |
% in a subject-disjoint training set (IJB-A Train set, here). | |
% | |
% For verification (1:1), given two templates (P,Q), encode P and | |
% evaluate the SVM for Q, then encode Q and evaluate the SVM for P | |
% and take the mean. This is the similarity score. This is a | |
% "probe classifier" and requires training an SVM at test time. The | |
% only assumption about identity that is needed is that there exists | |
% a large negative training set that does not contain either P or Q. | |
% TODO - print verification dataset info | |
scores = zeros(1, length(imdb.pairs.tid)) ; | |
negatives = ismember(imdb.images.set, [1 2]) ; % Train set (remains constant) | |
% numWorkers = matlabpool('size') ; | |
parfor i = 1:length(imdb.pairs.tid), | |
fprintf('\n-------------------------------------- '); | |
fprintf('\nVerification: pair: %d\n', i) ; | |
% Train an SVM for each verification pair | |
C = 1 ; | |
w = {} ; | |
b = {} ; | |
% Training classifier | |
for c = 1:2 | |
% negative data - Train set | |
y = -negatives ; | |
% positive data - pair template | |
tid = imdb.pairs.tid(i,c) ; | |
positives = (imdb.images.template == tid) ; | |
y(positives) = 1 ; | |
train = positives | negatives ; | |
np = sum(y(train) > 0) ; | |
nn = sum(y(train) < 0) ; | |
n = np + nn ; | |
fprintf('OVA-classifier: class: %d\n', c) ; | |
[w{c},b{c}] = vl_svmtrain(psi(:,train & y ~= 0), y(train & y ~= 0), 1/(n* C), ... | |
'epsilon', 0.001, 'verbose', 'biasMultiplier', 1, ... | |
'maxNumIterations', n * 200) ; | |
end | |
% Test pair | |
tid1 = imdb.pairs.tid(i,1) ; | |
tid2 = imdb.pairs.tid(i,2) ; | |
% template pooling - average features | |
pred1 = w{2}'* mean(psi(:, imdb.images.template == tid1),2) + b{2} ; | |
pred2 = w{1}'* mean(psi(:, imdb.images.template == tid2),2) + b{1} ; | |
% mean of the two unnormalized classifier scores | |
scores(i) = 0.5*(pred1 + pred2) ; | |
end | |
% !!! NOT SAVING SVM FOR EACH VERIFICATION PAIR NOW !!!!! | |
% info.w = cat(2,w{:}) ; | |
% info.b = cat(2,b{:}) ; | |
info.scores = scores ; % saving verification scores | |
save(opts.resultPath, '-struct', 'info') ; | |
% ROC curve | |
[tpr, tnr, rocinfo] = vl_roc(imdb.pairs.label, scores) ; | |
info.tpr = tpr ; info.tnr = tnr ; | |
disp(rocinfo.auc) ; info.auc = rocinfo.auc ; | |
disp(1-rocinfo.eer) ; info.eer = rocinfo.eer ; info.accu = 1-rocinfo.eer ; | |
% TAR at FAR of 0.1 and 0.01 | |
fpr = 1 - tnr ; | |
tprUnique = nonduplicate(tpr) ; | |
fprUnique = nonduplicate(fpr) ; | |
fprResampled = [0.01, 0.1] ; | |
tprInterp = interp1(fprUnique, tprUnique, fprResampled); | |
info.tar1 = tprInterp(2); | |
info.tar01 = tprInterp(1); | |
% vl_roc(imdb.pairs.label, info.scores, 'plot', 'fptp') ; | |
info.test = info ; | |
return ; | |
else | |
% classification task | |
multiLabel = (size(imdb.images.label,1) > 1) ; % e.g. PASCAL VOC cls | |
train = ismember(imdb.images.set, [1 2]) ; % images in train+val set (Gallery) | |
test = ismember(imdb.images.set, 3) ; % images in test set (Probe) | |
info.classes = find(imdb.meta.inUse) ; | |
% Train classifiers | |
C = 1 ; | |
w = {} ; | |
b = {} ; | |
% IJB-A open-set | |
galleryClasses = unique(imdb.images.label(train)) ; | |
probeClasses = unique(imdb.images.label(test)) ; | |
distractorClasses = setdiff(probeClasses, galleryClasses) ; | |
for c=1:numel(galleryClasses) | |
fprintf('\n-------------------------------------- '); | |
fprintf('OVA-classifier: class: %d\n', c) ; | |
if ~multiLabel | |
y = 2*(imdb.images.label == galleryClasses(c)) - 1 ; | |
else | |
y = imdb.images.label(c,:) ; | |
end | |
np = sum(y(train) > 0) ; | |
nn = sum(y(train) < 0) ; | |
n = np + nn ; | |
[w{c},b{c}] = vl_svmtrain(psi(:,train & y ~= 0), y(train & y ~= 0), 1/(n* C), ... | |
'epsilon', 0.001, 'verbose', 'biasMultiplier', 1, ... | |
'maxNumIterations', n * 200) ; | |
pred = w{c}'*psi + b{c} ; | |
% try cheap calibration | |
mp = median(pred(train & y > 0)) ; | |
mn = median(pred(train & y < 0)) ; | |
b{c} = (b{c} - mn) / (mp - mn) ; | |
w{c} = w{c} / (mp - mn) ; | |
pred = w{c}'*psi + b{c} ; | |
scores{c} = pred ; | |
[~,~,i]= vl_pr(y(train), pred(train)) ; ap(c) = i.ap ; ap11(c) = i.ap_interp_11 ; | |
[~,~,i]= vl_pr(y(test), pred(test)) ; tap(c) = i.ap ; tap11(c) = i.ap_interp_11 ; | |
[~,~,i]= vl_pr(y(train), pred(train), 'normalizeprior', 0.01) ; nap(c) = i.ap ; | |
[~,~,i]= vl_pr(y(test), pred(test), 'normalizeprior', 0.01) ; tnap(c) = i.ap ; | |
end | |
% Book keeping | |
info.w = cat(2,w{:}) ; | |
info.b = cat(2,b{:}) ; | |
info.scores = cat(1, scores{:}) ; | |
info.train.ap = ap ; | |
info.train.ap11 = ap11 ; | |
info.train.nap = nap ; | |
info.train.map = mean(ap) ; | |
info.train.map11 = mean(ap11) ; | |
info.train.mnap = mean(nap) ; | |
info.test.ap = tap ; | |
info.test.ap11 = tap11 ; | |
info.test.nap = tnap ; | |
info.test.map = mean(tap) ; | |
info.test.map11 = mean(tap11) ; | |
info.test.mnap = mean(tnap) ; | |
clear ap nap tap tnap scores pred ; | |
fprintf('mAP train: %.1f, test: %.1f\n', ... | |
mean(info.train.ap)*100, ... | |
mean(info.test.ap)*100); | |
% Compute predictions, confusion and accuracy | |
[~,preds] = max(info.scores,[],1) ; | |
preds = galleryClasses(preds); % changed this for open-set | |
[~,gts] = ismember(imdb.images.label, info.classes) ; | |
usedClassIndex = (gts ~= 0) ; % remove class not in use | |
gts = gts(usedClassIndex) ; | |
preds = preds(usedClassIndex) ; | |
train = train(usedClassIndex) ; | |
test = test(usedClassIndex) ; | |
[info.train.confusion, info.train.acc] = compute_confusion(numel(info.classes), gts(train), preds(train)) ; | |
[info.test.confusion, info.test.acc] = compute_confusion(numel(info.classes), gts(test), preds(test)) ; | |
fprintf('Accuracy train: %.1f, test: %.1f\n', ... | |
mean(info.train.acc)*100, ... | |
mean(info.test.acc)*100); | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment