Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save uchidama/6389fe5567f80381feddf305d096a41f to your computer and use it in GitHub Desktop.
Save uchidama/6389fe5567f80381feddf305d096a41f to your computer and use it in GitHub Desktop.
A digit that is written on HTML Canvas is predictted by a Conv model that learned MNIST. https://github.com/uchidama/HandWriteDigitOnHtmlCanvasPrediction
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"'''\n",
"IMPORT MODULES\n",
"'''\n",
"from __future__ import print_function\n",
"import keras\n",
"from keras.datasets import mnist\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Dropout, Flatten\n",
"from keras.layers import Conv2D, MaxPooling2D\n",
"from keras import backend as K\n",
"from keras.models import load_model\n",
"\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib import cm\n",
"%matplotlib inline\n",
"\n",
"import random\n",
"\n",
"from IPython.core.display import HTML"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create a canvas to write a digit\n",
"1. Write a digit and push \"Save image to variable\".\n",
"2. Strings that are written on base64 are used on Python."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"<canvas id=\"canvas\" height=\"300px\" width=\"300px\" style=\"border: 1px solid;\"></canvas>\n",
"<p>\n",
" <button id=\"clear\">Clear</button>\n",
" <button id=\"submit\">Save image to variable</button>\n",
"</p>\n",
"<p id=\"msg\"></p>\n",
"<script>\n",
" var kernel = IPython.notebook.kernel;\n",
"\n",
" var config = {\n",
" \"linesize\": 7,\n",
" \"linecolor\": \"#000000\"\n",
" }\n",
"\n",
" var mouse = {\n",
" \"X\": null,\n",
" \"Y\": null,\n",
" }\n",
"\n",
" var clear = document.getElementById(\"clear\");\n",
" var submit = document.getElementById(\"submit\");\n",
" var canvas = document.getElementById(\"canvas\");\n",
" var ctx = canvas.getContext(\"2d\");\n",
"\n",
" clear.addEventListener(\"click\", function(){\n",
" ctx.clearRect(0, 0, canvas.width, canvas.height);\n",
" });\n",
"\n",
" submit.addEventListener(\"click\", function(){\n",
" var variable_value = 'base64_img';\n",
" kernel.execute(variable_value + \" = '\" + canvas.toDataURL() + \"'\");\n",
" msg.textContent = \"Success: \" + \"image -> \" + variable_value;\n",
" });\n",
"\n",
" canvas.addEventListener(\"mouseup\", drawEnd, false);\n",
" canvas.addEventListener(\"mouseout\", drawEnd, false);\n",
" \n",
" canvas.addEventListener(\"mousemove\", function(e){\n",
" if (e.buttons === 1 || e.witch === 1) {\n",
" var rect = e.target.getBoundingClientRect();\n",
" var X = e.clientX - rect.left;\n",
" var Y = e.clientY - rect.top;\n",
" draw(X, Y);\n",
" };\n",
" });\n",
" \n",
" canvas.addEventListener(\"mousedown\", function(e){\n",
" if (e.button === 0) {\n",
" var rect = e.target.getBoundingClientRect();\n",
" var X = e.clientX - rect.left;\n",
" var Y = e.clientY - rect.top;\n",
" draw(X, Y);\n",
" }\n",
" });\n",
"\n",
" function draw(X, Y) {\n",
" ctx.beginPath();\n",
" if (mouse.X === null) {\n",
" ctx.moveTo(X, Y);\n",
" } else {\n",
" ctx.moveTo(mouse.X, mouse.Y);\n",
" }\n",
" ctx.lineTo(X, Y);\n",
" \n",
" ctx.lineCap = \"round\";\n",
" ctx.lineWidth = config.linesize * 2;\n",
" ctx.strokeStyle = config.linecolor;\n",
" ctx.stroke();\n",
"\n",
" mouse.X = X;\n",
" mouse.Y = Y;\n",
" };\n",
" \n",
" function drawEnd() {\n",
" mouse.X = null;\n",
" mouse.Y = null;\n",
" }\n",
"</script>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"HTML('''\n",
"<canvas id=\"canvas\" height=\"300px\" width=\"300px\" style=\"border: 1px solid;\"></canvas>\n",
"<p>\n",
" <button id=\"clear\">Clear</button>\n",
" <button id=\"submit\">Save image to variable</button>\n",
"</p>\n",
"<p id=\"msg\"></p>\n",
"<script>\n",
" var kernel = IPython.notebook.kernel;\n",
"\n",
" var config = {\n",
" \"linesize\": 7,\n",
" \"linecolor\": \"#000000\"\n",
" }\n",
"\n",
" var mouse = {\n",
" \"X\": null,\n",
" \"Y\": null,\n",
" }\n",
"\n",
" var clear = document.getElementById(\"clear\");\n",
" var submit = document.getElementById(\"submit\");\n",
" var canvas = document.getElementById(\"canvas\");\n",
" var ctx = canvas.getContext(\"2d\");\n",
"\n",
" clear.addEventListener(\"click\", function(){\n",
" ctx.clearRect(0, 0, canvas.width, canvas.height);\n",
" });\n",
"\n",
" submit.addEventListener(\"click\", function(){\n",
" var variable_value = 'base64_img';\n",
" kernel.execute(variable_value + \" = '\" + canvas.toDataURL() + \"'\");\n",
" msg.textContent = \"Success: \" + \"image -> \" + variable_value;\n",
" });\n",
"\n",
" canvas.addEventListener(\"mouseup\", drawEnd, false);\n",
" canvas.addEventListener(\"mouseout\", drawEnd, false);\n",
" \n",
" canvas.addEventListener(\"mousemove\", function(e){\n",
" if (e.buttons === 1 || e.witch === 1) {\n",
" var rect = e.target.getBoundingClientRect();\n",
" var X = e.clientX - rect.left;\n",
" var Y = e.clientY - rect.top;\n",
" draw(X, Y);\n",
" };\n",
" });\n",
" \n",
" canvas.addEventListener(\"mousedown\", function(e){\n",
" if (e.button === 0) {\n",
" var rect = e.target.getBoundingClientRect();\n",
" var X = e.clientX - rect.left;\n",
" var Y = e.clientY - rect.top;\n",
" draw(X, Y);\n",
" }\n",
" });\n",
"\n",
" function draw(X, Y) {\n",
" ctx.beginPath();\n",
" if (mouse.X === null) {\n",
" ctx.moveTo(X, Y);\n",
" } else {\n",
" ctx.moveTo(mouse.X, mouse.Y);\n",
" }\n",
" ctx.lineTo(X, Y);\n",
" \n",
" ctx.lineCap = \"round\";\n",
" ctx.lineWidth = config.linesize * 2;\n",
" ctx.strokeStyle = config.linecolor;\n",
" ctx.stroke();\n",
"\n",
" mouse.X = X;\n",
" mouse.Y = Y;\n",
" };\n",
" \n",
" function drawEnd() {\n",
" mouse.X = null;\n",
" mouse.Y = null;\n",
" }\n",
"</script>\n",
"''')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# To use on Python\n",
"Base64 strings are used on Python. \n",
"Python codes split base64 strings and create image by it. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x11d89ee10>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"I think this digit is a 4\n"
]
}
],
"source": [
"import base64\n",
"import numpy as np\n",
"from io import BytesIO\n",
"from PIL import Image\n",
"from PIL import ImageOps\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"def CreateMnistDataArray(image):\n",
" ret = np.zeros(28*28).reshape((1,28,28,1))\n",
" width, height = image.size\n",
" for y in range(height):\n",
" \n",
" for x in range(width):\n",
" r, g, b,a = image.getpixel((x, y))\n",
" if a == 0:\n",
" ret[0][y][x][0] = 0.0\n",
" else:\n",
" ret[0][y][x][0] = 1.0\n",
" \n",
" return ret\n",
" \n",
"\n",
"base64_img = base64_img.split(\",\")[-1]\n",
"\n",
"img = Image.open(BytesIO(base64.b64decode(base64_img))).resize((28,28))\n",
"\n",
"plt.imshow(np.asarray(img))\n",
"plt.show()\n",
"\n",
"#print(type(img)) # <class 'PIL.PngImagePlugin.PngImageFile'>\n",
"#print(img.size) # (320, 240) \n",
"#print(img.mode) # RGBA\n",
"\n",
"mnist_type_data = CreateMnistDataArray(img)\n",
"\n",
"'''\n",
"LOAD MODEL AND PREDICT\n",
"'''\n",
"model = load_model('model_mnist_cnn.h5')\n",
"\n",
"ret = model.predict(mnist_type_data, batch_size=1) # OK\n",
"#print(\"predict ret:\", ret)\n",
"\n",
"bestnum = 0.0\n",
"bestclass = 0\n",
"for n in [0,1,2,3,4,5,6,7,8,9]:\n",
" if bestnum < ret[0][n]:\n",
" bestnum = ret[0][n]\n",
" bestclass = n\n",
"\n",
"print(\"I think this digit is a \", bestclass)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment