Skip to content

Instantly share code, notes, and snippets.

@fernandocamargoai
Created May 7, 2020 21:33
Show Gist options
  • Save fernandocamargoai/f431556921ccb33dc6aced390ff1a915 to your computer and use it in GitHub Desktop.
Save fernandocamargoai/f431556921ccb33dc6aced390ff1a915 to your computer and use it in GitHub Desktop.
consolidation.py
import json
from typing import NamedTuple, Tuple
import PIL.Image
import numpy as np
from bentoml import BentoService, api, env, ver, artifacts
from bentoml.artifact import KerasModelArtifact, TextFileArtifact, TensorflowSavedModelArtifact
from bentoml.handlers import ImageHandler
import tensorflow as tf
from tensorflow.keras import Model
from prometheus_client import Histogram, Counter, Gauge
from datalife.preprocessing import IMAGE_PREPROCESSING_FUNCTIONS
from datalife.utils import img_to_array, trim_border, resize_keeping_aspect_ratio, apply_clahe
class ModelConfig(NamedTuple):
input_shape: Tuple[int, int]
image_color_mode: str
image_data_format: str
image_preprocessing: str
image_preprocessing_extra_params: dict
interpolation_method: int
keep_aspect_ratio: bool
apply_clahe: bool
trim_image_border: bool
@ver(major=1, minor=0)
@artifacts([
TensorflowSavedModelArtifact("model"),
TextFileArtifact("model_config", ".json")
])
@env(conda_dependencies=["tensorflow=1.15.0", "numpy=1.17.2", "pillow=6.1.0", "scikit-image=0.16.2"],
conda_channels="conda-forge")
class ConsolidationClassification(BentoService):
@property
def model(self) -> tf.MetaGraphDef:
return self.artifacts.model
@property
def model_config(self) -> ModelConfig:
if not hasattr(self, "_model_config"):
model_config_json = self.artifacts.model_config
self._model_config = ModelConfig(*json.loads(model_config_json))
return self._model_config
def _preprocess_image(self, image: np.ndarray) -> np.ndarray:
img = PIL.Image.fromarray(image)
if self.model_config.trim_image_border:
img = trim_border(img)
if self.model_config.keep_aspect_ratio:
img = resize_keeping_aspect_ratio(img, self.model_config.input_shape,
self.model_config.interpolation_method)
else:
img = img.resize(self.model_config.input_shape, self.model_config.interpolation_method)
if self.model_config.apply_clahe:
img = apply_clahe(img)
if self.model_config.image_color_mode == "rgb":
img = img.convert("RGB")
return img_to_array(img, self.model_config.image_data_format)
@api(ImageHandler, input_names=("image",), pilmode="L")
def predict(self, image: np.ndarray) -> tf.Tensor:
img_array = self._preprocess_image(image)
input_ = IMAGE_PREPROCESSING_FUNCTIONS[self.model_config.image_preprocessing](
np.expand_dims(img_array.astype('float32'), axis=0),
{"image_data_format": self.model_config.image_data_format,
**self.model_config.image_preprocessing_extra_params}
)
return self.model(input_)
def pack_consolidation_model(model: Model, model_config: ModelConfig) -> ConsolidationClassification:
packed_model = ConsolidationClassification()
packed_model.pack("model", model)
packed_model.pack("model_config", json.dumps(model_config))
return packed_model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment