Skip to content

Instantly share code, notes, and snippets.

@gleeb
Created August 7, 2022 10:58
Show Gist options
  • Save gleeb/e49bd674dffdd29a7d9bc2cbf2d1dd4a to your computer and use it in GitHub Desktop.
Save gleeb/e49bd674dffdd29a7d9bc2cbf2d1dd4a to your computer and use it in GitHub Desktop.
sagemaker train and serve titanic
import argparse
import os
import numpy as np
import pandas as pd
import boto3
from io import StringIO
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.externals import joblib
from sagemaker_containers.beta.framework import (
content_types, encoders, env, modules, transformer, worker, server)
from sagemaker_inference import content_types, decoder, encoder
if __name__ =='__main__':
parser = argparse.ArgumentParser()
# Data, model, and output directories
parser.add_argument('--output-data-dir', type=str, default=os.environ.get('SM_OUTPUT_DATA_DIR'))
parser.add_argument('--model-dir', type=str, default=os.environ.get('SM_MODEL_DIR'))
parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN'))
args, _ = parser.parse_known_args()
file = os.path.join(args.train, "train.csv")
#manipulate
titanic_train_data = pd.read_csv(file, engine="python")
titanic_train_data = titanic_train_data.drop(columns='Cabin', axis=1)
titanic_train_data['Age'].fillna(titanic_train_data['Age'].mean(), inplace=True)
titanic_train_data['Embarked'].fillna(titanic_train_data['Embarked'].mode()[0], inplace=True)
titanic_train_data.replace({'Sex':{'male':0,'female':1}, 'Embarked':{'S':0,'C':1,'Q':2}}, inplace=True)
#train
X = titanic_train_data.drop(columns = ['PassengerId','Name','Ticket','Survived'],axis=1)
Y = titanic_train_data['Survived']
X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size=0.2, random_state=2)
model = LogisticRegression(solver='lbfgs', max_iter=400)
model.fit(X_train, Y_train)
# Print the coefficients of the trained classifier, and save the coefficients
joblib.dump(model, os.path.join(args.model_dir, "model.joblib"))
def model_fn(model_dir):
regressor = joblib.load(os.path.join(model_dir, "model.joblib"))
return regressor
def input_fn(input_data, content_type):
print("input_fn")
if content_type == 'text/csv':
# Read the raw input data as CSV.
df = pd.read_csv(StringIO(input_data), header=None)
return df.to_numpy()
elif content_type == 'application/x-npy':
np_array = encoders.decode(input_data, content_type)
result = np_array.astype(np.float32) if content_type in content_types.UTF8_TYPES else np_array
return result
else:
raise ValueError("{} not supported by script!".format(content_type))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment