Skip to content

Instantly share code, notes, and snippets.

@akirasosa
Last active September 14, 2018 02:39
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save akirasosa/812e81f14f300323df98e42ca5825604 to your computer and use it in GitHub Desktop.
Save akirasosa/812e81f14f300323df98e42ca5825604 to your computer and use it in GitHub Desktop.
Benchmark tensorflow model in Android.
apply plugin: 'com.android.application'
apply plugin: 'kotlin-android'
apply plugin: 'kotlin-android-extensions'
project.ext.ASSET_DIR = projectDir.toString() + '/assets'
android {
compileSdkVersion 26
defaultConfig {
applicationId "tfexample.myapp.com.myapplication"
minSdkVersion 21
targetSdkVersion 26
versionCode 1
versionName "1.0"
testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
}
}
sourceSets {
main {
assets.srcDirs = [project.ext.ASSET_DIR]
}
}
}
dependencies {
implementation fileTree(dir: 'libs', include: ['*.jar'])
implementation"org.jetbrains.kotlin:kotlin-stdlib-jre7:$kotlin_version"
implementation 'com.android.support:appcompat-v7:26.1.0'
implementation 'com.android.support.constraint:constraint-layout:1.0.2'
implementation 'org.tensorflow:tensorflow-android:1.4.0'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'com.android.support.test:runner:1.0.1'
androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.1'
}
package tfexample.myapp.com.myapplication
import android.os.Bundle
import android.support.v7.app.AppCompatActivity
import android.util.Log
import org.tensorflow.contrib.android.TensorFlowInferenceInterface
import kotlin.system.measureTimeMillis
val modelName = "file:///android_asset/mobile_unet_160_100_100.pb"
class MainActivity : AppCompatActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
measureInference(modelName, 160, 160)
}
private fun measureInference(modelName: String, height: Long, width: Long) {
val inferenceInterface = TensorFlowInferenceInterface(assets, modelName)
val inputs = FloatArray((width * height * 3).toInt())
val outputs = FloatArray((width * height).toInt())
// It's for excluding 1st measurement.
inferenceInterface.feed("input_1", inputs, 1, height, width, 3)
inferenceInterface.run(arrayOf("output_0"))
inferenceInterface.fetch("output_0", outputs)
val times = (1..10).map {
measureTimeMillis {
inferenceInterface.feed("input_1", inputs, 1, height, width, 3)
inferenceInterface.run(arrayOf("output_0"))
inferenceInterface.fetch("output_0", outputs)
}
}
Log.d("MainActivity", times.toString())
Log.d("MainActivity", times.average().toString())
}
}
@pribadihcr
Copy link

Hi
Where I can download the pretrained model

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment