Created
September 8, 2023 01:19
-
-
Save tedhtchang/f1f70f625a2c89dd75e788cc50a72a4d to your computer and use it in GitHub Desktop.
Ray Serve enabled KServe custom runtime
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
import io | |
from PIL import Image | |
from torchvision import models, transforms | |
from typing import Dict | |
import torch | |
from kserve import Model, ModelServer | |
from ray import serve | |
@serve.deployment(name="custom-model", num_replicas=1) | |
class AlexNetModel(Model): | |
def __init__(self): | |
self.name = "custom-model" | |
super().__init__(self.name) | |
self.load() | |
def load(self): | |
self.model = models.alexnet(pretrained=True) | |
self.model.eval() | |
self.ready = True | |
def predict(self, payload: Dict, headers: Dict[str, str] = None) -> Dict: | |
import json | |
payload = json.loads(payload.decode('utf-8')) | |
img_data = payload["instances"][0]["image"]["b64"] | |
raw_img_data = base64.b64decode(img_data) | |
input_image = Image.open(io.BytesIO(raw_img_data)) | |
preprocess = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
input_tensor = preprocess(input_image).unsqueeze(0) | |
output = self.model(input_tensor) | |
torch.nn.functional.softmax(output, dim=1) | |
values, top_5 = torch.topk(output, 5) | |
result = values.flatten().tolist() | |
return {"predictions": result} | |
if __name__ == "__main__": | |
ModelServer().start({"custom-model": AlexNetModel}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment