Last active
September 19, 2018 21:57
-
-
Save dorsev/0f571579f20e668d78f20cf0cc022905 to your computer and use it in GitHub Desktop.
Confusion matrix trait and implementation with some sample app written using java 10 features
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
interface ConfusionMatrixInterface { | |
Stream<String> getClassNames(); | |
int truePositive(String className); | |
int trueNegative(String className); | |
int falsePositive(String className); | |
int falseNegative(String className); | |
int sumValues(String className); | |
default double accurecy() { | |
return getClassNames().mapToDouble(this::truePositive).sum() | |
/ getClassNames().mapToInt(this::sumValues).sum(); | |
} | |
default double recall(String className) { | |
if (truePositive(className) + falseNegative(className) != 0) { | |
return truePositive(className) * 1.0 / (truePositive(className) + falseNegative(className)); | |
} else return 0; | |
} | |
default double precision(String className) { | |
if (truePositive(className) + falsePositive(className) != 0) { | |
return truePositive(className) * 1.0 / (truePositive(className) + falsePositive(className)); | |
} else return 0; | |
} | |
default double fScore(String className) { | |
return this.fScore(className, 1.0); | |
} | |
/** | |
* if weight = 1, this is just the harmonic mean between precision and recall | |
* if weight>1, you assign more value to recall-(ie: weight=2.0). | |
* if weight<1, you assign more value to precision-(ie: weight=0.5) | |
* <p> | |
* weight should be > 0 | |
*/ | |
default double fScore(String className, double weight) { | |
return ((1 + weight * weight) * precision(className) * recall(className)) / | |
((weight * weight * precision(className)) + recall(className)); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ConfusionMatrixResultsAnalyzer implements ConfusionMatrixInterface { | |
private Map<String, Map<String, Integer>> classificationTable; | |
ConfusionMatrixResultsAnalyzer(Map<String, Map<String, Integer>> resultSet) { | |
this.classificationTable = resultSet; | |
} | |
@Override | |
public Stream<String> getClassNames() { | |
return classificationTable.keySet().stream(); | |
} | |
@Override | |
public int truePositive(String className) { | |
return classificationTable.get(className).get(className); | |
} | |
@Override | |
public int trueNegative(String className) { | |
return classificationTable.entrySet().stream().filter(keyAndValue-> !keyAndValue.getKey().equalsIgnoreCase(className)) | |
.mapToInt(keyAndValue-> keyAndValue.getValue().entrySet(). | |
stream().filter(key-> !key.getKey().equalsIgnoreCase(className)). | |
mapToInt(Map.Entry::getValue).sum()) | |
.sum(); | |
} | |
@Override | |
public int falsePositive(String className) { | |
return classificationTable.entrySet().stream(). | |
filter(x-> !x.getKey().equals(className)). | |
mapToInt(x-> x.getValue().get(className)).sum(); | |
} | |
@Override | |
public int falseNegative(String className) { | |
return classificationTable.get(className).entrySet().stream().filter(keyAndValue-> !keyAndValue.getKey().equalsIgnoreCase(className)) | |
.mapToInt(Map.Entry::getValue) | |
.sum(); | |
} | |
@Override | |
public int sumValues(String className) { | |
return classificationTable.get(className).entrySet().stream().mapToInt(Map.Entry::getValue).sum(); | |
} | |
@Override | |
public String toString() { | |
return getClassNames().map(className-> String.format("precision for class %s is %s, recall : %s", className, precision(className), recall(className))) | |
.collect(Collectors.joining(System.lineSeparator())); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package confusionMatrixExamples; | |
import java.util.Map; | |
import java.util.stream.Collectors; | |
import java.util.stream.Stream; | |
public class Main { | |
public static void main(String[] args) { | |
//lets assume our classifier classifies 3 categories. A,B,C. | |
//and the following classificationTable mean the following: | |
//classificationTable.get("A") => | |
// all items in this category were expected to be "A". in a perfect classifier. | |
// meaning each item in this group was manually labeled as "A" at first, and the map inside classificationTable.get("A") is the classifier classificationTable of that group. | |
//classificationTable.get("A").get("A") ==> all "A' that were correctly classified as "A". aka TP of this class | |
//classificationTable.get("A").get("B"|"C") ==> all the A's that got classified as "B" or "C". by mistake of course. | |
//please notice, this classier does not assign every value a result and therefore all unknown values are entered into "Unknown" group | |
//if your classifier always outputs a classification, remove this group | |
var classificationResultsPerfect = Map.of("A", Map.of("A", 1, "B", 0, "C", 0, "Unknown", 0), | |
"B", Map.of("A", 0, "B", 2, "C", 0, "Unknown", 0), | |
"C", Map.of("A", 0, "B", 0, "C", 3, "Unknown", 0)); | |
//for example in this classificationTable, our classifier did amazing job and we expect precision of 100 percent. | |
var perfectPrecisionAndRecall = new ConfusionMatrixResultsAnalyzer(classificationResultsPerfect); | |
System.out.println(String.format("perfect finding %s", perfectPrecisionAndRecall)); | |
var allClassifiedAsA = Map.of("A", Map.of("A", 2, "B", 0, "C", 0, "Unknown", 0), | |
"B", Map.of("A", 2, "B", 0, "C", 0, "Unknown", 0), | |
"C", Map.of("A", 2, "B", 0, "C", 0, "Unknown", 0)); | |
var allClassifiedAsAResults = new ConfusionMatrixResultsAnalyzer(allClassifiedAsA); | |
System.out.println(String.format("all classified as A findings %s", allClassifiedAsAResults)); | |
var aNormalResultSet = Map.of("A", Map.of("A", 1, "B", 0, "C", 0, "Unknown", 2), | |
"B", Map.of("A", 0, "B", 1, "C", 0, "Unknown", 1), | |
"C", Map.of("A", 0, "B", 0, "C", 1, "Unknown", 1)); | |
var aNormalResultSetResults = new ConfusionMatrixResultsAnalyzer(aNormalResultSet); | |
System.out.println(String.format("a normal classificationTable set findings %s", aNormalResultSetResults)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment