Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save hariby/37a9ea4b859d35096d87a38b521e42ac to your computer and use it in GitHub Desktop.
Save hariby/37a9ea4b859d35096d87a38b521e42ac to your computer and use it in GitHub Desktop.
Use high-level APIs of SageMaker SDK for PIPE mode augmented manifest training on object_detection_tutorial.ipynb
# object_detection_tutorial.ipynb
od_model = sagemaker.estimator.Estimator(training_image,
role,
train_instance_count=1,
train_instance_type='ml.p3.2xlarge',
train_volume_size = 50,
train_max_run = 360000,
input_mode = 'Pipe',
output_path=s3_output_path,
sagemaker_session=sess)
od_model.set_hyperparameters(base_network='resnet-50',
use_pretrained_model=1,
num_classes=80,
mini_batch_size=16,
epochs=30,
learning_rate=0.001,
lr_scheduler_step='10',
lr_scheduler_factor=0.1,
optimizer='sgd',
momentum=0.9,
weight_decay=0.0005,
overlap_threshold=0.5,
nms_threshold=0.45,
image_shape=512,
label_width=600,
num_training_samples=4452)
# Create a train data channel with S3_data_type as 'AugmentedManifestFile' and attribute names.
train_data = sagemaker.session.s3_input(s3_train_data_path,
distribution='FullyReplicated',
content_type='application/x-recordio',
record_wrapping='RecordIO',
s3_data_type='AugmentedManifestFile',
attribute_names=attribute_names)
validation_data = sagemaker.session.s3_input(s3_validation_data_path,
distribution='FullyReplicated',
content_type='application/x-recordio',
record_wrapping='RecordIO',
s3_data_type='AugmentedManifestFile',
attribute_names=attribute_names)
data_channels = {'train': train_data, 'validation': validation_data}
od_model.fit(inputs=data_channels, logs=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment