Created
September 1, 2020 16:53
-
-
Save Bryanx/b839e3ceea0f9647ffbc5f90e3091742 to your computer and use it in GitHub Desktop.
tflite_interpreter.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import 'dart:io'; | |
import 'dart:math'; | |
import 'package:image/image.dart' as img; | |
import 'package:tflite_flutter/tflite_flutter.dart'; | |
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart'; | |
class TfLiteInterpreter { | |
static const MODEL_PATH = "path_to_model.tflite"; | |
static const LABELS_PATH = "path_to_labels.tflite"; | |
Interpreter _interpreter; | |
List<int> _inputShape; | |
List<int> _outputShape; | |
TfLiteType _outputType = TfLiteType.uint8; | |
TensorImage _inputImage; | |
TensorBuffer _outputBuffer; | |
NormalizeOp get preProcessNormalizeOp => NormalizeOp(127.5, 127.5); | |
Future<void> predictImage(String imgPath) async { | |
var image = File(imgPath); | |
await _loadModel(); | |
await _predict(image); | |
} | |
Future<void> _loadModel() async { | |
try { | |
this._interpreter = await Interpreter.fromAsset(MODEL_PATH); | |
_inputShape = _interpreter.getInputTensor(0).shape; // [1, 257, 257, 3] | |
print('input_shape:'); | |
print(_inputShape); | |
print(_interpreter.getInputTensor(0).type); //TfLiteType.float32 | |
_outputShape = _interpreter.getOutputTensor(0).shape; | |
print('output_shape:'); | |
print(_outputShape); | |
_outputType = _interpreter.getOutputTensor(0).type; | |
print(_outputType); | |
_outputBuffer = TensorBuffer.createFixedSize(_outputShape, _outputType); | |
} catch (e) { | |
print('Unable to create interpreter, Caught Exception: ${e.toString()}'); | |
} | |
} | |
Future<void> _predict(File image) async { | |
//read the image as bytes for TensorImage | |
img.Image imageInput = img.decodeImage(image.readAsBytesSync()); | |
//this will be the tensor that will be used for prediction | |
_inputImage = TensorImage.fromImage(imageInput); | |
_inputImage = _preProcess(); | |
_interpreter.run(_inputImage.buffer, _outputBuffer.getBuffer()); | |
print('output buffer shape and type'); | |
print(_outputBuffer.getShape()); | |
print(_outputBuffer.getDataType()); | |
List<String> labels = await FileUtil.loadLabels(LABELS_PATH); | |
TensorLabel tensorLabel = TensorLabel.fromList(labels, _outputBuffer); | |
Map<String, double> doubleMap = tensorLabel.getMapWithFloatValue(); | |
print('predictions:\n$doubleMap'); | |
} | |
TensorImage _preProcess() { | |
int cropSize = min(_inputImage.height, _inputImage.width); | |
return ImageProcessorBuilder() | |
.add(ResizeWithCropOrPadOp(cropSize, cropSize)) | |
.add(ResizeOp( | |
_inputShape[1], _inputShape[2], ResizeMethod.NEAREST_NEIGHBOUR)) | |
.build() | |
.process(_inputImage); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment