Created
September 20, 2018 12:38
-
-
Save tomekdz/ef09cb59b233d8b1324df00eda11afad to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fun processOutput(data: Array<Array<Array<FloatArray>>>): Map<Int, Float> { | |
val output = data[0] //output is now 13x13x285 | |
val flatOutput = mutableListOf<Float>() | |
output.forEach { flatOutput.addAll(it.flatten()) } | |
val resultsMap = mutableMapOf<Int, Float>() | |
val gridHeight = 13 | |
val gridWidth = 13 | |
for (y in 0 until gridHeight) { | |
for (x in 0 until gridWidth) { | |
for (b in 0 until NUM_BOXES_PER_BLOCK) { | |
val offset = (gridWidth * (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5)) * y | |
+ NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5) * x | |
+ (NUM_CLASSES + 5) * b) | |
val confidence = sigmoid(flatOutput[offset + 4]) | |
var detectedClass = -1 | |
var maxClass = 0f | |
val classes = FloatArray(NUM_CLASSES) | |
for (c in 0 until NUM_CLASSES) { | |
classes[c] = flatOutput[offset + 5 + c] | |
} | |
softmax(classes) | |
for (c in 0 until NUM_CLASSES) { | |
if (classes[c] > maxClass) { | |
detectedClass = c | |
maxClass = classes[c] | |
} | |
} | |
val confidenceInClass = maxClass * confidence | |
if (confidenceInClass > DETECTION_THRESHOLD) { | |
resultsMap.put(detectedClass, confidenceInClass) | |
} | |
} | |
} | |
} | |
return resultsMap | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment