Skip to content

Instantly share code, notes, and snippets.

@jaeyow
Created August 14, 2022 01:13
Show Gist options
  • Save jaeyow/aab705a405ebec15bc3de36e385bd822 to your computer and use it in GitHub Desktop.
Save jaeyow/aab705a405ebec15bc3de36e385bd822 to your computer and use it in GitHub Desktop.
A Metaflow step for training a model using AWS SageMaker
@step
def model_training(self):
"""
Model training
- now training starts, first we specify the Docker image for the required algorithm, in this case linear learner
- create an estimator with the specified parameters,
- set the static hyperparameters, and SageMaker will automatically calculate those set as 'auto'
- calling fit() starts the training process, upto the specified number of epochs
- the save the model name and location for the next steps
- take note that we have to specify an instance for training, which may be different from the endpoint instance
"""
import boto3
import sagemaker
from sagemaker import image_uris
image = image_uris.retrieve(region=boto3.Session().region_name, framework="linear-learner")
self.output_location = f"s3://{self.bucket}/{self.prefix}/output"
print(f"training artifacts will be uploaded to: {self.output_location}")
session = sagemaker.Session()
linear = sagemaker.estimator.Estimator(
image,
self.role,
instance_count=1,
instance_type="ml.c4.xlarge",
output_path=self.output_location,
sagemaker_session=session,
)
linear.set_hyperparameters(
epochs=10,
feature_dim=784,
predictor_type="binary_classifier",
mini_batch_size=200)
linear.fit({"train": self.s3_train_data})
# after an Estimator fit, the model will have been persisted in the defined S3 output location base folder
self.model_data = linear.model_data
print(f'Estimator model data: {self.model_data}')
self.next(self.create_sagemaker_model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment