Skip to content

Instantly share code, notes, and snippets.

@npinto
Created July 25, 2012 21:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save npinto/3178727 to your computer and use it in GitHub Desktop.
Save npinto/3178727 to your computer and use it in GitHub Desktop.
Fiji's Segmentation Metrics
package trainableSegmentation.metrics;
/**
*
* License: GPL
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License 2
* as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*
* Authors: Ignacio Arganda-Carreras (iarganda@mit.edu)
*/
import java.util.ArrayList;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import trainableSegmentation.utils.Utils;
import ij.IJ;
import ij.ImagePlus;
import ij.ImageStack;
import ij.process.ByteProcessor;
import ij.process.ImageProcessor;
import ij.process.ShortProcessor;
/**
* This class implements the adjusted Rand error, defined as the 1 - adjusted Rand index.
* We follow the Rand index definition described by Lawrence Hubert and Phipps Arabie \cite{Hubert85}.
*
* BibTeX:
* <pre>
* &#64;article{Hubert85,
* author = {Lawrence Hubert and Phipps Arabie},
* title = {Comparing partitions},
* journal = {Journal of Classification},
* year = {1985},
* volume = {2},
* issue = {1},
* pages = {193-218},
* doi = {10.1007/BF01908075)
* }
* </pre>
*
*/
public class AdjustedRandError extends Metrics
{
/**
* Initialize ajusted Rand error metric.
*
* @param originalLabels original labels (single 2D image or stack)
* @param proposedLabels proposed new labels (single 2D image or stack of the same as as the original labels)
*/
public AdjustedRandError(ImagePlus originalLabels, ImagePlus proposedLabels) {
super(originalLabels, proposedLabels);
}
/**
* Calculate the Rand error in 2D between some original labels
* and the corresponding proposed labels. Both image are binarized.
* The adjusted Rand error is defined as the 1 - adjusted Rand index,
* as described by William M. Rand \cite{Rand71}.
*
* BibTeX:
* <pre>
* &#64;article{Rand71,
* author = {William M. Rand},
* title = {Objective criteria for the evaluation of clustering methods},
* journal = {Journal of the American Statistical Association},
* year = {1971},
* volume = {66},
* number = {336},
* pages = {846--850},
* doi = {10.2307/2284239)
* }
* </pre>
*
* @param binaryThreshold threshold value to binarize proposal (larger than 0 and smaller than 1)
* @return adjusted Rand error
*/
public double getMetricValue(double binaryThreshold)
{
final ImageStack labelSlices = originalLabels.getImageStack();
final ImageStack proposalSlices = proposedLabels.getImageStack();
double randError = 0;
// Executor service to produce concurrent threads
final ExecutorService exe = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
final ArrayList< Future<Double> > futures = new ArrayList< Future<Double> >();
try{
for(int i = 1; i <= labelSlices.getSize(); i++)
{
futures.add(exe.submit( getAdjustedRandErrorConcurrent(labelSlices.getProcessor(i).convertToFloat(),
proposalSlices.getProcessor(i).convertToFloat(),
binaryThreshold ) ) );
}
// Wait for the jobs to be done
for(Future<Double> f : futures)
{
randError += f.get();
}
}
catch(Exception ex)
{
IJ.log("Error when calculating rand error in a concurrent way.");
ex.printStackTrace();
}
finally{
exe.shutdown();
}
return randError / labelSlices.getSize();
}
/**
* Calculate the adjusted Rand error between some 2D original labels
* and the corresponding proposed labels. Both image are binarized.
* The adjusted Rand error is defined as the 1 - adjusted Rand index,
* as described by Lawrence Hubert and Phipps Arabie \cite{Hubert85}.
*
* BibTeX:
* <pre>
* &#64;article{Hubert85,
* author = {Lawrence Hubert and Phipps Arabie},
* title = {Comparing partitions},
* journal = {Journal of Classification},
* year = {1985},
* volume = {2},
* issue = {1},
* pages = {193-218},
* doi = {10.1007/BF01908075)
* }
* </pre>
*
* @param label 2D image with the original labels
* @param proposal 2D image with the proposed labels
* @param binaryThreshold threshold value to binarize the input images
* @return adjusted Rand error
*/
public static double adjustedRandError(
ImageProcessor label,
ImageProcessor proposal,
double binaryThreshold)
{
// Binarize inputs
ByteProcessor binaryLabel = new ByteProcessor( label.getWidth(), label.getHeight() );
ByteProcessor binaryProposal = new ByteProcessor( proposal.getWidth(), proposal.getHeight() );
for(int x=0; x<label.getWidth(); x++)
for(int y=0; y<label.getHeight(); y++)
{
binaryLabel.set( x, y, label.getPixelValue( x, y ) > binaryThreshold ? 255 : 0);
binaryProposal.set(x, y, proposal.getPixelValue( x, y ) > binaryThreshold ? 255 : 0);
}
// Find components
final ImagePlus im1 = new ImagePlus("binary labels", binaryLabel);
//im1.show();
ShortProcessor components1 = ( ShortProcessor ) Utils.connectedComponents(
im1, 4).allRegions.getProcessor();
final ImagePlus im2 = new ImagePlus("proposal labels", binaryProposal);
//im2.show();
ShortProcessor components2 = ( ShortProcessor ) Utils.connectedComponents(
im2, 4).allRegions.getProcessor();
return 1 - adjustedRandIndex( components1, components2 );
}
/**
* Get adjusted Rand error between two images in a concurrent way
* (to be submitted to an Executor Service). Both images
* are binarized.
* The adjusted Rand error is defined as the 1 - adjusted Rand index,
* as described by Lawrence Hubert and Phipps Arabie \cite{Hubert85}.
*
* BibTeX:
* <pre>
* &#64;article{Hubert85,
* author = {Lawrence Hubert and Phipps Arabie},
* title = {Comparing partitions},
* journal = {Journal of Classification},
* year = {1985},
* volume = {2},
* issue = {1},
* pages = {193-218},
* doi = {10.1007/BF01908075)
* }
* </pre>
*
* @param image1 first image
* @param image2 second image
* @param binaryThreshold threshold to apply to both images
* @return adjusted Rand error
*/
public Callable<Double> getAdjustedRandErrorConcurrent(
final ImageProcessor image1,
final ImageProcessor image2,
final double binaryThreshold)
{
return new Callable<Double>()
{
public Double call()
{
return adjustedRandError ( image1, image2, binaryThreshold );
}
};
}
/**
* Calculate the adjusted Rand index between to clusters, as described by
* Lawrence Hubert and Phipps Arabie \cite{Rand71}.
*
* BibTeX:
* <pre>
* &#64;article{Hubert85,
* author = {Lawrence Hubert and Phipps Arabie},
* title = {Comparing partitions},
* journal = {Journal of Classification},
* year = {1985},
* volume = {2},
* issue = {1},
* pages = {193-218},
* doi = {10.1007/BF01908075)
* }
* </pre>
*
* @param cluster1 2D segmented image (objects are labeled with different numbers)
* @param cluster2 2D segmented image (objects are labeled with different numbers)
* @return adjusted Rand index
*/
public static double adjustedRandIndex(
ShortProcessor cluster1,
ShortProcessor cluster2)
{
final short[] pixels1 = (short[]) cluster1.getPixels();
final short[] pixels2 = (short[]) cluster2.getPixels();
double n = pixels1.length;
// Form contingency matrix
int[][]cont = new int[(int) cluster1.getMax() ] [ (int) cluster2.getMax() ];
for(int i=0; i<n; i++)
cont[ pixels1[i] ] [ pixels2[i] ] ++;
// sum over rows & columnns of nij^2
double t2 = 0;
// sum of squares of sums of rows
double[] ni = new double[ cont.length ];
for(int i=0; i<cont.length; i++)
for(int j=0; j<cont[i].length; j++)
ni[ i ] += cont[ i ][ j ];
double nis = 0;
for(int k=0; k<ni.length; k++)
nis += ni[ k ] * ni[ k ];
// sum of squares of sums of columns
double[] nj = new double[ cont.length ];
for(int j=0; j<cont[0].length; j++)
for(int i=0; i<cont.length; i++)
{
nj[ j ] += cont[ i ][ j ];
t2 += cont[ i ][ j ] * cont[ i ][ j ];
}
double njs = 0;
for(int k=0; k<nj.length; k++)
njs += nj[ k ] * nj[ k ];
// total number of pairs of entities
double t1 = n * (n - 1) / 2 ;
double t3 = 0.5 * (nis+njs);
//Expected index (for adjustment)
double nc = ( n*(n*n+1) - (n+1)*nis - (n+1)*njs+2*(nis*njs)/n) / (2*(n-1) );
double agreements=t1+t2-t3; // number of agreements
/*
double D= -t2+t3; // number of disagreements
double RI=agreements/t1; // Rand 1971 %Probability of agreement
double MI=D/t1; // Mirkin 1970 %p(disagreement)
double HI=(agreements-D)/t1; // Hubert 1977 %p(agree)-p(disagree)
*/
//IJ.log("n = " + n + ", nis = " + nis + ", njs = " + njs);
//IJ.log("t1 = " + t1 + ", t2 = " + t2 + ", t3 = " + t3);
//IJ.log("nc = " + nc);
double adjustedRandIndex;
if ( t1 == nc )
adjustedRandIndex=0; // avoid division by zero; if k=1, define Rand = 0
else
adjustedRandIndex=(agreements-nc)/(t1-nc); // adjusted Rand - Hubert & Arabie 1985
return adjustedRandIndex;
}
}
package trainableSegmentation.metrics;
/**
*
* License: GPL
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License 2
* as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*
* Authors: Ignacio Arganda-Carreras (iarganda@mit.edu), Verena Kaynig (verena.kaynig@inf.ethz.ch),
* Albert Cardona (acardona@ini.phys.ethz.ch)
*/
/**
* This class stores statistics from a classification
*/
public class ClassificationStatistics
{
/** number of true positives */
public double truePositives = 0;
/** number of true negatives */
public double trueNegatives = 0;
/** number of false positives */
public double falsePositives = 0;
/** number of false negatives */
public double falseNegatives = 0;
/** value of the classification metric */
public double metricValue = 0;
/** precision: true positives / ( true positives + false positives ) */
public double precision = 0;
/** recall (also called sensitivity of hit rate): true positives / ( true positives + false negatives ) */
public double recall = 0;
/** F-score, harmonic mean of precision and recall */
public double fScore = 0;
/** specificity, also called true negative rate (TNR): true negatives / (true negatives + false negatives) */
public double specificity = 0;
/**
* Create classification statistics
*
* @param truePositives number of true positives
* @param trueNegatives number of true negatives
* @param falsePositives number of false positives
* @param falseNegatives number of false negatives
* @param metricValue value of the classification metric
*/
public ClassificationStatistics(
double truePositives,
double trueNegatives,
double falsePositives,
double falseNegatives,
double metricValue)
{
this.truePositives = truePositives;
this.trueNegatives = trueNegatives;
this.falsePositives = falsePositives;
this.falseNegatives = falseNegatives;
this.metricValue = metricValue;
final double totalNegatives = trueNegatives + falsePositives;
this.specificity = (totalNegatives > 0) ? trueNegatives / totalNegatives : 0;
// no false positives involves maximum precision
if( falsePositives == 0 )
this.precision = 1;
else
this.precision = truePositives / (truePositives + falsePositives);
// no false negatives involves maximum recall
if( falseNegatives == 0)
this.recall = 1;
else
this.recall = truePositives / (truePositives + falseNegatives);
if( (precision + recall) > 0)
this.fScore = 2 * precision * recall / ( precision + recall );
}
}
package trainableSegmentation.metrics;
/**
*
* License: GPL
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License 2
* as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*
* Authors: Ignacio Arganda-Carreras (iarganda@mit.edu)
*/
/**
* This class stores the number of mismatches after applying
* a topology-preserving warping. The mismatches are clustered
* by the different type of possible mistakes.
*/
public class ClusteredWarpingMismatches
{
public int numOfObjectAdditions = 0;
public int numOfHoleDeletions = 0;
public int numOfMergers = 0;
public int numOfHoleAdditions = 0;
public int numOfObjectDeletions = 0;
public int numOfSplits = 0;
public ClusteredWarpingMismatches(
int numOfObjectAdditions,
int numOfHoleDeletions,
int numOfMergers,
int numOfHoleAdditions,
int numOfObjectDeletions,
int numOfSplits
)
{
this.numOfHoleAdditions = numOfHoleAdditions;
this.numOfHoleDeletions = numOfHoleDeletions;
this.numOfMergers = numOfMergers;
this.numOfObjectAdditions = numOfObjectAdditions;
this.numOfObjectDeletions = numOfObjectDeletions;
this.numOfSplits = numOfSplits;
}
}
package trainableSegmentation.metrics;
/**
*
* License: GPL
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License 2
* as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*
* Authors: Ignacio Arganda-Carreras (iarganda@mit.edu)
*/
import ij.ImagePlus;
/**
* This is the mother class for 2D segmentation metrics
*/
public abstract class Metrics
{
/** original labels (single 2D image or stack) */
ImagePlus originalLabels;
/** proposed new labels (single 2D image or stack of the same as as the original labels) */
ImagePlus proposedLabels;
public Metrics(ImagePlus originalLabels, ImagePlus proposedLabels)
{
this.originalLabels = originalLabels;
this.proposedLabels = proposedLabels;
}
public abstract double getMetricValue(double binaryThreshold);
}
package trainableSegmentation.metrics;
/**
*
* License: GPL
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License 2
* as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*
* Authors: Ignacio Arganda-Carreras (iarganda@mit.edu)
*/
import java.util.ArrayList;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import ij.IJ;
import ij.ImagePlus;
import ij.ImageStack;
import ij.process.ImageProcessor;
/**
* This class implements the pixel error metric
*/
public class PixelError extends Metrics
{
/** boolean flag to set the level of detail on the standard output messages */
private boolean verbose = true;
/**
* Initialize pixel error metric
* @param originalLabels original labels (single 2D image or stack)
* @param proposedLabels proposed new labels (single 2D image or stack of the same as as the original labels)
*/
public PixelError(ImagePlus originalLabels, ImagePlus proposedLabels)
{
super(originalLabels, proposedLabels);
}
/**
* Calculate the pixel error in 2D between some original labels
* and the corresponding proposed labels. Both image are binarized.
* @param binaryThreshold threshold value to binarize proposal (larger than 0 and smaller than 1)
* @return pixel error
*/
@Override
public double getMetricValue(double binaryThreshold)
{
final ImageStack labelSlices = originalLabels.getImageStack();
final ImageStack proposalSlices = proposedLabels.getImageStack();
double pixelError = 0;
// Executor service to produce concurrent threads
final ExecutorService exe = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
final ArrayList< Future<Double> > futures = new ArrayList< Future<Double> >();
try{
for(int i = 1; i <= labelSlices.getSize(); i++)
{
futures.add(exe.submit( getPixelErrorConcurrent(labelSlices.getProcessor(i).convertToFloat(),
proposalSlices.getProcessor(i).convertToFloat(),
binaryThreshold ) ) );
}
// Wait for the jobs to be done
for(Future<Double> f : futures)
{
pixelError += f.get();
}
}
catch(Exception ex)
{
IJ.log("Error when warping ground truth in a concurrent way.");
ex.printStackTrace();
}
finally{
exe.shutdown();
}
return pixelError / labelSlices.getSize();
}
/**
* Get pixel error between two image in a concurrent way
* (to be submitted to an Executor Service). Both images
* are binarized.
*
* @param image1 first image
* @param image2 second image
* @param binaryThreshold threshold to apply to both images
* @return pixel error
*/
public Callable<Double> getPixelErrorConcurrent(
final ImageProcessor image1,
final ImageProcessor image2,
final double binaryThreshold)
{
return new Callable<Double>()
{
public Double call()
{
double pixelError = 0;
for(int x=0; x<image1.getWidth(); x++)
for(int y=0; y<image1.getHeight(); y++)
{
double pix1 = image1.getPixelValue(x, y) > binaryThreshold ? 1 : 0;
double pix2 = image2.getPixelValue(x, y) > binaryThreshold ? 1 : 0;
pixelError += ( pix1 - pix2 ) * ( pix1 - pix2 ) ;
}
return pixelError / (image1.getWidth() * image1.getHeight());
}
};
}
/**
* Calculate the pixel error in 2D between some original labels
* and the corresponding proposed labels (without thresholding).
*
* @return pixel error
*/
public double getMetricValue()
{
final ImageStack labelSlices = originalLabels.getImageStack();
final ImageStack proposalSlices = proposedLabels.getImageStack();
double pixelError = 0;
// Executor service to produce concurrent threads
final ExecutorService exe = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
final ArrayList< Future<Double> > futures = new ArrayList< Future<Double> >();
try{
for(int i = 1; i <= labelSlices.getSize(); i++)
{
futures.add(exe.submit( getPixelErrorConcurrent(labelSlices.getProcessor(i).convertToFloat(),
proposalSlices.getProcessor(i).convertToFloat() ) ) );
}
// Wait for the jobs to be done
for(Future<Double> f : futures)
{
pixelError += f.get();
}
}
catch(Exception ex)
{
IJ.log("Error when warping ground truth in a concurrent way.");
ex.printStackTrace();
}
finally{
exe.shutdown();
}
return pixelError / labelSlices.getSize();
}
/**
* Get pixel error between two image in a concurrent way
* (to be submitted to an Executor Service).
*
* @param image1 first image
* @param image2 second image
* @return pixel error
*/
public Callable<Double> getPixelErrorConcurrent(
final ImageProcessor image1,
final ImageProcessor image2)
{
return new Callable<Double>()
{
public Double call()
{
double pixelError = 0;
for(int x=0; x<image1.getWidth(); x++)
{
for(int y=0; y<image1.getHeight(); y++)
{
double pix1 = image1.getPixelValue(x, y);
double pix2 = image2.getPixelValue(x, y);
pixelError += ( pix1 - pix2 ) * ( pix1 - pix2 ) ;
}
}
return pixelError / (image1.getWidth() * image1.getHeight());
}
};
}
/**
* Calculate the precision-recall values based on pixel error between
* some 2D original labels and the corresponding proposed labels.
*
* @param minThreshold minimum threshold value to binarize the input images
* @param maxThreshold maximum threshold value to binarize the input images
* @param stepThreshold threshold step value to use during binarization
* @return pixel error value and derived statistics for each threshold
*/
public ArrayList< ClassificationStatistics > getPrecisionRecallStats(
double minThreshold,
double maxThreshold,
double stepThreshold )
{
if( minThreshold < 0 || minThreshold >= maxThreshold || maxThreshold > 1)
{
IJ.log("Error: unvalid threshold values.");
return null;
}
ArrayList< ClassificationStatistics > cs = new ArrayList<ClassificationStatistics>();
double bestFscore = 0;
double bestTh = minThreshold;
for(double th = minThreshold; th <= maxThreshold; th += stepThreshold)
{
if( verbose )
IJ.log(" Calculating pixel error statistics for threshold value " + String.format("%.3f", th) + "...");
cs.add( getPrecisionRecallStats( th ));
final double fScore = cs.get( cs.size()-1 ).fScore;
if( fScore > bestFscore )
{
bestFscore = fScore;
bestTh = th;
}
if( verbose )
IJ.log(" F-score = " + fScore);
}
if( verbose )
IJ.log(" ** Best F-score = " + bestFscore + ", with threshold = " + bestTh + " **\n");
return cs;
}
/**
* Calculate the pixel error and its derived statistics in 2D between
* some original labels and the corresponding proposed labels. Both images
* are binarized.
*
* @param binaryThreshold threshold value to binarize proposal (larger than 0 and smaller than 1)
* @return pixel error value and derived statistics
*/
public ClassificationStatistics getPrecisionRecallStats( double binaryThreshold )
{
final ImageStack labelSlices = originalLabels.getImageStack();
final ImageStack proposalSlices = proposedLabels.getImageStack();
double pixelError = 0;
double tp = 0;
double tn = 0;
double fp = 0;
double fn = 0;
// Executor service to produce concurrent threads
final ExecutorService exe = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
final ArrayList< Future<ClassificationStatistics> > futures = new ArrayList< Future<ClassificationStatistics> >();
try{
for(int i = 1; i <= labelSlices.getSize(); i++)
{
futures.add(exe.submit( getPrecisionRecallStatsConcurrent(labelSlices.getProcessor(i).convertToFloat(),
proposalSlices.getProcessor(i).convertToFloat(),
binaryThreshold ) ) );
}
// Wait for the jobs to be done
for(Future<ClassificationStatistics> f : futures)
{
ClassificationStatistics cs = f.get();
pixelError += cs.metricValue;
tp += cs.truePositives;
tn += cs.trueNegatives;
fp += cs.falsePositives;
fn += cs.falseNegatives;
}
}
catch(Exception ex)
{
IJ.log("Error when calculating pixel error in a concurrent way.");
ex.printStackTrace();
}
finally{
exe.shutdown();
}
return new ClassificationStatistics( tp, tn, fp, fn, pixelError / labelSlices.getSize() );
}
/**
* Calculate the pixel error and its derived statistics in 2D between
* some original labels and the corresponding proposed labels. Both images
* are binarized.
*
* @param binaryThreshold threshold value to binarize proposal (larger than 0 and smaller than 1)
* @param mask mask image
* @return pixel error value and derived statistics
*/
public ClassificationStatistics getPrecisionRecallStats(
double binaryThreshold,
ImagePlus mask )
{
final ImageStack labelSlices = originalLabels.getImageStack();
final ImageStack proposalSlices = proposedLabels.getImageStack();
double pixelError = 0;
double tp = 0;
double tn = 0;
double fp = 0;
double fn = 0;
// Executor service to produce concurrent threads
final ExecutorService exe = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
final ArrayList< Future<ClassificationStatistics> > futures = new ArrayList< Future<ClassificationStatistics> >();
try{
for(int i = 1; i <= labelSlices.getSize(); i++)
{
futures.add(exe.submit( getPrecisionRecallStatsConcurrent(labelSlices.getProcessor(i).convertToFloat(),
proposalSlices.getProcessor(i).convertToFloat(),
( null != mask ) ? mask.getImageStack().getProcessor(i).convertToFloat() : null,
binaryThreshold ) ) );
}
// Wait for the jobs to be done
for(Future<ClassificationStatistics> f : futures)
{
ClassificationStatistics cs = f.get();
pixelError += cs.metricValue;
tp += cs.truePositives;
tn += cs.trueNegatives;
fp += cs.falsePositives;
fn += cs.falseNegatives;
}
}
catch(Exception ex)
{
IJ.log("Error when calculating pixel error in a concurrent way.");
ex.printStackTrace();
}
finally{
exe.shutdown();
}
return new ClassificationStatistics( tp, tn, fp, fn, pixelError / labelSlices.getSize() );
}
/**
* Get pixel error value and derived statistics between two images
* in a concurrent way (to be submitted to an Executor Service).
* Both images are binarized.
*
* @param image1 first image
* @param image2 second image
* @param binaryThreshold threshold to apply to both images
* @return pixel error value and derived statistics
*/
public Callable<ClassificationStatistics> getPrecisionRecallStatsConcurrent(
final ImageProcessor image1,
final ImageProcessor image2,
final double binaryThreshold)
{
return new Callable<ClassificationStatistics>()
{
public ClassificationStatistics call()
{
return precisionRecallStats( image1, image2, binaryThreshold );
}
};
}
/**
* Get pixel error value and derived statistics between two images
* in a concurrent way (to be submitted to an Executor Service).
* Both images are binarized.
*
* @param image1 first image
* @param image2 second image
* @param mask mask image
* @param binaryThreshold threshold to apply to both images
* @return pixel error value and derived statistics
*/
public Callable<ClassificationStatistics> getPrecisionRecallStatsConcurrent(
final ImageProcessor image1,
final ImageProcessor image2,
final ImageProcessor mask,
final double binaryThreshold)
{
return new Callable<ClassificationStatistics>()
{
public ClassificationStatistics call()
{
if(null == mask)
return precisionRecallStats( image1, image2, binaryThreshold );
else
return precisionRecallStats( image1, image2, mask, binaryThreshold );
}
};
}
/**
* Calculate the pixel error and derived statistics between some 2D original labels
* and the corresponding proposed labels. Both image are binarized.
*
* @param label 2D image with the original labels
* @param proposal 2D image with the proposed labels
* @param binaryThreshold threshold value to binarize the input images
* @return rand index value and derived statistics
*/
public ClassificationStatistics precisionRecallStats(
ImageProcessor label,
ImageProcessor proposal,
double binaryThreshold)
{
// Binarize inputs
float[] labelPix = (float[]) label.getPixels();
float[] proposalPix = (float[]) proposal.getPixels();
double truePositives = 0;
double trueNegatives = 0;
double falsePositives = 0;
double falseNegatives = 0;
double pixelError = 0;
for(int i=0; i<labelPix.length; i++)
{
// make sure labels are binary
int pix1 = (labelPix[ i ] > 0) ? 1 : 0;
// threshold proposal
int pix2 = (proposalPix[ i ] > binaryThreshold) ? 1 : 0;
if (pix2 == 1)
{
if(pix1 == 1)
truePositives ++;
else
falsePositives ++;
}
else
{
if(pix1 == 1)
falseNegatives ++;
else
trueNegatives ++;
}
pixelError += ( pix1 - pix2 ) * ( pix1 - pix2 ) ;
}
pixelError /= label.getWidth() * label.getHeight();
return new ClassificationStatistics(truePositives, trueNegatives, falsePositives, falseNegatives, pixelError);
}
/**
* Calculate the pixel error and derived statistics between some 2D original labels
* and the corresponding proposed labels. Both image are binarized.
*
* @param label 2D image with the original labels
* @param proposal 2D image with the proposed labels
* @param mask 2D image representing the binary mask
* @param binaryThreshold threshold value to binarize the input images
* @return classification statistics
*/
public ClassificationStatistics precisionRecallStats(
ImageProcessor label,
ImageProcessor proposal,
ImageProcessor mask,
double binaryThreshold)
{
// Binarize inputs
float[] labelPix = (float[]) label.getPixels();
float[] proposalPix = (float[]) proposal.getPixels();
float[] maskPixels = (float[]) mask.getPixels();
double truePositives = 0;
double trueNegatives = 0;
double falsePositives = 0;
double falseNegatives = 0;
double pixelError = 0;
double n = 0;
for(int i=0; i<labelPix.length; i++)
{
// make sure labels are binary
int pix1 = (labelPix[ i ] > 0) ? 1 : 0;
// threshold proposal
int pix2 = (proposalPix[ i ] > binaryThreshold) ? 1 : 0;
// check mask
if ( maskPixels[ i ] > 0 )
{
if (pix2 == 1)
{
if(pix1 == 1)
truePositives ++;
else
falsePositives ++;
}
else
{
if(pix1 == 1)
falseNegatives ++;
else
trueNegatives ++;
}
pixelError += ( pix1 - pix2 ) * ( pix1 - pix2 ) ;
n++;
}
}
if ( n > 0 )
pixelError /= n;
return new ClassificationStatistics(truePositives, trueNegatives, falsePositives, falseNegatives, pixelError);
}
/**
* Get the best F-score of the pixel error over a set of thresholds
*
* @param minThreshold minimum threshold value to binarize the input images
* @param maxThreshold maximum threshold value to binarize the input images
* @param stepThreshold threshold step value to use during binarization
* @param verbose flag to print or not output information
* @return maximal F-score of the pixel error
*/
public double getPixelErrorMaximalFScore(
double minThreshold,
double maxThreshold,
double stepThreshold )
{
ArrayList<ClassificationStatistics> stats = getPrecisionRecallStats( minThreshold, maxThreshold, stepThreshold );
// trainableSegmentation.utils.Utils.plotPrecisionRecall( stats );
double maxFScore = 0;
for(ClassificationStatistics stat : stats)
{
if (stat.fScore > maxFScore)
maxFScore = stat.fScore;
}
return maxFScore;
}
/**
* Main method for calculate the pixel error metrics from the command line
*
* @param args arguments to decide the action
*/
public static void main(String args[])
{
if (args.length<1)
{
dumpSyntax();
System.exit(1);
}
else
{
if( args[0].equals("-help") )
dumpSyntax();
else if (args[0].equals("-maxFScore"))
System.out.println( maximalFScoreCommandLine(args) );
else
dumpSyntax();
}
System.exit(0);
}
/**
* Calculate the maximal F-score of pixel similarity based on the
* parameters introduced by command line
*
* @param args command line arguments
* @return maximal F-score
*/
static double maximalFScoreCommandLine(String[] args)
{
if (args.length != 6)
{
dumpSyntax();
return -1;
}
final ImagePlus label = new ImagePlus( args[ 1 ] );
final ImagePlus proposal = new ImagePlus( args[ 2 ] );
final double minThreshold = Double.parseDouble( args[ 3 ] );
final double maxThreshold = Double.parseDouble( args[ 4 ] );
final double stepThreshold = Double.parseDouble( args[ 5 ] );
PixelError pe = new PixelError(label, proposal);
pe.setVerboseMode( false );
return pe.getPixelErrorMaximalFScore(minThreshold, maxThreshold, stepThreshold );
}
/**
* Method to write the syntax of the program in the command line.
*/
private static void dumpSyntax ()
{
System.out.println("Purpose: calculate pixel error between proposed and original labels.\n");
System.out.println("Usage: PixelError ");
System.out.println(" -help : show this message");
System.out.println("");
System.out.println(" -maxFScore : calculate the best F-score of the pixel error over a set of thresholds");
System.out.println(" labels : image with the original labels");
System.out.println(" proposal : image with the proposed labels");
System.out.println(" minThreshold : minimum threshold value to binarize the proposal");
System.out.println(" maxThreshold : maximum threshold value to binarize the proposal");
System.out.println(" stepThreshold : threshold step value to use during binarization\n");
System.out.println("Examples:");
System.out.println("Calculate the maximal F-score of pixel similarity between proposed and original labels over a set of");
System.out.println("thresholds (from 0.0 to 1.0 in steps of 0.1):");
System.out.println(" PixelError -maxFScore original-labels.tif proposed-labels.tif 0.0 1.0 0.1");
}
/**
* Set verbose mode
* @param verbose true to display more information in the standard output
*/
public void setVerboseMode(boolean verbose)
{
this.verbose = verbose;
}
}
package trainableSegmentation.metrics;
/**
*
* License: GPL
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License 2
* as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*
* Authors: Ignacio Arganda-Carreras (iarganda@mit.edu)
*/
import java.util.ArrayList;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import trainableSegmentation.utils.Utils;
import ij.IJ;
import ij.ImagePlus;
import ij.ImageStack;
import ij.process.ByteProcessor;
import ij.process.ImageProcessor;
import ij.process.ShortProcessor;
/**
* This class implements the Rand error metric.
* The Rand error is defined as the 1 - Rand index. We follow the
* definition of Rand index as described by William M. Rand \cite{Rand71}.
*
* BibTeX:
* <pre>
* &#64;article{Rand71,
* author = {William M. Rand},
* title = {Objective criteria for the evaluation of clustering methods},
* journal = {Journal of the American Statistical Association},
* year = {1971},
* volume = {66},
* number = {336},
* pages = {846--850},
* doi = {10.2307/2284239)
* }
* </pre>
*
*/
public class RandError extends Metrics
{
/** boolean flag to set the level of detail on the standard output messages */
private boolean verbose = true;
/**
* Initialize Rand error metric
* @param originalLabels original labels (single 2D image or stack)
* @param proposedLabels threshold value to binarize proposal (larger than 0 and smaller than 1)
*/
public RandError(ImagePlus originalLabels, ImagePlus proposedLabels)
{
super(originalLabels, proposedLabels);
}
/**
* Calculate the Rand error in 2D between some original labels
* and the corresponding proposed labels. Both image are binarized.
* The Rand error is defined as the 1 - Rand index, as described by
* William M. Rand \cite{Rand71}.
*
* BibTeX:
* <pre>
* &#64;article{Rand71,
* author = {William M. Rand},
* title = {Objective criteria for the evaluation of clustering methods},
* journal = {Journal of the American Statistical Association},
* year = {1971},
* volume = {66},
* number = {336},
* pages = {846--850},
* doi = {10.2307/2284239)
* }
* </pre>
*
* @param binaryThreshold threshold value to binarize proposal (larger than 0 and smaller than 1)
* @return Rand error
*/
public double getMetricValue(double binaryThreshold)
{
final ImageStack labelSlices = originalLabels.getImageStack();
final ImageStack proposalSlices = proposedLabels.getImageStack();
double randError = 0;
// Executor service to produce concurrent threads
final ExecutorService exe = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
final ArrayList< Future<Double> > futures = new ArrayList< Future<Double> >();
try{
for(int i = 1; i <= labelSlices.getSize(); i++)
{
futures.add(exe.submit( getRandErrorConcurrent(labelSlices.getProcessor(i).convertToFloat(),
proposalSlices.getProcessor(i).convertToFloat(),
binaryThreshold ) ) );
}
// Wait for the jobs to be done
for(Future<Double> f : futures)
{
randError += f.get();
}
}
catch(Exception ex)
{
IJ.log("Error when calculating rand error in a concurrent way.");
ex.printStackTrace();
}
finally{
exe.shutdown();
}
return randError / labelSlices.getSize();
}
/**
* Calculate the Rand index and its derived statistics in 2D between
* some original labels and the corresponding proposed labels. Both images
* are binarized. We follow the definition of Rand index described by
* William M. Rand \cite{Rand71}.
*
* BibTeX:
* <pre>
* &#64;article{Rand71,
* author = {William M. Rand},
* title = {Objective criteria for the evaluation of clustering methods},
* journal = {Journal of the American Statistical Association},
* year = {1971},
* volume = {66},
* number = {336},
* pages = {846--850},
* doi = {10.2307/2284239)
* }
* </pre>
*
* @param binaryThreshold threshold value to binarize proposal (larger than 0 and smaller than 1)
* @return Rand index value and derived satatistics
*/
public ClassificationStatistics getRandIndexStats( double binaryThreshold )
{
final ImageStack labelSlices = originalLabels.getImageStack();
final ImageStack proposalSlices = proposedLabels.getImageStack();
double randIndex = 0;
double tp = 0;
double tn = 0;
double fp = 0;
double fn = 0;
// Executor service to produce concurrent threads
final ExecutorService exe = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
final ArrayList< Future<ClassificationStatistics> > futures = new ArrayList< Future<ClassificationStatistics> >();
try{
for(int i = 1; i <= labelSlices.getSize(); i++)
{
futures.add(exe.submit( getRandIndexStatsConcurrent(labelSlices.getProcessor(i).convertToFloat(),
proposalSlices.getProcessor(i).convertToFloat(),
binaryThreshold ) ) );
}
// Wait for the jobs to be done
for(Future<ClassificationStatistics> f : futures)
{
ClassificationStatistics cs = f.get();
randIndex += cs.metricValue;
tp += cs.truePositives;
tn += cs.trueNegatives;
fp += cs.falsePositives;
fn += cs.falseNegatives;
}
}
catch(Exception ex)
{
IJ.log("Error when calculating rand error in a concurrent way.");
ex.printStackTrace();
}
finally{
exe.shutdown();
}
return new ClassificationStatistics( tp, tn, fp, fn, randIndex / labelSlices.getSize() );
}
/**
* Calculate the precision-recall values based on Rand index between
* some 2D original labels and the corresponding proposed labels.
* We follow the definition of Rand index as described by
* William M. Rand \cite{Rand71}.
*
* BibTeX:
* <pre>
* &#64;article{Rand71,label
* author = {William M. Rand},
* title = {Objective criteria for the evaluation of clustering methods},
* journal = {Journal of the American Statistical Association},
* year = {1971},
* volume = {66},
* number = {336},
* pages = {846--850},
* doi = {10.2307/2284239)
* }
* </pre>
*
* @param minThreshold minimum threshold value to binarize the input images
* @param maxThreshold maximum threshold value to binarize the input images
* @param stepThreshold threshold step value to use during binarization
* @return Rand index value and derived statistics for each threshold
*/
public ArrayList< ClassificationStatistics > getRandIndexStats(
double minThreshold,
double maxThreshold,
double stepThreshold)
{
if( minThreshold < 0 || minThreshold >= maxThreshold || maxThreshold > 1)
{
IJ.log("Error: unvalid threshold values.");
return null;
}
ArrayList< ClassificationStatistics > cs = new ArrayList<ClassificationStatistics>();
double bestFscore = 0;
double bestTh = minThreshold;
for(double th = minThreshold; th <= maxThreshold; th += stepThreshold)
{
if( verbose )
IJ.log(" Calculating Rand index statistics for threshold value " + String.format("%.3f", th) + "...");
cs.add( getRandIndexStats( th ));
final double fScore = cs.get( cs.size()-1 ).fScore;
if( fScore > bestFscore )
{
bestFscore = fScore;
bestTh = th;
}
if( verbose )
IJ.log(" F-score = " + fScore);
}
if( verbose )
IJ.log(" ** Best F-score = " + bestFscore + ", with threshold = " + bestTh + " **\n");
return cs;
}
/**
* Get Rand error between two images in a concurrent way
* (to be submitted to an Executor Service). Both images
* are binarized.
* The Rand error is defined as the 1 - Rand index, as described by
* William M. Rand \cite{Rand71}.
*
* BibTeX:
* <pre>
* &#64;article{Rand71,
* author = {William M. Rand},
* title = {Objective criteria for the evaluation of clustering methods},
* journal = {Journal of the American Statistical Association},
* year = {1971},
* volume = {66},
* number = {336},
* pages = {846--850},
* doi = {10.2307/2284239)
* }
* </pre>
*
* @param image1 first image
* @param image2 second image
* @param binaryThreshold threshold to apply to both images
* @return Rand error
*/
public Callable<Double> getRandErrorConcurrent(
final ImageProcessor image1,
final ImageProcessor image2,
final double binaryThreshold)
{
return new Callable<Double>()
{
public Double call()
{
return randError ( image1, image2, binaryThreshold );
}
};
}
/**
* Get Rand index value and derived statistics between two images
* in a concurrent way (to be submitted to an Executor Service).
* Both images are binarized.
* We follow the Rand index definition described by William M. Rand \cite{Rand71}.
*
* BibTeX:
* <pre>
* &#64;article{Rand71,
* author = {William M. Rand},
* title = {Objective criteria for the evaluation of clustering methods},
* journal = {Journal of the American Statistical Association},
* year = {1971},
* volume = {66},
* number = {336},
* pages = {846--850},
* doi = {10.2307/2284239)
* }
* </pre>
*
* @param image1 first image
* @param image2 second image
* @param binaryThreshold threshold to apply to both images
* @return Rand index value and derived statistics
*/
public Callable<ClassificationStatistics> getRandIndexStatsConcurrent(
final ImageProcessor image1,
final ImageProcessor image2,
final double binaryThreshold)
{
return new Callable<ClassificationStatistics>()
{
public ClassificationStatistics call()
{
return randIndexStats( image1, image2, binaryThreshold );
}
};
}
/**
* Calculate the Rand error between some 2D original labels
* and the corresponding proposed labels. Both image are binarized.
* The Rand error is defined as the 1 - Rand index, as described by
* William M. Rand \cite{Rand71}.
*
* BibTeX:
* <pre>
* &#64;article{Rand71,
* author = {William M. Rand},
* title = {Objective criteria for the evaluation of clustering methods},
* journal = {Journal of the American Statistical Association},
* year = {1971},
* volume = {66},
* number = {336},
* pages = {846--850},
* doi = {10.2307/2284239)
* }
* </pre>
*
* @param label 2D image with the original labels
* @param proposal 2D image with the proposed labels
* @param binaryThreshold threshold value to binarize the input images
* @return Rand error
*/
public double randError(
ImageProcessor label,
ImageProcessor proposal,
double binaryThreshold)
{
// Binarize inputs
ByteProcessor binaryLabel = new ByteProcessor( label.getWidth(), label.getHeight() );
ByteProcessor binaryProposal = new ByteProcessor( label.getWidth(), label.getHeight() );
for(int x=0; x<label.getWidth(); x++)
for(int y=0; y<label.getHeight(); y++)
{
binaryLabel.set( x, y, label.getPixelValue( x, y ) > binaryThreshold ? 255 : 0);
binaryProposal.set(x, y, proposal.getPixelValue( x, y ) > binaryThreshold ? 255 : 0);
}
// Find components
ShortProcessor components1 = ( ShortProcessor ) Utils.connectedComponents(
new ImagePlus("binary labels", binaryLabel), 4).allRegions.getProcessor();
ShortProcessor components2 = ( ShortProcessor ) Utils.connectedComponents(
new ImagePlus("proposal labels", binaryProposal), 4).allRegions.getProcessor();
return 1 - randIndex( components1, components2 );
}
/**
* Calculate the Rand index between some 2D original labels
* and the corresponding proposed labels. Both image are binarized.
* We follow the definition of Rand index as described by
* William M. Rand \cite{Rand71}.
*
* BibTeX:
* <pre>
* &#64;article{Rand71,
* author = {William M. Rand},
* title = {Objective criteria for the evaluation of clustering methods},
* journal = {Journal of the American Statistical Association},
* year = {1971},
* volume = {66},
* number = {336},
* pages = {846--850},
* doi = {10.2307/2284239)
* }
* </pre>
*
* @param label 2D image with the original labels
* @param proposal 2D image with the proposed labels
* @param binaryThreshold threshold value to binarize the input images
* @return rand index value and derived statistics
*/
public ClassificationStatistics randIndexStats(
ImageProcessor label,
ImageProcessor proposal,
double binaryThreshold)
{
// Binarize inputs
ByteProcessor binaryLabel = new ByteProcessor( label.getWidth(), label.getHeight() );
ByteProcessor binaryProposal = new ByteProcessor( label.getWidth(), label.getHeight() );
for(int x=0; x<label.getWidth(); x++)
for(int y=0; y<label.getHeight(); y++)
{
binaryLabel.set( x, y, label.getPixelValue( x, y ) > binaryThreshold ? 255 : 0);
binaryProposal.set(x, y, proposal.getPixelValue( x, y ) > binaryThreshold ? 255 : 0);
}
// Find components
ShortProcessor components1 = ( ShortProcessor ) Utils.connectedComponents(
new ImagePlus("binary labels", binaryLabel), 4).allRegions.getProcessor();
ShortProcessor components2 = ( ShortProcessor ) Utils.connectedComponents(
new ImagePlus("proposal labels", binaryProposal), 4).allRegions.getProcessor();
return getRandIndexStats( components1, components2 );
}
/**
* Calculate the Rand index between to clusters, as described by
* William M. Rand \cite{Rand71}. Note that this version of the
* Rand index treats the zero component (background) as another
* object.
*
* BibTeX:
* <pre>
* &#64;article{Rand71,
* author = {William M. Rand},
* title = {Objective criteria for the evaluation of clustering methods},
* journal = {Journal of the American Statistical Association},
* year = {1971},
* volume = {66},
* number = {336},
* pages = {846--850},
* doi = {10.2307/2284239)
* }
* </pre>
*
* @param cluster1 2D segmented image (objects are labeled with different numbers)
* @param cluster2 2D segmented image (objects are labeled with different numbers)
* @return Rand index
*/
public double classicRandIndex(
ShortProcessor cluster1,
ShortProcessor cluster2)
{
final short[] pixels1 = (short[]) cluster1.getPixels();
final short[] pixels2 = (short[]) cluster2.getPixels();
double n = pixels1.length;
// Form contingency matrix
int[][]cont = new int[(int) cluster1.getMax() ] [ (int) cluster2.getMax() ];
for(int i=0; i<n; i++)
cont[ pixels1[i] ] [ pixels2[i] ] ++;
// sum over rows & columnns of nij^2
double t2 = 0;
// sum of squares of sums of rows
double[] ni = new double[ cont.length ];
for(int i=0; i<cont.length; i++)
for(int j=0; j<cont[i].length; j++)
ni[ i ] += cont[ i ][ j ];
double nis = 0;
for(int k=0; k<ni.length; k++)
nis += ni[ k ] * ni[ k ];
// sum of squares of sums of columns
double[] nj = new double[ cont.length ];
for(int j=0; j<cont[0].length; j++)
for(int i=0; i<cont.length; i++)
{
nj[ j ] += cont[ i ][ j ];
t2 += cont[ i ][ j ] * cont[ i ][ j ];
}
double njs = 0;
for(int k=0; k<nj.length; k++)
njs += nj[ k ] * nj[ k ];
// total number of pairs of entities
double t1 = n * (n - 1) / 2 ;
double t3 = 0.5 * (nis+njs);
double agreements=t1+t2-t3; // number of agreements
return agreements/t1;
}
/**
* Calculate the Rand index between to clusters, as described by
* William M. Rand \cite{Rand71}, but pruning out the zero component.
* Otherwise the Rand index gets symmetric.
*
* BibTeX:
* <pre>
* &#64;article{Rand71,
* author = {William M. Rand},
* title = {Objective criteria for the evaluation of clustering methods},
* journal = {Journal of the American Statistical Association},
* year = {1971},
* volume = {66},
* number = {336},
* pages = {846--850},
* doi = {10.2307/2284239)
* }
* </pre>
*
* @param cluster1 2D segmented image (objects are labeled with different numbers)
* @param cluster2 2D segmented image (objects are labeled with different numbers)
* @return Rand index
*/
public double randIndex(
ShortProcessor cluster1,
ShortProcessor cluster2)
{
final short[] pixels1 = (short[]) cluster1.getPixels();
final short[] pixels2 = (short[]) cluster2.getPixels();
//(new ImagePlus("cluster 1", cluster1)).show();
//(new ImagePlus("cluster 2", cluster2)).show();
double nPixels = pixels1.length;
// number of pixels that are "in" (not background)
double n = 0;
// Form the contingency matrix
int[][]cont = new int[(int) cluster1.getMax() + 1] [ (int) cluster2.getMax() + 1];
for(int i=0; i<nPixels; i++)
{
cont[ pixels1[i] ] [ pixels2[i] ] ++;
if( pixels1[ i ] > 0)
n++;
}
// sum of squares of sums of rows
// (skip background objects in the first cluster)
double[] ni = new double[ cont.length ];
for(int i=1; i<cont.length; i++)
for(int j=0; j<cont[0].length; j++)
{
ni[ i ] += cont[ i ][ j ];
}
// sum of squares of sums of columns
// (prune out the zero component in the labeling (un-assigned "out" space))
double[] nj = new double[ cont[0].length ];
for(int j=1; j<cont[0].length; j++)
for(int i=1; i<cont.length; i++)
{
nj[ j ] += cont[ i ][ j ];
}
// true positives - type (i): objects in the pair are placed in the
// same class in cluster1 and in the same class in claster2
// (prune out the zero component in the labeling (un-assigned "out" space))
double truePositives = 0;
for(int j=1; j<cont[0].length; j++)
for(int i=1; i<cont.length; i++)
truePositives += cont[ i ][ j ] * ( cont[ i ][ j ] - 1 ) / 2;
// total number of pairs
double nPairsTotal = n * (n-1) / 2 ;
double nPosTrue = 0;
for(int k=0; k<ni.length; k++)
nPosTrue += ni[ k ] * (ni[ k ]-1) /2;
double nPosActual = 0;
for(int k=0; k<nj.length; k++)
nPosActual += nj[ k ] * (nj[ k ]-1)/2;
// true negatives - type (ii): objects in the pair are placed in different
// classes in cluster1 and in different classes in claster2
//double trueNegatives = (n*n + t2 - nis - njs) / 2;
double trueNegatives = nPairsTotal + truePositives - nPosTrue - nPosActual;
double agreements = truePositives + trueNegatives; // number of agreements
double randIndex = agreements / nPairsTotal;
return randIndex;
}
/**
* Calculate the Rand index between to clusters, as described by
* William M. Rand \cite{Rand71}, but pruning out the zero component.
* Otherwise the Rand index gets symmetric.
*
* BibTeX:
* <pre>
* &#64;article{Rand71,
* author = {William M. Rand},
* title = {Objective criteria for the evaluation of clustering methods},
* journal = {Journal of the American Statistical Association},
* year = {1971},
* volume = {66},
* number = {336},
* pages = {846--850},
* doi = {10.2307/2284239)
* }
* </pre>
*
* @param cluster1 2D segmented image (objects are labeled with different numbers)
* @param cluster2 2D segmented image (objects are labeled with different numbers)
* @return Rand index
*/
public ClassificationStatistics getRandIndexStats(
ShortProcessor cluster1,
ShortProcessor cluster2)
{
final short[] pixels1 = (short[]) cluster1.getPixels();
final short[] pixels2 = (short[]) cluster2.getPixels();
//(new ImagePlus("cluster 1", cluster1)).show();
//(new ImagePlus("cluster 2", cluster2)).show();
double nPixels = pixels1.length;
// number of pixels that are "in" (not background)
double n = 0;
// reset min and max of the cluster processors (neede in order to have correct values)
cluster1.resetMinAndMax();
cluster2.resetMinAndMax();
// Form the contingency matrix
int[][]cont = new int[(int) cluster1.getMax() + 1] [ (int) cluster2.getMax() + 1];
//IJ.log(" cont.length = " +cont.length );
//IJ.log(" cont[0].length = " +cont[0].length );
for(int i=0; i<nPixels; i++)
{
cont[ pixels1[i] ] [ pixels2[i] ] ++;
if( pixels1[ i ] > 0)
n++;
}
// sum over rows & columnns of nij^2
//double t2 = 0;
// sums of rows
// (skip background objects in the first cluster)
double[] ni = new double[ cont.length ];
for(int i=1; i<cont.length; i++)
for(int j=0; j<cont[0].length; j++)
{
ni[ i ] += cont[ i ][ j ];
}
/*
// sum of squares of sums of rows
double nis = 0;
for(int k=0; k<ni.length; k++)
nis += ni[ k ] * ni[ k ];
*/
// sums of columns
// (prune out the zero component in the labeling (un-assigned "out" space))
double[] nj = new double[ cont[0].length ];
for(int j=1; j<cont[0].length; j++)
for(int i=1; i<cont.length; i++)
{
nj[ j ] += cont[ i ][ j ];
//t2 += cont[ i ][ j ] * cont[ i ][ j ];
}
/*
// sum of squares of sums of columns
double njs = 0;
for(int k=0; k<nj.length; k++)
njs += nj[ k ] * nj[ k ];
*/
// true positives - type (i): objects in the pair are placed in the
// same class in cluster1 and in the same class in claster2
// (prune out the zero component in the labeling (un-assigned "out" space))
double truePositives = 0;
for(int j=1; j<cont[0].length; j++)
for(int i=1; i<cont.length; i++)
{
truePositives += cont[ i ][ j ] * ( cont[ i ][ j ] - 1.0 ) / 2.0;
}
// total number of pairs
double nPairsTotal = n * (n-1) / 2 ;
//
double nPosTrue = 0;
for(int k=0; k<ni.length; k++)
nPosTrue += ni[ k ] * (ni[ k ]-1) /2;
// number of pairs that were actually classified as positive
double nPosActual = 0;
for(int k=0; k<nj.length; k++)
nPosActual += nj[ k ] * (nj[ k ]-1)/2;
double nNegCorrect = nPairsTotal + truePositives - nPosTrue - nPosActual;
// true negatives - type (ii): objects in the pair are placed in different
// classes in cluster1 and in different classes in claster2
//double trueNegatives = (n*n + t2 - nis - njs) / 2;
double trueNegatives = nNegCorrect;
// false positives - type (iii): objects in the pair are placed in different
// classes in cluster1 and in the same class in claster2
double falsePositives = nPosActual - truePositives; //(njs - t2) / 2;
// number of pairs actually classified as negative
double nNegActual = nPairsTotal - nPosActual;
// false negatives - type (iv): objects in the pair are placed in the same
// class in cluster1 and in different classes in claster2
double falseNegatives = nNegActual - nNegCorrect; //(nis - t2) / 2;
// number of pairs classified as negative
//double nNegTrue = nPairsTotal - nPosTrue;
// the number of incorrectly classified pairs
//double nPosIncorrect = nPosTrue-truePositives;
//double nNegIncorrect = nNegTrue-nNegCorrect;
//double nPairsIncorrect = nPosIncorrect + nNegIncorrect;
// clustering error
//double clusteringError = nPairsIncorrect/nPairsTotal;
double agreements = truePositives + trueNegatives; // number of agreements
double randIndex = agreements / nPairsTotal;
/*
IJ.log(" In getRandIndexStats:");
IJ.log(" tp = " + truePositives);
IJ.log(" tn = " + trueNegatives);
IJ.log(" fp = " + falsePositives);
IJ.log(" fn = " + falseNegatives);
IJ.log(" nPairsTotal = " + nPairsTotal);
IJ.log(" nPosTrue = " + nPosTrue);
IJ.log(" nPosActual = " + nPosActual);
IJ.log(" nNegCorrect = " + nNegCorrect);
IJ.log(" nNegActual = " + nNegActual);
IJ.log(" clusteringError = " + clusteringError);
IJ.log(" agreements / nPairsTotal = " + (agreements / nPairsTotal));
*/
return new ClassificationStatistics( truePositives, trueNegatives,
falsePositives, falseNegatives, randIndex);
}
/**
* Get the best F-score of the Rand index over a set of thresholds
*
* @param minThreshold minimum threshold value to binarize the input images
* @param maxThreshold maximum threshold value to binarize the input images
* @param stepThreshold threshold step value to use during binarization
* @return maximal F-score of the Rand index
*/
public double getRandIndexMaximalFScore(
double minThreshold,
double maxThreshold,
double stepThreshold)
{
ArrayList<ClassificationStatistics> stats = getRandIndexStats( minThreshold, maxThreshold, stepThreshold );
// trainableSegmentation.utils.Utils.plotPrecisionRecall( stats );
double maxFScore = 0;
for(ClassificationStatistics stat : stats)
{
if (stat.fScore > maxFScore)
maxFScore = stat.fScore;
}
return maxFScore;
}
/**
* Main method for calcualte the Rand error metrics
* from the command line
*
* @param args arguments to decide the action
*/
public static void main(String args[])
{
if (args.length<1)
{
dumpSyntax();
System.exit(1);
}
else
{
if( args[0].equals("-help") )
dumpSyntax();
else if (args[0].equals("-maxFScoreRandIndex"))
System.out.println( maxFScoreRandIndex(args) );
else
dumpSyntax();
}
System.exit(0);
}
/**
* Calculate the maximum F-score of the Rand index based on the
* parameters introduced by command line
*
* @param args command line arguments
* @return maximal F-score of the Rand index
*/
static double maxFScoreRandIndex(String[] args)
{
if (args.length != 6)
{
dumpSyntax();
return -1;
}
final ImagePlus label = new ImagePlus( args[ 1 ] );
final ImagePlus proposal = new ImagePlus( args[ 2 ] );
final double minThreshold = Double.parseDouble( args[ 3 ] );
final double maxThreshold = Double.parseDouble( args[ 4 ] );
final double stepThreshold = Double.parseDouble( args[ 5 ] );
RandError re = new RandError(label, proposal);
re.setVerboseMode( false );
return re.getRandIndexMaximalFScore(minThreshold, maxThreshold, stepThreshold);
}
/**
* Set verbose mode
* @param verbose true to display more information in the standard output
*/
public void setVerboseMode(boolean verbose)
{
this.verbose = verbose;
}
/**
* Method to write the syntax of the program in the command line.
*/
private static void dumpSyntax ()
{
System.out.println("Purpose: calculate Rand error between proposed and original labels.\n");
System.out.println("Usage: RandError ");
System.out.println(" -help : show this message");
System.out.println("");
System.out.println(" -maxFScoreRandIndex : calculate maximum F-score of the Rand index over a set of thresholds");
System.out.println(" labels : image with the original labels");
System.out.println(" proposal : image with the proposed labels");
System.out.println(" minThreshold : minimum threshold value to binarize the proposal");
System.out.println(" maxThreshold : maximum threshold value to binarize the proposal");
System.out.println(" stepThreshold : threshold step value to use during binarization");
System.out.println("Examples:");
System.out.println("Calculate the maximum F-score of the Rand index between proposed and original labels over a set of");
System.out.println("thresholds (from 0.0 to 1.0 in steps of 0.1)");
System.out.println(" WarpingError -maxFScoreRandIndex original-labels.tif proposed-labels.tif 0.0 1.0 0.1");
}
} // end class RandError
package trainableSegmentation.metrics;
/**
*
* License: GPL
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License 2
* as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*
* Authors: Ignacio Arganda-Carreras (iarganda@mit.edu)
*/
import java.awt.Point;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import javax.vecmath.Point3f;
import trainableSegmentation.utils.Utils;
import ij.IJ;
import ij.ImagePlus;
import ij.ImageStack;
import ij.process.Blitter;
import ij.process.ByteProcessor;
import ij.process.FloatProcessor;
import ij.process.ImageProcessor;
/**
* This class implements the warping error metric \cite{Jain10}
*
* BibTeX:
* <pre>
* &#64;article{Jain10,
* author = {V. Jain, B. Bollmann, M. Richardson, D.R. Berger, M.N. Helmstaedter,
* K.L. Briggman, W. Denk, J.B. Bowden, J.M. Mendenhall, W.C. Abraham,
* K.M. Harris, N. Kasthuri, K.J. Hayworth, R. Schalek, J.C. Tapia,
* J.W. Lichtman, S.H. Seung},
* title = {Boundary Learning by Optimization with Topological Constraints},
* booktitle = {2010 IEEE CONFERENCE ON COMPUTER VISION AND PATTERN RECOGNITION (CVPR)},
* year = {2010},
* series = {IEEE Conference on Computer Vision and Pattern Recognition},
* pages = {2488-2495},
* doi = {10.1109/CVPR.2010.5539950)
* }
* </pre>
*/
public class WarpingError extends Metrics {
/** simple point threshold value */
public static final double SIMPLE_POINT_THRESHOLD = 0;
/** merger flag */
public static final int MERGE = 0x1;
/** split flag */
public static final int SPLIT = 0x2;
/** hole addition error flag */
public static final int HOLE_ADDITION = 0x4;
/** object deletion error flag */
public static final int OBJECT_DELETION = 0x8;
/** object addition error flag */
public static final int OBJECT_ADDITION = 0x10;
/** hole deletion error flag */
public static final int HOLE_DELETION = 0x20;
/** default flags */
public static final int DEFAULT_FLAGS = MERGE + SPLIT + HOLE_ADDITION + OBJECT_DELETION + OBJECT_ADDITION + HOLE_DELETION;
/** image mask containing in white the areas where warping is allowed (null for not geometric constraints) */
ImagePlus mask = null;
/** flags to select which error should be taken into account and which not */
int flags = DEFAULT_FLAGS;
/** boolean flag to set the level of detail on the standard output messages */
private boolean verbose = true;
/**
* Initialize warping error metric
* @param originalLabels original labels (single 2D image or stack)
* @param proposedLabels proposed new labels (single 2D image or stack of the same as as the original labels)
*/
public WarpingError(
ImagePlus originalLabels,
ImagePlus proposedLabels)
{
super(originalLabels, proposedLabels);
}
/**
* Initialize warping error metric
* @param originalLabels original labels (single 2D image or stack)
* @param proposedLabels proposed new labels (single 2D image or stack of the same as as the original labels)
* @param mask image mask containing in white the areas where warping is allowed (null for not geometric constraints)
*/
public WarpingError(
ImagePlus originalLabels,
ImagePlus proposedLabels,
ImagePlus mask)
{
super(originalLabels, proposedLabels);
this.mask = mask;
}
/**
* Initialize warping error metric
* @param originalLabels original labels (single 2D image or stack)
* @param proposedLabels proposed new labels (single 2D image or stack of the same as as the original labels)
* @param mask image mask containing in white the areas where warping is allowed (null for not geometric constraints)
* @param flags flags to select which error should be taken into account and which not
*/
public WarpingError(
ImagePlus originalLabels,
ImagePlus proposedLabels,
ImagePlus mask,
int flags)
{
super(originalLabels, proposedLabels);
this.mask = mask;
this.flags = flags;
}
/**
* Calculate the classic topology-preserving warping error \cite{Jain10}
* in 2D between some original labels and the corresponding proposed labels.
* Both, original and proposed labels are expected to have float values
* between 0 and 1. Otherwise, they will be converted.
*
* BibTeX:
* <pre>
* &#64;article{Jain10,
* author = {V. Jain, B. Bollmann, M. Richardson, D.R. Berger, M.N. Helmstaedter,
* K.L. Briggman, W. Denk, J.B. Bowden, J.M. Mendenhall, W.C. Abraham,
* K.M. Harris, N. Kasthuri, K.J. Hayworth, R. Schalek, J.C. Tapia,
* J.W. Lichtman, S.H. Seung},
* title = {Boundary Learning by Optimization with Topological Constraints},
* booktitle = {2010 IEEE CONFERENCE ON COMPUTER VISION AND PATTERN RECOGNITION (CVPR)},
* year = {2010},
* series = {IEEE Conference on Computer Vision and Pattern Recognition},
* pages = {2488-2495},
* doi = {10.1109/CVPR.2010.5539950)
* }
* </pre>
*
* @param binaryThreshold threshold value to binarize proposal (larger than 0 and smaller than 1)
* @return total warping error (it counts all type of mismatches as errors)
*/
@Override
public double getMetricValue(double binaryThreshold)
{
if( verbose )
IJ.log(" Warping ground truth...");
// Warp ground truth, relax original labels to proposal. Only simple
// points warping is allowed.
WarpingResults[] wrs = simplePointWarp2dMT(super.originalLabels, super.proposedLabels, mask, binaryThreshold);
if(null == wrs)
return -1;
double error = 0;
for(int j=0; j<wrs.length; j++)
error += wrs[ j ].warpingError;
if(wrs.length != 0)
return error / wrs.length;
else
return -1;
}
/**
* Calculate the topology-preserving warping error in 2D between some
* original labels and the corresponding proposed labels. Pixels belonging
* to the same mistake will be only counted once. For example, if we have
* a line of 15 pixels that prevent from a merger, it will count as 1 instead
* of 15 as in the classic warping error method.
* Both, original and proposed labels are expected to have float values between
* 0 and 1. Otherwise, they will be converted.
*
* @param binaryThreshold threshold value to binarize proposal (larger than 0 and smaller than 1)
* @param clusterByError if false, cluster mismatches by type, otherwise cluster them by error and type
* @return clustered warping error (it clusters the mismatches that belong to the same type and/or error together)
*/
public double getMetricValue(
double binaryThreshold,
boolean clusterByError)
{
if( verbose )
IJ.log(" Warping ground truth...");
// Get clustered mismatches after warping ground truth, i.e. relaxing original labels to proposal.
// Only simple points warping is allowed.
ClusteredWarpingMismatches[] cwm = getClusteredWarpingMismatches(originalLabels, proposedLabels,
mask, binaryThreshold, clusterByError, -1);
if(null == cwm)
return -1;
double error = 0;
double count = originalLabels.getWidth() * originalLabels.getHeight() * originalLabels.getImageStackSize();
if( (flags & HOLE_ADDITION) != 0)
for(int j=0; j<cwm.length; j++)
error += cwm[ j ].numOfHoleAdditions;
if( (flags & HOLE_DELETION) != 0)
for(int j=0; j<cwm.length; j++)
error += cwm[ j ].numOfHoleDeletions;
if( (flags & MERGE) != 0)
for(int j=0; j<cwm.length; j++)
error += cwm[ j ].numOfMergers;
if( (flags & OBJECT_ADDITION) != 0)
for(int j=0; j<cwm.length; j++)
error += cwm[ j ].numOfObjectAdditions;
if( (flags & OBJECT_DELETION) != 0)
for(int j=0; j<cwm.length; j++)
error += cwm[ j ].numOfObjectDeletions;
if( (flags & SPLIT) != 0)
for(int j=0; j<cwm.length; j++)
error += cwm[ j ].numOfSplits;
if(count != 0)
return error / count;
else
return -1;
}
/**
* Calculate the topology-preserving warping error in 2D between some
* original labels and the corresponding proposed labels. Pixels belonging
* to the same mistake will be only counted once. For example, if we have
* a line of 15 pixels that prevent from a merger, it will count as 1 instead
* of 15 as in the classic warping error method.
* Both, original and proposed labels are expected to have float values between
* 0 and 1. Otherwise, they will be converted.
*
* @param binaryThreshold threshold value to binarize proposal (larger than 0 and smaller than 1)
* @param clusterByError if false, cluster mismatches by type, otherwise cluster them by error and type
* @param radius radius in pixels to use when classifying mismatches
* @return clustered warping error (it clusters the mismatches that belong to the same type and/or error together)
*/
public double getMetricValue(
double binaryThreshold,
boolean clusterByError,
int radius)
{
if( verbose )
IJ.log(" Warping ground truth...");
// Get clustered mismatches after warping ground truth, i.e. relaxing original labels to proposal.
// Only simple points warping is allowed.
ClusteredWarpingMismatches[] cwm = getClusteredWarpingMismatches(originalLabels, proposedLabels,
mask, binaryThreshold, clusterByError, radius);
if(null == cwm)
return -1;
double error = 0;
double count = originalLabels.getWidth() * originalLabels.getHeight() * originalLabels.getImageStackSize();
if( (flags & HOLE_ADDITION) != 0)
for(int j=0; j<cwm.length; j++)
error += cwm[ j ].numOfHoleAdditions;
if( (flags & HOLE_DELETION) != 0)
for(int j=0; j<cwm.length; j++)
error += cwm[ j ].numOfHoleDeletions;
if( (flags & MERGE) != 0)
for(int j=0; j<cwm.length; j++)
error += cwm[ j ].numOfMergers;
if( (flags & OBJECT_ADDITION) != 0)
for(int j=0; j<cwm.length; j++)
error += cwm[ j ].numOfObjectAdditions;
if( (flags & OBJECT_DELETION) != 0)
for(int j=0; j<cwm.length; j++)
error += cwm[ j ].numOfObjectDeletions;
if( (flags & SPLIT) != 0)
for(int j=0; j<cwm.length; j++)
error += cwm[ j ].numOfSplits;
if(count != 0)
return error / count;
else
return -1;
}
/**
* Calculate the number of splits and mergers for different thresholds
*
* @param minThreshold minimum threshold value to binarize the input images
* @param maxThreshold maximum threshold value to binarize the input images
* @param stepThreshold threshold step value to use during binarization
* @param clusterByError if false, cluster mismatches by type, otherwise cluster them by error and type
* @return list with arrays with the number of splits and mergers
*/
public ArrayList<int[]> getSplitsAndMergers(
double minThreshold,
double maxThreshold,
double stepThreshold,
boolean clusterByError)
{
if( minThreshold < 0 || minThreshold >= maxThreshold || maxThreshold > 1)
{
IJ.log("Error: unvalid threshold values.");
return null;
}
ArrayList < int[] > listOfSplitsAndMergers = new ArrayList<int[]>();
for(double th = minThreshold; th<=maxThreshold; th += stepThreshold)
{
if( verbose )
IJ.log(" Calculating splits and mergers for threshold value " + String.format("%.3f", th) + "...");
ClusteredWarpingMismatches[] cwm =
getClusteredWarpingMismatches(originalLabels, proposedLabels,
mask, th, clusterByError, -1);
if(null == cwm)
return null;
int[] splitsAndMergers = new int[2];
for(int j=0; j<cwm.length; j++)
{
splitsAndMergers[ 0 ] += cwm[ j ].numOfSplits;
splitsAndMergers[ 1 ] += cwm[ j ].numOfMergers;
}
listOfSplitsAndMergers.add( splitsAndMergers );
if( verbose )
IJ.log( " # splits = " + splitsAndMergers[ 0 ] + ", # mergers = " + splitsAndMergers[ 1 ]);
}
return listOfSplitsAndMergers;
}
/**
* Calculate the number of splits and mergers for different thresholds
*
* @param minThreshold minimum threshold value to binarize the input images
* @param maxThreshold maximum threshold value to binarize the input images
* @param stepThreshold threshold step value to use during binarization
* @param clusterByError if false, cluster mismatches by type, otherwise cluster them by error and type
* @param radius radius in pixel to use when classifying mismatches
* @return list with arrays with the number of splits and mergers
*/
public ArrayList<int[]> getSplitsAndMergers(
double minThreshold,
double maxThreshold,
double stepThreshold,
boolean clusterByError,
int radius )
{
if( minThreshold < 0 || minThreshold >= maxThreshold || maxThreshold > 1)
{
IJ.log("Error: unvalid threshold values.");
return null;
}
ArrayList < int[] > listOfSplitsAndMergers = new ArrayList<int[]>();
for(double th = minThreshold; th<=maxThreshold; th += stepThreshold)
{
if( verbose )
IJ.log(" Calculating splits and mergers for threshold value " + String.format("%.3f", th) + "...");
ClusteredWarpingMismatches[] cwm =
getClusteredWarpingMismatches(originalLabels, proposedLabels,
mask, th, clusterByError, radius );
if(null == cwm)
return null;
int[] splitsAndMergers = new int[2];
for(int j=0; j<cwm.length; j++)
{
splitsAndMergers[ 0 ] += cwm[ j ].numOfSplits;
splitsAndMergers[ 1 ] += cwm[ j ].numOfMergers;
}
listOfSplitsAndMergers.add( splitsAndMergers );
if( verbose )
IJ.log( " # splits = " + splitsAndMergers[ 0 ] + ", # mergers = " + splitsAndMergers[ 1 ]);
}
return listOfSplitsAndMergers;
}
/**
* Calculate error with the minimum number of splits and mergers for different thresholds
*
* @param minThreshold minimum threshold value to binarize the input images
* @param maxThreshold maximum threshold value to binarize the input images
* @param stepThreshold threshold step value to use during binarization
* @param clusterByError if false, cluster mismatches by type, otherwise cluster them by error and type
* @return list with arrays with the number of splits and mergers
*/
public double getMinimumSplitsAndMergersErrorValue(
double minThreshold,
double maxThreshold,
double stepThreshold,
boolean clusterByError)
{
if( minThreshold < 0 || minThreshold >= maxThreshold || maxThreshold > 1)
{
IJ.log("Error: unvalid threshold values.");
return -1;
}
flags = MERGE + SPLIT;
double minError = Double.MAX_VALUE;
double bestTh = minThreshold;
for(double th = minThreshold; th<=maxThreshold; th += stepThreshold)
{
if( verbose )
IJ.log(" Calculating splits and mergers for threshold value " + String.format("%.3f", th) + "...");
double error = getMetricValue( th, clusterByError );
if ( verbose )
IJ.log(" error = " + error);
if ( error < minError)
{
minError = error;
bestTh = th;
}
}
if (verbose)
IJ.log(" ** Minimum error = " + minError + ", with threshold = " + bestTh + " **\n");
return minError;
}
/**
* Calculate error with the minimum number of splits and mergers for different thresholds
*
* @param minThreshold minimum threshold value to binarize the input images
* @param maxThreshold maximum threshold value to binarize the input images
* @param stepThreshold threshold step value to use during binarization
* @param clusterByError if false, cluster mismatches by type, otherwise cluster them by error and type
* @param radius radius in pixel to use when classifying mismatches
* @return list with arrays with the number of splits and mergers
*/
public double getMinimumSplitsAndMergersErrorValue(
double minThreshold,
double maxThreshold,
double stepThreshold,
boolean clusterByError,
int radius )
{
if( minThreshold < 0 || minThreshold >= maxThreshold || maxThreshold > 1)
{
IJ.log("Error: unvalid threshold values.");
return -1;
}
flags = MERGE + SPLIT;
double minError = Double.MAX_VALUE;
double bestTh = minThreshold;
for(double th = minThreshold; th<=maxThreshold; th += stepThreshold)
{
if ( verbose )
IJ.log(" Calculating splits and mergers for threshold value " + String.format("%.3f", th) + "...");
double error = getMetricValue( th, clusterByError, radius );
if ( verbose )
IJ.log(" error = " + error);
if ( error < minError)
{
minError = error;
bestTh = th;
}
}
if (verbose)
IJ.log(" ** Minimum error = " + minError + ", with threshold = " + bestTh + " **\n");
return minError;
}
/**
* Get the best F-score of the pixel error between proposed and original labels
* over a set of thresholds
*
* @param minThreshold minimum threshold value to binarize the input images
* @param maxThreshold maximum threshold value to binarize the input images
* @param stepThreshold threshold step value to use during binarization
* @param verbose flag to print or not output information
* @return maximal F-score of the pixel error
*/
public double getPixelErrorMaximalFScore(
double minThreshold,
double maxThreshold,
double stepThreshold )
{
ArrayList<ClassificationStatistics> stats = getPrecisionRecallStats( minThreshold, maxThreshold, stepThreshold );
// trainableSegmentation.utils.Utils.plotPrecisionRecall( stats );
double maxFScore = 0;
double th = 0;
double bestTh = 0;
for(ClassificationStatistics stat : stats)
{
if (stat.fScore > maxFScore)
{
maxFScore = stat.fScore;
bestTh = th;
}
th += stepThreshold;
}
if( verbose )
IJ.log(" ** Best F-score = " + maxFScore + ", with threshold = " + bestTh + " **\n");
return maxFScore;
}
/**
* Calculate the precision-recall values based on pixel error between
* some warped 2D original labels and the corresponding proposed labels.
*
* @param minThreshold minimum threshold value to binarize the input images
* @param maxThreshold maximum threshold value to binarize the input images
* @param stepThreshold threshold step value to use during binarization
* @return pixel error value and derived statistics for each threshold
*/
public ArrayList< ClassificationStatistics > getPrecisionRecallStats(
double minThreshold,
double maxThreshold,
double stepThreshold)
{
if( minThreshold < 0 || minThreshold >= maxThreshold || maxThreshold > 1)
{
IJ.log("Error: unvalid threshold values.");
return null;
}
ArrayList< ClassificationStatistics > cs = new ArrayList<ClassificationStatistics>();
for(double th = minThreshold; th <= maxThreshold; th += stepThreshold)
{
if( verbose )
IJ.log(" Calculating warping error statistics for threshold value " + String.format("%.3f", th) + "...");
WarpingResults[] wrs = simplePointWarp2dMT(originalLabels, proposedLabels, mask, th);
ImageStack is = new ImageStack( originalLabels.getWidth(), originalLabels.getHeight() );
for(int i = 0; i < wrs.length; i ++)
is.addSlice("warped source slice " + (i+1), wrs[i].warpedSource.getProcessor() );
ImagePlus warpedSource = new ImagePlus ("warped source", is);
// We calculate the precision-recall value between the warped original labels and the
// proposed labels
PixelError pixelError = new PixelError( warpedSource, proposedLabels);
ClassificationStatistics stats = pixelError.getPrecisionRecallStats( th );
if( verbose )
IJ.log(" F-score = " + stats.fScore );
cs.add( stats );
}
return cs;
}
/**
* Get the best F-score of the pixel error between proposed and original labels
* (and all the way around) over a set of thresholds
*
* @param minThreshold minimum threshold value to binarize the input images
* @param maxThreshold maximum threshold value to binarize the input images
* @param stepThreshold threshold step value to use during binarization
* @param verbose flag to print or not output information
* @return maximal F-score of the pixel error
*/
public double getDualPixelErrorMaximalFScore(
double minThreshold,
double maxThreshold,
double stepThreshold )
{
ArrayList<ClassificationStatistics> stats = getDualPrecisionRecallStats( minThreshold, maxThreshold, stepThreshold );
// trainableSegmentation.utils.Utils.plotPrecisionRecall( stats );
double maxFScore = 0;
double th = 0;
double bestTh = 0;
for(ClassificationStatistics stat : stats)
{
if (stat.fScore > maxFScore)
{
maxFScore = stat.fScore;
bestTh = th;
}
th += stepThreshold;
}
if( verbose )
IJ.log(" ** Best F-score = " + maxFScore + ", with threshold = " + bestTh + " **\n");
return maxFScore;
}
/**
* Calculate the precision-recall values based on pixel error between
* some warped 2D original labels and the corresponding proposed labels
* in both directions (from original labels to proposal and reversely).
*
* @param minThreshold minimum threshold value to binarize the input images
* @param maxThreshold maximum threshold value to binarize the input images
* @param stepThreshold threshold step value to use during binarization
* @return pixel error value and derived statistics for each threshold
*/
public ArrayList< ClassificationStatistics > getDualPrecisionRecallStats(
double minThreshold,
double maxThreshold,
double stepThreshold)
{
if( minThreshold < 0 || minThreshold >= maxThreshold || maxThreshold > 1)
{
IJ.log("Error: unvalid threshold values.");
return null;
}
ArrayList< ClassificationStatistics > cs = new ArrayList<ClassificationStatistics>();
for(double th = minThreshold; th <= maxThreshold; th += stepThreshold)
{
if( verbose )
IJ.log(" Calculating warping error statistics for threshold value " + String.format("%.3f", th) + "...");
WarpingResults[] wrs = simplePointWarp2dMT(originalLabels, proposedLabels, mask, th);
ImageStack is = new ImageStack( originalLabels.getWidth(), originalLabels.getHeight() );
for(int i = 0; i < wrs.length; i ++)
is.addSlice("warped source slice " + (i+1), wrs[i].warpedSource.getProcessor() );
ImagePlus warpedSource = new ImagePlus ("warped source", is);
// We calculate first the precision-recall value between the warped
// original labels and the proposed labels
PixelError pixelError = new PixelError( warpedSource, proposedLabels );
ClassificationStatistics stats = pixelError.getPrecisionRecallStats( th );
// ... and then from warped proposed labels to original labels
// apply threshold to proposed labels so they are binary
double max = proposedLabels.getImageStack().getProcessor( 1 ) instanceof ByteProcessor ? 255 : 1.0;
ImagePlus proposal8bit = proposedLabels.duplicate();
IJ.setThreshold( proposal8bit, th + 0.00001, max);
IJ.run( proposal8bit, "Convert to Mask", " black");
// warp proposal into original labels
wrs = simplePointWarp2dMT( proposal8bit, originalLabels, mask, th);
is = new ImageStack( originalLabels.getWidth(), originalLabels.getHeight() );
for(int i = 0; i < wrs.length; i ++)
is.addSlice("warped source slice " + (i+1), wrs[i].warpedSource.getProcessor() );
warpedSource = new ImagePlus ("warped source", is);
// then calculate pixel error
pixelError = new PixelError( warpedSource, originalLabels );
ClassificationStatistics statsInverse = pixelError.getPrecisionRecallStats( th );
// Join statistics and average errors
stats.metricValue = (stats.metricValue + statsInverse.metricValue) / 2.0;
ClassificationStatistics finalStats = new ClassificationStatistics(
stats.truePositives + statsInverse.truePositives,
stats.trueNegatives + statsInverse.trueNegatives,
stats.falsePositives + statsInverse.falsePositives,
stats.falseNegatives + statsInverse.falseNegatives,
(stats.metricValue + statsInverse.metricValue) / 2.0);
if( verbose )
IJ.log(" F-score = " + finalStats.fScore );
cs.add( finalStats );
}
return cs;
}
/**
* Calculate the precision-recall values based on Rand index between
* some warped 2D original labels and the corresponding proposed labels.
*
* @param minThreshold minimum threshold value to binarize the input images
* @param maxThreshold maximum threshold value to binarize the input images
* @param stepThreshold threshold step value to use during binarization
* @return Rand index value and derived statistics for each threshold
*/
public ArrayList< ClassificationStatistics > getRandIndexStats(
double minThreshold,
double maxThreshold,
double stepThreshold)
{
if( minThreshold < 0 || minThreshold >= maxThreshold || maxThreshold > 1)
{
IJ.log("Error: unvalid threshold values.");
return null;
}
ArrayList< ClassificationStatistics > cs = new ArrayList<ClassificationStatistics>();
for(double th = minThreshold; th <= maxThreshold; th += stepThreshold)
{
if( verbose )
IJ.log(" Calculating warping error statistics for threshold value " + String.format("%.3f", th) + "...");
WarpingResults[] wrs = simplePointWarp2dMT(originalLabels, proposedLabels, mask, th);
ImageStack is = new ImageStack( originalLabels.getWidth(), originalLabels.getHeight() );
for(int i = 0; i < wrs.length; i ++)
is.addSlice("warped source slice " + (i+1), wrs[i].warpedSource.getProcessor() );
ImagePlus warpedSource = new ImagePlus ("warped source", is);
// We calculate the precision-recall value between the warped original labels and the
// proposed labels
RandError randError = new RandError( warpedSource, proposedLabels );
ClassificationStatistics stats = randError.getRandIndexStats( th );
if( verbose )
IJ.log(" F-score = " + stats.fScore );
cs.add( stats );
}
return cs;
}
/**
* Get the best F-score of the Rand index based on Rand index between
* some warped 2D original labels and the corresponding proposed labels.
*
* @param minThreshold minimum threshold value to binarize the input images
* @param maxThreshold maximum threshold value to binarize the input images
* @param stepThreshold threshold step value to use during binarization
* @return maximal F-score of the Rand index
*/
public double getRandIndexMaximalFScore(
double minThreshold,
double maxThreshold,
double stepThreshold)
{
ArrayList<ClassificationStatistics> stats = getRandIndexStats( minThreshold, maxThreshold, stepThreshold );
// trainableSegmentation.utils.Utils.plotPrecisionRecall( stats );
double maxFScore = 0;
double th = 0;
double bestTh = 0;
for(ClassificationStatistics stat : stats)
{
if (stat.fScore > maxFScore)
{
maxFScore = stat.fScore;
bestTh = th;
}
th += stepThreshold;
}
if( verbose )
IJ.log(" ** Best F-score = " + maxFScore + ", with threshold = " + bestTh + " **\n");
return maxFScore;
}
/**
* Check if a point is simple (in 2D) based on 3D code from Mark Richardson
* inspired in the work of Bertrand et al. \cite{Bertrand94}
*
* BibTeX:
* <pre>
* &#64;article{Bertrand94,
* author = {Bertrand, Gilles and Malandain, Gr\'{e}goire},
* title = {A new characterization of three-dimensional simple points},
* journal = {Pattern Recogn. Lett.},
* volume = {15},
* issue = {2},
* month = {February},
* year = {1994},
* issn = {0167-8655},
* pages = {169--175},
* numpages = {7},
* url = {http://dl.acm.org/citation.cfm?id=179348.179356},
* doi = {10.1016/0167-8655(94)90046-9},
* acmid = {179356},
* publisher = {Elsevier Science Inc.},
* address = {New York, NY, USA},
* keywords = {digital topology, simple points, thinning algorithms, three dimensions},
* }
* </pre>
* @param im input patch
* @param n neighbors
* @return true if the center pixel of the patch is a simple point
*/
public boolean simple2DBertrand(ImagePlus im, int n)
{
float[] input = new float[27];
float[] center = (float[])im.getProcessor().getPixels();
for(int i=0; i<9; i++)
input[i+9] = center[i];
switch (n)
{
case 4:
return simple3d( input, 6);
case 8:
return simple3d(input, 26);
default:
IJ.error("Non valid adjacency value");
return false;
}
}
/**
* Check if a point is simple (in 2D)
* @param im input patch
* @param n neighbors
* @return true if the center pixel of the patch is a simple point
*/
public boolean simple2D(ImagePlus im, int n)
{
final ImagePlus invertedIm = new ImagePlus("inverted", im.getProcessor().duplicate());
//IJ.run(invertedIm, "Invert","");
final float[] pix = (float[])invertedIm.getProcessor().getPixels();
for(int i=0; i<pix.length; i++)
pix[i] = pix[i] == 0f ? 1f : 0f;
switch (n)
{
case 4:
if ( topo(im,4)==1 && topo(invertedIm, 8)==1 )
return true;
else
return false;
case 8:
if ( topo(im,8)==1 && topo(invertedIm, 4)==1 )
return true;
else
return false;
default:
IJ.error("Non valid adjacency value");
return false;
}
}
/**
* Computes topological numbers for the central point of an image patch.
* These numbers can be used as the basis of a topological classification.
* T_4 and T_8 are used when IM is a 2d image patch of size 3x3
* defined on p. 172 of Bertrand & Malandain, Patt. Recog. Lett. 15, 169-75 (1994).
*
* @param im input image
* @param adjacency number of neighbors
* @return number of components in the patch excluding the center pixel
*/
public int topo(final ImagePlus im, final int adjacency)
{
ImageProcessor components = null;
final ImagePlus im2 = new ImagePlus("copy of im", im.getProcessor().duplicate());
switch (adjacency)
{
case 4:
if( im.getStack().getSize() > 1 )
{
IJ.error("n=4 is valid for a 2d image");
return -1;
}
if( im.getProcessor().getWidth() > 3 || im.getProcessor().getHeight() > 3)
{
IJ.error("must be 3x3 image patch");
return -1;
}
// ignore the central point
im2.getProcessor().set(1, 1, 0);
components = Utils.connectedComponents(im2, adjacency).allRegions.getProcessor();
// zero out locations that are not in the four-neighborhood
components.set(0,0,0);
components.set(0,2,0);
components.set(1,1,0);
components.set(2,0,0);
components.set(2,2,0);
break;
case 8:
if( im.getStack().getSize() > 1 )
{
IJ.error("n=8 is valid for a 2d image");
return -1;
}
if( im.getProcessor().getWidth() > 3 || im.getProcessor().getHeight() > 3)
{
IJ.error("must be 3x3 image patch");
return -1;
}
// ignore the central point
im2.getProcessor().set(1, 1, 0);
components = Utils.connectedComponents(im2, adjacency).allRegions.getProcessor();
break;
default:
IJ.error("Non valid adjacency value");
return -1;
}
if(null == components)
return -1;
int t = 0;
ArrayList<Integer> uniqueId = new ArrayList<Integer>();
for(int i = 0; i < 3; i++)
for(int j = 0; j < 3; j++)
{
if(( t = components.get(i, j) ) != 0)
if(!uniqueId.contains(t))
uniqueId.add(t);
}
return uniqueId.size();
}
/**
* Use simple point relaxation to warp 2D source into 2D target.
* Source is only modified at nonzero locations in the mask
*
* @param source input 2D image to be relaxed
* @param target target 2D image
* @param mask 2D image mask
* @param binaryThreshold binarization threshold
* @return warped source image and warping error
*/
public WarpingResults simplePointWarp2d(
final ImageProcessor source,
final ImageProcessor target,
final ImageProcessor mask,
double binaryThreshold)
{
if(binaryThreshold < 0 || binaryThreshold > 1.01)
binaryThreshold = 0.5;
// Grayscale target
final ImagePlus targetReal;// = new ImagePlus("target_real", target.duplicate());
// Binarized target
final ImagePlus targetBin; // = new ImagePlus("target_aux", target.duplicate());
final ImagePlus sourceReal; // = new ImagePlus("source_real", source.duplicate());
final ImagePlus maskReal; // = (null != mask) ? new ImagePlus("mask_real", mask.duplicate().convertToFloat()) : null;
final int width = target.getWidth();
final int height = target.getHeight();
// Resize canvas to avoid checking the borders
ImageProcessor ip = target.createProcessor(width+2, height+2);
ip.insert(target, 1, 1);
targetReal = new ImagePlus("target_real", ip.duplicate());
targetBin = new ImagePlus("target_aux", ip.duplicate());
ip = target.createProcessor(width+2, height+2);
ip.insert(source, 1, 1);
sourceReal = new ImagePlus("source_real", ip.duplicate());
if(null != mask)
{
ip = target.createProcessor(width+2, height+2);
ip.insert(mask, 1, 1);
maskReal = new ImagePlus("mask_real", ip.duplicate());
}
else{
maskReal = null;
}
// make sure source and target are binary images
final float[] sourceRealPix = (float[])sourceReal.getProcessor().getPixels();
for(int i=0; i < sourceRealPix.length; i++)
if(sourceRealPix[i] > 0)
sourceRealPix[i] = 1.0f;
final float[] targetBinPix = (float[])targetBin.getProcessor().getPixels();
for(int i=0; i < targetBinPix.length; i++)
targetBinPix[i] = (targetBinPix[i] <= binaryThreshold) ? 0.0f : 1.0f;
double diff = Double.MIN_VALUE;
double diff_before = 0;
final WarpingResults result = new WarpingResults();
while(true)
{
ImageProcessor missclass_points_image = sourceReal.getProcessor().duplicate();
missclass_points_image.copyBits(targetBin.getProcessor(), 0, 0, Blitter.DIFFERENCE);
diff_before = diff;
// Count mismatches
float pixels[] = (float[]) missclass_points_image.getPixels();
float mask_pixels[] = (null != maskReal) ? (float[]) maskReal.getProcessor().getPixels() : new float[pixels.length];
if(null == maskReal)
Arrays.fill(mask_pixels, 1f);
diff = 0;
for(int k = 0; k < pixels.length; k++)
if(pixels[k] != 0 && mask_pixels[k] != 0)
diff ++;
//IJ.log("Difference = " + diff);
if( diff == 0 )
{
result.mismatches = new ArrayList<Point3f>();
break;
}
if(diff == diff_before)
break;
final ArrayList<Point3f> mismatches = new ArrayList<Point3f>();
final float[] realTargetPix = (float[])targetReal.getProcessor().getPixels();
// Sort mismatches by the absolute value of the target pixel value - threshold
for(int x = 1; x < width+1; x++)
for(int y = 1; y < height+1; y++)
{
if(pixels[x+y*(width+2)] != 0 && mask_pixels[x+y*(width+2)] != 0)
mismatches.add(new Point3f(x , y , (float) Math.abs( realTargetPix[x+y*(width+2)] - binaryThreshold) ));
}
// Sort mismatches in descending order
Collections.sort(mismatches, new Comparator<Point3f>() {
public int compare(Point3f o1, Point3f o2) {
return (int)((o2.z - o1.z) *10000);
}});
// Process mismatches
for(final Point3f p : mismatches)
{
final int x = (int) p.x;
final int y = (int) p.y;
if(p.z < SIMPLE_POINT_THRESHOLD)
continue;
double[] val = new double[]{
sourceRealPix[ (x-1) + (y-1) * (width+2) ],
sourceRealPix[ (x ) + (y-1) * (width+2) ],
sourceRealPix[ (x+1) + (y-1) * (width+2) ],
sourceRealPix[ (x-1) + (y ) * (width+2) ],
sourceRealPix[ (x ) + (y ) * (width+2) ],
sourceRealPix[ (x+1) + (y ) * (width+2) ],
sourceRealPix[ (x-1) + (y+1) * (width+2) ],
sourceRealPix[ (x ) + (y+1) * (width+2) ],
sourceRealPix[ (x+1) + (y+1) * (width+2) ]
};
final double pix = val[4];
final ImagePlus patch = new ImagePlus("patch", new FloatProcessor(3,3,val));
if( simple2DBertrand(patch, 4) )
{
sourceRealPix[ x + y * (width+2)] = pix > 0.0 ? 0.0f : 1.0f ;
//IJ.log("flipping pixel x: " + x + " y: " + y + " to " + (pix > 0 ? 0.0 : 1.0));
}
}
result.mismatches = mismatches;
}
//IJ.run(sourceReal, "Canvas Size...", "width="+ width + " height=" + height + " position=Center zero");
ip = source.createProcessor(width, height);
ip.insert(sourceReal.getProcessor(), -1, -1);
sourceReal.setProcessor(ip.duplicate());
// Adjust mismatches coordinates
final ArrayList<Point3f> mismatches = new ArrayList<Point3f>();
for(Point3f p : result.mismatches)
{
mismatches.add(new Point3f( p.x - 1, p.y - 1, p.z));
}
sourceReal.setTitle("Warped source");
result.mismatches = mismatches;
result.warpedSource = sourceReal;
result.warpingError = diff / (width * height);
return result;
}
/**
* Calculate the simple point warping in a concurrent way
* (to be submitted to an Executor Service)
* @param source moving image
* @param target fixed image
* @param mask mask image
* @param binaryThreshold binary threshold to use
* @return warping results (warped labels, warping error value and mismatching points)
*/
public Callable<WarpingResults> simplePointWarp2DConcurrent(
final ImageProcessor source,
final ImageProcessor target,
final ImageProcessor mask,
final double binaryThreshold)
{
return new Callable<WarpingResults>(){
public WarpingResults call(){
return simplePointWarp2d(source, target, mask, binaryThreshold);
}
};
}
/**
* Calculate the simple point warping in a concurrent way
* (to be submitted to an Executor Service)
* @param source moving image
* @param target fixed image
* @param mask mask image
* @param binaryThreshold binary threshold to use
* @param radius radius in pixels to use while classifying pixels
* @return warping results (warped labels, warping error value and mismatching points)
*/
public Callable<WarpingResults> simplePointWarp2DConcurrent(
final ImageProcessor source,
final ImageProcessor target,
final ImageProcessor mask,
final double binaryThreshold,
final boolean calculateMismatchImage,
final int radius )
{
return new Callable<WarpingResults>(){
public WarpingResults call(){
WarpingResults wr = simplePointWarp2d(source, target, mask, binaryThreshold);
if( calculateMismatchImage )
wr.classifiedMismatches = getMismatchImage( wr, radius );
return wr;
}
};
}
/**
* Get the image with the classified mismatches
*
* @param wr warping results
* @param radius radius in pixels to use while classifying pixels
* @return image with classified mismatches
*/
public ImagePlus getMismatchImage(WarpingResults wr, int radius)
{
int[] mismatchesLabels = classifyMismatches2d( wr.warpedSource, wr.mismatches, radius );
ByteProcessor bp = new ByteProcessor( wr.warpedSource.getWidth(), wr.warpedSource.getHeight() );
for(int i=0; i < wr.mismatches.size(); i++)
{
Point3f p = wr.mismatches.get( i );
bp.set( (int)p.x, (int)p.y, mismatchesLabels[ i ] );
}
return new ImagePlus( "Mismatches", bp );
}
/**
* Get the image with the classified mismatches
*
* @param wr warping results
* @param radius radius in pixels to use while classifying pixels
* @return image with classified mismatches
*/
public ImagePlus getMismatchImage(WarpingResults wr, int radius, int flags)
{
int[] mismatchesLabels = classifyMismatches2d( wr.warpedSource, wr.mismatches, radius );
ByteProcessor bp = new ByteProcessor( wr.warpedSource.getWidth(), wr.warpedSource.getHeight() );
for(int i=0; i < wr.mismatches.size(); i++)
{
Point3f p = wr.mismatches.get( i );
bp.set( (int)p.x, (int)p.y, mismatchesLabels[ i ] & flags );
}
return new ImagePlus( "Mismatches", bp );
}
/**
* Get the image with the classified mismatches
*
* @param wr warping results
* @param mismatchesLabels labels of the warping mismatches
* @return image with classified mismatches
*/
public ImagePlus getMismatchImage(WarpingResults wr,int[] mismatchesLabels, int flags)
{
ByteProcessor bp = new ByteProcessor( wr.warpedSource.getWidth(), wr.warpedSource.getHeight() );
for(int i=0; i < wr.mismatches.size(); i++)
{
Point3f p = wr.mismatches.get( i );
bp.set( (int)p.x, (int)p.y, mismatchesLabels[ i ] & flags );
}
return new ImagePlus( "Mismatches", bp );
}
/**
* Use simple point relaxation to warp 2D source into 2D target.
* Source is only modified at nonzero locations in the mask
* (multi-thread version)
*
* @param source input image to be relaxed
* @param target target image
* @param mask image mask
* @param binaryThreshold binarization threshold
* @param mismatches list of points that could not be flipped
* @return warped source image
*/
public ImagePlus simplePointWarp2dMT(
ImagePlus source,
ImagePlus target,
ImagePlus mask,
double binaryThreshold,
ArrayList<Point3f>[] mismatches)
{
if(source.getWidth() != target.getWidth()
|| source.getHeight() != target.getHeight()
|| source.getImageStackSize() != target.getImageStackSize())
{
IJ.log("Error: label and training image sizes do not fit.");
return null;
}
final ImageStack sourceSlices = source.getImageStack();
final ImageStack targetSlices = target.getImageStack();
final ImageStack maskSlices = (null != mask) ? mask.getImageStack() : null;
final ImageStack warpedSource = new ImageStack(source.getWidth(), source.getHeight());
if(null == mismatches)
mismatches = new ArrayList[sourceSlices.getSize()];
// Executor service to produce concurrent threads
final ExecutorService exe = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
final ArrayList< Future<WarpingResults> > futures = new ArrayList< Future<WarpingResults> >();
try{
for(int i = 1; i <= sourceSlices.getSize(); i++)
{
futures.add(exe.submit( simplePointWarp2DConcurrent(sourceSlices.getProcessor(i),
targetSlices.getProcessor(i),
null != maskSlices ? maskSlices.getProcessor(i) : null,
binaryThreshold ) ) );
}
double warpingError = 0;
int i = 0;
// Wait for the jobs to be done
for(Future<WarpingResults> f : futures)
{
final WarpingResults wr = f.get();
if(null != wr.warpedSource)
warpedSource.addSlice("warped source " + i, wr.warpedSource.getProcessor());
if(wr.warpingError != -1)
warpingError += wr.warpingError;
if(null != wr.mismatches)
mismatches[i] = wr.mismatches;
i++;
}
if( verbose )
IJ.log("Warping error = " + (warpingError / sourceSlices.getSize()));
}
catch(Exception ex)
{
IJ.log("Error when warping ground truth in a concurrent way.");
ex.printStackTrace();
}
finally{
exe.shutdown();
}
return new ImagePlus("warped source", warpedSource);
}
/**
* Use simple point relaxation to warp 2D source into 2D target.
* Source is only modified at nonzero locations in the mask
*
* @param source input image to be relaxed
* @param target target image
* @param mask image mask
* @param binaryThreshold binarization threshold
* @return warped source image
*/
public ImagePlus simplePointWarp2d(
ImagePlus source,
ImagePlus target,
ImagePlus mask,
double binaryThreshold)
{
if(source.getWidth() != target.getWidth()
|| source.getHeight() != target.getHeight()
|| source.getImageStackSize() != target.getImageStackSize())
{
IJ.log("Error: label and training image sizes do not fit.");
return null;
}
final ImageStack sourceSlices = source.getImageStack();
final ImageStack targetSlices = target.getImageStack();
final ImageStack maskSlices = (null != mask) ? mask.getImageStack() : null;
final ImageStack warpedSource = new ImageStack(source.getWidth(), source.getHeight());
double warpingError = 0;
for(int i = 1; i <= sourceSlices.getSize(); i++)
{
WarpingResults wr = simplePointWarp2d(sourceSlices.getProcessor(i),
targetSlices.getProcessor(i), null != mask ? maskSlices.getProcessor(i) : null,
binaryThreshold);
if(null != wr.warpedSource)
warpedSource.addSlice("warped source " + i, wr.warpedSource.getProcessor());
if(wr.warpingError != -1)
warpingError += wr.warpingError;
}
//IJ.log("Warping error = " + (warpingError / sourceSlices.getSize()));
return new ImagePlus("warped source", warpedSource);
}
/**
* Use simple point relaxation to warp 2D source into 2D target.
* Source is only modified at nonzero locations in the mask
* (multi-thread version)
*
* @param source input image to be relaxed (2D image or stack)
* @param target target image (2D image or stack)
* @param mask image mask (2D image or stack)
* @param binaryThreshold binarization threshold
* @return warping results for each slice of the source
*/
public WarpingResults[] simplePointWarp2dMT(
ImagePlus source,
ImagePlus target,
ImagePlus mask,
double binaryThreshold)
{
if(source.getWidth() != target.getWidth()
|| source.getHeight() != target.getHeight()
|| source.getImageStackSize() != target.getImageStackSize())
{
IJ.log("Error: label and training image sizes do not fit.");
return null;
}
final ImageStack sourceSlices = source.getImageStack();
final ImageStack targetSlices = target.getImageStack();
final ImageStack maskSlices = (null != mask) ? mask.getImageStack() : null;
final WarpingResults[] wrs = new WarpingResults[ source.getImageStackSize() ];
// Executor service to produce concurrent threads
final ExecutorService exe = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
final ArrayList< Future<WarpingResults> > futures = new ArrayList< Future<WarpingResults> >();
try{
for(int i = 1; i <= sourceSlices.getSize(); i++)
{
futures.add(exe.submit( simplePointWarp2DConcurrent(sourceSlices.getProcessor(i).convertToFloat(),
targetSlices.getProcessor(i).convertToFloat(),
null != maskSlices ? maskSlices.getProcessor(i) : null,
binaryThreshold ) ) );
}
int i = 0;
// Wait for the jobs to be done
for(Future<WarpingResults> f : futures)
{
wrs[ i ] = f.get();
i++;
}
}
catch(Exception ex)
{
IJ.log("Error when warping ground truth in a concurrent way.");
ex.printStackTrace();
}
finally{
exe.shutdown();
}
return wrs;
}
/**
* Use simple point relaxation to warp 2D labels into the 2D proposal.
* Source is only modified at nonzero locations in the mask
* (multi-thread version)
*
* @param binaryThreshold binarization threshold
* @param clusterByError if false, cluster mismatches by type, otherwise cluster them by error and type
* @param calculateMismatchImage boolean flag to calculate mismatch image
* @param radius radius in pixels to use while classifying mismatches
* @return warping results for each slice of the source
*/
public WarpingResults[] simplePointWarp2dMT(
double binaryThreshold,
boolean clusterByError,
boolean calculateMismatchImage,
int radius )
{
final ImageStack sourceSlices = originalLabels.getImageStack();
final ImageStack targetSlices = proposedLabels.getImageStack();
final ImageStack maskSlices = (null != mask) ? mask.getImageStack() : null;
final WarpingResults[] wrs = new WarpingResults[ originalLabels.getImageStackSize() ];
// Executor service to produce concurrent threads
final ExecutorService exe = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
final ArrayList< Future<WarpingResults> > futures = new ArrayList< Future<WarpingResults > >();
try{
for(int i = 1; i <= sourceSlices.getSize(); i++)
{
futures.add(exe.submit( getWarpingResultsConcurrent(sourceSlices.getProcessor(i).convertToFloat(),
targetSlices.getProcessor(i).convertToFloat(),
null != maskSlices ? maskSlices.getProcessor(i) : null,
binaryThreshold, clusterByError, radius,
flags, calculateMismatchImage ) ) );
}
int i = 0;
// Wait for the jobs to be done
for(Future<WarpingResults> f : futures)
{
wrs[ i ] = f.get();
i++;
}
}
catch(Exception ex)
{
IJ.log("Error when warping ground truth in a concurrent way.");
ex.printStackTrace();
}
finally{
exe.shutdown();
}
return wrs;
}
/**
* Calculate warping error (single thread version)
*
* @param label original labels (single image or stack)
* @param proposal proposed new labels
* @param mask image mask
* @param binaryThreshold binary threshold to binarize proposal
* @return total warping error
*/
public double warpingErrorSingleThread(
ImagePlus label,
ImagePlus proposal,
ImagePlus mask,
double binaryThreshold)
{
final ImagePlus warpedLabels = simplePointWarp2d(label, proposal, mask, binaryThreshold);
if(null == warpedLabels)
return -1;
double error = 0;
double count = 0;
for(int j=1; j<=proposal.getImageStackSize(); j++)
{
final float[] proposalPixels = (float[])proposal.getImageStack().getProcessor(j).getPixels();
final float[] warpedPixels = (float[])warpedLabels.getImageStack().getProcessor(j).getPixels();
for(int i=0; i<proposalPixels.length; i++)
{
count ++;
final float thresholdedProposal = (proposalPixels[i] <= binaryThreshold) ? 0.0f : 1.0f;
if (warpedPixels[i] != thresholdedProposal)
error++;
}
}
if(count != 0)
return error / count;
else
return -1;
}
/**
* Get all the mismatches of warping a source image into a target image
* and clustering them when they belong to the same error. Simple point
* relaxation is used for the warping. The source is only modified at
* nonzero locations in the mask (multi-thread static version)
*
* @param source input image to be relaxed (2D image or stack)
* @param target target image (2D image or stack)
* @param mask image mask (2D image or stack)
* @param binaryThreshold binarization threshold
* @param clusterByError if false, cluster mismatches by type, otherwise cluster them by error and type
* @param radius radius in pixels of the local area to look when deciding some cases (small radius speed up the method a lot, -1 to use whole image)
* @return clustered warping mismatches for each slice of the source
*/
public ClusteredWarpingMismatches[] getClusteredWarpingMismatches(
ImagePlus source,
ImagePlus target,
ImagePlus mask,
double binaryThreshold,
boolean clusterByError,
int radius)
{
if(source.getWidth() != target.getWidth()
|| source.getHeight() != target.getHeight()
|| source.getImageStackSize() != target.getImageStackSize())
{
IJ.log("Error: label and training image sizes do not fit.");
return null;
}
final ImageStack sourceSlices = source.getImageStack();
final ImageStack targetSlices = target.getImageStack();
final ImageStack maskSlices = (null != mask) ? mask.getImageStack() : null;
final ClusteredWarpingMismatches[] cwm = new ClusteredWarpingMismatches[ source.getImageStackSize() ];
// Executor service to produce concurrent threads
final ExecutorService exe = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
final ArrayList< Future<ClusteredWarpingMismatches> > futures = new ArrayList< Future<ClusteredWarpingMismatches> >();
try{
for(int i = 1; i <= sourceSlices.getSize(); i++)
{
futures.add(exe.submit( getClusteredWarpingMismatchesConcurrent(sourceSlices.getProcessor(i).convertToFloat(),
targetSlices.getProcessor(i).convertToFloat(),
null != maskSlices ? maskSlices.getProcessor(i) : null,
binaryThreshold, clusterByError, radius ) ) );
}
int i = 0;
// Wait for the jobs to be done
for(Future<ClusteredWarpingMismatches> f : futures)
{
cwm[ i ] = f.get();
i++;
}
}
catch(Exception ex)
{
IJ.log("Error when getting the clustered warping mismatches in a concurrent way.");
ex.printStackTrace();
}
finally{
exe.shutdown();
}
return cwm;
}
/**
* Calculate the simple point warping in a concurrent way
* (to be submitted to an Executor Service)
*
* @param source moving image
* @param target fixed image
* @param mask mask image
* @param binaryThreshold binary threshold to use
* @param clusterByError boolean flag to use clustering by error or only by type
* @param radius radius in pixels of the local area to look when deciding some cases (small radius speed up the method a lot, -1 to use whole image)
* @return clustered mismatching points after warping
*/
public Callable<ClusteredWarpingMismatches> getClusteredWarpingMismatchesConcurrent(
final ImageProcessor source,
final ImageProcessor target,
final ImageProcessor mask,
final double binaryThreshold,
final boolean clusterByError,
final int radius)
{
return new Callable<ClusteredWarpingMismatches>()
{
public ClusteredWarpingMismatches call()
{
WarpingResults wr = simplePointWarp2d(source, target, mask, binaryThreshold);
//wr.warpedSource.show();
int[] mismatchesLabels = classifyMismatches2d( wr.warpedSource, wr.mismatches, radius );
if( clusterByError )
return clusterMismatchesByError( wr.warpedSource, wr.mismatches, mismatchesLabels );
else
return clusterMismatchesByType( mismatchesLabels );
}
};
}
/**
* Calculate the simple point warping in a concurrent way
* (to be submitted to an Executor Service)
*
* @param source moving image
* @param target fixed image
* @param mask mask image
* @param binaryThreshold binary threshold to use
* @param clusterByError boolean flag to use clustering by error or only by type
* @param radius radius in pixels of the local area to look when deciding some cases (small radius speed up the method a lot, -1 to use whole image)
* @param flags flags indicating the type of errors to take into account
* @param calculateMismatchImage boolean flag to determine if the mismatches image should be calculated
* @return clustered mismatching points after warping
*/
public Callable<WarpingResults> getWarpingResultsConcurrent(
final ImageProcessor source,
final ImageProcessor target,
final ImageProcessor mask,
final double binaryThreshold,
final boolean clusterByError,
final int radius,
final int flags,
final boolean calculateMismatchImage)
{
return new Callable<WarpingResults>()
{
public WarpingResults call()
{
WarpingResults wr = simplePointWarp2d(source, target, mask, binaryThreshold);
//wr.warpedSource.show();
int[] mismatchesLabels = classifyMismatches2d( wr.warpedSource, wr.mismatches, radius );
if( calculateMismatchImage )
wr.classifiedMismatches = getMismatchImage( wr, mismatchesLabels, flags );
ClusteredWarpingMismatches cwm = null;
if( clusterByError )
cwm = clusterMismatchesByError( wr.warpedSource, wr.mismatches, mismatchesLabels );
else
cwm = clusterMismatchesByType( mismatchesLabels );
double error = 0;
double count = source.getWidth() * source.getHeight();
if( (flags & HOLE_ADDITION) != 0)
error += cwm.numOfHoleAdditions;
if( (flags & HOLE_DELETION) != 0)
error += cwm.numOfHoleDeletions;
if( (flags & MERGE) != 0)
error += cwm.numOfMergers;
if( (flags & OBJECT_ADDITION) != 0)
error += cwm.numOfObjectAdditions;
if( (flags & OBJECT_DELETION) != 0)
error += cwm.numOfObjectDeletions;
if( (flags & SPLIT) != 0)
error += cwm.numOfSplits;
wr.warpingError = error / count;
return wr;
}
};
}
/**
* Classify warping mismatches as MERGE, SPLIT, HOLE_ADDITION, HOLE_DELETION, OBJECT_ADDITION, OBJECT_DELETION
*
* @param warpedLabels labels after warping (binary image)
* @param mismatches list of mismatch points after warping
* @param radius radius in pixels of the local area to look when deciding some cases (small radius speed up the method a lot, -1 to use whole image)
* @return array of mismatch classifications
*/
public int[] classifyMismatches2d(
ImagePlus warpedLabels,
ArrayList<Point3f> mismatches,
int radius)
{
final int[] pointClassification = new int[ mismatches.size() ];
int radiusToUse = radius;
if(radius <1 || radius > warpedLabels.getWidth() || radius > warpedLabels.getHeight())
radiusToUse = -1;
// Calculate components in warped labels
ImageProcessor components = Utils.connectedComponents(
new ImagePlus("8-bit warped labels", warpedLabels.getProcessor().convertToByte(true)
), 4).allRegions.getProcessor();
int n = 0;
for(Point3f p : mismatches)
{
final int x = (int) p.x;
final int y = (int) p.y;
final ArrayList<Integer> neighborhood = getNeighborhood(components, new Point(x, y), 1, 1);
//IJ.log(" mismatch ("+ p.x + ", " + p.y + ")");
// Count number of unique IDs in the neighborhood
ArrayList<Integer> uniqueId = new ArrayList<Integer>();
for( Integer neighbor : neighborhood)
{
if(!uniqueId.contains( neighbor ))
uniqueId.add( neighbor );
}
// If all surrounding pixels are background
if( uniqueId.size() == 1 && uniqueId.get(0) == 0)
{
if(components.getPixel(x, y) != 0)
{
pointClassification[ n ] = OBJECT_DELETION;
//IJ.log(" all surrounding pixels are black and the point is white -> object deletion");
}
else
{
pointClassification[ n ] = OBJECT_ADDITION;
//IJ.log(" all surrounding pixels are black and the point is black -> object addition");
}
}
// If all surrounding pixels belong to one object
else if ( uniqueId.size() == 1 && uniqueId.get(0) != 0)
{
if(components.getPixel(x, y) != 0)
{
pointClassification[ n ] = HOLE_ADDITION;
//IJ.log(" all surrounding pixels are white and the point is white -> hole addition");
}
else
{
pointClassification[ n ] = HOLE_DELETION;
//IJ.log(" all surrounding pixels are white and the point is black -> hole deletion");
}
}
// If there are background and one single object ID in the surrounding pixels
else if ( uniqueId.size() == 2 )
{
// if the point is black, that's a hole addition error (flipping it to white would create a hole)
if (components.getPixel(x, y) == 0)
{
pointClassification[ n ] = HOLE_ADDITION;
//IJ.log(" surrounding pixels are white and black and the point is black -> hole addition");
}
else // if the point is white
{
// flip pixel and apply connected components again
ByteProcessor warpedPixels2;
warpedPixels2 = (ByteProcessor) warpedLabels.getProcessor().duplicate().convertToByte(true);
Point pixelOfInterest = new Point( x, y );
if (radiusToUse != -1)
{
warpedPixels2 = new ByteProcessor( 2*radiusToUse+1, 2*radiusToUse+1 );
for(int i = x-radiusToUse, l=0; i<=x+radiusToUse; i++, l++)
for(int j = y-radiusToUse, k=0; j<=y+radiusToUse; j++, k++)
warpedPixels2.set(l, k, warpedLabels.getProcessor().getPixel(i, j) == 0 ? 0 : 255);
pixelOfInterest = new Point( radiusToUse , radiusToUse );
}
// flip pixel
warpedPixels2.set( pixelOfInterest.x, pixelOfInterest.y, 0 );
// Calculate components in the new warped labels
ImageProcessor components2 = Utils.connectedComponents(new ImagePlus("8-bit warped labels", warpedPixels2), 4).allRegions.getProcessor();
//(new ImagePlus( "components", components2)).show();
final ArrayList<Integer> neighborhood2 = getNeighborhood(components2, pixelOfInterest, 1, 1);
// Count number of unique IDs in the neighborhood of the new components
ArrayList<Integer> uniqueId2 = new ArrayList<Integer>();
for( Integer neighbor : neighborhood2)
{
if(!uniqueId2.contains( neighbor ))
uniqueId2.add( neighbor );
}
// If there are more than 2 new components then it's a split
if ( uniqueId2.size() > 2 )
{
pointClassification[ n ] = SPLIT;
//IJ.log(" all surrounding pixels are white, the point is white and second CC has more than 2 objects -> split");
}
// otherwise it deletes a hole
else
{
pointClassification[ n ] = HOLE_DELETION;
//IJ.log(" all surrounding pixels are white, the point is white and second CC has 2 objects -> hole deletion");
}
}
}
else // If there are more than 1 object ID in the surrounding pixels
{
if(components.getPixel(x, y) == 0)
{
pointClassification[ n ] = MERGE;
//IJ.log(" surrounding pixels have at least 2 objects and the point is black -> merge");
}
else
{
pointClassification[ n ] = SPLIT;
//IJ.log(" surrounding pixels have at least 2 objects and the point is white -> split");
}
}
n++;
}
return pointClassification;
}
/**
* Classify warping mismatches as MERGE, SPLIT, HOLE_ADDITION,
* HOLE_DELETION, OBJECT_ADDITION, OBJECT_DELETION and count
* the number of false positives and false negatives
*
* @param warpedLabels labels after warping (binary image)
* @param mismatches list of mismatch points after warping
* @param falsePositives (output) number of false positives
* @param falseNegatives (output) number of false negatives
* @param flags
* @return array of mismatch classifications
*/
public int[] classifyMismatches2d(
ImagePlus warpedLabels,
ArrayList<Point3f> mismatches,
double falsePositives,
double falseNegatives,
int flags)
{
final int[] pointClassification = new int[ mismatches.size() ];
// Calculate components in warped labels
ImageProcessor components = Utils.connectedComponents(
new ImagePlus("8-bit warped labels", warpedLabels.getProcessor().convertToByte(true)
), 4).allRegions.getProcessor();
int n = 0;
for(Point3f p : mismatches)
{
final int x = (int) p.x;
final int y = (int) p.y;
final ArrayList<Integer> neighborhood = getNeighborhood(components, new Point(x, y), 1, 1);
// Count number of unique IDs in the neighborhood
ArrayList<Integer> uniqueId = new ArrayList<Integer>();
for( Integer neighbor : neighborhood)
{
if(!uniqueId.contains( neighbor ))
uniqueId.add( neighbor );
}
// If all surrounding pixels are background
if( uniqueId.size() == 1 && uniqueId.get(0) == 0)
{
if(components.getPixel(x, y) != 0)
{
pointClassification[ n ] = OBJECT_DELETION;
if( (flags & OBJECT_DELETION) != 0 )
falseNegatives ++;
}
else
{
pointClassification[ n ] = OBJECT_ADDITION;
if( (flags & OBJECT_ADDITION) != 0 )
falsePositives ++;
}
}
// If all surrounding pixels belong to one object
else if ( uniqueId.size() == 1 && uniqueId.get(0) != 0)
{
if(components.getPixel(x, y) != 0)
{
pointClassification[ n ] = HOLE_ADDITION;
if( (flags & HOLE_ADDITION) != 0 )
falseNegatives ++;
}
else
{
pointClassification[ n ] = HOLE_DELETION;
if( (flags & HOLE_DELETION) != 0 )
falsePositives ++;
}
}
// If there are background and one single object ID in the surrounding pixels
else if ( uniqueId.size() == 2 )
{
if (components.getPixel(x, y) == 0)
{
pointClassification[ n ] = HOLE_ADDITION;
if( (flags & HOLE_ADDITION) != 0 )
falsePositives ++;
}
else
{
// flip pixel and apply connected components again
final ByteProcessor warpedPixels2 = (ByteProcessor) warpedLabels.getProcessor().duplicate().convertToByte(true);
warpedPixels2.set( x, y, warpedPixels2.get(x, y) != 0 ? 0 : 255);
// Calculate components in the new warped labels
ImageProcessor components2 = Utils.connectedComponents(new ImagePlus("8-bit warped labesl", warpedPixels2), 4).allRegions.getProcessor();
final ArrayList<Integer> neighborhood2 = getNeighborhood(components2, new Point(x, y), 1, 1);
// Count number of unique IDs in the neighborhood of the new components
ArrayList<Integer> uniqueId2 = new ArrayList<Integer>();
for( Integer neighbor : neighborhood2)
{
if(!uniqueId2.contains( neighbor ))
uniqueId2.add( neighbor );
}
// If there are more than 2 new components then it's a split
if ( uniqueId2.size() > 2 )
{
pointClassification[ n ] = SPLIT;
if( (flags & SPLIT) != 0 )
falseNegatives ++;
}
// otherwise it deletes a hole
else
{
pointClassification[ n ] = HOLE_DELETION;
if( (flags & HOLE_DELETION) != 0 )
falseNegatives ++;
}
}
}
else // If there are more than 1 object ID in the surrounding pixels
{
if(components.getPixel(x, y) == 0)
{
pointClassification[ n ] = MERGE;
if( (flags & MERGE) != 0 )
falsePositives ++;
}
else
{
pointClassification[ n ] = SPLIT;
if( (flags & SPLIT) != 0 )
falseNegatives ++;
}
}
n++;
}
return pointClassification;
}
/**
* Cluster the result mismatches from the warping so pixels
* belonging to the same error are only counted once.
*
* @param warpedLabels result warped labels
* @param mismatches list of non simple points
* @param mismatchClassification array of classified mismatches
* @return number of warping mismatches after clustering by error
*/
public ClusteredWarpingMismatches clusterMismatchesByError(
ImagePlus warpedLabels,
ArrayList<Point3f> mismatches,
int [] mismatchClassification)
{
// Create the 8 possible cases out of the mismatches
// 0: object addition, 1: hole deletion with an isolated background pixel
// 2: merger, 3: hole creation by removing a background pixel
// 4: delete object, 5: hole creation by adding a background pixel
// 6: split ,7: hole deletion by removing a foreground pixel
ByteProcessor[] binaryMismatches = new ByteProcessor[ 8 ];
final int width = warpedLabels.getWidth();
final int height = warpedLabels.getHeight();
for(int i=0; i<8; i++)
binaryMismatches[ i ] = new ByteProcessor(width, height);
// corresponding connectivity for each case (to run connected components)
final int[] connectivity = new int[]{4, 4, 8, 4, 4, 8, 4, 4};
for(int i=0 ; i < mismatchClassification.length; i++)
{
final int x = (int) mismatches.get( i ).x;
final int y = (int) mismatches.get( i ).y;
switch( mismatchClassification[ i ])
{
case OBJECT_ADDITION:
binaryMismatches[ 0 ].set(x, y, 255);
break;
case HOLE_DELETION:
if( warpedLabels.getProcessor().getPixel(x, y) == 0)
binaryMismatches[ 1 ].set(x, y, 255);
else
binaryMismatches[ 7 ].set(x, y, 255);
break;
case MERGE:
binaryMismatches[ 2 ].set(x, y, 255);
break;
case HOLE_ADDITION:
if( warpedLabels.getProcessor().getPixel(x, y) == 0)
binaryMismatches[ 3 ].set(x, y, 255);
else
binaryMismatches[ 5 ].set(x, y, 255);
break;
case OBJECT_DELETION:
binaryMismatches[ 4 ].set(x, y, 255);
break;
case SPLIT:
binaryMismatches[ 6 ].set(x, y, 255);
break;
default:
}
}
// run connected components on each case
int[] componentsPerCase = new int[8];
for(int i=0; i<8; i++)
{
ImagePlus im = new ImagePlus("components case " + i, binaryMismatches[ i ]);
//im.show();
componentsPerCase[i] = Utils.connectedComponents( im, connectivity[ i ]).regionInfo.size();
}
return new ClusteredWarpingMismatches(componentsPerCase[ 0 ],
componentsPerCase[ 1 ] + componentsPerCase[ 7 ],
componentsPerCase[ 2 ],
componentsPerCase[ 3 ] + componentsPerCase[ 5 ],
componentsPerCase[ 4 ],
componentsPerCase[ 6 ]);
}
/**
* Cluster the result mismatches from the warping
* by types of errors.
*
* @param mismatchClassification array of classified mismatches
* @return number of warping mismatches after clustering by type
*/
public ClusteredWarpingMismatches clusterMismatchesByType(
int [] mismatchClassification)
{
// Create the 8 possible cases out of the mismatches
// 0: object addition, 1: hole deletion with an isolated background pixel
// 2: merger, 3: hole creation by removing a background pixel
// 4: delete object, 5: hole creation by adding a background pixel
// 6: split, 7: hole deletion by removing a foreground pixel
int numOfObjectAdditions = 0;
int numOfHoleDeletions = 0;
int numOfMergers = 0;
int numOfHoleAdditions = 0;
int numOfObjectDeletions = 0;
int numOfSplits = 0;
for(int i=0 ; i < mismatchClassification.length; i++)
{
switch( mismatchClassification[ i ])
{
case OBJECT_ADDITION:
numOfObjectAdditions ++;
break;
case HOLE_DELETION:
numOfHoleDeletions ++;
break;
case MERGE:
numOfMergers ++;
break;
case HOLE_ADDITION:
numOfHoleAdditions ++;
break;
case OBJECT_DELETION:
numOfObjectDeletions ++;
break;
case SPLIT:
numOfSplits ++;
break;
default:
IJ.log("Unrecognized mismatch classification!");
}
}
return new ClusteredWarpingMismatches(numOfObjectAdditions,
numOfHoleDeletions, numOfMergers,
numOfHoleAdditions, numOfObjectDeletions,
numOfSplits);
}
/**
* Get neighborhood of a pixel in a 2D image
*
* @param image 2D image
* @param p point coordinates
* @param x_offset x- neighborhood offset
* @param y_offset y- neighborhood offset
* @return corresponding neighborhood
*/
public ArrayList<Integer> getNeighborhood(
final ImageProcessor image,
final Point p,
final int x_offset,
final int y_offset)
{
final ArrayList<Integer> neighborhood = new ArrayList<Integer>();
for(int j = p.y - y_offset; j <= p.y + y_offset; j++)
for(int i = p.x - x_offset; i <= p.x + x_offset; i++)
{
if(i!=p.x || j!= p.y)
if(j>=0 && j<image.getHeight() && i>=0 && i<image.getWidth())
neighborhood.add( image.get(i, j));
}
return neighborhood;
} // end getNeighborhood
/**
* Calculate the number of cavities of a 3D neighborhood
*
* @param input 3D neighborhood
* @param con connectivity (6, 18 or 26)
* @param space
* @return number of cavities of the 3D neighborhood
*/
int nca(float[] input, int con, int space)
{
int tsum;
switch (con)
{
case 6:
tsum=((int)input[4] + (int)input[10]+(int)input[12]+(int)input[14]+(int)input[16] + (int)input[22]);
return (space==1)?(6*space - tsum):(tsum);
case 18:
tsum=((int)input[1]+(int)input[3]+(int)input[4]+(int)input[5]+(int)input[7] + (int)input[9]+(int)input[10]+(int)input[11]+(int)input[12]+(int)input[14]+(int)input[15]+(int)input[16]+(int)input[17] + (int)input[19]+(int)input[21]+(int)input[22]+(int)input[23]+(int)input[25]);
return (space==1)?(18*space - tsum):(tsum);
case 26:
tsum=((int)input[0]+(int)input[1]+(int)input[2]+(int)input[3]+(int)input[4]+(int)input[5]+(int)input[6]+(int)input[7]+(int)input[8] + (int)input[9]+(int)input[10]+(int)input[11]+(int)input[12]+(int)input[14]+(int)input[15]+(int)input[16]+(int)input[17] + (int)input[18]+(int)input[19]+(int)input[20]+(int)input[21]+(int)input[22]+(int)input[23]+(int)input[24]+(int)input[25]+(int)input[26]);
return (space==1)?(26*space - tsum):(tsum);
default:
return 0;
}
}
/**
*
* @param input 3D neighborhood
* @param ctyp
* @param con connectivity (6, 18 or 26)
* @param space
* @return
*/
int ncb(float[] input, char ctyp, int con, int space)
{
int tsum;
final int[][][] a6m = new int[][][]{{{0,1,0}, {1,5,1}, {0,1,0}},
{{0,0,0}, {0,0,0}, {0,0,0}},
{{0,0,0}, {0,0,0}, {0,0,0}}};
final int[][][] a18m = new int[][][]{{{0,0,0}, {0,1,1}, {0,0,0}},
{{0,0,0}, {0,0,1}, {0,0,0}},
{{0,0,0}, {0,0,0}, {0,0,0}}};
final int[][][] a26m = new int[][][]{{{0,0,0}, {0,1,1}, {0,1,1}},
{{0,0,0}, {0,0,1}, {0,1,1}},
{{0,0,0}, {0,0,0}, {0,0,0}}};
final int[][][] b18m = new int[][][]{{{0,1,0}, {0,1,9}, {0,1,0}},
{{0,1,1}, {0,0,1}, {0,1,1}},
{{0,0,0}, {0,0,0}, {0,0,0}}};
final int[][][] b26m = new int[][][]{{{0,0,0}, {0,1,1}, {0,1,7}},
{{0,0,0}, {0,0,1}, {0,1,1}},
{{0,0,0}, {0,0,0}, {0,0,0}}};
int[] tsuma = new int[12];
int x, y, z, i;
if(ctyp=='a'){
switch (con)
{
case 6:
for(x=0;x<3;x++){
for(y=0;y<3;y++){
for(z=0;z<3;z++){
tsuma[0]+=a6m[x][y][z]*input[x+y*3+z*3*3];
tsuma[1]+=a6m[2-x][y][z]*input[x+y*3+z*3*3];
tsuma[2]+=a6m[y][x][z]*input[x+y*3+z*3*3];
tsuma[3]+=a6m[2-y][x][z]*input[x+y*3+z*3*3];
tsuma[4]+=a6m[z][y][x]*input[x+y*3+z*3*3];
tsuma[5]+=a6m[2-z][y][x]*input[x+y*3+z*3*3];
}
}
}
tsum=0;
for(i=0;i<6;i++)
tsum += (tsuma[i]==(5-space))?1:0;
return tsum;
case 18:
for(x=0;x<3;x++){
for(y=0;y<3;y++){
for(z=0;z<3;z++){
tsuma[0]+=a18m[x][y][z]*input[x+y*3+z*3*3];
tsuma[1]+=a18m[x][y][2-z]*input[x+y*3+z*3*3];
tsuma[2]+=a18m[2-x][y][z]*input[x+y*3+z*3*3];
tsuma[3]+=a18m[2-x][y][2-z]*input[x+y*3+z*3*3];
tsuma[4]+=a18m[y][x][z]*input[x+y*3+z*3*3];
tsuma[5]+=a18m[y][x][2-z]*input[x+y*3+z*3*3];
tsuma[6]+=a18m[2-y][x][z]*input[x+y*3+z*3*3];
tsuma[7]+=a18m[2-y][x][2-z]*input[x+y*3+z*3*3];
tsuma[8]+=a18m[x][z][y]*input[x+y*3+z*3*3];
tsuma[9]+=a18m[x][z][2-y]*input[x+y*3+z*3*3];
tsuma[10]+=a18m[2-x][z][y]*input[x+y*3+z*3*3];
tsuma[11]+=a18m[2-x][z][2-y]*input[x+y*3+z*3*3];
}
}
}
tsum=0;
for(i=0;i<12;i++){
tsum += (tsuma[i]==(3-3*space))?1:0;
}
return tsum;
case 26:
for(x=0;x<3;x++){
for(y=0;y<3;y++){
for(z=0;z<3;z++){
tsuma[0]+=a26m[x][y][z]*input[x+y*3+z*3*3];
tsuma[1]+=a26m[2-x][y][z]*input[x+y*3+z*3*3];
tsuma[2]+=a26m[x][2-y][z]*input[x+y*3+z*3*3];
tsuma[3]+=a26m[x][y][2-z]*input[x+y*3+z*3*3];
tsuma[4]+=a26m[2-x][2-y][z]*input[x+y*3+z*3*3];
tsuma[5]+=a26m[x][2-y][2-z]*input[x+y*3+z*3*3];
tsuma[6]+=a26m[2-x][y][2-z]*input[x+y*3+z*3*3];
tsuma[7]+=a26m[2-x][2-y][2-z]*input[x+y*3+z*3*3];
}
}
}
tsum=0;
for(i=0;i<8;i++){
tsum += (tsuma[i]==(7-7*space))?1:0;
}
return tsum;
default:
return 0;
}
}else if(ctyp=='b'){
switch (con)
{
case 18:
for(x=0;x<3;x++){
for(y=0;y<3;y++){
for(z=0;z<3;z++){
tsuma[0]+=b18m[x][y][z]*input[x+y*3+z*3*3];
tsuma[1]+=b18m[x][y][2-z]*input[x+y*3+z*3*3];
tsuma[2]+=b18m[2-x][y][z]*input[x+y*3+z*3*3];
tsuma[3]+=b18m[2-x][y][2-z]*input[x+y*3+z*3*3];
tsuma[4]+=b18m[y][x][z]*input[x+y*3+z*3*3];
tsuma[5]+=b18m[y][x][2-z]*input[x+y*3+z*3*3];
tsuma[6]+=b18m[2-y][x][z]*input[x+y*3+z*3*3];
tsuma[7]+=b18m[2-y][x][2-z]*input[x+y*3+z*3*3];
tsuma[8]+=b18m[x][z][y]*input[x+y*3+z*3*3];
tsuma[9]+=b18m[x][z][2-y]*input[x+y*3+z*3*3];
tsuma[10]+=b18m[2-x][z][y]*input[x+y*3+z*3*3];
tsuma[11]+=b18m[2-x][z][2-y]*input[x+y*3+z*3*3];
}
}
}
tsum=0;
for(i=0;i<12;i++){
tsum += (tsuma[i]==(9-space))?1:0;
}
return tsum;
case 26:
for(x=0;x<3;x++){
for(y=0;y<3;y++){
for(z=0;z<3;z++){
tsuma[0]+=b26m[x][y][z]*input[x+y*3+z*3*3];
tsuma[1]+=b26m[2-x][y][z]*input[x+y*3+z*3*3];
tsuma[2]+=b26m[x][2-y][z]*input[x+y*3+z*3*3];
tsuma[3]+=b26m[x][y][2-z]*input[x+y*3+z*3*3];
tsuma[4]+=b26m[2-x][2-y][z]*input[x+y*3+z*3*3];
tsuma[5]+=b26m[x][2-y][2-z]*input[x+y*3+z*3*3];
tsuma[6]+=b26m[2-x][y][2-z]*input[x+y*3+z*3*3];
tsuma[7]+=b26m[2-x][2-y][2-z]*input[x+y*3+z*3*3];
}
}
}
tsum=0;
for(i=0;i<8;i++){
tsum += (tsuma[i]==7-space)?1:0;
}
return tsum;
default:
return 0;
}
}
else
return 0;
}
/**
* Calculate if a point is simple in 3D
*
* @param input 3D neighborhood (27 pixels) in a single array
* @param region adjacency (26 or 6)
* @return true if the point is simple
*/
boolean simple3d(float[] input, int region)
{
boolean simple = false;
if(region==26)
{
simple=false;
if( nca(input, 6, 1)==1 ){
simple=true;
}else if(nca(input, 26, 0)==1 ){
simple=true;
}else if( ncb(input, 'b', 26, 0)==0 ){
if( nca(input, 18, 0)==1 ){
simple=true;
}else if( (ncb(input, 'a', 6, 1)==0) && (ncb(input, 'b', 18, 0)==0) && ((nca(input, 6, 1)-ncb(input, 'a', 18, 1)+ncb(input, 'a', 26, 1))==1) ){
simple=true;
}
}
}
else if(region==6)
{
int i;
float[] input2 = new float[27];
for(i=0;i<27;i++){
input2[i] = input[i] == 1.0f ? 0.0f : 1.0f;
}
simple=false;
if( nca(input2, 6, 1)==1 ){
simple=true;
}else if(nca(input2, 26, 0)==1 ){
simple=true;
}else if( ncb(input2, 'b', 26, 0)==0 ){
if( nca(input2, 18, 0)==1 ){
simple=true;
}else if( (ncb(input2, 'a', 6, 1)==0) && (ncb(input2, 'b', 18, 0)==0) && ((nca(input2, 6, 1)-ncb(input2, 'a', 18, 1)+ncb(input2, 'a', 26, 1))==1) ){
simple=true;
}
}
}
return simple;
}
/**
* Main method for calcualte the warping error metrics
* from the command line
*
* @param args arguments to decide the action
*/
public static void main(String args[])
{
if (args.length<1)
{
dumpSyntax();
System.exit(1);
}
else
{
if( args[0].equals("-help") )
dumpSyntax();
else if (args[0].equals("-splitsAndMergers"))
System.out.println( splitsAndMergersCommandLine(args) );
else
dumpSyntax();
}
System.exit(0);
}
/**
* Calculate the best splits and mergers ratio based on the
* parameters introduced by command line
*
* @param args command line arguments
* @return warping error with minimum splits and mergers ratio
*/
static double splitsAndMergersCommandLine(String[] args)
{
if (args.length != 8)
{
dumpSyntax();
return -1;
}
final ImagePlus label = new ImagePlus( args[ 1 ] );
final ImagePlus proposal = new ImagePlus( args[ 2 ] );
final double minThreshold = Double.parseDouble( args[ 3 ] );
final double maxThreshold = Double.parseDouble( args[ 4 ] );
final double stepThreshold = Double.parseDouble( args[ 5 ] );
final boolean clusterByError = Boolean.parseBoolean( args[ 6 ]);
final int radius = Integer.parseInt( args[ 7 ]);
WarpingError we = new WarpingError(label, proposal);
we.setVerboseMode( false );
return we.getMinimumSplitsAndMergersErrorValue(minThreshold, maxThreshold, stepThreshold, clusterByError, radius );
}
/**
* Set verbose mode
* @param verbose true to display more information in the standard output
*/
public void setVerboseMode(boolean verbose)
{
this.verbose = verbose;
}
/**
* Method to write the syntax of the program in the command line.
*/
private static void dumpSyntax ()
{
System.out.println("Purpose: calculate warping error between proposed and original labels.\n");
System.out.println("Usage: WarpingError ");
System.out.println(" -help : show this message");
System.out.println("");
System.out.println(" -splitsAndMergers : calculate the splits and mergers ratio over a set of thresholds");
System.out.println(" labels : image with the original labels");
System.out.println(" proposal : image with the proposed labels");
System.out.println(" minThreshold : minimum threshold value to binarize the proposal");
System.out.println(" maxThreshold : maximum threshold value to binarize the proposal");
System.out.println(" stepThreshold : threshold step value to use during binarization");
System.out.println(" clusterMistakes : boolean flag to cluster or not the mistakes by type of error");
System.out.println(" radius : radius of the search neighborhood to decide simple points classification\n");
System.out.println("Examples:");
System.out.println("Calculate the splits and mergers ratio between proposed and original labels over a set of");
System.out.println("thresholds (from 0.0 to 1.0 in steps of 0.1) without clustering the mistakes and using a \n" +
"radius of 20 pixels:");
System.out.println(" WarpingError -splitsAndMergers original-labels.tif proposed-labels.tif 0.0 1.0 0.1 false 20");
}
/**
* Calculate warping error and return the related result images and values.
*
* @param binaryThreshold threshold value to binarize proposal (larger than 0 and smaller than 1)
* @param clusterByError if false, cluster topology errors by type, otherwise cluster by type and mistake
* @param calculateMismatchImage flag to calculate mismatch image
* @param radius radius in pixels to use when classifiying mismatches
* @return total warping error (it counts all type of mismatches as errors)
*/
public WarpingResults getWarpingResults(
double binaryThreshold,
boolean clusterByError,
boolean calculateMismatchImage,
int radius )
{
if( verbose )
IJ.log(" Warping ground truth...");
// Warp ground truth, relax original labels to proposal. Only simple
// points warping is allowed.
WarpingResults[] wrs = simplePointWarp2dMT( binaryThreshold, clusterByError, calculateMismatchImage, radius );
if(null == wrs)
return null;
WarpingResults result = new WarpingResults();
result.warpingError = 0;
ImageStack is = new ImageStack( originalLabels.getWidth(), originalLabels.getHeight() );
ImageStack is2 = calculateMismatchImage ? new ImageStack( originalLabels.getWidth(), originalLabels.getHeight()) : null;
for(int i = 0; i < wrs.length; i ++)
{
result.warpingError += wrs[ i ].warpingError;
is.addSlice("warped source slice " + (i+1), wrs[i].warpedSource.getProcessor() );
if( calculateMismatchImage )
is2.addSlice("Mismatches slice " + (i+1), wrs[i].classifiedMismatches.getProcessor() );
}
result.warpedSource = new ImagePlus ("warped source", is);
if( calculateMismatchImage )
result.classifiedMismatches = new ImagePlus( "Classified mismatches", is2);
if(wrs.length != 0)
result.warpingError /= wrs.length;
return result;
}
} // end class WarpingError
package trainableSegmentation.metrics;
/**
*
* License: GPL
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License 2
* as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*
* Authors: Ignacio Arganda-Carreras (iarganda@mit.edu)
*/
import ij.ImagePlus;
import java.util.ArrayList;
import javax.vecmath.Point3f;
/**
* Results from simple point warping (2D)
*
*/
public class WarpingResults{
/** warped source image after 2D simple point relaxation */
public ImagePlus warpedSource;
/** warping error */
public double warpingError;
public ArrayList<Point3f> mismatches;
public ImagePlus classifiedMismatches = null;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment