Last active
October 28, 2020 14:34
-
-
Save danield137/d54fdf218c4f39c50fbd9289ca79cf6c to your computer and use it in GitHub Desktop.
Bento bundle wrapper
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import importlib | |
import logging | |
import os | |
import shutil | |
from bentoml.exceptions import InvalidArgument, MissingDependencyException | |
from bentoml.service.artifacts import BentoServiceArtifact | |
from bentoml.service.env import BentoServiceEnv | |
logger = logging.getLogger(__name__) | |
class PytorchBundleArtifact(BentoServiceArtifact): | |
def __init__(self, name, file_extension=".pt"): | |
super().__init__(name) | |
self._file_extension = file_extension | |
self._model = None | |
def _file_path(self, base_path): | |
return os.path.join(base_path, self.name + self._file_extension) | |
def pack(self, model): # pylint:disable=arguments-differ | |
try: | |
import torch | |
except ImportError: | |
raise MissingDependencyException( | |
"torch package is required to use PytorchModelArtifact" | |
) | |
if not isinstance(model, torch.nn.Module): | |
raise InvalidArgument( | |
"PytorchModelArtifact can only pack type 'torch.nn.Module'" | |
) | |
self._model = model | |
return self | |
def load(self, path): | |
import sys | |
try: | |
import torch | |
except ImportError: | |
raise MissingDependencyException( | |
"torch package is required to use PytorchModelArtifact" | |
) | |
sys.path.insert(1, self.name) | |
model_wrapper = importlib.import_module(self.name) | |
model = model_wrapper.load() | |
if not isinstance(model, torch.nn.Module): | |
raise InvalidArgument( | |
"Expecting PytorchModelArtifact loaded object type to be " | |
"'torch.nn.Module' but actually it is {}".format(type(model)) | |
) | |
return self.pack(model) | |
def set_dependencies(self, env: BentoServiceEnv): | |
logger.warning( | |
"BentoML by default does not include spacy and torchvision package when " | |
"using PytorchModelArtifact. To make sure BentoML bundle those packages if " | |
"they are required for your model, either import those packages in " | |
"BentoService definition file or manually add them via " | |
"`@env(pip_packages=['torchvision'])` when defining a BentoService" | |
) | |
env.add_pip_packages(['torch']) | |
def get(self): | |
return self._model | |
def save(self, dst): | |
src = os.path.abspath(self.name) | |
dst = os.path.join(dst, self.name) | |
shutil.copytree(src, dst) | |
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from bentoml.server import start_dev_server | |
from bentoml.server.api_server import BentoAPIServer | |
from my_service.inference.service import MyClassifier | |
import model | |
bento_service = MyClassifier() | |
bento_service.pack('model', model.load()) | |
api_server = BentoAPIServer(bento_service, port=5000) | |
api_server.start() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment