Skip to content

Instantly share code, notes, and snippets.

@jhabr
Last active July 24, 2023 21:48
Show Gist options
  • Save jhabr/aced0ee37d6a585616e3edce47706235 to your computer and use it in GitHub Desktop.
Save jhabr/aced0ee37d6a585616e3edce47706235 to your computer and use it in GitHub Desktop.
Huggingface Distilbert Conversion to TFLite using TF 2.x and Usage on Android
from transformers import TFDistilBertForSequenceClassification, DistilBertConfig
import tensorflow as tf
# load pretrained model
config = DistilBertConfig.from_pretrained('distilbert-base-cased', num_labels=10)
model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-cased', from_pt=True, config=config)
# create input and output
SEQUENCE_LENGTH = 100
NO_CLASSES = 10
input_ids = tf.keras.layers.Input(shape=(SEQUENCE_LENGTH,), dtype=tf.int32, name="input_ids", batch_size=1)
attention_mask = tf.keras.layers.Input(shape=(SEQUENCE_LENGTH,), dtype=tf.int32, name="attention_mask", batch_size=1)
inputs = [input_ids, attention_mask]
outputs = model(input_ids, attention_mask)[0]
outputs = tf.keras.layers.Dense(NO_CLASSES, activation='softmax')(outputs)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
# export model to tflite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.experimental_new_converter = True
converter.allow_custom_ops = True
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS
]
# For conversion with FP16 quantization
# converter.target_spec.supported_types = [tf.float16]
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_model = converter.convert()
open("distilbert_L100.tflite", 'wb' ).write(tflite_model)
import android.content.Context
import org.tensorflow.lite.Interpreter
import timber.log.Timber
import java.io.FileInputStream
import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel
import kotlin.system.measureTimeMillis
/*
make sure to include the following in the app/build.gradle dependencies:
implementation 'org.tensorflow:tensorflow-lite:2.5.0'
implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.5.0'
*/
class TextClassifier(private val context: Context) {
companion object {
private const val MODEL_PATH = "distilbert_L100.tflite"
private const val LABELS_PATH = "labels.txt"
}
private var model: MappedByteBuffer
private var labels: List<String>
private var interpreter: Interpreter
init {
model = loadModel()
labels = loadLabels()
val options = Interpreter.Options()
options.setNumThreads(5)
options.setUseNNAPI(true)
interpreter = Interpreter(model, options)
}
private fun loadModel(): MappedByteBuffer {
val fileDescriptor = context.assets.openFd(MODEL_PATH)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel: FileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}
private fun loadLabels(): List<String> {
return context.assets.open(LABELS_PATH).bufferedReader().useLines { it.toList() }
}
// input example:
//
// inputIds = intArrayOf(102, 4845, 103)
// attentionMask = intArrayOf(1, 1, 1)
// input = arrayOf(inputIds, attentionMask)
fun runInference(input: Array<IntArray>) {
val output = mapOf(0 to Array(1) {
FloatArray(labels.size)
})
interpreter.runForMultipleInputsOutputs(input, output)
val probabilities = output[0]?.get(0)
Timber.d("Probabilities Results: $probabilities")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment