Created
August 14, 2022 01:13
-
-
Save jaeyow/aab705a405ebec15bc3de36e385bd822 to your computer and use it in GitHub Desktop.
A Metaflow step for training a model using AWS SageMaker
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@step | |
def model_training(self): | |
""" | |
Model training | |
- now training starts, first we specify the Docker image for the required algorithm, in this case linear learner | |
- create an estimator with the specified parameters, | |
- set the static hyperparameters, and SageMaker will automatically calculate those set as 'auto' | |
- calling fit() starts the training process, upto the specified number of epochs | |
- the save the model name and location for the next steps | |
- take note that we have to specify an instance for training, which may be different from the endpoint instance | |
""" | |
import boto3 | |
import sagemaker | |
from sagemaker import image_uris | |
image = image_uris.retrieve(region=boto3.Session().region_name, framework="linear-learner") | |
self.output_location = f"s3://{self.bucket}/{self.prefix}/output" | |
print(f"training artifacts will be uploaded to: {self.output_location}") | |
session = sagemaker.Session() | |
linear = sagemaker.estimator.Estimator( | |
image, | |
self.role, | |
instance_count=1, | |
instance_type="ml.c4.xlarge", | |
output_path=self.output_location, | |
sagemaker_session=session, | |
) | |
linear.set_hyperparameters( | |
epochs=10, | |
feature_dim=784, | |
predictor_type="binary_classifier", | |
mini_batch_size=200) | |
linear.fit({"train": self.s3_train_data}) | |
# after an Estimator fit, the model will have been persisted in the defined S3 output location base folder | |
self.model_data = linear.model_data | |
print(f'Estimator model data: {self.model_data}') | |
self.next(self.create_sagemaker_model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment