Skip to content

Instantly share code, notes, and snippets.

@furuya02
Last active April 22, 2020 00:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save furuya02/c54298df21bbd117d9bd16cba2de3943 to your computer and use it in GitHub Desktop.
Save furuya02/c54298df21bbd117d9bd16cba2de3943 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"import sagemaker\n",
"from sagemaker import get_execution_role\n",
"\n",
"role = get_execution_role()\n",
"print(role)\n",
"sess = sagemaker.Session()\n",
"\n",
"bucket = 'sagemaker-bucket' # custom bucket name.\n",
"prefix = 'my-sample'\n",
"\n",
"from sagemaker.amazon.amazon_estimator import get_image_uri\n",
"\n",
"training_image = get_image_uri(sess.boto_region_name, 'object-detection', repo_version=\"latest\")\n",
"print (training_image)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# DataSet"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import urllib.request\n",
"\n",
"# DataSet\n",
"train_channel = prefix + '/train'\n",
"validation_channel = prefix + '/validation'\n",
"train_annotation_channel = prefix + '/train_annotation'\n",
"validation_annotation_channel = prefix + '/validation_annotation'\n",
"\n",
"s3_train_data = 's3://{}/{}'.format(bucket, train_channel)\n",
"s3_validation_data = 's3://{}/{}'.format(bucket, validation_channel)\n",
"s3_train_annotation = 's3://{}/{}'.format(bucket, train_annotation_channel)\n",
"s3_validation_annotation = 's3://{}/{}'.format(bucket, validation_annotation_channel)\n",
"\n",
"train_data = sagemaker.session.s3_input(s3_train_data, distribution='FullyReplicated', \n",
" content_type='image/jpeg', s3_data_type='S3Prefix')\n",
"validation_data = sagemaker.session.s3_input(s3_validation_data, distribution='FullyReplicated', \n",
" content_type='image/jpeg', s3_data_type='S3Prefix')\n",
"train_annotation = sagemaker.session.s3_input(s3_train_annotation, distribution='FullyReplicated', \n",
" content_type='image/jpeg', s3_data_type='S3Prefix')\n",
"validation_annotation = sagemaker.session.s3_input(s3_validation_annotation, distribution='FullyReplicated', \n",
" content_type='image/jpeg', s3_data_type='S3Prefix')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Model\n",
"s3_model_data = \"s3://sagemaker-bucket/my-sample/output/object-detection-2020-04-21-20-51-50-507/output/model.tar.gz\" #od_model.model_data\n",
"model_data = sagemaker.session.s3_input(s3_model_data, distribution='FullyReplicated', \n",
" content_type='application/x-sagemaker-model', s3_data_type='S3Prefix')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_channels = {'train': train_data, 'validation': validation_data, 'train_annotation': train_annotation, 'validation_annotation':validation_annotation,'model': model_data}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)\n",
"\n",
"od_model = sagemaker.estimator.Estimator(training_image,\n",
" role, \n",
" train_instance_count=1, \n",
" train_instance_type='ml.p3.2xlarge',\n",
" train_volume_size = 50,\n",
" train_max_run = 360000,\n",
" input_mode = 'File',\n",
" output_path=s3_output_location,\n",
" sagemaker_session=sess)\n",
"\n",
"od_model.set_hyperparameters(base_network='resnet-50',\n",
" #use_pretrained_model=1,\n",
" num_classes=3, ### label count ###\n",
" mini_batch_size=16,\n",
" epochs=10, ### epoch count ###\n",
" learning_rate=0.001,\n",
" lr_scheduler_step='10',\n",
" lr_scheduler_factor=0.1,\n",
" optimizer='sgd',\n",
" momentum=0.9,\n",
" weight_decay=0.0005,\n",
" overlap_threshold=0.5,\n",
" nms_threshold=0.45,\n",
" image_shape=512,\n",
" label_width=600,\n",
" num_training_samples=1808) ### data count ###\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"od_model.fit(inputs=data_channels, logs=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_mxnet_p36",
"language": "python",
"name": "conda_mxnet_p36"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment