Skip to content

Instantly share code, notes, and snippets.

@metal3d
Last active June 7, 2018 09:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save metal3d/863122401413197ddaf98b95d22890b7 to your computer and use it in GitHub Desktop.
Save metal3d/863122401413197ddaf98b95d22890b7 to your computer and use it in GitHub Desktop.
Sagemaker estimator
import sagemaker as sage
import boto3
from sagemaker import get_execution_role
role = get_execution_role()
sess = sage.Session()
account = sess.boto_session.client('sts').get_caller_identity()['Account']
region = sess.boto_session.region_name # or setup the region by hand
# the image you pushed
image = '{}.dkr.ecr.{}.amazonaws.com/sagemaker-test'.format(account, region)
# S3 prefix, where you uploaded
# your training data
prefix = 'aws-test'
# you mau use other bucket than the default, one more time
# adapt the code to your needs
data_location = 's3://{}/sagemaker-test'.format(sess.default_bucket())
tree = sage.estimator.Estimator(
image,
role, 1,
# p2 provides GPU
'ml.p2.xlarge',
# where to write model
output_path="s3://{}/output".format(sess.default_bucket()),
sagemaker_session=sess,
# adapt hyperparamters that you manage
# in train script
hyperparameters={
'epochs' : 10,
'num_classes': 3,
})
# start training
tree.fit(data_location)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment