Skip to content

Instantly share code, notes, and snippets.

Created July 17, 2020 01:05
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save fancyerii/a2ecc6d1696a6c03542a9c42e3b9083e to your computer and use it in GitHub Desktop.
from abc import ABC
import json
import logging
import os
import torch
from transformers import BertModel, BertTokenizer
from torch import nn
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
class OrderClassifier(nn.Module):
def __init__(self, n_classes, path):
super(OrderClassifier, self).__init__()
self.bert = BertModel.from_pretrained(path)
self.drop = nn.Dropout(p=0.3)
self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
def forward(self, input_ids, attention_mask):
_, pooled_output = self.bert(
output = self.drop(pooled_output)
return self.out(output)
class TransformersClassifierHandler(BaseHandler, ABC):
Transformers text classifier handler class. This handler takes a text (string) and
as input and returns the classification text based on the serialized transformers checkpoint.
def __init__(self):
super(TransformersClassifierHandler, self).__init__()
self.initialized = False
def initialize(self, ctx):
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
# Read model serialize/pt file
print("model: "+model_dir+"/best_model.bin")
self.model = torch.load(model_dir+"/best_model.bin")
print("load model success")
self.tokenizer = BertTokenizer.from_pretrained(model_dir)
print("load tokenizer success")
print("eval success")
logger.debug('Transformer model from path {0} loaded successfully'.format(model_dir))
# Read the mapping file, index to object name
mapping_file_path = os.path.join(model_dir, "index_to_name.json")
if os.path.isfile(mapping_file_path):
with open(mapping_file_path) as f:
self.mapping = json.load(f)
logger.warning('Missing the index_to_name.json file. Inference output will not include class name.')
self.initialized = True
def preprocess(self, data):
""" Very basic preprocessing code - only tokenizes.
Extend with your own preprocessing steps as needed.
text = data[0].get("data")
if text is None:
text = data[0].get("body")
sentences = text.decode('utf-8')"Received text: '%s'", sentences)
inputs = self.tokenizer.encode_plus(
return inputs
def inference(self, inputs):
Predict the class of a text using a trained transformer model.
# NOTE: This makes the assumption that your model expects text to be tokenized
# with "input_ids" and "token_type_ids" - which is true for some popular transformer models, e.g. bert.
# If your transformer model expects different tokenization, adapt this code to suit
# its expected input format.
prediction = self.model(
)[0].argmax().item()"Model predicted: '%s'", prediction)
if self.mapping:
prediction = self.mapping[str(prediction)]
return [prediction]
def postprocess(self, inference_output):
# TODO: Add any needed post-processing of the model predictions here
return inference_output
_service = TransformersClassifierHandler()
def handle(data, context):
if not _service.initialized:
if data is None:
return None
data = _service.preprocess(data)
data = _service.inference(data)
data = _service.postprocess(data)
return data
except Exception as e:
raise e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment