Skip to content

Instantly share code, notes, and snippets.

@furuya02
Last active April 23, 2020 01:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save furuya02/9ecbc1773aff4536f113e2ab8fa6097e to your computer and use it in GitHub Desktop.
Save furuya02/9ecbc1773aff4536f113e2ab8fa6097e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sagemaker\n",
"from sagemaker import get_execution_role\n",
"from sagemaker.amazon.amazon_estimator import get_image_uri\n",
"\n",
"role = get_execution_role()\n",
"sess = sagemaker.Session()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sagemaker.model import Model\n",
"from sagemaker.predictor import RealTimePredictor, json_deserializer\n",
"\n",
"class ImagePredictor(RealTimePredictor):\n",
" def __init__(self, endpoint_name, sagemaker_session):\n",
" super().__init__(endpoint_name, sagemaker_session=sagemaker_session, serializer=None, \n",
" deserializer=json_deserializer, content_type='image/jpeg')\n",
"training_image = get_image_uri(sess.boto_region_name, 'object-detection', repo_version=\"latest\")\n",
"model_data = 's3://sagemaker-working-bucket/Sweets/output/model.tar.gz'\n",
"\n",
"model = Model(role =role,image=training_image,model_data = model_data, predictor_cls=ImagePredictor, sagemaker_session=sess)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Deploy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"object_detector = model.deploy(initial_instance_count = 1, instance_type = 'ml.m4.xlarge')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Detection"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"def visualize_detection(img_file, dets, classes=[], thresh=0.1):\n",
" import random\n",
" import matplotlib.pyplot as plt\n",
" import matplotlib.image as mpimg\n",
"\n",
" img=mpimg.imread(img_file)\n",
" plt.imshow(img)\n",
" height = img.shape[0]\n",
" width = img.shape[1]\n",
" colors = dict()\n",
" for det in dets:\n",
" (klass, score, x0, y0, x1, y1) = det\n",
" if score < thresh:\n",
" continue\n",
" cls_id = int(klass)\n",
" if cls_id not in colors:\n",
" colors[cls_id] = (random.random(), random.random(), random.random())\n",
" xmin = int(x0 * width)\n",
" ymin = int(y0 * height)\n",
" xmax = int(x1 * width)\n",
" ymax = int(y1 * height)\n",
" rect = plt.Rectangle((xmin, ymin), xmax - xmin,\n",
" ymax - ymin, fill=False,\n",
" edgecolor=colors[cls_id],\n",
" linewidth=3.5)\n",
" plt.gca().add_patch(rect)\n",
" class_name = str(cls_id)\n",
" if classes and len(classes) > cls_id:\n",
" class_name = classes[cls_id]\n",
" plt.gca().text(xmin, ymin - 2,\n",
" '{:s} {:.3f}'.format(class_name, score),\n",
" bbox=dict(facecolor=colors[cls_id], alpha=0.5),\n",
" fontsize=12, color='white')\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"file_name = 'TestData_Sweets/sweet001.png'\n",
"\n",
"with open(file_name, 'rb') as image:\n",
" f = image.read()\n",
" b = bytearray(f)\n",
" ne = open('n.txt','wb')\n",
" ne.write(b)\n",
"object_detector.content_type = 'image/jpeg'\n",
"detections = object_detector.predict(b)\n",
"\n",
"print(detections)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"object_categories = ['BlackThunder','HomePie','Bisco']\n",
"\n",
"threshold = 0.2\n",
"visualize_detection(file_name, detections['prediction'], object_categories, threshold)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Delete EndPoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sagemaker.Session().delete_endpoint(object_detector.endpoint)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_mxnet_p36",
"language": "python",
"name": "conda_mxnet_p36"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment