Skip to content

Instantly share code, notes, and snippets.

@dorsev
Last active September 19, 2018 21:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dorsev/0f571579f20e668d78f20cf0cc022905 to your computer and use it in GitHub Desktop.
Save dorsev/0f571579f20e668d78f20cf0cc022905 to your computer and use it in GitHub Desktop.
Confusion matrix trait and implementation with some sample app written using java 10 features
interface ConfusionMatrixInterface {
Stream<String> getClassNames();
int truePositive(String className);
int trueNegative(String className);
int falsePositive(String className);
int falseNegative(String className);
int sumValues(String className);
default double accurecy() {
return getClassNames().mapToDouble(this::truePositive).sum()
/ getClassNames().mapToInt(this::sumValues).sum();
}
default double recall(String className) {
if (truePositive(className) + falseNegative(className) != 0) {
return truePositive(className) * 1.0 / (truePositive(className) + falseNegative(className));
} else return 0;
}
default double precision(String className) {
if (truePositive(className) + falsePositive(className) != 0) {
return truePositive(className) * 1.0 / (truePositive(className) + falsePositive(className));
} else return 0;
}
default double fScore(String className) {
return this.fScore(className, 1.0);
}
/**
* if weight = 1, this is just the harmonic mean between precision and recall
* if weight>1, you assign more value to recall-(ie: weight=2.0).
* if weight<1, you assign more value to precision-(ie: weight=0.5)
* <p>
* weight should be > 0
*/
default double fScore(String className, double weight) {
return ((1 + weight * weight) * precision(className) * recall(className)) /
((weight * weight * precision(className)) + recall(className));
}
}
class ConfusionMatrixResultsAnalyzer implements ConfusionMatrixInterface {
private Map<String, Map<String, Integer>> classificationTable;
ConfusionMatrixResultsAnalyzer(Map<String, Map<String, Integer>> resultSet) {
this.classificationTable = resultSet;
}
@Override
public Stream<String> getClassNames() {
return classificationTable.keySet().stream();
}
@Override
public int truePositive(String className) {
return classificationTable.get(className).get(className);
}
@Override
public int trueNegative(String className) {
return classificationTable.entrySet().stream().filter(keyAndValue-> !keyAndValue.getKey().equalsIgnoreCase(className))
.mapToInt(keyAndValue-> keyAndValue.getValue().entrySet().
stream().filter(key-> !key.getKey().equalsIgnoreCase(className)).
mapToInt(Map.Entry::getValue).sum())
.sum();
}
@Override
public int falsePositive(String className) {
return classificationTable.entrySet().stream().
filter(x-> !x.getKey().equals(className)).
mapToInt(x-> x.getValue().get(className)).sum();
}
@Override
public int falseNegative(String className) {
return classificationTable.get(className).entrySet().stream().filter(keyAndValue-> !keyAndValue.getKey().equalsIgnoreCase(className))
.mapToInt(Map.Entry::getValue)
.sum();
}
@Override
public int sumValues(String className) {
return classificationTable.get(className).entrySet().stream().mapToInt(Map.Entry::getValue).sum();
}
@Override
public String toString() {
return getClassNames().map(className-> String.format("precision for class %s is %s, recall : %s", className, precision(className), recall(className)))
.collect(Collectors.joining(System.lineSeparator()));
}
}
package confusionMatrixExamples;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class Main {
public static void main(String[] args) {
//lets assume our classifier classifies 3 categories. A,B,C.
//and the following classificationTable mean the following:
//classificationTable.get("A") =>
// all items in this category were expected to be "A". in a perfect classifier.
// meaning each item in this group was manually labeled as "A" at first, and the map inside classificationTable.get("A") is the classifier classificationTable of that group.
//classificationTable.get("A").get("A") ==> all "A' that were correctly classified as "A". aka TP of this class
//classificationTable.get("A").get("B"|"C") ==> all the A's that got classified as "B" or "C". by mistake of course.
//please notice, this classier does not assign every value a result and therefore all unknown values are entered into "Unknown" group
//if your classifier always outputs a classification, remove this group
var classificationResultsPerfect = Map.of("A", Map.of("A", 1, "B", 0, "C", 0, "Unknown", 0),
"B", Map.of("A", 0, "B", 2, "C", 0, "Unknown", 0),
"C", Map.of("A", 0, "B", 0, "C", 3, "Unknown", 0));
//for example in this classificationTable, our classifier did amazing job and we expect precision of 100 percent.
var perfectPrecisionAndRecall = new ConfusionMatrixResultsAnalyzer(classificationResultsPerfect);
System.out.println(String.format("perfect finding %s", perfectPrecisionAndRecall));
var allClassifiedAsA = Map.of("A", Map.of("A", 2, "B", 0, "C", 0, "Unknown", 0),
"B", Map.of("A", 2, "B", 0, "C", 0, "Unknown", 0),
"C", Map.of("A", 2, "B", 0, "C", 0, "Unknown", 0));
var allClassifiedAsAResults = new ConfusionMatrixResultsAnalyzer(allClassifiedAsA);
System.out.println(String.format("all classified as A findings %s", allClassifiedAsAResults));
var aNormalResultSet = Map.of("A", Map.of("A", 1, "B", 0, "C", 0, "Unknown", 2),
"B", Map.of("A", 0, "B", 1, "C", 0, "Unknown", 1),
"C", Map.of("A", 0, "B", 0, "C", 1, "Unknown", 1));
var aNormalResultSetResults = new ConfusionMatrixResultsAnalyzer(aNormalResultSet);
System.out.println(String.format("a normal classificationTable set findings %s", aNormalResultSetResults));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment