Skip to content

Instantly share code, notes, and snippets.

@furuya02
Last active April 13, 2020 07:28
Show Gist options
  • Save furuya02/1429744465506d6080813cafc8fe9579 to your computer and use it in GitHub Desktop.
Save furuya02/1429744465506d6080813cafc8fe9579 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
" ## CreateDataset\n",
" Ground Truth の出力(output.manifest)から SageMaker用のDataSet(train,validation)を作成する"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bucket_name = \"sagemaker-working-bucket-001\" \n",
"prefix = 'object-detection-with-ground-truth'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Ground Truthの出力(output.manifest)をダウンロードする\n",
"\n",
"inputManifestPath = 's3://sagemaker-xxxxxxx/GroundTruth-output/AHIRU-Project/manifests/output/output.manifest'\n",
"!aws s3 cp $inputManifestPath \"./output.manifest\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# output.manifestをtrain及びvalidationに分割する\n",
"\n",
"import json \n",
"\n",
"# 1件のデータの表現するクラス(定義されているラベルを把握するために使用する)\n",
"class Data():\n",
" def __init__(self, src):\n",
" self.src = src\n",
" # プロジェクト名の取得\n",
" for key in src.keys():\n",
" index = key.rfind(\"-metadata\")\n",
" if(index!=-1):\n",
" self.projectName = key[0:index]\n",
"\n",
" cls_map = src[self.projectName + \"-metadata\"][\"class-map\"]\n",
"\n",
" # アノテーション一覧からクラスIDの指定を取得する\n",
" self.annotations = []\n",
" for annotation in src[self.projectName][\"annotations\"]:\n",
" id = annotation['class_id']\n",
" self.annotations.append({\n",
" \"label\":cls_map[str(id)]\n",
" })\n",
" \n",
" # 指定されたラベルを含むかどうか\n",
" def exsists(self, label):\n",
" for annotation in self.annotations:\n",
" if(annotation[\"label\"] == label):\n",
" return True\n",
" return False\n",
"\n",
"# 全てのJSONデータを読み込む\n",
"def getDataList(inputPath):\n",
" dataList = []\n",
" with open(inputPath, 'r') as f:\n",
" srcList = f.read().split('\\n')\n",
" for src in srcList:\n",
" if(src != ''):\n",
" dataList.append(Data(json.loads(src)))\n",
" return dataList\n",
"\n",
"# ラベルの件数の少ない順に並べ替える(配列のインデックスが、クラスIDとなる)\n",
"def getLabel(dataList):\n",
" labels = {}\n",
" for data in dataList:\n",
" for annotation in data.annotations:\n",
" label = annotation[\"label\"]\n",
" if(label in labels):\n",
" labels[label] += 1\n",
" else:\n",
" labels[label] = 1\n",
" # ラベルの件数の少ない順に並べ替える(配列のインデックスが、クラスIDとなる)\n",
" labels = sorted(labels.items(), key=lambda x:x[1])\n",
" return labels\n",
"\n",
"# dataListをラベルを含むものと、含まないものに分割する\n",
"def deviedDataList(dataList, label):\n",
" targetList = []\n",
" unTargetList = []\n",
" for data in dataList:\n",
" if(data.exsists(label)):\n",
" targetList.append(data)\n",
" else:\n",
" unTargetList.append(data)\n",
" return (targetList, unTargetList)\n",
"\n",
"\n",
"# 学習用と検証用の分割比率\n",
"ratio = 0.8 # 80%対、20%に分割する\n",
"# GroundTruthの出力\n",
"inputManifestFile = './output.manifest'\n",
"# SageMaker用の出力\n",
"outputTrainFile = './train'\n",
"outputValidationFile = './validation'\n",
"\n",
"dataList = getDataList(inputManifestFile)\n",
"projectName = dataList[0].projectName\n",
"print(\"全データ: {}件 \".format(len(dataList)))\n",
"\n",
"# ラベルの件数の少ない順に並べ替える(配列のインデックスが、クラスIDとなる)\n",
"labels = getLabel(dataList)\n",
"for i,label in enumerate(labels):\n",
" print(\"[{}]{}: {}件 \".format(i, label[0], label[1]))\n",
"\n",
"# 保存済みリスト\n",
"storedList = [] \n",
"\n",
"# 学習及び検証用の出力\n",
"train = ''\n",
"validation = ''\n",
"\n",
"# ラベルの数の少ないものから優先して分割する\n",
"for i,label in enumerate(labels):\n",
" print(\"{} => \".format(label[0]))\n",
" # dataListをラベルが含まれるものと、含まないものに分割する\n",
" (targetList, unTargetList) = deviedDataList(dataList, label[0])\n",
" # 保存済みリストから、当該ラベルで既に保存済の件数をカウントする\n",
" (include, notInclude) = deviedDataList(storedList, label[0])\n",
" storedCounst = len(include)\n",
" # train用に必要な件数\n",
" count = int(label[1] * ratio) - storedCounst\n",
" print(\"train :{}\".format(count))\n",
" # train側への保存\n",
" for i in range(count):\n",
" data = targetList.pop()\n",
" train += json.dumps(data.src) + '\\n'\n",
" storedList.append(data)\n",
" # validation側への保存\n",
" print(\"validation :{} \".format(len(targetList)))\n",
" for data in targetList:\n",
" validation += json.dumps(data.src) + '\\n'\n",
" storedList.append(data)\n",
"\n",
" dataList = unTargetList\n",
" print(\"残り:{}件\".format(len(dataList)))\n",
"\n",
"with open(outputTrainFile, mode='w') as f:\n",
" f.write(train)\n",
"\n",
"with open(outputValidationFile, mode='w') as f:\n",
" f.write(validation)\n",
"\n",
"num_training_samples = len(train.split('\\n'))\n",
"print(\"projectName:{} num_training_samples:{}\".format(projectName, num_training_samples))\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 分割した train及び、validationをS3にアップロードする\n",
"s3_train = \"s3://{}/{}/train\".format(bucket_name, prefix)\n",
"!aws s3 cp $outputTrainFile $s3_train\n",
"s3_validation = \"s3://{}/{}/validation\".format(bucket_name, prefix)\n",
"!aws s3 cp $outputValidationFile $s3_validation\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Traning\n",
"object-detectionで、モデルの作成を行う"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import boto3\n",
"import sagemaker\n",
"from sagemaker import get_execution_role\n",
"\n",
"role = get_execution_role()\n",
"sess = sagemaker.Session()\n",
"\n",
"# 学習用のコンテナ取得\n",
"training_image = sagemaker.amazon.amazon_estimator.get_image_uri(boto3.Session().region_name, 'object-detection', repo_version='latest')\n",
"\n",
"# モデルを出力するバケットの指定\n",
"s3_output_path = \"s3://{}/{}/output\".format(bucket_name, prefix)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# パラメータの作成\n",
"import time\n",
"from time import gmtime, strftime\n",
"\n",
"job_name_prefix = 'groundtruth-to-sagemaker'\n",
"timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())\n",
"job_name = job_name_prefix + timestamp\n",
"\n",
"training_params = {\n",
" \"AlgorithmSpecification\": {\n",
" \"TrainingImage\": training_image, \n",
" \"TrainingInputMode\": \"Pipe\"\n",
" },\n",
" \"RoleArn\": role,\n",
" \"OutputDataConfig\": {\n",
" \"S3OutputPath\": s3_output_path\n",
" },\n",
" \"ResourceConfig\": {\n",
" \"InstanceCount\": 1, \n",
" \"InstanceType\": \"ml.p3.2xlarge\",\n",
" \"VolumeSizeInGB\": 50\n",
" },\n",
" \"TrainingJobName\": job_name,\n",
" \"HyperParameters\": {\n",
" \"base_network\": \"resnet-50\",\n",
" \"use_pretrained_model\": \"1\",\n",
" \"num_classes\": \"2\",\n",
" \"mini_batch_size\": \"10\",\n",
" \"epochs\": \"200\",\n",
" \"learning_rate\": \"0.001\",\n",
" \"lr_scheduler_step\": \"3,6\",\n",
" \"lr_scheduler_factor\": \"0.1\",\n",
" \"optimizer\": \"rmsprop\",\n",
" \"momentum\": \"0.9\",\n",
" \"weight_decay\": \"0.0005\",\n",
" \"overlap_threshold\": \"0.5\",\n",
" \"nms_threshold\": \"0.45\",\n",
" \"image_shape\": \"300\",\n",
" \"label_width\": \"350\",\n",
" \"num_training_samples\": str(num_training_samples)\n",
" },\n",
" \"StoppingCondition\": {\n",
" \"MaxRuntimeInSeconds\": 86400\n",
" },\n",
" \"InputDataConfig\": [\n",
" {\n",
" \"ChannelName\": \"train\",\n",
" \"DataSource\": {\n",
" \"S3DataSource\": {\n",
" \"S3DataType\": \"AugmentedManifestFile\", \n",
" \"S3Uri\": s3_train,\n",
" \"S3DataDistributionType\": \"FullyReplicated\",\n",
" \"AttributeNames\": [\"source-ref\",projectName]\n",
" }\n",
" },\n",
" \"ContentType\": \"application/x-recordio\",\n",
" \"RecordWrapperType\": \"RecordIO\",\n",
" \"CompressionType\": \"None\"\n",
" },\n",
" {\n",
" \"ChannelName\": \"validation\",\n",
" \"DataSource\": {\n",
" \"S3DataSource\": {\n",
" \"S3DataType\": \"AugmentedManifestFile\", \n",
" \"S3Uri\": s3_validation,\n",
" \"S3DataDistributionType\": \"FullyReplicated\",\n",
" \"AttributeNames\": [\"source-ref\",projectName]\n",
" }\n",
" },\n",
" \"ContentType\": \"application/x-recordio\",\n",
" \"RecordWrapperType\": \"RecordIO\",\n",
" \"CompressionType\": \"None\"\n",
" }\n",
" ]\n",
"}\n",
" \n",
"print('Training job name: {}'.format(job_name))\n",
"print('\\nInput Data Location: {}'.format(training_params['InputDataConfig'][0]['DataSource']['S3DataSource']))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Traning\n",
"client = boto3.client(service_name='sagemaker')\n",
"client.create_training_job(**training_params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 学習ジョブの状態取得\n",
"status = client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n",
"print('Training job current status: {}'.format(status))\n",
" \n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create Endpoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# model名\n",
"modelName = 'object-detection-2020-04-10-01-21-35-598'\n",
"# Endpoint configurations \n",
"configName = 'sampleConfig'\n",
"# Endpoint\n",
"endPointName = 'sampleEndPoint'\n",
"\n",
"response = client.create_endpoint_config(\n",
" EndpointConfigName = configName,\n",
" ProductionVariants=[\n",
" {\n",
" 'VariantName': 'VariantName',\n",
" 'ModelName': modelName,\n",
" 'InitialInstanceCount': 1,\n",
" 'InstanceType': 'ml.m4.4xlarge'\n",
" },\n",
" ]\n",
")\n",
"\n",
"print(response)\n",
"\n",
"response = client.create_endpoint(\n",
" EndpointName=endPointName,\n",
" EndpointConfigName=configName\n",
")\n",
"\n",
"print(response)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Detection"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import boto3\n",
"import json\n",
"\n",
"client = boto3.client('sagemaker-runtime')\n",
"\n",
"def visualize_detection(img_file, dets, classes=[], thresh=0.6):\n",
" import random\n",
" import matplotlib.pyplot as plt\n",
" import matplotlib.image as mpimg\n",
"\n",
" img=mpimg.imread(img_file)\n",
" plt.imshow(img)\n",
" height = img.shape[0]\n",
" width = img.shape[1]\n",
" colors = dict()\n",
" for det in dets:\n",
" (klass, score, x0, y0, x1, y1) = det\n",
" print(\"{} {}\".format(klass,score))\n",
" if score < thresh:\n",
" continue\n",
" cls_id = int(klass)\n",
" if cls_id not in colors:\n",
" colors[cls_id] = (random.random(), random.random(), random.random())\n",
" xmin = int(x0 * width)\n",
" ymin = int(y0 * height)\n",
" xmax = int(x1 * width)\n",
" ymax = int(y1 * height)\n",
" rect = plt.Rectangle((xmin, ymin), xmax - xmin,\n",
" ymax - ymin, fill=False,\n",
" edgecolor=colors[cls_id],\n",
" linewidth=3.5)\n",
" plt.gca().add_patch(rect)\n",
" class_name = str(cls_id)\n",
" if classes and len(classes) > cls_id:\n",
" class_name = classes[cls_id]\n",
" plt.gca().text(xmin, ymin - 2,\n",
" '{:s} {:.3f}'.format(class_name, score),\n",
" bbox=dict(facecolor=colors[cls_id], alpha=0.5),\n",
" fontsize=12, color='white')\n",
" #plt.rcParams['figure.figsize'] = (50 ,50)\n",
" plt.show()\n",
"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"file_name = 'ahiru_test_004.jpg'\n",
"endPoint = 'sampleEndPoint';\n",
"\n",
"categories = ['DOG', 'AHIRU']\n",
"\n",
"\n",
"with open(file_name, 'rb') as image:\n",
" f = image.read()\n",
" b = bytearray(f)\n",
"\n",
"endpoint_response = client.invoke_endpoint(\n",
" EndpointName=endPoint,\n",
" Body=b,\n",
" ContentType='image/jpeg'\n",
")\n",
"results = endpoint_response['Body'].read()\n",
"detections = json.loads(results)\n",
"\n",
"thresh = 0.3\n",
"visualize_detection(file_name, detections[\"prediction\"], categories, thresh)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Delete EndPoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from boto3.session import Session\n",
"import boto3\n",
"\n",
"client = boto3.client('sagemaker')\n",
"\n",
"configName = 'sampleConfig'\n",
"endPointName = 'sampleEndPoint'\n",
"\n",
"\n",
"response = client.delete_endpoint(\n",
" EndpointName=endPointName\n",
")\n",
"print(response)\n",
"\n",
"response = client.delete_endpoint_config(\n",
" EndpointConfigName=configName\n",
")\n",
"print(response)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_mxnet_p36",
"language": "python",
"name": "conda_mxnet_p36"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment