Skip to content

Instantly share code, notes, and snippets.

@Geoyi
Last active December 21, 2017 16:58
Show Gist options
  • Save Geoyi/608a6c9ca8dbee5fc62de438d41bcdc8 to your computer and use it in GitHub Desktop.
Save Geoyi/608a6c9ca8dbee5fc62de438d41bcdc8 to your computer and use it in GitHub Desktop.

Train a model with MXNet SageMaker

Amazon SageMaker is a new service from Amazon Web Service (AWS) that enables users to build, train, deploy and scale up machine learning approaches.It is pretty straightforward to use. Here are few steps to follow if you are interested in using it to train an image classification with MXNet:

  • You could go to your AWS console;
  • Log in your account, and go to the sagemaker home page
  • Create an Notebook InstanceScreenshot 2017-12-20 17.20.42 Create notebook Instance. You will have three instance options, ml.t2.medium, ml.m4.xlarge and ml.p2.xlarge, to choose from. We recommend you to us the p2 machine (a gpu machine) to train this image classification.

Once you have your p2 instance notebook set up, congratulations, you are now ready to train a building classifier. Specifically, you are going to learn how to plug your own script into Amazon SageMaker MXNet Estimator and train the building classifier we prepared.

Training a LeNet building classifier using MXNet Estimator:

  • Prepare your own training script, and you could use our mx_lenet.py here, just slightly modify it;
  • Run the script on SageMaker via an MXNet Estimator, use the script Jupyter Notebook SageMaker_mx-lenet.ipynb directly.
    • Inside of the MXNet estimator you need to have you entry-point, which is the prepared script mx_lenet.py;
    • Your SageMaker role, and it could be obtained by get_execution_role;
    • The train_instance_type, we used and also recommend GPU instance ml.p2.xlarge" here;
    • The train_instance_count is equal to 1, which means we are gonna train this LeNet on only one machine. Apparently, you could train the model by multiple machines through SageMaker.
    • Pass your training data to mxnet_estimator.fit() from a S3 bucket.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment