Skip to content

Instantly share code, notes, and snippets.

@danield137
Last active October 28, 2020 14:34
Show Gist options
  • Save danield137/d54fdf218c4f39c50fbd9289ca79cf6c to your computer and use it in GitHub Desktop.
Save danield137/d54fdf218c4f39c50fbd9289ca79cf6c to your computer and use it in GitHub Desktop.
Bento bundle wrapper
import importlib
import logging
import os
import shutil
from bentoml.exceptions import InvalidArgument, MissingDependencyException
from bentoml.service.artifacts import BentoServiceArtifact
from bentoml.service.env import BentoServiceEnv
logger = logging.getLogger(__name__)
class PytorchBundleArtifact(BentoServiceArtifact):
def __init__(self, name, file_extension=".pt"):
super().__init__(name)
self._file_extension = file_extension
self._model = None
def _file_path(self, base_path):
return os.path.join(base_path, self.name + self._file_extension)
def pack(self, model): # pylint:disable=arguments-differ
try:
import torch
except ImportError:
raise MissingDependencyException(
"torch package is required to use PytorchModelArtifact"
)
if not isinstance(model, torch.nn.Module):
raise InvalidArgument(
"PytorchModelArtifact can only pack type 'torch.nn.Module'"
)
self._model = model
return self
def load(self, path):
import sys
try:
import torch
except ImportError:
raise MissingDependencyException(
"torch package is required to use PytorchModelArtifact"
)
sys.path.insert(1, self.name)
model_wrapper = importlib.import_module(self.name)
model = model_wrapper.load()
if not isinstance(model, torch.nn.Module):
raise InvalidArgument(
"Expecting PytorchModelArtifact loaded object type to be "
"'torch.nn.Module' but actually it is {}".format(type(model))
)
return self.pack(model)
def set_dependencies(self, env: BentoServiceEnv):
logger.warning(
"BentoML by default does not include spacy and torchvision package when "
"using PytorchModelArtifact. To make sure BentoML bundle those packages if "
"they are required for your model, either import those packages in "
"BentoService definition file or manually add them via "
"`@env(pip_packages=['torchvision'])` when defining a BentoService"
)
env.add_pip_packages(['torch'])
def get(self):
return self._model
def save(self, dst):
src = os.path.abspath(self.name)
dst = os.path.join(dst, self.name)
shutil.copytree(src, dst)
return self
from bentoml.server import start_dev_server
from bentoml.server.api_server import BentoAPIServer
from my_service.inference.service import MyClassifier
import model
bento_service = MyClassifier()
bento_service.pack('model', model.load())
api_server = BentoAPIServer(bento_service, port=5000)
api_server.start()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment