Created
July 17, 2020 01:05
-
-
Save fancyerii/a2ecc6d1696a6c03542a9c42e3b9083e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC | |
import json | |
import logging | |
import os | |
import torch | |
from transformers import BertModel, BertTokenizer | |
from torch import nn | |
from ts.torch_handler.base_handler import BaseHandler | |
logger = logging.getLogger(__name__) | |
class OrderClassifier(nn.Module): | |
def __init__(self, n_classes, path): | |
super(OrderClassifier, self).__init__() | |
self.bert = BertModel.from_pretrained(path) | |
self.drop = nn.Dropout(p=0.3) | |
self.out = nn.Linear(self.bert.config.hidden_size, n_classes) | |
def forward(self, input_ids, attention_mask): | |
_, pooled_output = self.bert( | |
input_ids=input_ids, | |
attention_mask=attention_mask | |
) | |
output = self.drop(pooled_output) | |
return self.out(output) | |
class TransformersClassifierHandler(BaseHandler, ABC): | |
""" | |
Transformers text classifier handler class. This handler takes a text (string) and | |
as input and returns the classification text based on the serialized transformers checkpoint. | |
""" | |
def __init__(self): | |
super(TransformersClassifierHandler, self).__init__() | |
self.initialized = False | |
def initialize(self, ctx): | |
self.manifest = ctx.manifest | |
properties = ctx.system_properties | |
model_dir = properties.get("model_dir") | |
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") | |
# Read model serialize/pt file | |
print("model: "+model_dir+"/best_model.bin") | |
self.model = torch.load(model_dir+"/best_model.bin") | |
#AutoModelForSequenceClassification.from_pretrained(model_dir) | |
print("load model success") | |
self.tokenizer = BertTokenizer.from_pretrained(model_dir) | |
print("load tokenizer success") | |
self.model.to(self.device) | |
self.model.eval() | |
print("eval success") | |
logger.debug('Transformer model from path {0} loaded successfully'.format(model_dir)) | |
# Read the mapping file, index to object name | |
mapping_file_path = os.path.join(model_dir, "index_to_name.json") | |
if os.path.isfile(mapping_file_path): | |
with open(mapping_file_path) as f: | |
self.mapping = json.load(f) | |
else: | |
logger.warning('Missing the index_to_name.json file. Inference output will not include class name.') | |
self.initialized = True | |
def preprocess(self, data): | |
""" Very basic preprocessing code - only tokenizes. | |
Extend with your own preprocessing steps as needed. | |
""" | |
text = data[0].get("data") | |
if text is None: | |
text = data[0].get("body") | |
sentences = text.decode('utf-8') | |
logger.info("Received text: '%s'", sentences) | |
inputs = self.tokenizer.encode_plus( | |
sentences, | |
add_special_tokens=True, | |
return_tensors="pt" | |
) | |
return inputs | |
def inference(self, inputs): | |
""" | |
Predict the class of a text using a trained transformer model. | |
""" | |
# NOTE: This makes the assumption that your model expects text to be tokenized | |
# with "input_ids" and "token_type_ids" - which is true for some popular transformer models, e.g. bert. | |
# If your transformer model expects different tokenization, adapt this code to suit | |
# its expected input format. | |
prediction = self.model( | |
inputs['input_ids'].to(self.device), | |
token_type_ids=inputs['token_type_ids'].to(self.device) | |
)[0].argmax().item() | |
logger.info("Model predicted: '%s'", prediction) | |
if self.mapping: | |
prediction = self.mapping[str(prediction)] | |
return [prediction] | |
def postprocess(self, inference_output): | |
# TODO: Add any needed post-processing of the model predictions here | |
return inference_output | |
_service = TransformersClassifierHandler() | |
def handle(data, context): | |
try: | |
if not _service.initialized: | |
_service.initialize(context) | |
if data is None: | |
return None | |
data = _service.preprocess(data) | |
data = _service.inference(data) | |
data = _service.postprocess(data) | |
return data | |
except Exception as e: | |
raise e |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment