Created
October 7, 2022 13:15
-
-
Save samhita-alla/37cfcc91fe1e771debf5180e03772880 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
... | |
ONNX_MODEL = "onnx_geolocator" | |
VERSION = "latest" | |
class GeoLocatorRunnable(bentoml.Runnable): | |
""" | |
Custom BentoML runner to fetch multiple ONNX outputs | |
""" | |
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu") | |
SUPPORTS_CPU_MULTI_THREADING = True | |
def __init__(self): | |
super().__init__() | |
# load the model instance | |
self.model = bentoml.onnx.load_model(f"{ONNX_MODEL}:{VERSION}") | |
@bentoml.Runnable.method(batchable=False) | |
def fetch_multiple_onnx_outputs(self, images) -> List[np.ndarray]: | |
ort_inputs = {self.model.get_inputs()[0].name: to_numpy(images)} | |
ort_outs = self.model.run(None, ort_inputs) | |
return ort_outs | |
# initialize custom bentoml runner, svc and input spec | |
geolocator_runner = bentoml.Runner( | |
GeoLocatorRunnable, | |
models=[bentoml.onnx.get(f"{ONNX_MODEL}:{VERSION}")], | |
) | |
svc = bentoml.Service("geolocator", runners=[geolocator_runner]) | |
... | |
@svc.api(input=Image(), output=Text(), route="predict-image") | |
def predict_image(image: PILImage) -> str: | |
image_dir = img_processor(img_data=image) | |
return predict_helper(image_dir=image_dir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment