Last active August 18, 2020 21:42
The SageMaker attached notebook that allows building an MXNet model that counts shapes in an image
"cells": [
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# You can modify these ones\n",
"dataset_size = 12000\n",
"img_size = 128\n",
"shape_size = 8\n",
"max_rows = 14\n",
"max_cols = 14\n",
"target_shape = \"circle\"\n",
"# Do not modify these one\n",
"shapes = [\"circle\", \"triangle\", \"square\"]\n",
"img_width = img_height = img_size"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from random import randint, seed\n",
"# Make sure the required folder structure is in place\n",
"rmtree(\"data/raw\", True)\n",
" \n",
"# Make sure data generation is reproducible\n",
"with open(\"data/dataset.csv\", \"w+\") as dataset:\n",
" dataset.write(\"img_path,target\\n\")\n",
" for i in range(1, dataset_size + 1):\n",
" target_number_of_shapes = 0\n",
" filename = str(i).zfill(6) + \".shape\"\n",
" # Generate Shaper DSL source code\n",
" file_content = \"img_dim:{},shp_dim:{}>>>\".format(img_size, shape_size)\n",
" for x in range(0, randint(1, max_rows)):\n",
" for y in range(0, randint(1, max_cols)):\n",
" shape_idx = randint(0, len(shapes) - 1)\n",
" shape = shapes[shape_idx]\n",
" if shape == target_shape:\n",
" target_number_of_shapes += 1\n",
" file_content = file_content + shape + \",\"\n",
" file_content = file_content[:-1]\n",
" file_content += \"|\"\n",
" file_content = file_content[:-1]\n",
" file_content += \"<<<\"\n",
" with open(\"./data/raw/{}\".format(filename), \"w\") as shape_file:\n",
" shape_file.write(file_content)\n",
" dataset.write(\"./raw/{}.png,{:d}\\n\".format(filename, target_number_of_shapes))"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
" java -cp bin/shaper-all.jar com.cosminsanda.shaper.compiler.Shaper2Image \\\n",
" --source-dir data/raw\"\"\", shell=True);"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pandas import read_csv\n",
"import numpy as np\n",
"df = read_csv(\"data/dataset.csv\")\n",
"train = df.sample(frac=.8333, random_state=42)\n",
"validation = df.loc[~df.index.isin(train.index), :] \\\n",
" .sample(frac=.5, random_state=42)\n",
"test = df.loc[np.logical_not(\n",
" np.logical_xor(\n",
" ~df.index.isin(train.index),\n",
" ~df.index.isin(validation.index))), :]"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"import mxnet as mx\n",
"def transform(row):\n",
" img = cv2.imread(\"./data/{}\".format(row[\"img_path\"]))\n",
" img = mx.nd.array(img)\n",
" img = img.astype(np.float32)\n",
" img = mx.nd.transpose(img, (2, 0, 1))\n",
" img = img / 255\n",
" label = np.float32(row[\"target\"])\n",
" return img, label"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_nd = [transform(row) for _, row in train.iterrows()]\n",
"validation_nd = [transform(row) for _, row in validation.iterrows()]"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from pickle import dump\n",
"def save_to_disk(data, type):\n",
" os.makedirs(\"data/pickles/{}\".format(type))\n",
" with open(\"data/pickles/{}/data.p\".format(type), \"wb\") as out:\n",
" dump(data, out)"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"save_to_disk(train_nd, \"train\")\n",
"save_to_disk(validation_nd, \"validation\")"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sagemaker\n",
"sagemaker_session = sagemaker.Session()"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"inputs = sagemaker_session.upload_data(path=\"data/pickles\", bucket=\"adt-adhoc\", key_prefix=\"cosmin/sagemaker/demo\")"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from shutil import rmtree, copy2\n",
"rmtree(\"./test\", True)\n",
"for _, row in test.iterrows():\n",
" os.makedirs(\"test/{}\".format(row[\"target\"]), exist_ok=True)\n",
" copy2(\"./data/{}\".format(row[\"img_path\"]), \"./test/{}\".format(row[\"target\"]))"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rmtree(\"data\", True)"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"estimator = sagemaker.mxnet.MXNet(\"\", \n",
" role=sagemaker.get_execution_role(), \n",
" train_instance_count=1, \n",
" train_instance_type=\"ml.p2.xlarge\",\n",
" hyperparameters={\"epochs\": 5},\n",
" py_version=\"py3\")"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"predictor = estimator.deploy(1, \"ml.m4.xlarge\")"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import boto3\n",
"sagemaker_runtime_client = boto3.client(\"sagemaker-runtime\")"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import base64\n",
"import json\n",
"for directory in os.listdir(\"./test\"):\n",
" batch = []\n",
" for file in os.listdir(\"./test/{}\".format(directory)):\n",
" with open(\"./test/{}/{}\".format(directory, file), \"rb\") as image_file:\n",
" batch.append(base64.b64encode(\"utf-8\"))\n",
" binary_json = json.dumps(batch).encode(\"utf-8\")\n",
" response = sagemaker_runtime_client.invoke_endpoint(\n",
" EndpointName=predictor.endpoint,\n",
" Body=binary_json,\n",
" ContentType=\"application/json\",\n",
" Accept=\"application/json\"\n",
" )[\"Body\"].read()\n",
" individual_predictions = json.loads(response, encoding=\"utf-8\")\n",
" total = 0\n",
" detected = 0\n",
" for prediction in individual_predictions:\n",
" total += 1 \n",
" if int(prediction) == int(directory):\n",
" detected += 1\n",
" print(\"\"\"Images with {} circles:\n",
" Total: {}\n",
" Detected: {}\n",
" Accuracy: {:0.2f}\n",
" \"\"\".format(directory, str(total), str(detected), detected/total))"
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
"nbformat": 4,
"nbformat_minor": 2
