Skip to content

Instantly share code, notes, and snippets.

@Gnzlt
Created July 26, 2018 08:46
Show Gist options
  • Save Gnzlt/1417dbca22593eb168c25faae8c3186d to your computer and use it in GitHub Desktop.
Save Gnzlt/1417dbca22593eb168c25faae8c3186d to your computer and use it in GitHub Desktop.
Android image classifier for Firebase ML Kit with Tensorflow Lite custom model
class ImageClassifier(
private val context: Context,
private val modelName: String,
private val modelPath: String,
private val modelLabelPath: String
) {
companion object {
private val TAG = ImageClassifier.javaClass.simpleName
private const val RESULTS_TO_SHOW = 1
private const val DIM_BATCH_SIZE = 1
private const val DIM_PIXEL_SIZE = 3
private const val DIM_IMG_SIZE = 299
}
private val imageBuffer = IntArray(DIM_IMG_SIZE * DIM_IMG_SIZE)
private val sortedLabels = PriorityQueue<MutableMap.MutableEntry<String, Float>>(RESULTS_TO_SHOW) { o1, o2 -> o1.value.compareTo(o2.value) }
private val outputLabels = arrayListOf<String>()
private lateinit var dataOptions: FirebaseModelInputOutputOptions
private var modelInterpreter: FirebaseModelInterpreter? = null
init {
try {
initLabels()
setupDataOptions()
registerModelSources()
setupInterpreter()
} catch (e: FirebaseMLException) {
e.printStackTrace()
}
}
private fun initLabels() {
try {
val inputStreamReader = InputStreamReader(context.assets.open(modelLabelPath))
val reader = BufferedReader(inputStreamReader)
while (reader.readLine() != null) {
outputLabels.add(reader.readLine())
}
} catch (e: IOException) {
Log.e(TAG, "Failed to read label list", e)
}
}
private fun setupDataOptions() {
val inputDims = intArrayOf(DIM_BATCH_SIZE, DIM_IMG_SIZE, DIM_IMG_SIZE, DIM_PIXEL_SIZE)
val outputDims = intArrayOf(DIM_BATCH_SIZE, outputLabels.size)
dataOptions = FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.BYTE, inputDims)
.setOutputFormat(0, FirebaseModelDataType.BYTE, outputDims)
.build()
}
private fun registerModelSources() {
val localModelSource = FirebaseLocalModelSource.Builder(modelName)
.setAssetFilePath(modelPath)
.build()
val downloadConditions = FirebaseModelDownloadConditions.Builder()
.requireWifi()
.build()
val cloudSource = FirebaseCloudModelSource.Builder(modelName)
.enableModelUpdates(true)
.setInitialDownloadConditions(downloadConditions)
.setUpdatesDownloadConditions(downloadConditions)
.build()
FirebaseModelManager.getInstance().apply {
registerLocalModelSource(localModelSource)
registerCloudModelSource(cloudSource)
}
}
private fun setupInterpreter() {
val modelOptions = FirebaseModelOptions.Builder()
.setLocalModelName(modelName)
.setCloudModelName(modelName)
.build()
modelInterpreter = FirebaseModelInterpreter.getInstance(modelOptions)
}
fun classify(bitmap: Bitmap): Observable<String> {
return Observable.create({ emitter ->
try {
modelInterpreter?.let { interpreter ->
val modelInputs = FirebaseModelInputs.Builder()
.add(bitmap.toByteBuffer())
.build()
interpreter
.run(modelInputs, dataOptions)
.addOnFailureListener { error -> emitter.onError(error) }
.addOnSuccessListener { result ->
val labelProbArray = result.getOutput<Array<ByteArray>>(0)
val label = labelProbArray.getTopLabel()
emitter.onNext(label)
emitter.onCompleted()
}
} ?: throw IllegalStateException("Interpreter not initalised")
} catch (e: Exception) {
emitter.onError(e)
}
}, Emitter.BackpressureMode.BUFFER)
}
@Synchronized
private fun Bitmap.toByteBuffer(): ByteBuffer {
val imgData = ByteBuffer
.allocateDirect(DIM_BATCH_SIZE * DIM_IMG_SIZE * DIM_IMG_SIZE * DIM_PIXEL_SIZE)
.apply {
order(ByteOrder.nativeOrder())
rewind()
}
getPixels(imageBuffer, 0, width, 0, 0, width, height)
var pixel = 0
for (i in 0 until DIM_IMG_SIZE) {
for (j in 0 until DIM_IMG_SIZE) {
val pixelVal = imageBuffer[pixel++]
imgData.put((pixelVal shr 16 and 0xFF).toByte())
imgData.put((pixelVal shr 8 and 0xFF).toByte())
imgData.put((pixelVal and 0xFF).toByte())
}
}
return imgData
}
@Synchronized
private fun Array<ByteArray>.getTopLabel(): String? {
for (i in outputLabels.indices) {
val labelEntry = AbstractMap.SimpleEntry<String, Float>(
outputLabels[i],
(this[0][i] and 0xff.toByte()) / 255.0f
)
sortedLabels.add(labelEntry)
if (sortedLabels.size > RESULTS_TO_SHOW) {
sortedLabels.poll()
}
}
return sortedLabels.firstOrNull()?.key
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment