Skip to content

Instantly share code, notes, and snippets.

@Muhammad-Yunus
Last active August 27, 2023 05:40
Show Gist options
  • Save Muhammad-Yunus/b5f3288615dd13df1117c7f66e96594b to your computer and use it in GitHub Desktop.
Save Muhammad-Yunus/b5f3288615dd13df1117c7f66e96594b to your computer and use it in GitHub Desktop.
keras_01_mnist.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Muhammad-Yunus/b5f3288615dd13df1117c7f66e96594b/keras_01_mnist.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Lvo0t7XVIkWZ"
},
"source": [
"### Parameters\n",
"\n",
"- The batch size, number of training epochs and location of the data files is defined here.\n",
"- Data files are hosted in a Google Cloud Storage (GCS) bucket which is why their address starts with `gs://`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cCpkS9C_H7Tl"
},
"outputs": [],
"source": [
"BATCH_SIZE = 128\n",
"EPOCHS = 10\n",
"#edit\n",
"training_images_file = 'gs://mnist-public/train-images-idx3-ubyte'\n",
"training_labels_file = 'gs://mnist-public/train-labels-idx1-ubyte'\n",
"validation_images_file = 'gs://mnist-public/t10k-images-idx3-ubyte'\n",
"validation_labels_file = 'gs://mnist-public/t10k-labels-idx1-ubyte'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qpiJj8ym0v0-"
},
"source": [
"### Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AoilhmYe1b5t"
},
"outputs": [],
"source": [
"import os, re, math, json, shutil, pprint\n",
"import PIL.Image, PIL.ImageFont, PIL.ImageDraw\n",
"import IPython.display as display\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from matplotlib import pyplot as plt\n",
"print(\"Tensorflow version \" + tf.__version__)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "qhdz68Xm3Z4Z"
},
"outputs": [],
"source": [
"#@title visualization utilities [RUN ME]\n",
"\"\"\"\n",
"This cell contains helper functions used for visualization\n",
"and downloads only. You can skip reading it. There is very\n",
"little useful Keras/Tensorflow code here.\n",
"\"\"\"\n",
"\n",
"# Matplotlib config\n",
"plt.ioff()\n",
"plt.rc('image', cmap='gray_r')\n",
"plt.rc('grid', linewidth=1)\n",
"plt.rc('xtick', top=False, bottom=False, labelsize='large')\n",
"plt.rc('ytick', left=False, right=False, labelsize='large')\n",
"plt.rc('axes', facecolor='F8F8F8', titlesize=\"large\", edgecolor='white')\n",
"plt.rc('text', color='a8151a')\n",
"plt.rc('figure', facecolor='F0F0F0', figsize=(16,9))\n",
"# Matplotlib fonts\n",
"MATPLOTLIB_FONT_DIR = os.path.join(os.path.dirname(plt.__file__), \"mpl-data/fonts/ttf\")\n",
"\n",
"# pull a batch from the datasets. This code is not very nice, it gets much better in eager mode (TODO)\n",
"def dataset_to_numpy_util(training_dataset, validation_dataset, N):\n",
"\n",
" # get one batch from each: 10000 validation digits, N training digits\n",
" batch_train_ds = training_dataset.unbatch().batch(N)\n",
"\n",
" # eager execution: loop through datasets normally\n",
" if tf.executing_eagerly():\n",
" for validation_digits, validation_labels in validation_dataset:\n",
" validation_digits = validation_digits.numpy()\n",
" validation_labels = validation_labels.numpy()\n",
" break\n",
" for training_digits, training_labels in batch_train_ds:\n",
" training_digits = training_digits.numpy()\n",
" training_labels = training_labels.numpy()\n",
" break\n",
"\n",
" else:\n",
" v_images, v_labels = validation_dataset.make_one_shot_iterator().get_next()\n",
" t_images, t_labels = batch_train_ds.make_one_shot_iterator().get_next()\n",
" # Run once, get one batch. Session.run returns numpy results\n",
" with tf.Session() as ses:\n",
" (validation_digits, validation_labels,\n",
" training_digits, training_labels) = ses.run([v_images, v_labels, t_images, t_labels])\n",
"\n",
" # these were one-hot encoded in the dataset\n",
" validation_labels = np.argmax(validation_labels, axis=1)\n",
" training_labels = np.argmax(training_labels, axis=1)\n",
"\n",
" return (training_digits, training_labels,\n",
" validation_digits, validation_labels)\n",
"\n",
"# create digits from local fonts for testing\n",
"def create_digits_from_local_fonts(n):\n",
" font_labels = []\n",
" img = PIL.Image.new('LA', (28*n, 28), color = (0,255)) # format 'LA': black in channel 0, alpha in channel 1\n",
" font1 = PIL.ImageFont.truetype(os.path.join(MATPLOTLIB_FONT_DIR, 'DejaVuSansMono-Oblique.ttf'), 25)\n",
" font2 = PIL.ImageFont.truetype(os.path.join(MATPLOTLIB_FONT_DIR, 'STIXGeneral.ttf'), 25)\n",
" d = PIL.ImageDraw.Draw(img)\n",
" for i in range(n):\n",
" font_labels.append(i%10)\n",
" d.text((7+i*28,0 if i<10 else -4), str(i%10), fill=(255,255), font=font1 if i<10 else font2)\n",
" font_digits = np.array(img.getdata(), np.float32)[:,0] / 255.0 # black in channel 0, alpha in channel 1 (discarded)\n",
" font_digits = np.reshape(np.stack(np.split(np.reshape(font_digits, [28, 28*n]), n, axis=1), axis=0), [n, 28*28])\n",
" return font_digits, font_labels\n",
"\n",
"# utility to display a row of digits with their predictions\n",
"def display_digits(digits, predictions, labels, title, n):\n",
" fig = plt.figure(figsize=(13,3))\n",
" digits = np.reshape(digits, [n, 28, 28])\n",
" digits = np.swapaxes(digits, 0, 1)\n",
" digits = np.reshape(digits, [28, 28*n])\n",
" plt.yticks([])\n",
" plt.xticks([28*x+14 for x in range(n)], predictions)\n",
" #plt.grid(b=None)\n",
" for i,t in enumerate(plt.gca().xaxis.get_ticklabels()):\n",
" if predictions[i] != labels[i]: t.set_color('red') # bad predictions in red\n",
" plt.imshow(digits)\n",
" plt.grid(None)\n",
" plt.title(title)\n",
" display.display(fig)\n",
"\n",
"# utility to display multiple rows of digits, sorted by unrecognized/recognized status\n",
"def display_top_unrecognized(digits, predictions, labels, n, lines):\n",
" idx = np.argsort(predictions==labels) # sort order: unrecognized first\n",
" for i in range(lines):\n",
" display_digits(digits[idx][i*n:(i+1)*n], predictions[idx][i*n:(i+1)*n], labels[idx][i*n:(i+1)*n],\n",
" \"{} sample validation digits out of {} with bad predictions in red and sorted first\".format(n*lines, len(digits)) if i==0 else \"\", n)\n",
"\n",
"def plot_learning_rate(lr_func, epochs):\n",
" xx = np.arange(epochs+1, dtype=np.float)\n",
" y = [lr_decay(x) for x in xx]\n",
" fig, ax = plt.subplots(figsize=(9, 6))\n",
" ax.set_xlabel('epochs')\n",
" ax.set_title('Learning rate\\ndecays from {:0.3g} to {:0.3g}'.format(y[0], y[-2]))\n",
" ax.minorticks_on()\n",
" ax.grid(True, which='major', axis='both', linestyle='-', linewidth=1)\n",
" ax.grid(True, which='minor', axis='both', linestyle=':', linewidth=0.5)\n",
" ax.step(xx,y, linewidth=3, where='post')\n",
" display.display(fig)\n",
"\n",
"class PlotTraining(tf.keras.callbacks.Callback):\n",
" def __init__(self, sample_rate=1, zoom=1):\n",
" self.sample_rate = sample_rate\n",
" self.step = 0\n",
" self.zoom = zoom\n",
" self.steps_per_epoch = 60000//BATCH_SIZE\n",
"\n",
" def on_train_begin(self, logs={}):\n",
" self.batch_history = {}\n",
" self.batch_step = []\n",
" self.epoch_history = {}\n",
" self.epoch_step = []\n",
" self.fig, self.axes = plt.subplots(1, 2, figsize=(16, 7))\n",
" plt.ioff()\n",
"\n",
" def on_batch_end(self, batch, logs={}):\n",
" if (batch % self.sample_rate) == 0:\n",
" self.batch_step.append(self.step)\n",
" for k,v in logs.items():\n",
" # do not log \"batch\" and \"size\" metrics that do not change\n",
" # do not log training accuracy \"acc\"\n",
" if k=='batch' or k=='size':# or k=='acc':\n",
" continue\n",
" self.batch_history.setdefault(k, []).append(v)\n",
" self.step += 1\n",
"\n",
" def on_epoch_end(self, epoch, logs={}):\n",
" plt.close(self.fig)\n",
" self.axes[0].cla()\n",
" self.axes[1].cla()\n",
"\n",
" self.axes[0].set_ylim(0, 1.2/self.zoom)\n",
" self.axes[1].set_ylim(1-1/self.zoom/2, 1+0.1/self.zoom/2)\n",
"\n",
" self.epoch_step.append(self.step)\n",
" for k,v in logs.items():\n",
" # only log validation metrics\n",
" if not k.startswith('val_'):\n",
" continue\n",
" self.epoch_history.setdefault(k, []).append(v)\n",
"\n",
" display.clear_output(wait=True)\n",
"\n",
" for k,v in self.batch_history.items():\n",
" self.axes[0 if k.endswith('loss') else 1].plot(np.array(self.batch_step) / self.steps_per_epoch, v, label=k)\n",
"\n",
" for k,v in self.epoch_history.items():\n",
" self.axes[0 if k.endswith('loss') else 1].plot(np.array(self.epoch_step) / self.steps_per_epoch, v, label=k, linewidth=3)\n",
"\n",
" self.axes[0].legend()\n",
" self.axes[1].legend()\n",
" self.axes[0].set_xlabel('epochs')\n",
" self.axes[1].set_xlabel('epochs')\n",
" self.axes[0].minorticks_on()\n",
" self.axes[0].grid(True, which='major', axis='both', linestyle='-', linewidth=1)\n",
" self.axes[0].grid(True, which='minor', axis='both', linestyle=':', linewidth=0.5)\n",
" self.axes[1].minorticks_on()\n",
" self.axes[1].grid(True, which='major', axis='both', linestyle='-', linewidth=1)\n",
" self.axes[1].grid(True, which='minor', axis='both', linestyle=':', linewidth=0.5)\n",
" display.display(self.fig)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Lz1Zknfk4qCx"
},
"source": [
"### tf.data.Dataset: parse files and prepare training and validation datasets\n",
"Please read the [best practices for building](https://www.tensorflow.org/guide/performance/datasets) input pipelines with tf.data.Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZE8dgyPC1_6m"
},
"outputs": [],
"source": [
"AUTO = tf.data.experimental.AUTOTUNE\n",
"\n",
"def read_label(tf_bytestring):\n",
" # Encode label into one hot encoding with depth 10\n",
" label = tf.io.decode_raw(tf_bytestring, tf.uint8)\n",
" label = tf.reshape(label, [])\n",
" label = tf.one_hot(label, 10)\n",
" return label\n",
"\n",
"def read_image(tf_bytestring):\n",
" # The image is then converted to floating point values between 0 and 1.\n",
" # We could reshape it here as a 2D image but actually we keep it as a flat array of pixels of size 28*28 because that is what our initial dense layer expects.\n",
" image = tf.io.decode_raw(tf_bytestring, tf.uint8)\n",
" image = tf.cast(image, tf.float32)/256.0\n",
" image = tf.reshape(image, [28*28])\n",
" return image\n",
"\n",
"def load_dataset(image_file, label_file):\n",
" # Use tf.data.Dataset API to load the MNIST dataset form the data files.\n",
" # We apply this function to the dataset using .map and obtain a dataset of images/labels\n",
"\n",
" # load image dataset and convert into flat array\n",
" imagedataset = tf.data.FixedLengthRecordDataset(image_file, 28*28, header_bytes=16)\n",
" imagedataset = imagedataset.map(read_image, num_parallel_calls=16)\n",
"\n",
" # load label dataset and encode with one hot encoding\n",
" labelsdataset = tf.data.FixedLengthRecordDataset(label_file, 1, header_bytes=8)\n",
" labelsdataset = labelsdataset.map(read_label, num_parallel_calls=16)\n",
"\n",
" # .zip images and labels together\n",
" dataset = tf.data.Dataset.zip((imagedataset, labelsdataset))\n",
" return dataset\n",
"\n",
"def get_training_dataset(image_file, label_file, batch_size):\n",
"\n",
" # load dataset of pairs (image, label)\n",
" dataset = load_dataset(image_file, label_file)\n",
"\n",
" # this small dataset can be entirely cached in RAM, for TPU this is important to get good performance from such a small dataset\n",
" dataset = dataset.cache()\n",
"\n",
" # shuffle helps the model to learn from different patterns in each epoch\n",
" # and prevents it from overfitting to specific patterns in the training data.\n",
" dataset = dataset.shuffle(5000, reshuffle_each_iteration=True)\n",
"\n",
" # As soon as all the entries are read from the dataset and you try to read the next element, the dataset will throw an error.\n",
" # That's where ds.repeat() comes into play. It will re-initialize the dataset\n",
" dataset = dataset.repeat()\n",
"\n",
" # pulls multiple images and labels together into a mini-batch.\n",
" # drop_remainder is important on TPU, batch size must be fixed\n",
" dataset = dataset.batch(batch_size, drop_remainder=True)\n",
"\n",
" # fetch next batches while training on the current one (-1: autotune prefetch buffer size)\n",
" dataset = dataset.prefetch(AUTO)\n",
" return dataset\n",
"\n",
"def get_validation_dataset(image_file, label_file):\n",
" dataset = load_dataset(image_file, label_file)\n",
" dataset = dataset.cache()\n",
" dataset = dataset.batch(10000, drop_remainder=True)\n",
" dataset = dataset.repeat()\n",
" return dataset\n",
"\n",
"# instantiate the datasets\n",
"training_dataset = get_training_dataset(training_images_file, training_labels_file, BATCH_SIZE)\n",
"validation_dataset = get_validation_dataset(validation_images_file, validation_labels_file)\n",
"\n",
"# For TPU, we will need a function that returns the dataset\n",
"training_input_fn = lambda: get_training_dataset(training_images_file, training_labels_file, BATCH_SIZE)\n",
"validation_input_fn = lambda: get_validation_dataset(validation_images_file, validation_labels_file)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_fXo6GuvL3EB"
},
"source": [
"### Let's have a look at the data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yZ4tjPKvL2eh"
},
"outputs": [],
"source": [
"# get N training sample\n",
"N = 25\n",
"\n",
"# convert dataset to numpy array via dataset_to_numpy_util()\n",
"(training_digits, training_labels, validation_digits, validation_labels) = dataset_to_numpy_util(training_dataset, validation_dataset, N)\n",
"\n",
"# show N digit from training dataset\n",
"display_digits(training_digits, training_labels, training_labels, \"training digits and their labels\", N)\n",
"\n",
"# show N digit from validation dataset\n",
"display_digits(validation_digits[:N], validation_labels[:N], validation_labels[:N], \"validation digits and their labels\", N)"
]
},
{
"cell_type": "code",
"source": [
"# generate N sample digit (image & label) from local font just for testing purposes\n",
"font_digits, font_labels = create_digits_from_local_fonts(N)\n",
"\n",
"# show N sample digit from local font\n",
"display_digits(font_digits, font_labels, font_labels, \"sample digit from local font with their labels\", N)"
],
"metadata": {
"id": "km-YWHIPklF-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "KIc0oqiD40HC"
},
"source": [
"# Keras model\n",
"\n",
"- All of our models will be straight sequences of layers so we can use the `tf.keras.Sequential` style to create them.\n",
"- `tf.keras.layers.Input` can be used to define it. Here, input vectors are flat vectors of pixel values of length 28*28.\n",
"- Initially here, it's a single `tf.keras.layers.Dense` layer. It has 10 neurons because we are classifying handwritten digits into 10 classes.\n",
"- It uses \"softmax\" activation because it is the last layer in a classifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "56y8UNFQIVwj"
},
"outputs": [],
"source": [
"model = tf.keras.Sequential(\n",
" [\n",
" tf.keras.layers.Input(shape=(28*28,)),\n",
" tf.keras.layers.Dense(10, activation='softmax')\n",
" ])"
]
},
{
"cell_type": "markdown",
"source": [
"- Configuring the model is done in Keras using the `model.compile` function.\n",
"- Here we use the basic optimizer '`sgd`' (Stochastic Gradient Descent).\n",
"- A classification model requires a cross-entropy loss function, called '`categorical_crossentropy`' in Keras.\n",
"- Finally, we ask the model to compute the '`accuracy`' metric, which is the percentage of correctly classified images."
],
"metadata": {
"id": "RFkd3GjCmIzn"
}
},
{
"cell_type": "code",
"source": [
"model.compile(optimizer='sgd',\n",
" loss='categorical_crossentropy',\n",
" metrics=['accuracy'])"
],
"metadata": {
"id": "dGecQbbGmJKu"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"- Keras offers the very nice `model.summary()` utility that prints the details of the model you have created.\n",
"- Your kind instructor has added the `PlotTraining()` utility (defined in the \"visualization utilities\" cell) which will display various training curves during the training."
],
"metadata": {
"id": "NTzmRv4imfnn"
}
},
{
"cell_type": "code",
"source": [
"\n",
"# print model layers\n",
"model.summary()\n",
"\n",
"# instantiate utility callback that displays training curves\n",
"plot_training = PlotTraining(sample_rate=10, zoom=1)"
],
"metadata": {
"id": "INHngyzvme-z"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "CuhDh8ao8VyB"
},
"source": [
"# Train and validate the model\n",
"- This is where the training happens, by calling `model.fit` and passing in both the training and validation datasets.\n",
"- By default, Keras runs a round of validation at the end of each epoch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TTwH_P-ZJ_xx"
},
"outputs": [],
"source": [
"steps_per_epoch = 60000//BATCH_SIZE # 60,000 items in this dataset\n",
"\n",
"print(\"Steps per epoch: \", steps_per_epoch)\n",
"print(\"Total Epoch: \", EPOCHS)\n",
"print(\"Batch Size: \", BATCH_SIZE)\n",
"\n",
"history = model.fit(\n",
" training_dataset,\n",
" steps_per_epoch = steps_per_epoch,\n",
" epochs = EPOCHS,\n",
" validation_data = validation_dataset,\n",
" validation_steps = 1,\n",
" callbacks = [plot_training] # set custom callback\n",
" )"
]
},
{
"cell_type": "markdown",
"source": [
"- In Keras, it is possible to add custom behaviors during training by using `callbacks`.\n",
"- That is how the dynamically updating training plot was implemented for this workshop."
],
"metadata": {
"id": "uVIDKwM6nKP4"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "9jFVovcUUVs1"
},
"source": [
"### Visualize predictions\n",
"\n",
"- Once the model is trained, we can get predictions from it by calling `model.predict()`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "w12OId8Mz7dF"
},
"outputs": [],
"source": [
"# recognize digits from local fonts\n",
"probabilities = model.predict(font_digits, steps=1)\n",
"predicted_labels = np.argmax(probabilities, axis=1)\n",
"display_digits(font_digits, predicted_labels, font_labels, \"predictions from local fonts (bad predictions in red)\", N)"
]
},
{
"cell_type": "code",
"source": [
"# recognize validation digits\n",
"probabilities = model.predict(validation_digits, steps=1)\n",
"predicted_labels = np.argmax(probabilities, axis=1)\n",
"display_top_unrecognized(validation_digits, predicted_labels, validation_labels, N, 7)"
],
"metadata": {
"id": "7UhdiPO0oJ1f"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "SVY1pBg5ydH-"
},
"source": [
"## License"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hleIN5-pcr0N"
},
"source": [
"\n",
"\n",
"---\n",
"\n",
"\n",
"author: Martin Gorner<br>\n",
"twitter: @martin_gorner\n",
"\n",
"\n",
"---\n",
"\n",
"\n",
"Copyright 2019 Google LLC\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
" http://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License.\n",
"\n",
"\n",
"---\n",
"\n",
"\n",
"This is not an official Google product but sample code provided for an educational purpose\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NoNIRjoRPkB3"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "keras_01_mnist.ipynb",
"provenance": [],
"include_colab_link": true
},
"environment": {
"name": "tf22-gpu.2-2.m47",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/tf22-gpu.2-2:m47"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment