Created
March 21, 2020 04:47
-
-
Save maxoobot/117febcba144655d7df9a15107a49850 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import allspark | |
import tensorflow as tf | |
import cv2 as cv | |
import numpy as np | |
import math | |
import os | |
import mnist_util | |
class MyProcessor(allspark.BaseProcessor): | |
def initialize(self): | |
""" Load model file at processor initialization | |
""" | |
self.model = tf.keras.models.load_model('mnist_model.h5') | |
def pre_process(self, data): | |
""" Get image and convert to mnist format | |
""" | |
np_arr = np.fromstring(data, np.uint8) | |
img_np = cv.imdecode(np_arr, cv.IMREAD_COLOR) | |
img_mnist = mnist_util.mnist_format(img_np) | |
return img_mnist | |
def post_process(self, data): | |
""" Get prediction and class string | |
""" | |
pred = np.argmax(data) | |
category = mnist_util.mnist_class_index(pred) | |
return str(category).encode() | |
def process(self, data): | |
""" Get model prediction from input image data | |
""" | |
try: | |
X = self.pre_process(data) | |
X = X.reshape((1, 28, 28, 1)) | |
prediction = self.model.predict(X) | |
category = self.post_process(prediction) | |
return category, 200 | |
except: | |
return "Invalid input format", 400 | |
if __name__ == '__main__': | |
# paramter worker_threads indicates concurrency of processing | |
runner = MyProcessor(worker_threads=10) | |
runner.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment