Skip to content

Instantly share code, notes, and snippets.

@wendykan
Created July 10, 2018 01:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wendykan/3ef30f74355f7da5ea9dfabcee4c13da to your computer and use it in GitHub Desktop.
Save wendykan/3ef30f74355f7da5ea9dfabcee4c13da to your computer and use it in GitHub Desktop.
Production C# code for OpenImagesVisualRelations metric: https://www.kaggle.com/c/google-ai-open-images-visual-relationship-track#Evaluation
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
namespace Kaggle.Metrics.Utilities
{
public static class OpenImagesUtil
{
const double EPSILON = 1e-7;
public static double IntersectionOverUnion(OpenImagesLabelBox b1, OpenImagesLabelBox b2)
{
double left, right, bottom, top;
getTwoBoxesIntersectionBoundaries(b1, b2, out left, out right, out bottom, out top);
if (left < right && bottom < top)
{
var intersection = (right - left) * (top - bottom);
var b1Area = (b1.xMax - b1.xMin) * (b1.yMax - b1.yMin);
var b2Area = (b2.xMax - b2.xMin) * (b2.yMax - b2.yMin);
return (double)intersection / (b1Area + b2Area - intersection + EPSILON);
}
return 0.0; // if no overlap, return 0
}
public static double IntersectionOverArea(OpenImagesLabelBox b1, OpenImagesLabelBox gob)
{
double left, right, bottom, top;
getTwoBoxesIntersectionBoundaries(b1, gob, out left, out right, out bottom, out top);
if (left < right && bottom < top)
{
var intersection = (right - left) * (top - bottom);
var b1Area = (b1.xMax - b1.xMin) * (b1.yMax - b1.yMin);
return (double)intersection / (b1Area + EPSILON);
}
return 0.0; // if no overlap, return 0
}
// find the IoU between two convex boxes
// first, get the union (convex box) for triplet1
// then get the union (convex box) for triplet2
// then calculate the IoU between those convex boxes
public static double IntersectionOverUnionConvex(OpenImagesRelationTriplet t1, OpenImagesRelationTriplet t2)
{
getTwoBoxesConvexBoundaries(t1.BoxA, t1.BoxB, out double left1, out double right1, out double bottom1, out double top1);
getTwoBoxesConvexBoundaries(t2.BoxA, t2.BoxB, out double left2, out double right2, out double bottom2, out double top2);
var b1 = new OpenImagesLabelBox { xMin = left1, xMax = right1, yMin = bottom1, yMax = top1 };
var b2 = new OpenImagesLabelBox { xMin = left2, xMax = right2, yMin = bottom2, yMax = top2 };
return IntersectionOverUnion(b1,b2); // if no overlap, return 0
}
// interpolated AP
// original PASCAL VOC paper http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.157.5766&rep=rep1&type=pdf
// university class video https://www.youtube.com/watch?v=yjCMEjoc_ZI
// scikit learn discussion https://github.com/scikit-learn/scikit-learn/issues/4577//
public static double calculateAP(List<bool> isCorrectEnum, int numGT)
{
var ap = 0.0;
var tpCount = 0;
var fpCount = 0;
var precisionList = new List<double>(isCorrectEnum.Count+2);
var recallList = new List<double>(isCorrectEnum.Count+2);
// padding initial p/r
precisionList.Add(0.0);
recallList.Add(0.0);
foreach (var isCorrect in isCorrectEnum)
{
tpCount += isCorrect ? 1 : 0;
fpCount += isCorrect ? 0 : 1;
var precision = tpCount / (double)(tpCount + fpCount);
var recall = tpCount / (double)numGT;
precisionList.Add(precision);
recallList.Add(recall);
}
// padding ending p/r
precisionList.Add(0.0);
recallList.Add(1.0);
// smoothing precision (max precision for all recalls > current recall)
for (int i = precisionList.Count - 2; i >= 0; i--)
{
precisionList[i] = Math.Max(precisionList[i], precisionList[i + 1]);
}
for (int i = 1; i < precisionList.Count; i++)
{
ap += (recallList[i] - recallList[i - 1]) * precisionList[i];
}
return ap;
}
public static string SanitizeName(string name)
{
return Regex.Replace(name, "[^a-z0-9]", "_", RegexOptions.IgnoreCase);
}
public static IEnumerable<string> CustomSplit(string newtext, char splitChar)
{
var result = new List<string>();
var sb = new StringBuilder();
foreach (var c in newtext)
{
if (c == splitChar)
{
if (sb.Length > 0)
{
result.Add(sb.ToString());
sb.Clear();
}
continue;
}
sb.Append(c);
}
if (sb.Length > 0)
{
result.Add(sb.ToString());
}
return result;
}
// get the intersection of two boxes
private static void getTwoBoxesIntersectionBoundaries(OpenImagesLabelBox b1, OpenImagesLabelBox b2, out double left, out double right, out double bottom, out double top)
{
left = Math.Max(b1.xMin, b2.xMin);
right = Math.Min(b1.xMax, b2.xMax);
bottom = Math.Max(b1.yMin, b2.yMin);
top = Math.Min(b1.yMax, b2.yMax);
}
// get the union of two boxes
private static void getTwoBoxesConvexBoundaries(OpenImagesLabelBox b1, OpenImagesLabelBox b2, out double left, out double right, out double bottom, out double top)
{
left = Math.Min(b1.xMin, b2.xMin);
right = Math.Max(b1.xMax, b2.xMax);
bottom = Math.Min(b1.yMin, b2.yMin);
top = Math.Max(b1.yMax, b2.yMax);
}
}
public class OpenImagesRelationTriplet : ICloneable
{
public OpenImagesLabelBox BoxA;
public OpenImagesLabelBox BoxB;
public string RelationLabel;
public double Confidence;
public bool isMatched;
public string getABRKey()
{
return BoxA.Label + '-' + BoxB.Label + '-' + RelationLabel;
}
public object Clone()
{
return this.MemberwiseClone();
}
}
public class OpenImagesLabelBox : ICloneable
{
public string Label;
public double Confidence;
public double xMin;
public double yMin;
public double xMax;
public double yMax;
public int isGroupOf;
public bool isMatched;
public object Clone()
{
return this.MemberwiseClone();
}
public bool isValid()
{
if (xMin >= xMax)
throw new Exception(String.Format("Your box's XMin, {0}, >= XMax, {1}", xMin, xMax));
if (yMin >= yMax)
throw new Exception(String.Format("Your box's YMin, {0}, >= YMax, {1}", yMin, yMax));
if (xMin < 0)
throw new Exception(String.Format("Your box's XMin, {0}, < 0", xMin));
if (xMin > 1)
throw new Exception(String.Format("Your box's XMin, {0}, > 1", xMin));
if (xMax < 0)
throw new Exception(String.Format("Your box's XMax, {0}, < 0", xMax));
if (xMax > 1)
throw new Exception(String.Format("Your box's XMax, {0}, > 1", xMax));
if (yMin < 0)
throw new Exception(String.Format("Your box's YMin, {0}, < 0", yMin));
if (yMin > 1)
throw new Exception(String.Format("Your box's YMin, {0}, > 1", yMin));
if (yMax < 0)
throw new Exception(String.Format("Your box's YMax, {0}, < 0", yMax));
if (yMax > 1)
throw new Exception(String.Format("Your box's YMax, {0}, > 1", yMax));
return true;
}
public bool Equals(OpenImagesLabelBox b)
{
if (Label == b.Label && xMin == b.xMin && yMin == b.yMin && xMax == b.xMax && yMax == b.yMax && isGroupOf == b.isGroupOf)
return true;
else
return false;
}
}
}
using System;
using System.Linq;
using System.Collections.Generic;
using Kaggle.DataFrames;
using Kaggle.Metrics.Utilities;
using System.Collections.Concurrent;
namespace Kaggle.Metrics.Custom
{
[PublishedEvaluationAlgorithm("OpenImagesVisualRelations", Name = "OpenImagesVisualRelations",
Description = "OpenImagesVisualRelations metric for Open Images (by Google AI). The metric is the mean of 3 different measurements: mAP for each relation, mean Recall per image, and mAP of union of two boxes per relation",
IsPublic = false, IsMax = true)]
public class OpenImagesVisualRelations : DataFrameEvaluationAlgorithm<OpenImagesVisualRelations.Solution, OpenImagesVisualRelations.Submission, OpenImagesVisualRelations.Parameters>
{
public class Parameters : EvaluationAlgorithmParameters
{
public double DefaultThreshold { get; set; }
public int MaxKPredictionsPerImage { get; set; }
public double RelationMAPWeight { get; set; }
public double ImageRecallWeight { get; set; }
public double PhraseMAPWeight { get; set; }
}
/* the submission would be: imageid, <labelA, x_min, x_max, y_min, y_max, labelB, x_min, x_max, y_min, y_max, labelR, confidence>(space) < >
*
* the solution would be: imageid, <labelA, x_min, x_max, y_min, y_max, labelB, x_min, x_max, y_min, y_max, labelR>, < >
*/
public class Solution : TypedDataFrame
{
[Series(IsKey = true)]
public Series<string> ImageId { get; set; }
public Series<string> PredictionString { get; set; }
//public Series<string> Usage { get; set; }
}
public class Submission : TypedDataFrame
{
[Series(IsKey = true)]
public Series<string> ImageId { get; set; }
public Series<string> PredictionString { get; set; }
}
public override double EvaluateSubmissionSubset(Solution solution, Submission submission, Parameters parameters, IDictionary<string, object> additionalDetails)
{
return GetPredictions(solution, submission, parameters, additionalDetails);
}
class ImageRelationPrediction
{
public bool IsCorrect;
public string ImageId;
public string RelationLabel;
public string ABRKey;
public double Confidence;
}
protected static double GetPredictions(Solution solution, Submission submission, Parameters param, IDictionary<string, object> additionalDetails)
{
var spaceSeps = new[] { ' ' };
var rowCount = solution.RowCount;
var defaultThreshold = param.DefaultThreshold;
var numGTBoxesPerRelation = new ConcurrentDictionary<string, int>();
var detectedPredictionsStack = new ConcurrentStack<ImageRelationPrediction>();
var detectedPredictionsStackForConvexBox = new ConcurrentStack<ImageRelationPrediction>();
// for each image
Enumerable.Range(0, rowCount).AsParallel().ForAll(i =>
{
var imageId = solution.ImageId[i];
var solutionElevenlets = (solution.PredictionString[i] ?? "").Split(spaceSeps, StringSplitOptions.RemoveEmptyEntries);
var submissionTwelvelets = (submission.PredictionString[i] ?? "").Split(spaceSeps, StringSplitOptions.RemoveEmptyEntries);
if (submissionTwelvelets.Length % 12 != 0)
{
throw new Exception(String.Format("Image {0} has an incorrectly formatted predictions.", submission.ImageId[i]));
}
// make a list of positive/negative labels
var posLabels = new HashSet<string>();
var negLabels = new HashSet<string>();
var solTripletsDictionary = GetSolutionTripletsFromString(numGTBoxesPerRelation, solutionElevenlets, imageId, posLabels, negLabels);
// do "deep clone" of sol triplets to prepare for convex
var solTripletsDictionaryForConvexMatching = new Dictionary<string, List<OpenImagesRelationTriplet>>();
foreach (var kvpair in solTripletsDictionary)
{
List<OpenImagesRelationTriplet> newList = new List<OpenImagesRelationTriplet>(kvpair.Value.Count);
kvpair.Value.ForEach((item) => { newList.Add((OpenImagesRelationTriplet)item.Clone()); });
solTripletsDictionaryForConvexMatching.Add(kvpair.Key, newList);
}
if (submission.PredictionString[i] != "")
{
var subTripletsList = GetSubmissionTripletsFromString(imageId, submissionTwelvelets);
// sort the sub boxes by descending confidence
var subTripletsSorted = subTripletsList.OrderByDescending(x => x.Confidence);
foreach (var subTriplet in subTripletsSorted)
{
var subABRKey = subTriplet.getABRKey();
var relationPrediction = new ImageRelationPrediction
{
ImageId = imageId,
Confidence = subTriplet.Confidence,
RelationLabel = subTriplet.RelationLabel,
ABRKey = subABRKey
};
var convexBoxPrediction = new ImageRelationPrediction
{
ImageId = imageId,
Confidence = subTriplet.Confidence,
RelationLabel = subTriplet.RelationLabel,
ABRKey = subABRKey
};
var metaIoU = new List<double>();
var metaIoUConvex = new List<double>();
var relationLabel = subTriplet.RelationLabel;
if (negLabels.Contains(subTriplet.BoxA.Label) || negLabels.Contains(subTriplet.BoxB.Label))
{ // either A or B found in neg: false positive
relationPrediction.IsCorrect = false;
detectedPredictionsStack.Push(relationPrediction);
detectedPredictionsStackForConvexBox.Push(relationPrediction);
continue;
}
else if (posLabels.Contains(subTriplet.BoxA.Label) && posLabels.Contains(subTriplet.BoxB.Label))
// find the triplets that have the same unique ABR keys
// found in pos, go through the entire thing
{
var solTripletIsMatched = false;
var solTripletIsMatchedConvex = false;
if (solTripletsDictionary.ContainsKey(subABRKey)) // for metric #1 and #2
{
var solTripletsSameABR = solTripletsDictionary[subABRKey];
foreach (var solTriplet in solTripletsSameABR)
{
metaIoU.Add(Math.Min(OpenImagesUtil.IntersectionOverUnion(subTriplet.BoxA, solTriplet.BoxA),
OpenImagesUtil.IntersectionOverUnion(subTriplet.BoxB, solTriplet.BoxB)));
}
if (metaIoU.Count != 0 && metaIoU.Max() >= defaultThreshold)
{ // enough overlap => match
solTripletIsMatched = true;
var indxMax = metaIoU.FindIndex(a => a == metaIoU.Max());
var matchedSolTriplet = solTripletsDictionary[subABRKey][indxMax];
if (matchedSolTriplet.isMatched)
{
// the found triplet was already matched! mark it as false positive
relationPrediction.IsCorrect = false;
detectedPredictionsStack.Push(relationPrediction);
}
else
{// a new box is matched
relationPrediction.IsCorrect = true;
detectedPredictionsStack.Push(relationPrediction);
}
// mark that solution box matched in the sol boxes dictionary
matchedSolTriplet.isMatched = true;
solTripletsDictionary[subABRKey][indxMax] = matchedSolTriplet;
}
}
if (solTripletsDictionaryForConvexMatching.ContainsKey(subABRKey)) // for metric #3
{
var solTripletsConvexSameABR = solTripletsDictionaryForConvexMatching[subABRKey];
foreach (var solTriplet in solTripletsConvexSameABR)
{
metaIoUConvex.Add(OpenImagesUtil.IntersectionOverUnionConvex(subTriplet, solTriplet));
}
if (metaIoUConvex.Count != 0 && metaIoUConvex.Max() >= defaultThreshold)
{ // for metric #3, enough overlap => match
solTripletIsMatchedConvex = true;
var indxMax = metaIoUConvex.FindIndex(a => a == metaIoUConvex.Max());
var matchedSolConvexTriplet = solTripletsConvexSameABR[indxMax];
if (matchedSolConvexTriplet.isMatched)
{
// the found triplet was already matched! mark it as false positive
convexBoxPrediction.IsCorrect = false;
detectedPredictionsStackForConvexBox.Push(convexBoxPrediction);
}
else
{// a new box is matched
convexBoxPrediction.IsCorrect = true;
detectedPredictionsStackForConvexBox.Push(convexBoxPrediction);
}
matchedSolConvexTriplet.isMatched = true;
solTripletsDictionaryForConvexMatching[subABRKey][indxMax] = matchedSolConvexTriplet;
}
}
if (!solTripletIsMatched)
{
relationPrediction.IsCorrect = false;
detectedPredictionsStack.Push(relationPrediction);
}
if (!solTripletIsMatchedConvex)
{
convexBoxPrediction.IsCorrect = false;
detectedPredictionsStackForConvexBox.Push(convexBoxPrediction);
}
}
else
{
// A and B not both in pos , or not both neg, ignore!
continue;
}
}
}
});
// sort detectedPredictions again by confidence score because they are now sorted first by image then by confidence score
var detectedPredictionsByRelation = detectedPredictionsStack.GroupBy(x => x.RelationLabel)
.Select(grouping => grouping.OrderByDescending(x => x.Confidence));
// Metric #1: AP for each relation
var APPerRelation = new Dictionary<string, double>();
foreach (var enumOfPredictionsPerRelation in detectedPredictionsByRelation)
{
var relation = enumOfPredictionsPerRelation.First().RelationLabel;
// if the label doesn't exist in the solution, give it 0, and don't count that label's AP
var numGTforRelation = (numGTBoxesPerRelation.ContainsKey(relation) ? numGTBoxesPerRelation[relation] : 0);
if (numGTforRelation == 0)
continue;
var isCorrectEnum = enumOfPredictionsPerRelation.Select(x => x.IsCorrect).ToList();
double ap = OpenImagesUtil.calculateAP(isCorrectEnum, numGTforRelation);
APPerRelation.Add(relation, ap);
additionalDetails.Add(OpenImagesUtil.SanitizeName("Relation-AP-" + relation), ap);
}
var mAPRelation = APPerRelation.Values.Sum() / numGTBoxesPerRelation.Keys.Count;
// Metric #2: recall score at top K per image
// first, let's re-shape detectedPredictions to be ordered by image
// then take top K that has the highest confidence for that image
var detectedPredictionsByImage = detectedPredictionsStack.GroupBy(x => x.ImageId)
.Select(imageGroup => new
{
imageId = imageGroup,
predList = imageGroup.OrderByDescending(ImagePreds => ImagePreds.Confidence)
.Take(param.MaxKPredictionsPerImage)
});
var totalTP = detectedPredictionsByImage.Sum(a => a.predList.Count(x => x.IsCorrect));
var avgRecall = (double)totalTP / numGTBoxesPerRelation.Values.Sum();
// Metric #3: AP for each relation for fuse box
// sort detectedPredictions again by confidence score because they are now sorted first by image then by confidence score
var detectedPredictionsConvexBoxByRelation = detectedPredictionsStackForConvexBox.GroupBy(x => x.RelationLabel)
.Select(grouping => grouping.OrderByDescending(x => x.Confidence));
var APPerRelationConvexBox = new Dictionary<string, double>();
foreach (var enumOfConvexPredictionsPerRelation in detectedPredictionsConvexBoxByRelation)
{
var relation = enumOfConvexPredictionsPerRelation.First().RelationLabel;
// if the label doesn't exist in the solution, give it 0, and don't count that label's AP
var numGTforRelation = (numGTBoxesPerRelation.ContainsKey(relation) ? numGTBoxesPerRelation[relation] : 0);
if (numGTforRelation == 0)
continue;
var isCorrectEnum = enumOfConvexPredictionsPerRelation.Select(x => x.IsCorrect).ToList();
double ap = OpenImagesUtil.calculateAP(isCorrectEnum, numGTforRelation);
APPerRelationConvexBox.Add(relation, ap);
additionalDetails.Add(OpenImagesUtil.SanitizeName("Phrases_AP-" + relation), ap);
}
var mAPRelationConvexBox = APPerRelationConvexBox.Values.Sum() / numGTBoxesPerRelation.Keys.Count;
if (additionalDetails != null)
{
additionalDetails["mAPRelation"] = mAPRelation;
additionalDetails["avgRecall"] = avgRecall;
additionalDetails["mAPRelationConvexBox"] = mAPRelationConvexBox;
}
return (param.RelationMAPWeight * mAPRelation + param.ImageRecallWeight * avgRecall + param.PhraseMAPWeight * mAPRelationConvexBox);
}
private static List<OpenImagesRelationTriplet> GetSubmissionTripletsFromString(string imageId, string[] submissionTwelvelets)
{
// read all sub boxes
var subBoxesList = new List<OpenImagesRelationTriplet>();
for (int subIndx = 0; subIndx < submissionTwelvelets.Length; subIndx += 12)
{
try
{
double conf = double.Parse(submissionTwelvelets[subIndx]);
var subALabel = submissionTwelvelets[subIndx + 1];
var subAXMin = double.Parse(submissionTwelvelets[subIndx + 2]);
var subAYMin = double.Parse(submissionTwelvelets[subIndx + 3]);
var subAXMax = double.Parse(submissionTwelvelets[subIndx + 4]);
var subAYMax = double.Parse(submissionTwelvelets[subIndx + 5]);
var subBLabel = submissionTwelvelets[subIndx + 6];
var subBXMin = double.Parse(submissionTwelvelets[subIndx + 7]);
var subBYMin = double.Parse(submissionTwelvelets[subIndx + 8]);
var subBXMax = double.Parse(submissionTwelvelets[subIndx + 9]);
var subBYMax = double.Parse(submissionTwelvelets[subIndx + 10]);
var subRLabel = submissionTwelvelets[subIndx + 11];
var subABox = new OpenImagesLabelBox { Label = subALabel, xMin = subAXMin, yMin = subAYMin, xMax = subAXMax, yMax = subAYMax };
var subBBox = new OpenImagesLabelBox { Label = subBLabel, xMin = subBXMin, yMin = subBYMin, xMax = subBXMax, yMax = subBYMax };
var subTriplet = new OpenImagesRelationTriplet { BoxA = subABox, BoxB = subBBox, RelationLabel = subRLabel, Confidence = conf };
if (subABox.isValid() && subBBox.isValid())
subBoxesList.Add(subTriplet);
}
catch (FormatException e)
{
throw new Exception(String.Format("Error parsing predictions for Image {0}. {1}", imageId, e.Message));
}
}
return subBoxesList;
}
private static Dictionary<string, List<OpenImagesRelationTriplet>> GetSolutionTripletsFromString(ConcurrentDictionary<string, int> numGTBoxesPerRelation,
string[] solutionElevenlets, string imageId, HashSet<string> posLabels, HashSet<string> negLabels)
{
var solTripletsDictionary = new Dictionary<string, List<OpenImagesRelationTriplet>>();
for (int solIndx = 0; solIndx < solutionElevenlets.Length; solIndx += 11)
{
var solALabel = solutionElevenlets[solIndx];
var solAXMin = double.Parse(solutionElevenlets[solIndx + 1]);
var solAYMin = double.Parse(solutionElevenlets[solIndx + 2]);
var solAXMax = double.Parse(solutionElevenlets[solIndx + 3]);
var solAYMax = double.Parse(solutionElevenlets[solIndx + 4]);
var solBLabel = solutionElevenlets[solIndx + 5];
var solBXMin = double.Parse(solutionElevenlets[solIndx + 6]);
var solBYMin = double.Parse(solutionElevenlets[solIndx + 7]);
var solBXMax = double.Parse(solutionElevenlets[solIndx + 8]);
var solBYMax = double.Parse(solutionElevenlets[solIndx + 9]);
var solRLabel = solutionElevenlets[solIndx + 10];
var solABox = new OpenImagesLabelBox { Label = solALabel, xMin = solAXMin, yMin = solAYMin, xMax = solAXMax, yMax = solAYMax };
var solBBox = new OpenImagesLabelBox { Label = solBLabel, xMin = solBXMin, yMin = solBYMin, xMax = solBXMax, yMax = solBYMax };
var solTriplet = new OpenImagesRelationTriplet { BoxA = solABox, BoxB = solBBox, RelationLabel = solRLabel, isMatched = false };
if (solAXMin != -1 && solBXMin != -1) // not negative labels
{
numGTBoxesPerRelation.AddOrUpdate(solRLabel, 1, (key, oldValue) => oldValue + 1);
posLabels.Add(solALabel);
posLabels.Add(solBLabel);
var solABRKey = solTriplet.getABRKey();
if (solTripletsDictionary.ContainsKey(solABRKey))
{
solTripletsDictionary[solABRKey].Add(solTriplet);
}
else
{
solTripletsDictionary.Add(solABRKey, new List<OpenImagesRelationTriplet>());
solTripletsDictionary[solABRKey].Add(solTriplet);
}
}
else if (solAXMin == -1)
{
negLabels.Add(solALabel);
}
else if (solBXMin == -1)
{
negLabels.Add(solBLabel);
}
}
return solTripletsDictionary;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment