This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@RunWith(AndroidJUnit4.class) | |
@LargeTest | |
public class MLModelTest { | |
@Rule | |
public ActivityTestRule<ModelTestActivity> mainActivityActivityRule = new ActivityTestRule<>(ModelTestActivity.class); | |
@Test | |
public void testClassificationUI() { | |
ModelTestActivity activity = mainActivityActivityRule.getActivity(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class ModelTestActivity extends AppCompatActivity { | |
private ImageView ivPreview; | |
private TextView tvClassification; | |
private ModelClassificator modelClassificator; | |
@Override | |
protected void onCreate(@Nullable Bundle savedInstanceState) { | |
super.onCreate(savedInstanceState); | |
setContentView(com.frogermcs.imageclassificationtester.test.R.layout.activity_model_test); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from PIL import Image | |
VAL_BATCH_DIR = "validation_batch" | |
!mkdir {VAL_BATCH_DIR} | |
# Export batch to *.jpg files with specific naming convention. | |
# Make sure they are exported in the full quality, otherwise the inference | |
# process will return different results. | |
for n in range(32): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class ModelClassificator { | |
private static final int MAX_CLASSIFICATION_RESULTS = 3; | |
private static final float CLASSIFICATION_THRESHOLD = 0.2f; | |
private final Interpreter interpreter; | |
private final List<String> labels; | |
private final ModelConfig modelConfig; | |
public ModelClassificator(Context context, | |
ModelConfig modelConfig) throws IOException { |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class ClassificationFrameProcessor implements FrameProcessor { | |
private final ModelClassificator modelClassificator; | |
private final ClassificationListener classificationListener; | |
public ClassificationFrameProcessor(ModelClassificator modelClassificator, | |
ClassificationListener classificationListener) { | |
this.modelClassificator = modelClassificator; | |
this.classificationListener = classificationListener; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class MainActivity extends AppCompatActivity | |
implements ClassificationFrameProcessor.ClassificationListener { | |
private CameraView cameraView; | |
private TextView tvClassification; | |
private ClassificationFrameProcessor classificationFrameProcessor; | |
@Override | |
protected void onCreate(Bundle savedInstanceState) { | |
super.onCreate(savedInstanceState); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Output, CoreML | |
(CPU) Prediction for Golden Retriever: golden retriever 0.611853480339 | |
(GPU) Prediction for laptop: notebook 0.515091240406 | |
Output, TensorFlow | |
Prediction for Golden Retriever: golden retriever 0.61186796 | |
Prediction for laptop: notebook 0.51475537 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
TF_INPUT_TENSOR = 'input:0' | |
TF_OUTPUT_TENSOR = 'MobilenetV2/Predictions/Reshape_1:0' | |
with tf.Session(graph = g) as sess: | |
tf_laptop_out = sess.run(TF_OUTPUT_TENSOR, feed_dict={TF_INPUT_TENSOR: img_laptop_tf}) | |
tf_golden_out = sess.run(TF_OUTPUT_TENSOR, feed_dict={TF_INPUT_TENSOR: img_golden_tf}) | |
tf_laptop_out = tf_laptop_out.flatten() | |
tf_golden_out = tf_golden_out.flatten() | |
laptop_idx = np.argmax(tf_laptop_out) | |
golden_idx = np.argmax(tf_golden_out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load TensorFlow frozen model | |
TF_FROZEN_MODEL = "mobilenet_v2_1.0_224_frozen.pb" | |
with open(TF_FROZEN_MODEL, 'rb') as f: | |
serialized_model = f.read() | |
tf.reset_default_graph() | |
graph_definition = tf.GraphDef() | |
graph_definition.ParseFromString(serialized_model) | |
with tf.Graph().as_default() as g: | |
tf.import_graph_def(graph_definition, name='') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Prepare images for TensorFlow requirements | |
#Convert to expected type: float | |
img_laptop_tf = np.array(img_laptop).astype(np.float32) | |
# Setup expected input shape: [1,224,224,3] | |
img_laptop_tf = np.expand_dims(img_laptop_tf, axis = 0) | |
# Convert to expected values ranges: [0, 1] | |
img_laptop_tf = (1.0/255.0) * img_laptop_tf | |
print( 'Image shape:', img_laptop_tf.shape) |