Last active
April 13, 2020 07:28
-
-
Save furuya02/1429744465506d6080813cafc8fe9579 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
" ## CreateDataset\n", | |
" Ground Truth の出力(output.manifest)から SageMaker用のDataSet(train,validation)を作成する" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"bucket_name = \"sagemaker-working-bucket-001\" \n", | |
"prefix = 'object-detection-with-ground-truth'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Ground Truthの出力(output.manifest)をダウンロードする\n", | |
"\n", | |
"inputManifestPath = 's3://sagemaker-xxxxxxx/GroundTruth-output/AHIRU-Project/manifests/output/output.manifest'\n", | |
"!aws s3 cp $inputManifestPath \"./output.manifest\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# output.manifestをtrain及びvalidationに分割する\n", | |
"\n", | |
"import json \n", | |
"\n", | |
"# 1件のデータの表現するクラス(定義されているラベルを把握するために使用する)\n", | |
"class Data():\n", | |
" def __init__(self, src):\n", | |
" self.src = src\n", | |
" # プロジェクト名の取得\n", | |
" for key in src.keys():\n", | |
" index = key.rfind(\"-metadata\")\n", | |
" if(index!=-1):\n", | |
" self.projectName = key[0:index]\n", | |
"\n", | |
" cls_map = src[self.projectName + \"-metadata\"][\"class-map\"]\n", | |
"\n", | |
" # アノテーション一覧からクラスIDの指定を取得する\n", | |
" self.annotations = []\n", | |
" for annotation in src[self.projectName][\"annotations\"]:\n", | |
" id = annotation['class_id']\n", | |
" self.annotations.append({\n", | |
" \"label\":cls_map[str(id)]\n", | |
" })\n", | |
" \n", | |
" # 指定されたラベルを含むかどうか\n", | |
" def exsists(self, label):\n", | |
" for annotation in self.annotations:\n", | |
" if(annotation[\"label\"] == label):\n", | |
" return True\n", | |
" return False\n", | |
"\n", | |
"# 全てのJSONデータを読み込む\n", | |
"def getDataList(inputPath):\n", | |
" dataList = []\n", | |
" with open(inputPath, 'r') as f:\n", | |
" srcList = f.read().split('\\n')\n", | |
" for src in srcList:\n", | |
" if(src != ''):\n", | |
" dataList.append(Data(json.loads(src)))\n", | |
" return dataList\n", | |
"\n", | |
"# ラベルの件数の少ない順に並べ替える(配列のインデックスが、クラスIDとなる)\n", | |
"def getLabel(dataList):\n", | |
" labels = {}\n", | |
" for data in dataList:\n", | |
" for annotation in data.annotations:\n", | |
" label = annotation[\"label\"]\n", | |
" if(label in labels):\n", | |
" labels[label] += 1\n", | |
" else:\n", | |
" labels[label] = 1\n", | |
" # ラベルの件数の少ない順に並べ替える(配列のインデックスが、クラスIDとなる)\n", | |
" labels = sorted(labels.items(), key=lambda x:x[1])\n", | |
" return labels\n", | |
"\n", | |
"# dataListをラベルを含むものと、含まないものに分割する\n", | |
"def deviedDataList(dataList, label):\n", | |
" targetList = []\n", | |
" unTargetList = []\n", | |
" for data in dataList:\n", | |
" if(data.exsists(label)):\n", | |
" targetList.append(data)\n", | |
" else:\n", | |
" unTargetList.append(data)\n", | |
" return (targetList, unTargetList)\n", | |
"\n", | |
"\n", | |
"# 学習用と検証用の分割比率\n", | |
"ratio = 0.8 # 80%対、20%に分割する\n", | |
"# GroundTruthの出力\n", | |
"inputManifestFile = './output.manifest'\n", | |
"# SageMaker用の出力\n", | |
"outputTrainFile = './train'\n", | |
"outputValidationFile = './validation'\n", | |
"\n", | |
"dataList = getDataList(inputManifestFile)\n", | |
"projectName = dataList[0].projectName\n", | |
"print(\"全データ: {}件 \".format(len(dataList)))\n", | |
"\n", | |
"# ラベルの件数の少ない順に並べ替える(配列のインデックスが、クラスIDとなる)\n", | |
"labels = getLabel(dataList)\n", | |
"for i,label in enumerate(labels):\n", | |
" print(\"[{}]{}: {}件 \".format(i, label[0], label[1]))\n", | |
"\n", | |
"# 保存済みリスト\n", | |
"storedList = [] \n", | |
"\n", | |
"# 学習及び検証用の出力\n", | |
"train = ''\n", | |
"validation = ''\n", | |
"\n", | |
"# ラベルの数の少ないものから優先して分割する\n", | |
"for i,label in enumerate(labels):\n", | |
" print(\"{} => \".format(label[0]))\n", | |
" # dataListをラベルが含まれるものと、含まないものに分割する\n", | |
" (targetList, unTargetList) = deviedDataList(dataList, label[0])\n", | |
" # 保存済みリストから、当該ラベルで既に保存済の件数をカウントする\n", | |
" (include, notInclude) = deviedDataList(storedList, label[0])\n", | |
" storedCounst = len(include)\n", | |
" # train用に必要な件数\n", | |
" count = int(label[1] * ratio) - storedCounst\n", | |
" print(\"train :{}\".format(count))\n", | |
" # train側への保存\n", | |
" for i in range(count):\n", | |
" data = targetList.pop()\n", | |
" train += json.dumps(data.src) + '\\n'\n", | |
" storedList.append(data)\n", | |
" # validation側への保存\n", | |
" print(\"validation :{} \".format(len(targetList)))\n", | |
" for data in targetList:\n", | |
" validation += json.dumps(data.src) + '\\n'\n", | |
" storedList.append(data)\n", | |
"\n", | |
" dataList = unTargetList\n", | |
" print(\"残り:{}件\".format(len(dataList)))\n", | |
"\n", | |
"with open(outputTrainFile, mode='w') as f:\n", | |
" f.write(train)\n", | |
"\n", | |
"with open(outputValidationFile, mode='w') as f:\n", | |
" f.write(validation)\n", | |
"\n", | |
"num_training_samples = len(train.split('\\n'))\n", | |
"print(\"projectName:{} num_training_samples:{}\".format(projectName, num_training_samples))\n", | |
" \n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# 分割した train及び、validationをS3にアップロードする\n", | |
"s3_train = \"s3://{}/{}/train\".format(bucket_name, prefix)\n", | |
"!aws s3 cp $outputTrainFile $s3_train\n", | |
"s3_validation = \"s3://{}/{}/validation\".format(bucket_name, prefix)\n", | |
"!aws s3 cp $outputValidationFile $s3_validation\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Traning\n", | |
"object-detectionで、モデルの作成を行う" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"import boto3\n", | |
"import sagemaker\n", | |
"from sagemaker import get_execution_role\n", | |
"\n", | |
"role = get_execution_role()\n", | |
"sess = sagemaker.Session()\n", | |
"\n", | |
"# 学習用のコンテナ取得\n", | |
"training_image = sagemaker.amazon.amazon_estimator.get_image_uri(boto3.Session().region_name, 'object-detection', repo_version='latest')\n", | |
"\n", | |
"# モデルを出力するバケットの指定\n", | |
"s3_output_path = \"s3://{}/{}/output\".format(bucket_name, prefix)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# パラメータの作成\n", | |
"import time\n", | |
"from time import gmtime, strftime\n", | |
"\n", | |
"job_name_prefix = 'groundtruth-to-sagemaker'\n", | |
"timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())\n", | |
"job_name = job_name_prefix + timestamp\n", | |
"\n", | |
"training_params = {\n", | |
" \"AlgorithmSpecification\": {\n", | |
" \"TrainingImage\": training_image, \n", | |
" \"TrainingInputMode\": \"Pipe\"\n", | |
" },\n", | |
" \"RoleArn\": role,\n", | |
" \"OutputDataConfig\": {\n", | |
" \"S3OutputPath\": s3_output_path\n", | |
" },\n", | |
" \"ResourceConfig\": {\n", | |
" \"InstanceCount\": 1, \n", | |
" \"InstanceType\": \"ml.p3.2xlarge\",\n", | |
" \"VolumeSizeInGB\": 50\n", | |
" },\n", | |
" \"TrainingJobName\": job_name,\n", | |
" \"HyperParameters\": {\n", | |
" \"base_network\": \"resnet-50\",\n", | |
" \"use_pretrained_model\": \"1\",\n", | |
" \"num_classes\": \"2\",\n", | |
" \"mini_batch_size\": \"10\",\n", | |
" \"epochs\": \"200\",\n", | |
" \"learning_rate\": \"0.001\",\n", | |
" \"lr_scheduler_step\": \"3,6\",\n", | |
" \"lr_scheduler_factor\": \"0.1\",\n", | |
" \"optimizer\": \"rmsprop\",\n", | |
" \"momentum\": \"0.9\",\n", | |
" \"weight_decay\": \"0.0005\",\n", | |
" \"overlap_threshold\": \"0.5\",\n", | |
" \"nms_threshold\": \"0.45\",\n", | |
" \"image_shape\": \"300\",\n", | |
" \"label_width\": \"350\",\n", | |
" \"num_training_samples\": str(num_training_samples)\n", | |
" },\n", | |
" \"StoppingCondition\": {\n", | |
" \"MaxRuntimeInSeconds\": 86400\n", | |
" },\n", | |
" \"InputDataConfig\": [\n", | |
" {\n", | |
" \"ChannelName\": \"train\",\n", | |
" \"DataSource\": {\n", | |
" \"S3DataSource\": {\n", | |
" \"S3DataType\": \"AugmentedManifestFile\", \n", | |
" \"S3Uri\": s3_train,\n", | |
" \"S3DataDistributionType\": \"FullyReplicated\",\n", | |
" \"AttributeNames\": [\"source-ref\",projectName]\n", | |
" }\n", | |
" },\n", | |
" \"ContentType\": \"application/x-recordio\",\n", | |
" \"RecordWrapperType\": \"RecordIO\",\n", | |
" \"CompressionType\": \"None\"\n", | |
" },\n", | |
" {\n", | |
" \"ChannelName\": \"validation\",\n", | |
" \"DataSource\": {\n", | |
" \"S3DataSource\": {\n", | |
" \"S3DataType\": \"AugmentedManifestFile\", \n", | |
" \"S3Uri\": s3_validation,\n", | |
" \"S3DataDistributionType\": \"FullyReplicated\",\n", | |
" \"AttributeNames\": [\"source-ref\",projectName]\n", | |
" }\n", | |
" },\n", | |
" \"ContentType\": \"application/x-recordio\",\n", | |
" \"RecordWrapperType\": \"RecordIO\",\n", | |
" \"CompressionType\": \"None\"\n", | |
" }\n", | |
" ]\n", | |
"}\n", | |
" \n", | |
"print('Training job name: {}'.format(job_name))\n", | |
"print('\\nInput Data Location: {}'.format(training_params['InputDataConfig'][0]['DataSource']['S3DataSource']))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Traning\n", | |
"client = boto3.client(service_name='sagemaker')\n", | |
"client.create_training_job(**training_params)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# 学習ジョブの状態取得\n", | |
"status = client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n", | |
"print('Training job current status: {}'.format(status))\n", | |
" \n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Create Endpoint" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# model名\n", | |
"modelName = 'object-detection-2020-04-10-01-21-35-598'\n", | |
"# Endpoint configurations \n", | |
"configName = 'sampleConfig'\n", | |
"# Endpoint\n", | |
"endPointName = 'sampleEndPoint'\n", | |
"\n", | |
"response = client.create_endpoint_config(\n", | |
" EndpointConfigName = configName,\n", | |
" ProductionVariants=[\n", | |
" {\n", | |
" 'VariantName': 'VariantName',\n", | |
" 'ModelName': modelName,\n", | |
" 'InitialInstanceCount': 1,\n", | |
" 'InstanceType': 'ml.m4.4xlarge'\n", | |
" },\n", | |
" ]\n", | |
")\n", | |
"\n", | |
"print(response)\n", | |
"\n", | |
"response = client.create_endpoint(\n", | |
" EndpointName=endPointName,\n", | |
" EndpointConfigName=configName\n", | |
")\n", | |
"\n", | |
"print(response)\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Detection" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import boto3\n", | |
"import json\n", | |
"\n", | |
"client = boto3.client('sagemaker-runtime')\n", | |
"\n", | |
"def visualize_detection(img_file, dets, classes=[], thresh=0.6):\n", | |
" import random\n", | |
" import matplotlib.pyplot as plt\n", | |
" import matplotlib.image as mpimg\n", | |
"\n", | |
" img=mpimg.imread(img_file)\n", | |
" plt.imshow(img)\n", | |
" height = img.shape[0]\n", | |
" width = img.shape[1]\n", | |
" colors = dict()\n", | |
" for det in dets:\n", | |
" (klass, score, x0, y0, x1, y1) = det\n", | |
" print(\"{} {}\".format(klass,score))\n", | |
" if score < thresh:\n", | |
" continue\n", | |
" cls_id = int(klass)\n", | |
" if cls_id not in colors:\n", | |
" colors[cls_id] = (random.random(), random.random(), random.random())\n", | |
" xmin = int(x0 * width)\n", | |
" ymin = int(y0 * height)\n", | |
" xmax = int(x1 * width)\n", | |
" ymax = int(y1 * height)\n", | |
" rect = plt.Rectangle((xmin, ymin), xmax - xmin,\n", | |
" ymax - ymin, fill=False,\n", | |
" edgecolor=colors[cls_id],\n", | |
" linewidth=3.5)\n", | |
" plt.gca().add_patch(rect)\n", | |
" class_name = str(cls_id)\n", | |
" if classes and len(classes) > cls_id:\n", | |
" class_name = classes[cls_id]\n", | |
" plt.gca().text(xmin, ymin - 2,\n", | |
" '{:s} {:.3f}'.format(class_name, score),\n", | |
" bbox=dict(facecolor=colors[cls_id], alpha=0.5),\n", | |
" fontsize=12, color='white')\n", | |
" #plt.rcParams['figure.figsize'] = (50 ,50)\n", | |
" plt.show()\n", | |
"\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"file_name = 'ahiru_test_004.jpg'\n", | |
"endPoint = 'sampleEndPoint';\n", | |
"\n", | |
"categories = ['DOG', 'AHIRU']\n", | |
"\n", | |
"\n", | |
"with open(file_name, 'rb') as image:\n", | |
" f = image.read()\n", | |
" b = bytearray(f)\n", | |
"\n", | |
"endpoint_response = client.invoke_endpoint(\n", | |
" EndpointName=endPoint,\n", | |
" Body=b,\n", | |
" ContentType='image/jpeg'\n", | |
")\n", | |
"results = endpoint_response['Body'].read()\n", | |
"detections = json.loads(results)\n", | |
"\n", | |
"thresh = 0.3\n", | |
"visualize_detection(file_name, detections[\"prediction\"], categories, thresh)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Delete EndPoint" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from boto3.session import Session\n", | |
"import boto3\n", | |
"\n", | |
"client = boto3.client('sagemaker')\n", | |
"\n", | |
"configName = 'sampleConfig'\n", | |
"endPointName = 'sampleEndPoint'\n", | |
"\n", | |
"\n", | |
"response = client.delete_endpoint(\n", | |
" EndpointName=endPointName\n", | |
")\n", | |
"print(response)\n", | |
"\n", | |
"response = client.delete_endpoint_config(\n", | |
" EndpointConfigName=configName\n", | |
")\n", | |
"print(response)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "conda_mxnet_p36", | |
"language": "python", | |
"name": "conda_mxnet_p36" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment