Last active
April 23, 2020 01:12
-
-
Save furuya02/9ecbc1773aff4536f113e2ab8fa6097e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Setup" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import sagemaker\n", | |
"from sagemaker import get_execution_role\n", | |
"from sagemaker.amazon.amazon_estimator import get_image_uri\n", | |
"\n", | |
"role = get_execution_role()\n", | |
"sess = sagemaker.Session()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Create Model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sagemaker.model import Model\n", | |
"from sagemaker.predictor import RealTimePredictor, json_deserializer\n", | |
"\n", | |
"class ImagePredictor(RealTimePredictor):\n", | |
" def __init__(self, endpoint_name, sagemaker_session):\n", | |
" super().__init__(endpoint_name, sagemaker_session=sagemaker_session, serializer=None, \n", | |
" deserializer=json_deserializer, content_type='image/jpeg')\n", | |
"training_image = get_image_uri(sess.boto_region_name, 'object-detection', repo_version=\"latest\")\n", | |
"model_data = 's3://sagemaker-working-bucket/Sweets/output/model.tar.gz'\n", | |
"\n", | |
"model = Model(role =role,image=training_image,model_data = model_data, predictor_cls=ImagePredictor, sagemaker_session=sess)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Deploy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"object_detector = model.deploy(initial_instance_count = 1, instance_type = 'ml.m4.xlarge')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Detection" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import json\n", | |
"\n", | |
"def visualize_detection(img_file, dets, classes=[], thresh=0.1):\n", | |
" import random\n", | |
" import matplotlib.pyplot as plt\n", | |
" import matplotlib.image as mpimg\n", | |
"\n", | |
" img=mpimg.imread(img_file)\n", | |
" plt.imshow(img)\n", | |
" height = img.shape[0]\n", | |
" width = img.shape[1]\n", | |
" colors = dict()\n", | |
" for det in dets:\n", | |
" (klass, score, x0, y0, x1, y1) = det\n", | |
" if score < thresh:\n", | |
" continue\n", | |
" cls_id = int(klass)\n", | |
" if cls_id not in colors:\n", | |
" colors[cls_id] = (random.random(), random.random(), random.random())\n", | |
" xmin = int(x0 * width)\n", | |
" ymin = int(y0 * height)\n", | |
" xmax = int(x1 * width)\n", | |
" ymax = int(y1 * height)\n", | |
" rect = plt.Rectangle((xmin, ymin), xmax - xmin,\n", | |
" ymax - ymin, fill=False,\n", | |
" edgecolor=colors[cls_id],\n", | |
" linewidth=3.5)\n", | |
" plt.gca().add_patch(rect)\n", | |
" class_name = str(cls_id)\n", | |
" if classes and len(classes) > cls_id:\n", | |
" class_name = classes[cls_id]\n", | |
" plt.gca().text(xmin, ymin - 2,\n", | |
" '{:s} {:.3f}'.format(class_name, score),\n", | |
" bbox=dict(facecolor=colors[cls_id], alpha=0.5),\n", | |
" fontsize=12, color='white')\n", | |
" plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"file_name = 'TestData_Sweets/sweet001.png'\n", | |
"\n", | |
"with open(file_name, 'rb') as image:\n", | |
" f = image.read()\n", | |
" b = bytearray(f)\n", | |
" ne = open('n.txt','wb')\n", | |
" ne.write(b)\n", | |
"object_detector.content_type = 'image/jpeg'\n", | |
"detections = object_detector.predict(b)\n", | |
"\n", | |
"print(detections)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"object_categories = ['BlackThunder','HomePie','Bisco']\n", | |
"\n", | |
"threshold = 0.2\n", | |
"visualize_detection(file_name, detections['prediction'], object_categories, threshold)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Delete EndPoint" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sagemaker.Session().delete_endpoint(object_detector.endpoint)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "conda_mxnet_p36", | |
"language": "python", | |
"name": "conda_mxnet_p36" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment