-
-
Save dtt101/72e70f2c916237da15caaccf22b1a6be to your computer and use it in GitHub Desktop.
Deploying your own artificial intelligence model with Amazon Sagemaker
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%%writefile Dockerfile | |
FROM python:3.7-buster | |
# Set a docker label to advertise multi-model support on the container | |
LABEL com.amazonaws.sagemaker.capabilities.multi-models=false | |
# Set a docker label to enable container to use SAGEMAKER_BIND_TO_PORT environment variable if present | |
LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true | |
RUN apt-get update -y && apt-get -y install --no-install-recommends default-jdk | |
RUN rm -rf /var/lib/apt/lists/* | |
RUN pip --no-cache-dir install multi-model-server sagemaker-inference sagemaker-training | |
RUN pip --no-cache-dir install pandas numpy scipy scikit-learn gensim sagemaker | |
RUN pip --no-cache-dir install octis | |
ENV PYTHONUNBUFFERED=TRUE | |
ENV PYTHONDONTWRITEBYTECODE=TRUE | |
ENV PYTHONPATH="/opt/ml/code:${PATH}" | |
ENV AWS_DEFAULT_REGION='us-east-1' | |
COPY main.py /opt/ml/code/main.py | |
COPY train.py /opt/ml/code/train.py | |
COPY handler.py /opt/ml/code/serving/handler.py | |
ENTRYPOINT ["python", "/opt/ml/code/main.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%%writefile handler.py | |
import os | |
import sys | |
import joblib | |
from sagemaker_inference.default_inference_handler import DefaultInferenceHandler | |
from sagemaker_inference.default_handler_service import DefaultHandlerService | |
from sagemaker_inference import content_types, errors, transformer, encoder, decoder | |
class HandlerService(DefaultHandlerService, DefaultInferenceHandler): | |
def __init__(self): | |
op = transformer.Transformer(default_inference_handler=self) | |
super(HandlerService, self).__init__(transformer=op) | |
## Loads the model from the disk | |
def default_model_fn(self, model_dir): | |
print('estou em model_fn') | |
model_filename = os.path.join(model_dir, "pipeline_topic_modeling.joblib") | |
return joblib.load(open(model_filename, "rb")) | |
## Parse and check the format of the input data | |
def default_input_fn(self, input_data, content_type): | |
print('estou em input_fn') | |
if content_type != "text/csv": | |
raise Exception("Invalid content-type: %s" % content_type) | |
return decoder.decode(input_data, content_type).reshape(1,-1) | |
## Run our model and do the prediction | |
def default_predict_fn(self, mo_data, model): | |
print('estou em predict_fn') | |
import numpy as np | |
import pandas as pd | |
d = {'messageText_mo': [mo_data[0][0].decode('utf-8')]} | |
mo_pd = pd.DataFrame(d) | |
result = model.transform(mo_pd) | |
topicos = [np.argsort(each)[::-1][0:3] for each in result][0][0] | |
return np.array(topicos) | |
## Gets the prediction output and format it to be returned to the user | |
def default_output_fn(self, prediction, accept): | |
import numpy as np | |
print('estou em output_fn') | |
if accept != "text/csv": | |
raise Exception("Invalid accept: %s" % accept) | |
#return encoder.encode(prediction, accept) | |
return prediction.tolist() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%%writefile main.py | |
import train | |
import argparse | |
import sys | |
import os | |
import traceback | |
from sagemaker_inference import model_server | |
from sagemaker_training import environment | |
if __name__ == "__main__": | |
if len(sys.argv) < 2 or ( not sys.argv[1] in [ "serve", "train" ] ): | |
raise Exception("Invalid argument: you must inform 'train' for training mode or 'serve' predicting mode") | |
if sys.argv[1] == "train": | |
env = environment.Environment() | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--topk", type=int, default=3) | |
parser.add_argument("--bucket", type=str, default='topicmodeling') | |
parser.add_argument("--subfolder", type=str, default='sagemaker') | |
parser.add_argument("--metric", type=str, default='Diversity') | |
parser.add_argument("--train", type=str, default=env.channel_input_dirs["train"]) | |
parser.add_argument("--model-dir", type=str, default=env.model_dir) | |
parser.add_argument("--output-dir", type=str, default=env.output_dir) | |
args,unknown = parser.parse_known_args() | |
train.start(args) | |
else: | |
model_server.start_model_server(handler_service="serving.handler") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment