Skip to content

Instantly share code, notes, and snippets.

@kiransair
Created March 15, 2024 09:53
Show Gist options
  • Save kiransair/0f0872c64fedd4a77fd6579842f314c1 to your computer and use it in GitHub Desktop.
Save kiransair/0f0872c64fedd4a77fd6579842f314c1 to your computer and use it in GitHub Desktop.
TF_Forum_23085.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyPVp9woTP8Lp7W1beiD3yHG",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/kiransair/0f0872c64fedd4a77fd6579842f314c1/tf_forum_23085.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PsmJli3xJ_42",
"outputId": "a3f5871b-90dd-401b-f6f5-40174bb844da"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"DEBUG:tensorflow:Falling back to TensorFlow client; we recommended you install the Cloud TPU client directly with pip install cloud-tpu-client.\n"
]
}
],
"source": [
"import logging\n",
"logging.getLogger(\"tensorflow\").setLevel(logging.DEBUG)\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"import numpy as np\n",
"import pathlib"
]
},
{
"cell_type": "code",
"source": [
"# Load MNIST dataset\n",
"mnist = keras.datasets.mnist\n",
"(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
"\n",
"# Normalize the input image so that each pixel value is between 0 to 1.\n",
"train_images = train_images / 255.0\n",
"test_images = test_images / 255.0\n",
"\n",
"# Define the model architecture\n",
"model = keras.Sequential([\n",
" keras.layers.InputLayer(input_shape=(28, 28)),\n",
" keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
" keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),\n",
" keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
" keras.layers.Flatten(),\n",
" keras.layers.Dense(10)\n",
"])\n",
"\n",
"# Train the digit classification model\n",
"model.compile(optimizer='adam',\n",
" loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" metrics=['accuracy'])\n",
"model.fit(\n",
" train_images,\n",
" train_labels,\n",
" epochs=1,\n",
" validation_data=(test_images, test_labels)\n",
")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "yO70mHplKFve",
"outputId": "177d2a57-8209-4758-87ae-a0af69bb2f0b"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
"11490434/11490434 [==============================] - 0s 0us/step\n",
"1875/1875 [==============================] - 36s 19ms/step - loss: 0.2908 - accuracy: 0.9194 - val_loss: 0.1409 - val_accuracy: 0.9575\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<keras.src.callbacks.History at 0x7f70b8489c30>"
]
},
"metadata": {},
"execution_count": 2
}
]
},
{
"cell_type": "code",
"source": [
"converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
"tflite_model = converter.convert()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-L_OK9jmc93w",
"outputId": "ea6aa61d-a2d7-482c-c511-95845b5f029b"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO:tensorflow:Assets written to: /tmp/tmpqjec_r4i/assets\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"tflite_models_dir = pathlib.Path(\"/tmp/mnist_tflite_models/\")\n",
"tflite_models_dir.mkdir(exist_ok=True, parents=True)"
],
"metadata": {
"id": "_jhJrOugdABe"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n",
"tflite_model_file.write_bytes(tflite_model)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6umm3mw4dBmQ",
"outputId": "90a240cc-0b3e-4150-e5ca-59a04ec6d1f2"
},
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"84820"
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"source": [
"converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
"tflite_quant_model = converter.convert()\n",
"tflite_model_quant_file = tflite_models_dir/\"mnist_model_quant.tflite\"\n",
"tflite_model_quant_file.write_bytes(tflite_quant_model)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RyIS2ls5dDMd",
"outputId": "458467c8-5eb6-4b4b-d845-c26bc6a702af"
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO:tensorflow:Assets written to: /tmp/tmpj2jrid0z/assets\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"24064"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"source": [
"interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))\n",
"interpreter.allocate_tensors()"
],
"metadata": {
"id": "b-WBeH0udFkk"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"interpreter_quant = tf.lite.Interpreter(model_path=str(tflite_model_quant_file))\n",
"interpreter_quant.allocate_tensors()"
],
"metadata": {
"id": "jpgmxJVSdHVc"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"test_image = np.expand_dims(test_images[0], axis=0).astype(np.float32)\n",
"\n",
"input_index = interpreter.get_input_details()[0][\"index\"]\n",
"output_index = interpreter.get_output_details()[0][\"index\"]\n",
"\n",
"interpreter.set_tensor(input_index, test_image)\n",
"interpreter.invoke()\n",
"predictions = interpreter.get_tensor(output_index)"
],
"metadata": {
"id": "FjnCfRY1dIst"
},
"execution_count": 9,
"outputs": []
},
{
"cell_type": "code",
"source": [
" for test_image in test_images:\n",
" # Pre-processing: add batch dimension and convert to float32 to match with\n",
" # the model's input data format.\n",
" test_image = np.expand_dims(test_image, axis=0).astype(np.float32)\n",
" interpreter.set_tensor(input_index, test_image)\n",
"\n",
" # Run inference.\n",
" interpreter.invoke()\n",
"\n",
" # Post-processing: remove batch dimension and find the digit with highest\n",
" # probability.\n",
" output = interpreter.tensor(output_index)\n",
" digit = np.argmax(output()[0])\n",
" print(digit)\n",
" break;"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xkgNtNzPddsv",
"outputId": "7ca31275-9deb-4c2a-a2a9-5c89baa928ce"
},
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"7\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"image_array=test_image[0]"
],
"metadata": {
"id": "xlAJFMx7duVt"
},
"execution_count": 17,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt"
],
"metadata": {
"id": "_tmWRn4Vdy68"
},
"execution_count": 18,
"outputs": []
},
{
"cell_type": "code",
"source": [
"plt.imshow(image_array)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 449
},
"id": "kXueRJ8qeHoq",
"outputId": "5d9fb936-5e74-4527-e309-996413360b71"
},
"execution_count": 19,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f7087d7f400>"
]
},
"metadata": {},
"execution_count": 19
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "I_nfFqhGdPdP"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment