Last active
December 15, 2015 21:15
-
-
Save skaae/29282b1fb56864ebc0f9 to your computer and use it in GitHub Desktop.
Updated Confusion Class
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[[ A Confusion Matrix class | |
Example: | |
conf = optim.ConfusionMatrix( {'cat','dog','person'} ) -- new matrix | |
conf:zero() -- reset matrix | |
for i = 1,N do | |
conf:add( neuralnet:forward(sample), label ) -- accumulate errors | |
end | |
print(conf) -- print matrix | |
image.display(conf:render()) -- render matrix | |
]] | |
--local ConfusionMatrix = torch.class('optim.MyConfusionMatrix') | |
local ConfusionMatrix = torch.class('ConfusionMatrix') | |
function ConfusionMatrix:__init(nclasses, classes) | |
if type(nclasses) == 'table' then | |
classes = nclasses | |
nclasses = #classes | |
end | |
self.mat = torch.FloatTensor(nclasses,nclasses):zero() | |
self.valids = torch.FloatTensor(nclasses):zero() | |
self.unionvalids = torch.FloatTensor(nclasses):zero() | |
self.nclasses = nclasses | |
self.totalValid = 0 | |
self.averageValid = 0 | |
if classes then | |
self.classes = classes | |
else | |
local c = {} | |
for i = 1,nclasses do c[i] = tostring(i) end | |
self.classes = c | |
end | |
end | |
function ConfusionMatrix:add(prediction, target) | |
if type(prediction) == 'number' then | |
-- comparing numbers | |
self.mat[target][prediction] = self.mat[target][prediction] + 1 | |
elseif type(target) == 'number' then | |
-- prediction is a vector, then target assumed to be an index | |
self.prediction_1d = self.prediction_1d or torch.FloatTensor(self.nclasses) | |
self.prediction_1d:copy(prediction) | |
local _,prediction = self.prediction_1d:max(1) | |
self.mat[target][prediction[1]] = self.mat[target][prediction[1]] + 1 | |
else | |
-- both prediction and target are vectors | |
self.prediction_1d = self.prediction_1d or torch.FloatTensor(self.nclasses) | |
self.prediction_1d:copy(prediction) | |
self.target_1d = self.target_1d or torch.FloatTensor(self.nclasses) | |
self.target_1d:copy(target) | |
local _,prediction = self.prediction_1d:max(1) | |
local _,target = self.target_1d:max(1) | |
self.mat[target[1]][prediction[1]] = self.mat[target[1]][prediction[1]] + 1 | |
end | |
end | |
function ConfusionMatrix:batchAdd(predictions, targets) | |
local preds, targs, __ | |
if predictions:dim() == 1 then | |
-- predictions is a vector of classes | |
preds = predictions | |
elseif predictions:dim() == 2 then | |
-- prediction is a matrix of class likelihoods | |
if predictions:size(2) == 1 then | |
-- or prediction just needs flattening | |
preds = predictions:select(2,1) | |
else | |
__,preds = predictions:max(2) | |
preds:resize(preds:size(1)) | |
end | |
else | |
error("predictions has invalid number of dimensions") | |
end | |
if targets:dim() == 1 then | |
-- targets is a vector of classes | |
targs = targets | |
elseif targets:dim() == 2 then | |
-- targets is a matrix of one-hot rows | |
if targets:size(2) == 1 then | |
-- or targets just needs flattening | |
targs = targets:select(2,1) | |
else | |
__,targs = targets:max(2) | |
targs:resize(targs:size(1)) | |
end | |
else | |
error("targets has invalid number of dimensions") | |
end | |
--loop over each pair of indices | |
for i = 1,preds:size(1) do | |
self.mat[targs[i]][preds[i]] = self.mat[targs[i]][preds[i]] + 1 | |
end | |
end | |
function ConfusionMatrix:zero() | |
self.mat:zero() | |
self.valids:zero() | |
self.unionvalids:zero() | |
self.totalValid = 0 | |
self.averageValid = 0 | |
end | |
local function isNaN(number) | |
return number ~= number | |
end | |
local function remNaN(x,self) | |
for i = 1, self.nclasses do | |
if isNaN(x[{1,i}]) then | |
x[{1,i}] = 0 | |
end | |
end | |
return x | |
end | |
local function getErrors(self) | |
local tp = torch.diag(self.mat):resize(1,self.nclasses ) | |
local fn = (torch.sum(self.mat,2)-torch.diag(self.mat)):t() | |
local fp = torch.sum(self.mat,1)-torch.diag(self.mat) | |
local tn = torch.Tensor(1,self.nclasses):fill(torch.sum(self.mat)):typeAs(tp) - tp - fn - fp | |
return tp, tn, fp, fn | |
end | |
function ConfusionMatrix:getConfusion() | |
return getErrors(self) | |
end | |
function ConfusionMatrix:printscore(name,mytitle) | |
local score,class_app,class,val | |
if name == "sensitivity" then | |
score = self:sensitivity() | |
elseif name == 'specificity' then | |
score = self:specificity() | |
elseif name == 'positivePredictiveValue' then | |
score = self:positivePredictiveValue() | |
elseif name == 'negativePredictiveValue' then | |
score = self:negativePredictiveValue() | |
elseif name == 'falsePositiveRate' then | |
score = self:falsePositiveRate() | |
elseif name == 'falseDiscoveryRate' then | |
score = self:falseDiscoveryRate() | |
elseif name == 'classAccuracy' then | |
score = self:classAccuracy() | |
elseif name == 'F1' then | |
score = self:F1() | |
elseif name == 'matthewsCorrelation' then | |
score = self:matthewsCorrelation() | |
else | |
print("Unknown error type") | |
error() | |
end | |
if mytitle then | |
name = mytitle..": "..name | |
end | |
local ln = "|" | |
local ls = "|" | |
for i = 1,self.nclasses do | |
val = string.format("%.4f", score[{1,i}]) | |
class = self.classes[i] | |
class_app = math.max(1,4-math.floor(#class / 2)) | |
class = string.rep(" ",class_app)..class..string.rep(" ",class_app+1-#class%2) | |
ln = ln..class.."|" | |
ls =ls.." "..val | |
if (#ls+1) < #ln then | |
ls = ls .. string.rep(" ",#ln-#ls-1) | |
end | |
ls = ls .."|" | |
end | |
local line = string.rep("-",#ln) | |
ln = ln.."\n"..line.."\n"..ls | |
print(line) | |
print(string.rep(" ",math.min(0,math.floor(#ls/2)-math.floor(#name/2) ))..name) | |
print(line) | |
print(ln) | |
print(line) | |
end | |
function ConfusionMatrix:accuracy() | |
local tp, tn, fp, fn = getErrors(self) | |
return tp:sum() / self.mat:sum() | |
end | |
function ConfusionMatrix:matthewsCorrelation() | |
local tp, tn, fp, fn = getErrors(self) | |
local numerator = torch.cmul(tp,tn) - torch.cmul(fp,fn) | |
local denominator = torch.sqrt((tp+fp):cmul(tp+fn):cmul(tn+fp):cmul(tn+fn)) | |
local mcc = torch.cdiv(numerator,denominator) | |
local mcc = remNaN(mcc,self) | |
return mcc | |
end | |
function ConfusionMatrix:sensitivity() | |
local tp, tn, fp, fn = getErrors(self) | |
local res = torch.cdiv(tp, tp + fn ) | |
local res = remNaN(res,self) | |
return res -- TP / (TP + FN) | |
end | |
function ConfusionMatrix:specificity() | |
local tp, tn, fp, fn = getErrors(self) | |
local res = torch.cdiv(tn, tn + fp) -- TN / (TN + FP) | |
local res = remNaN(res,self) | |
return res -- TP / (TP + FN) | |
end | |
function ConfusionMatrix:positivePredictiveValue() | |
local tp, tn, fp, fn = getErrors(self) | |
local res = torch.cdiv(tp, tp + fp ) -- TP / (TP + FP) | |
local res = remNaN(res,self) | |
return res -- TP / (TP + FN) | |
end | |
function ConfusionMatrix:negativePredictiveValue() | |
local tp, tn, fp, fn = getErrors(self) | |
local res = torch.cdiv(tn, tn + fn ) -- TN / (TN + FN) | |
local res = remNaN(res,self) | |
return res -- TP / (TP + FN) | |
end | |
function ConfusionMatrix:falsePositiveRate() | |
local tp, tn, fp, fn = getErrors(self) | |
local res = torch.cdiv(fp, fp + tn) -- FP / (FP + TN) | |
local res = remNaN(res,self) | |
return res -- TP / (TP + FN) | |
end | |
function ConfusionMatrix:falseDiscoveryRate() | |
local tp, tn, fp, fn = getErrors(self) | |
local res = torch.cdiv(fp, tp + fp) -- FP / (TP + FP) | |
local res = remNaN(res,self) | |
return res -- TP / (TP + FN) | |
end | |
function ConfusionMatrix:classAccuracy() | |
local tp, tn, fp, fn = getErrors(self) | |
local res = torch.cdiv(tp + tn, tp + tn + fp + fn) -- (TP + FN) / (TN + TP + FN + FP) | |
local res = remNaN(res,self) | |
return res -- TP / (TP + FN) | |
end | |
function ConfusionMatrix:F1() | |
local tp, tn, fp, fn = getErrors(self) | |
local res = torch.cdiv(tp * 2, tp * 2 + fp + fn) -- (2*TP)/(TP*2+fp+fn) | |
local res = remNaN(res,self) | |
return res -- TP / (TP + FN) | |
end | |
function ConfusionMatrix:updateValids() | |
local total = 0 | |
for t = 1,self.nclasses do | |
self.valids[t] = self.mat[t][t] / self.mat:select(1,t):sum() | |
self.unionvalids[t] = self.mat[t][t] / (self.mat:select(1,t):sum()+self.mat:select(2,t):sum()-self.mat[t][t]) | |
total = total + self.mat[t][t] | |
end | |
self.totalValid = total / self.mat:sum() | |
self.averageValid = 0 | |
self.averageUnionValid = 0 | |
local nvalids = 0 | |
local nunionvalids = 0 | |
for t = 1,self.nclasses do | |
if not isNaN(self.valids[t]) then | |
self.averageValid = self.averageValid + self.valids[t] | |
nvalids = nvalids + 1 | |
end | |
if not isNaN(self.valids[t]) and not isNaN(self.unionvalids[t]) then | |
self.averageUnionValid = self.averageUnionValid + self.unionvalids[t] | |
nunionvalids = nunionvalids + 1 | |
end | |
end | |
self.averageValid = self.averageValid / nvalids | |
self.averageUnionValid = self.averageUnionValid / nunionvalids | |
end | |
function ConfusionMatrix:__tostring__() | |
self:updateValids() | |
local str = {'ConfusionMatrix:\n'} | |
local nclasses = self.nclasses | |
table.insert(str, '[') | |
for t = 1,nclasses do | |
local pclass = self.valids[t] * 100 | |
pclass = string.format('%2.3f', pclass) | |
if t == 1 then | |
table.insert(str, '[') | |
else | |
table.insert(str, ' [') | |
end | |
for p = 1,nclasses do | |
table.insert(str, string.format('%8d', self.mat[t][p])) | |
end | |
if self.classes and self.classes[1] then | |
if t == nclasses then | |
table.insert(str, ']] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n') | |
else | |
table.insert(str, '] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n') | |
end | |
else | |
if t == nclasses then | |
table.insert(str, ']] ' .. pclass .. '% \n') | |
else | |
table.insert(str, '] ' .. pclass .. '% \n') | |
end | |
end | |
end | |
table.insert(str, ' + average row correct: ' .. (self.averageValid*100) .. '% \n') | |
table.insert(str, ' + average rowUcol correct (VOC measure): ' .. (self.averageUnionValid*100) .. '% \n') | |
table.insert(str, ' + global correct: ' .. (self.totalValid*100) .. '%') | |
return table.concat(str) | |
end | |
function ConfusionMatrix:render(sortmode, display, block, legendwidth) | |
-- args | |
local confusion = self.mat | |
local classes = self.classes | |
local sortmode = sortmode or 'score' -- 'score' or 'occurrence' | |
local block = block or 25 | |
local legendwidth = legendwidth or 200 | |
local display = display or false | |
-- legends | |
local legend = { | |
['score'] = 'Confusion matrix [sorted by scores, global accuracy = %0.3f%%, per-class accuracy = %0.3f%%]', | |
['occurrence'] = 'Confusiong matrix [sorted by occurences, accuracy = %0.3f%%, per-class accuracy = %0.3f%%]' | |
} | |
-- parse matrix / normalize / count scores | |
local diag = torch.FloatTensor(#classes) | |
local freqs = torch.FloatTensor(#classes) | |
local unconf = confusion | |
local confusion = confusion:clone() | |
local corrects = 0 | |
local total = 0 | |
for target = 1,#classes do | |
freqs[target] = confusion[target]:sum() | |
corrects = corrects + confusion[target][target] | |
total = total + freqs[target] | |
confusion[target]:div( math.max(confusion[target]:sum(),1) ) | |
diag[target] = confusion[target][target] | |
end | |
-- accuracies | |
local accuracy = corrects / total * 100 | |
local perclass = 0 | |
local total = 0 | |
for target = 1,#classes do | |
if confusion[target]:sum() > 0 then | |
perclass = perclass + diag[target] | |
total = total + 1 | |
end | |
end | |
perclass = perclass / total * 100 | |
freqs:div(unconf:sum()) | |
-- sort matrix | |
if sortmode == 'score' then | |
_,order = torch.sort(diag,1,true) | |
elseif sortmode == 'occurrence' then | |
_,order = torch.sort(freqs,1,true) | |
else | |
error('sort mode must be one of: score | occurrence') | |
end | |
-- render matrix | |
local render = torch.zeros(#classes*block, #classes*block) | |
for target = 1,#classes do | |
for prediction = 1,#classes do | |
render[{ { (target-1)*block+1,target*block }, { (prediction-1)*block+1,prediction*block } }] = confusion[order[target]][order[prediction]] | |
end | |
end | |
-- add grid | |
for target = 1,#classes do | |
render[{ {target*block},{} }] = 0.1 | |
render[{ {},{target*block} }] = 0.1 | |
end | |
-- create rendering | |
require 'image' | |
require 'qtwidget' | |
require 'qttorch' | |
local win1 = qtwidget.newimage( (#render)[2]+legendwidth, (#render)[1] ) | |
image.display{image=render, win=win1} | |
-- add legend | |
for i in ipairs(classes) do | |
-- background cell | |
win1:setcolor{r=0,g=0,b=0} | |
win1:rectangle((#render)[2],(i-1)*block,legendwidth,block) | |
win1:fill() | |
-- % | |
win1:setfont(qt.QFont{serif=false, size=fontsize}) | |
local gscale = freqs[order[i]]/freqs:max()*0.9+0.1 --3/4 | |
win1:setcolor{r=gscale*0.5+0.2,g=gscale*0.5+0.2,b=gscale*0.8+0.2} | |
win1:moveto((#render)[2]+10,i*block-block/3) | |
win1:show(string.format('[%2.2f%% labels]',math.floor(freqs[order[i]]*10000+0.5)/100)) | |
-- legend | |
win1:setfont(qt.QFont{serif=false, size=fontsize}) | |
local gscale = diag[order[i]]*0.8+0.2 | |
win1:setcolor{r=gscale,g=gscale,b=gscale} | |
win1:moveto(120+(#render)[2]+10,i*block-block/3) | |
win1:show(classes[order[i]]) | |
for j in ipairs(classes) do | |
-- scores | |
local score = confusion[order[j]][order[i]] | |
local gscale = (1-score)*(score*0.8+0.2) | |
win1:setcolor{r=gscale,g=gscale,b=gscale} | |
win1:moveto((i-1)*block+block/5,(j-1)*block+block*2/3) | |
win1:show(string.format('%02.0f',math.floor(score*100+0.5))) | |
end | |
end | |
-- generate tensor | |
local t = win1:image():toTensor() | |
-- display | |
if display then | |
image.display{image=t, legend=string.format(legend[sortmode],accuracy,perclass)} | |
end | |
-- return rendering | |
return t | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'MyConfusionMatrix' | |
local nObs = 15 | |
local actual = torch.Tensor( {1, 2, 3, 1, 2, 3, 1, 1, 2, 3, 2, 1, 1, 2, 3}) | |
local prediction = torch.Tensor( {1 ,2 ,3, 1, 2, 3, 1, 1, 1, 2, 2, 1, 2, 1, 3}) | |
--class 1 : A, B, B, A, B, B, A, A, C, B, B, A, D, C, B [5, 7, 2, 1] | |
--class 2 : B, A, B ,B, A, B ,B, B, D, C, A, B, C, D, B [3, 8, 2, 2] | |
--class 3 : B, B, A ,B, B, A, B ,B, B, D, B, B, B, B, A [3, 11, 0, 1] | |
--- A = true positive, B = true negative, C = false positive, D = false negative | |
local function equal(t1,t2) | |
local diff = torch.abs(t1 -t2):sum() | |
return diff < 10^-7 | |
end | |
local function numEqual(val1,val2) | |
local diff = math.abs(val1-val2) | |
return diff < 10^-7 | |
end | |
local conf = ConfusionMatrix(3) | |
local conf_class = ConfusionMatrix({'AAAAAAAAAAAA','BBBB','CCCCCCCCCCC'}) | |
conf:batchAdd(prediction,actual) | |
local tp, tn, fp, fn = conf:getConfusion() | |
local tp_true = torch.Tensor({5,3,3}):view(1,3):float() --column 1 | |
local tn_true = torch.Tensor({7,8,11}):view(1,3):float() --column 2 | |
local fp_true = torch.Tensor({2,2,0}):view(1,3):float() --column 3 | |
local fn_true = torch.Tensor({1,2,1}):view(1,3):float() --column 4 | |
-- Check getErrors() | |
assert(equal(tp_true,tp)) | |
assert(equal(tn_true,tn)) | |
assert(equal(fp_true,fp)) | |
assert(equal(tn_true,tn)) | |
local tp1 = tp_true[{1,1}] | |
local tn1 = tn_true[{1,1}] | |
local fp1 = fp_true[{1,1}] | |
local fn1 = fn_true[{1,1}] | |
local tp2 = tp_true[{1,2}] | |
local tn2 = tn_true[{1,2}] | |
local fp2 = fp_true[{1,2}] | |
local fn2 = fn_true[{1,2}] | |
local tp3 = tp_true[{1,3}] | |
local tn3 = tn_true[{1,3}] | |
local fp3 = fp_true[{1,3}] | |
local fn3 = fn_true[{1,3}] | |
local acc_true = tp_true:sum()/nObs -- sum diagonal and divide by nObs | |
local acc = conf:accuracy() | |
assert(acc_true == acc) | |
local sens3 = conf:sensitivity()[{1,3}] | |
local sens3_true = tp3 / (tp3 + fn3) | |
assert(numEqual(sens3,sens3_true)) | |
local spec1 = conf:specificity()[{1,1}] | |
local spec1_true = tn1 / (fp1 + tn1) | |
assert(numEqual(spec1,spec1_true)) | |
local ppv2 = conf:positivePredictiveValue()[{1,2}] | |
local ppv2_true = tp2 / (tp2+fp2) | |
assert(numEqual(ppv2,ppv2_true)) | |
local npv3 = conf:negativePredictiveValue()[{1,3}] | |
local npv3_true = tn3 / (tn3+fn3) | |
assert(numEqual(npv3,npv3_true)) | |
local fpr1 = conf:falsePositiveRate()[{1,1}] | |
local fpr1_true = fp1 / (fp1+tn1) | |
assert(numEqual(fpr1,fpr1_true)) | |
local fdr2 = conf:falseDiscoveryRate()[{1,2}] | |
local fdr2_true = fp2 / (tp2 + fp2) | |
assert(numEqual(fdr2,fdr2_true)) | |
local acc3 = conf:classAccuracy()[{1,3}] | |
local acc3_true = (tp2 + tn3) / nObs | |
assert(numEqual(acc3,acc3_true)) | |
local f11 = conf:F1()[{1,1}] | |
local f11_true = 2*tp1 / (2*tp1 + fp1 + fn1) | |
assert(numEqual(f11,f11_true)) | |
local mcc2 = conf:matthewsCorrelation()[{1,2}] | |
local mcc2_true = (tp2*tn2-fp2*fn2) / math.sqrt((tp2+fp2)*(tp2+fn2)*(tn2+fp2)*(tn2+fn2) ) | |
assert(numEqual(mcc2,mcc2_true)) | |
-- Test print | |
print(conf_class:printscore('sensitivity')) | |
print(conf_class:printscore('sensitivity','VALIDATION')) | |
print(conf:printscore('sensitivity')) | |
print(conf:printscore('specificity')) | |
print(conf:printscore('positivePredictiveValue')) | |
print(conf:printscore('negativePredictiveValue')) | |
print(conf:printscore('falsePositiveRate')) | |
print(conf:printscore('falseDiscoveryRate')) | |
print(conf:printscore('classAccuracy')) | |
print(conf:printscore('F1')) | |
print(conf:printscore('matthewsCorrelation')) | |
conf:zero() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment