Created
July 9, 2020 05:19
-
-
Save ctodd/09850a042afbe416823347758a5d28b5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import sagemaker | |
role = sagemaker.get_execution_role() | |
sess = sagemaker.Session() | |
training_image = sagemaker.amazon.amazon_estimator.get_image_uri( | |
boto3.Session().region_name, 'object-detection', repo_version='latest') | |
s3_output_path = 's3://{}/{}/output'.format(BUCKET, pfx_training) | |
training_params = \ | |
{ | |
"AlgorithmSpecification": { | |
# NB. This is one of the named constants defined in the first cell. | |
"TrainingImage": training_image, | |
"TrainingInputMode": "Pipe" | |
}, | |
"RoleArn": role, | |
"OutputDataConfig": { | |
"S3OutputPath": s3_output_path | |
}, | |
"ResourceConfig": { | |
"InstanceCount": 1, | |
"InstanceType": "ml.p2.xlarge", | |
"VolumeSizeInGB": 50 | |
}, | |
"TrainingJobName": training_job_name, | |
"HyperParameters": { # NB. These hyperparameters are at the user's discretion and are beyond the scope of this demo. | |
"base_network": "resnet-50", | |
"use_pretrained_model": "1", | |
"num_classes": "2", | |
"mini_batch_size": "1", | |
"epochs": "100", | |
"learning_rate": "0.001", | |
"lr_scheduler_step": "", | |
"lr_scheduler_factor": "0.1", | |
"optimizer": "sgd", | |
"momentum": "0.9", | |
"weight_decay": "0.0005", | |
"overlap_threshold": "0.5", | |
"nms_threshold": "0.45", | |
"image_shape": "300", | |
"label_width": "350", | |
"num_training_samples": str(num_training_samples) | |
}, | |
"StoppingCondition": { | |
"MaxRuntimeInSeconds": 86400 | |
}, | |
"InputDataConfig": [ | |
{ | |
"ChannelName": "train", | |
"DataSource": { | |
"S3DataSource": { | |
"S3DataType": "AugmentedManifestFile", # NB. Augmented Manifest | |
"S3Uri": s3_train_data_path, | |
"S3DataDistributionType": "FullyReplicated", | |
# NB. This must correspond to the JSON field names in your augmented manifest. | |
"AttributeNames": ['source-ref', labeling_job_name] | |
} | |
}, | |
"ContentType": "application/x-recordio", | |
"RecordWrapperType": "RecordIO", | |
"CompressionType": "None" | |
}, | |
{ | |
"ChannelName": "validation", | |
"DataSource": { | |
"S3DataSource": { | |
"S3DataType": "AugmentedManifestFile", # NB. Augmented Manifest | |
"S3Uri": s3_validation_data_path, | |
"S3DataDistributionType": "FullyReplicated", | |
# NB. This must correspond to the JSON field names in your augmented manifest. | |
"AttributeNames": ['source-ref', labeling_job_name] | |
} | |
}, | |
"ContentType": "application/x-recordio", | |
"RecordWrapperType": "RecordIO", | |
"CompressionType": "None" | |
} | |
] | |
} | |
# Now we create the SageMaker training job. | |
client = boto3.client(service_name='sagemaker') | |
client.create_training_job(**training_params) | |
# Confirm that the training job has started | |
status = client.describe_training_job(TrainingJobName=training_job_name)['TrainingJobStatus'] | |
print('Training job current status: {}'.format(status)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment