Skip to content

Instantly share code, notes, and snippets.

@ceshine
Created May 4, 2020 10:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ceshine/15968edd3ac3eaf86b1d2d2775db2760 to your computer and use it in GitHub Desktop.
Save ceshine/15968edd3ac3eaf86b1d2d2775db2760 to your computer and use it in GitHub Desktop.
A custom handler example for TensorServe (image classification)
import io
import os
import logging
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torch.autograd import Variable
from torchvision import transforms
logger = logging.getLogger(__name__)
class HiggsClassifier:
def __init__(self):
self.model = None
self.device = None
self.initialized = False
def initialize(self, ctx):
self.manifest = ctx.manifest
properties = ctx.system_properties
self.device = torch.device(
"cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
model_dir = properties.get("model_dir")
serialized_file = self.manifest['model']['serializedFile']
model_pt_path = os.path.join(model_dir, serialized_file)
self.model = torch.jit.load(model_pt_path)
self.model.to(self.device)
self.model.eval()
logger.debug(
'Model file {0} loaded successfully'.format(model_pt_path))
self.initialized = True
def preprocess(self, request):
"""
Scales, crops, and normalizes a PIL image for a PyTorch model,
returns an PyTorch Tensor
"""
image_tensor = None
for _, data in enumerate(request):
image = data.get("data")
if image is None:
image = data.get("body")
my_preprocess_1 = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
my_preprocess_2 = transforms.Compose([
transforms.RandomHorizontalFlip(p=1.),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
input_image = Image.fromarray(np.load(io.BytesIO(image)))
input_image_1 = my_preprocess_1(input_image).unsqueeze(0)
input_image_2 = my_preprocess_2(input_image).unsqueeze(0)
if image_tensor is None:
image_tensor = torch.cat(
(
input_image_1,
input_image_2
), 0)
else:
image_tensor = torch.cat(
(
image_tensor,
input_image_1,
input_image_2
), 0)
return image_tensor
def inference(self, img):
return self.model(img.to(self.device))
def postprocess(self, inference_output):
probs = F.softmax(
inference_output.detach()
)
probs = torch.mean(
probs.view(probs.size(0)//2, 2, -1),
dim=1
)
return probs.cpu().numpy().tolist()
_service = HiggsClassifier()
def handle(data, context):
if not _service.initialized:
_service.initialize(context)
if data is None:
return None
data = _service.preprocess(data)
data = _service.inference(data)
data = _service.postprocess(data)
return data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment