Skip to content

Instantly share code, notes, and snippets.

@spirosdim
Created July 8, 2022 13:21
Show Gist options
  • Save spirosdim/76a6b101d3fca08de76c00fa7f1fceb3 to your computer and use it in GitHub Desktop.
Save spirosdim/76a6b101d3fca08de76c00fa7f1fceb3 to your computer and use it in GitHub Desktop.
AWS SageMaker inference script using HuggingFaceModel
from pathlib import Path
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoTokenizer, AutoModel
class Tagger(nn.Module):
"""
Minimal model class just for inference
"""
def __init__(self):
super().__init__()
self.bert_model_name='distilbert-base-uncased'
self.label_names=['ml', 'cs', 'ph', 'mth', 'bio', 'fin']
config = AutoConfig.from_pretrained(self.bert_model_name)
self.bert = AutoModel.from_config(config)
self.classifier = nn.Linear(self.bert.config.hidden_size, len(self.label_names))
self.sigmoid_fnc = nn.Sigmoid()
def forward(self, input_ids, attention_mask):
output = self.bert(input_ids, attention_mask=attention_mask)
output = self.classifier(output.last_hidden_state[:,0]) #taking the ouput from [CLS] token
output = self.sigmoid_fnc(output)
return output
def model_fn(model_dir):
model = Tagger()
with open(Path(model_dir) / 'model.pth', 'rb') as f:
model.load_state_dict(torch.load(f))
tokenizer = AutoTokenizer.from_pretrained(model.bert_model_name)
return model, tokenizer
def predict_fn(data, model_and_tokenizer):
# extract model and tokenizer
model, tokenizer = model_and_tokenizer
# tokenizer
abstract = data.pop("inputs", data)
encoded_input = tokenizer(abstract, padding=True, truncation=True, return_tensors='pt')
# forward pass
model.eval()
with torch.no_grad():
model_output = model(**encoded_input)
# results in dictionary
label_nms = ['Machine Learning', 'Computer Science', 'Physics', 'Mathematics', 'Biology', 'Finance-Economics']
get_result = {}
for i in range(len(label_nms)):
get_result[label_nms[i]] = round(model_output.tolist()[0][i], 3)
return get_result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment