Created
August 7, 2022 10:58
-
-
Save gleeb/e49bd674dffdd29a7d9bc2cbf2d1dd4a to your computer and use it in GitHub Desktop.
sagemaker train and serve titanic
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import numpy as np | |
import pandas as pd | |
import boto3 | |
from io import StringIO | |
from sklearn.model_selection import train_test_split | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.metrics import accuracy_score | |
from sklearn.externals import joblib | |
from sagemaker_containers.beta.framework import ( | |
content_types, encoders, env, modules, transformer, worker, server) | |
from sagemaker_inference import content_types, decoder, encoder | |
if __name__ =='__main__': | |
parser = argparse.ArgumentParser() | |
# Data, model, and output directories | |
parser.add_argument('--output-data-dir', type=str, default=os.environ.get('SM_OUTPUT_DATA_DIR')) | |
parser.add_argument('--model-dir', type=str, default=os.environ.get('SM_MODEL_DIR')) | |
parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN')) | |
args, _ = parser.parse_known_args() | |
file = os.path.join(args.train, "train.csv") | |
#manipulate | |
titanic_train_data = pd.read_csv(file, engine="python") | |
titanic_train_data = titanic_train_data.drop(columns='Cabin', axis=1) | |
titanic_train_data['Age'].fillna(titanic_train_data['Age'].mean(), inplace=True) | |
titanic_train_data['Embarked'].fillna(titanic_train_data['Embarked'].mode()[0], inplace=True) | |
titanic_train_data.replace({'Sex':{'male':0,'female':1}, 'Embarked':{'S':0,'C':1,'Q':2}}, inplace=True) | |
#train | |
X = titanic_train_data.drop(columns = ['PassengerId','Name','Ticket','Survived'],axis=1) | |
Y = titanic_train_data['Survived'] | |
X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size=0.2, random_state=2) | |
model = LogisticRegression(solver='lbfgs', max_iter=400) | |
model.fit(X_train, Y_train) | |
# Print the coefficients of the trained classifier, and save the coefficients | |
joblib.dump(model, os.path.join(args.model_dir, "model.joblib")) | |
def model_fn(model_dir): | |
regressor = joblib.load(os.path.join(model_dir, "model.joblib")) | |
return regressor | |
def input_fn(input_data, content_type): | |
print("input_fn") | |
if content_type == 'text/csv': | |
# Read the raw input data as CSV. | |
df = pd.read_csv(StringIO(input_data), header=None) | |
return df.to_numpy() | |
elif content_type == 'application/x-npy': | |
np_array = encoders.decode(input_data, content_type) | |
result = np_array.astype(np.float32) if content_type in content_types.UTF8_TYPES else np_array | |
return result | |
else: | |
raise ValueError("{} not supported by script!".format(content_type)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment