Skip to content

Instantly share code, notes, and snippets.

@ctodd
Created July 9, 2020 05:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ctodd/09850a042afbe416823347758a5d28b5 to your computer and use it in GitHub Desktop.
Save ctodd/09850a042afbe416823347758a5d28b5 to your computer and use it in GitHub Desktop.
import time
import sagemaker
role = sagemaker.get_execution_role()
sess = sagemaker.Session()
training_image = sagemaker.amazon.amazon_estimator.get_image_uri(
boto3.Session().region_name, 'object-detection', repo_version='latest')
s3_output_path = 's3://{}/{}/output'.format(BUCKET, pfx_training)
training_params = \
{
"AlgorithmSpecification": {
# NB. This is one of the named constants defined in the first cell.
"TrainingImage": training_image,
"TrainingInputMode": "Pipe"
},
"RoleArn": role,
"OutputDataConfig": {
"S3OutputPath": s3_output_path
},
"ResourceConfig": {
"InstanceCount": 1,
"InstanceType": "ml.p2.xlarge",
"VolumeSizeInGB": 50
},
"TrainingJobName": training_job_name,
"HyperParameters": { # NB. These hyperparameters are at the user's discretion and are beyond the scope of this demo.
"base_network": "resnet-50",
"use_pretrained_model": "1",
"num_classes": "2",
"mini_batch_size": "1",
"epochs": "100",
"learning_rate": "0.001",
"lr_scheduler_step": "",
"lr_scheduler_factor": "0.1",
"optimizer": "sgd",
"momentum": "0.9",
"weight_decay": "0.0005",
"overlap_threshold": "0.5",
"nms_threshold": "0.45",
"image_shape": "300",
"label_width": "350",
"num_training_samples": str(num_training_samples)
},
"StoppingCondition": {
"MaxRuntimeInSeconds": 86400
},
"InputDataConfig": [
{
"ChannelName": "train",
"DataSource": {
"S3DataSource": {
"S3DataType": "AugmentedManifestFile", # NB. Augmented Manifest
"S3Uri": s3_train_data_path,
"S3DataDistributionType": "FullyReplicated",
# NB. This must correspond to the JSON field names in your augmented manifest.
"AttributeNames": ['source-ref', labeling_job_name]
}
},
"ContentType": "application/x-recordio",
"RecordWrapperType": "RecordIO",
"CompressionType": "None"
},
{
"ChannelName": "validation",
"DataSource": {
"S3DataSource": {
"S3DataType": "AugmentedManifestFile", # NB. Augmented Manifest
"S3Uri": s3_validation_data_path,
"S3DataDistributionType": "FullyReplicated",
# NB. This must correspond to the JSON field names in your augmented manifest.
"AttributeNames": ['source-ref', labeling_job_name]
}
},
"ContentType": "application/x-recordio",
"RecordWrapperType": "RecordIO",
"CompressionType": "None"
}
]
}
# Now we create the SageMaker training job.
client = boto3.client(service_name='sagemaker')
client.create_training_job(**training_params)
# Confirm that the training job has started
status = client.describe_training_job(TrainingJobName=training_job_name)['TrainingJobStatus']
print('Training job current status: {}'.format(status))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment