Skip to content

Instantly share code, notes, and snippets.

@briandw
Created December 3, 2017 23:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save briandw/ab3ff8f4a8bdf4f1e8f11834adb827d8 to your computer and use it in GitHub Desktop.
Save briandw/ab3ff8f4a8bdf4f1e8f11834adb827d8 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import the Data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from PIL import Image\n",
"\n",
"#extract the bit pattern from the filename\n",
"def decodeFileName(hexString):\n",
" number = np.zeros(25, dtype=np.bool)\n",
" bit = 0\n",
" for i in reversed(range(8)):\n",
" nibble = int(hexString[i], 16)\n",
" for j in range(4):\n",
" if (bit < 25):\n",
" number[bit] = (nibble & 0x1) \n",
" nibble = nibble >> 1\n",
" bit = bit + 1\n",
" \n",
" return number\n",
"\n",
"#take an image and return an NP float arrray\n",
"def img_to_float_array(img):\n",
" img_array = np.asarray(img, dtype='float32')\n",
" img_array = img_array.astype('float32') / 255.\n",
" img_array = img_array.reshape((img_array.shape[0], img_array.shape[1], 1))\n",
" return img_array\n",
"\n",
"#use some BASH commands to load files randomly\n",
"def random_files(numMatches, directory):\n",
" files = !ls {directory} |sort -R |tail -{numMatches}\n",
" reutrnFiles = []\n",
" for i in range(numMatches):\n",
" reutrnFiles.append(files[i])\n",
" return files\n",
" \n",
"#load the images and lables\n",
"def loadImageBatch(batchSize=32, test=False):\n",
" images = np.zeros((batchSize, imageSize, imageSize, 1))\n",
" labels = np.zeros((batchSize, 25))\n",
" \n",
" directory = \"./images2/train/\"\n",
" \n",
" if (test):\n",
" directory = \"./images2/test/\"\n",
" \n",
" imageFiles = random_files(batchSize, directory)\n",
" \n",
" for i in range(batchSize):\n",
" fileName = imageFiles[i]\n",
" labels[i] = decodeFileName(fileName[0:8])\n",
" fileName = directory+imageFiles[i]\n",
" images[i] = img_to_float_array(Image.open(fileName))\n",
" \n",
" return (images, labels)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from keras.callbacks import TensorBoard\n",
"import glob, os\n",
"\n",
"batch_size = 32\n",
"\n",
"#tensorboard callback\n",
"tensorboard = TensorBoard(log_dir='/tmp/LEDS',\n",
" histogram_freq=0,\n",
" write_graph=True,\n",
" write_images=False)\n",
"\n",
"for i in range(0, 30000):\n",
" training_batch = loadImageBatch(batch_size)\n",
" validation_batch = loadImageBatch(batch_size, test=True)\n",
" model.fit(training_batch[0],\n",
" training_batch[1],\n",
" epochs=1,\n",
" batch_size=batch_size,\n",
" validation_data=validation_batch,\n",
" callbacks=[tensorboard])\n",
" print(\"run {}\".format(i))\n",
"\n",
"model.save_weights(\"./averagePoolWeights.h5\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment