Last active
June 7, 2018 09:45
-
-
Save metal3d/863122401413197ddaf98b95d22890b7 to your computer and use it in GitHub Desktop.
Sagemaker estimator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sagemaker as sage | |
import boto3 | |
from sagemaker import get_execution_role | |
role = get_execution_role() | |
sess = sage.Session() | |
account = sess.boto_session.client('sts').get_caller_identity()['Account'] | |
region = sess.boto_session.region_name # or setup the region by hand | |
# the image you pushed | |
image = '{}.dkr.ecr.{}.amazonaws.com/sagemaker-test'.format(account, region) | |
# S3 prefix, where you uploaded | |
# your training data | |
prefix = 'aws-test' | |
# you mau use other bucket than the default, one more time | |
# adapt the code to your needs | |
data_location = 's3://{}/sagemaker-test'.format(sess.default_bucket()) | |
tree = sage.estimator.Estimator( | |
image, | |
role, 1, | |
# p2 provides GPU | |
'ml.p2.xlarge', | |
# where to write model | |
output_path="s3://{}/output".format(sess.default_bucket()), | |
sagemaker_session=sess, | |
# adapt hyperparamters that you manage | |
# in train script | |
hyperparameters={ | |
'epochs' : 10, | |
'num_classes': 3, | |
}) | |
# start training | |
tree.fit(data_location) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment