public class ModelClassificator { | |
private static final int MAX_CLASSIFICATION_RESULTS = 3; | |
private static final float CLASSIFICATION_THRESHOLD = 0.2f; | |
private final Interpreter interpreter; | |
private final List<String> labels; | |
private final ModelConfig modelConfig; | |
public ModelClassificator(Context context, | |
ModelConfig modelConfig) throws IOException { | |
ByteBuffer model = AssetsUtils.loadFile(context, modelConfig.getModelFilename()); | |
this.interpreter = new Interpreter(model); | |
this.labels = AssetsUtils.loadLines(context, modelConfig.getLabelsFilename()); | |
this.modelConfig = modelConfig; | |
} | |
public List<ClassificationResult> process(Bitmap bitmap) { | |
Bitmap toClassify = ThumbnailUtils.extractThumbnail( | |
bitmap, modelConfig.getInputWidth(), modelConfig.getInputHeight() | |
); | |
ByteBuffer byteBufferToClassify = bitmapToModelsMatchingByteBuffer(toClassify); | |
float[][] result = new float[1][labels.size()]; | |
interpreter.run(byteBufferToClassify, result); | |
/* ... */ | |
return getSortedResult(result); | |
} | |
private ByteBuffer bitmapToModelsMatchingByteBuffer(Bitmap bitmap) { /* ... */ } | |
private float[] pixelToChannelValues(int pixel) { /* ... */ } | |
private List<ClassificationResult> getSortedResult(float[][] resultsArray) { /* ... */ } | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment