Skip to content

Instantly share code, notes, and snippets.

@hariby
Created June 10, 2021 13:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hariby/5d0976e5eccb6d07f11214b9b34b0a1d to your computer and use it in GitHub Desktop.
Save hariby/5d0976e5eccb6d07f11214b9b34b0a1d to your computer and use it in GitHub Desktop.
import os
import sagemaker
sagemaker_session = sagemaker.session.Session()
bucket = sagemaker_session.default_bucket()
prefix = 'sagemaker/DEMO-pytorch-mnist'
# set appropriate IAM Role
role = 'AmazonSageMaker-ExecutionRole-20210101T000000'
from sagemaker.pytorch import PyTorch
estimator = PyTorch(entry_point="mnist.py",
role=role,
framework_version='1.6.0',
py_version='py3',
instance_count=1,
instance_type='local',
# instance_type='ml.p3.2xlarge',
hyperparameters={
'batch-size':128,
'lr': 0.01,
'epochs': 1,
'backend': 'gloo'
})
estimator.fit({'training': os.path.join('s3://', bucket, prefix)})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment