Skip to content

Instantly share code, notes, and snippets.

@maxoobot
Created March 21, 2020 04:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maxoobot/117febcba144655d7df9a15107a49850 to your computer and use it in GitHub Desktop.
Save maxoobot/117febcba144655d7df9a15107a49850 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
import allspark
import tensorflow as tf
import cv2 as cv
import numpy as np
import math
import os
import mnist_util
class MyProcessor(allspark.BaseProcessor):
def initialize(self):
""" Load model file at processor initialization
"""
self.model = tf.keras.models.load_model('mnist_model.h5')
def pre_process(self, data):
""" Get image and convert to mnist format
"""
np_arr = np.fromstring(data, np.uint8)
img_np = cv.imdecode(np_arr, cv.IMREAD_COLOR)
img_mnist = mnist_util.mnist_format(img_np)
return img_mnist
def post_process(self, data):
""" Get prediction and class string
"""
pred = np.argmax(data)
category = mnist_util.mnist_class_index(pred)
return str(category).encode()
def process(self, data):
""" Get model prediction from input image data
"""
try:
X = self.pre_process(data)
X = X.reshape((1, 28, 28, 1))
prediction = self.model.predict(X)
category = self.post_process(prediction)
return category, 200
except:
return "Invalid input format", 400
if __name__ == '__main__':
# paramter worker_threads indicates concurrency of processing
runner = MyProcessor(worker_threads=10)
runner.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment