from abc import ABC
import json
import logging
import os
import torch
from transformers import BertModel, BertTokenizer
from torch import nn
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
class OrderClassifier(nn.Module):
def __init__(self, n_classes, path):
super(OrderClassifier, self).__init__()
self.bert = BertModel.from_pretrained(path)
self.drop = nn.Dropout(p=0.3)
self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
def forward(self, input_ids, attention_mask):
_, pooled_output = self.bert(
output = self.drop(pooled_output)
return self.out(output)
class TransformersClassifierHandler(BaseHandler, ABC):
Transformers text classifier handler class. This handler takes a text (string) and
as input and returns the classification text based on the serialized transformers checkpoint.
def __init__(self):
super(TransformersClassifierHandler, self).__init__()
self.initialized = False
def initialize(self, ctx):
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
# Read model serialize/pt file
print("model: "+model_dir+"/best_model.bin")
self.model = torch.load(model_dir+"/best_model.bin")
print("load model success")
self.tokenizer = BertTokenizer.from_pretrained(model_dir)
print("load tokenizer success")
print("eval success")
logger.debug('Transformer model from path {0} loaded successfully'.format(model_dir))
# Read the mapping file, index to object name
mapping_file_path = os.path.join(model_dir, "index_to_name.json")
if os.path.isfile(mapping_file_path):
with open(mapping_file_path) as f:
self.mapping = json.load(f)
logger.warning('Missing the index_to_name.json file. Inference output will not include class name.')
self.initialized = True
def preprocess(self, data):
""" Very basic preprocessing code - only tokenizes.
Extend with your own preprocessing steps as needed.
text = data[0].get("data")
if text is None:
text = data[0].get("body")
sentences = text.decode('utf-8')"Received text: '%s'", sentences)
inputs = self.tokenizer.encode_plus(
return inputs
def inference(self, inputs):
Predict the class of a text using a trained transformer model.
# NOTE: This makes the assumption that your model expects text to be tokenized
# with "input_ids" and "token_type_ids" - which is true for some popular transformer models, e.g. bert.
# If your transformer model expects different tokenization, adapt this code to suit
# its expected input format.
prediction = self.model(
)[0].argmax().item()"Model predicted: '%s'", prediction)
if self.mapping:
prediction = self.mapping[str(prediction)]
return [prediction]
def postprocess(self, inference_output):
# TODO: Add any needed post-processing of the model predictions here
return inference_output
_service = TransformersClassifierHandler()
def handle(data, context):
if not _service.initialized:
if data is None:
return None
data = _service.preprocess(data)
data = _service.inference(data)
data = _service.postprocess(data)
return data
except Exception as e:
raise e
