Created
May 7, 2020 21:33
-
-
Save fernandocamargoai/f431556921ccb33dc6aced390ff1a915 to your computer and use it in GitHub Desktop.
consolidation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from typing import NamedTuple, Tuple | |
import PIL.Image | |
import numpy as np | |
from bentoml import BentoService, api, env, ver, artifacts | |
from bentoml.artifact import KerasModelArtifact, TextFileArtifact, TensorflowSavedModelArtifact | |
from bentoml.handlers import ImageHandler | |
import tensorflow as tf | |
from tensorflow.keras import Model | |
from prometheus_client import Histogram, Counter, Gauge | |
from datalife.preprocessing import IMAGE_PREPROCESSING_FUNCTIONS | |
from datalife.utils import img_to_array, trim_border, resize_keeping_aspect_ratio, apply_clahe | |
class ModelConfig(NamedTuple): | |
input_shape: Tuple[int, int] | |
image_color_mode: str | |
image_data_format: str | |
image_preprocessing: str | |
image_preprocessing_extra_params: dict | |
interpolation_method: int | |
keep_aspect_ratio: bool | |
apply_clahe: bool | |
trim_image_border: bool | |
@ver(major=1, minor=0) | |
@artifacts([ | |
TensorflowSavedModelArtifact("model"), | |
TextFileArtifact("model_config", ".json") | |
]) | |
@env(conda_dependencies=["tensorflow=1.15.0", "numpy=1.17.2", "pillow=6.1.0", "scikit-image=0.16.2"], | |
conda_channels="conda-forge") | |
class ConsolidationClassification(BentoService): | |
@property | |
def model(self) -> tf.MetaGraphDef: | |
return self.artifacts.model | |
@property | |
def model_config(self) -> ModelConfig: | |
if not hasattr(self, "_model_config"): | |
model_config_json = self.artifacts.model_config | |
self._model_config = ModelConfig(*json.loads(model_config_json)) | |
return self._model_config | |
def _preprocess_image(self, image: np.ndarray) -> np.ndarray: | |
img = PIL.Image.fromarray(image) | |
if self.model_config.trim_image_border: | |
img = trim_border(img) | |
if self.model_config.keep_aspect_ratio: | |
img = resize_keeping_aspect_ratio(img, self.model_config.input_shape, | |
self.model_config.interpolation_method) | |
else: | |
img = img.resize(self.model_config.input_shape, self.model_config.interpolation_method) | |
if self.model_config.apply_clahe: | |
img = apply_clahe(img) | |
if self.model_config.image_color_mode == "rgb": | |
img = img.convert("RGB") | |
return img_to_array(img, self.model_config.image_data_format) | |
@api(ImageHandler, input_names=("image",), pilmode="L") | |
def predict(self, image: np.ndarray) -> tf.Tensor: | |
img_array = self._preprocess_image(image) | |
input_ = IMAGE_PREPROCESSING_FUNCTIONS[self.model_config.image_preprocessing]( | |
np.expand_dims(img_array.astype('float32'), axis=0), | |
{"image_data_format": self.model_config.image_data_format, | |
**self.model_config.image_preprocessing_extra_params} | |
) | |
return self.model(input_) | |
def pack_consolidation_model(model: Model, model_config: ModelConfig) -> ConsolidationClassification: | |
packed_model = ConsolidationClassification() | |
packed_model.pack("model", model) | |
packed_model.pack("model_config", json.dumps(model_config)) | |
return packed_model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment