Skip to content

Instantly share code, notes, and snippets.

@spirosdim
Last active July 7, 2022 13:30
Show Gist options
  • Save spirosdim/794ae076c7363e3ca891972c8494d14e to your computer and use it in GitHub Desktop.
Save spirosdim/794ae076c7363e3ca891972c8494d14e to your computer and use it in GitHub Desktop.
AWS SageMaker inference using PyTorchModel
import torch
import torch.nn as nn
import json
from pathlib import Path
from transformers import AutoConfig, AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
class Tagger(nn.Module):
"""
Minimal model class just for inference
"""
def __init__(self):
super().__init__()
self.bert_model_name='distilbert-base-uncased'
self.label_names=['ml', 'cs', 'ph', 'mth', 'bio', 'fin']
config = AutoConfig.from_pretrained(self.bert_model_name)
self.bert = AutoModel.from_config(config)
self.classifier = nn.Linear(self.bert.config.hidden_size, len(self.label_names))
self.sigmoid_fnc = nn.Sigmoid()
def forward(self, input_ids, attention_mask):
output = self.bert(input_ids, attention_mask=attention_mask)
output = self.classifier(output.last_hidden_state[:,0]) #taking the ouput from [CLS] token
output = self.sigmoid_fnc(output)
return output
def model_fn(model_dir):
model = Tagger()
with open(Path(model_dir) / 'model.pth', 'rb') as f:
model.load_state_dict(torch.load(f))
return model
def input_fn(request_body, request_content_type):
if request_content_type == "application/json":
data = json.loads(request_body)
if isinstance(data['text'], str):
abstract = data['text']
else:
raise ValueError("Unsupported input type. Input type can be a string or an non-empty list. I got {}".format(abstract))
encoded_input = tokenizer(abstract, padding=True, truncation=True, return_tensors='pt')
return encoded_input['input_ids'].long(), encoded_input['attention_mask'].long()
raise ValueError("Unsupported content type: {}".format(request_content_type))
def predict_fn(input_data, model):
# forward pass
model.eval()
with torch.no_grad():
model_output = model(input_data[0], input_data[1])
label_nms = ['Machine Learning', 'Computer Science', 'Physics', 'Mathematics', 'Biology', 'Finance-Economics']
# results in dictionary
get_result = {}
for i in range(len(label_nms)):
get_result[label_nms[i]] = round(model_output.tolist()[0][i], 3)
# return dictonary, which will be json serializable
return get_result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment