-
-
Save kiransair/2a25968f47e48f430cc5c16886d01096 to your computer and use it in GitHub Desktop.
TF_Forum_22518.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyMlUNT5AedoVtw8sytPOgNm", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/kiransair/2a25968f47e48f430cc5c16886d01096/tf_forum_22518.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"pip install tensorflow-model-optimization==0.7.5" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "h3kD5LCQh-wQ", | |
"outputId": "d7483297-9522-43af-bf0d-24ed59a80305" | |
}, | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Requirement already satisfied: tensorflow-model-optimization==0.7.5 in /usr/local/lib/python3.10/dist-packages (0.7.5)\n", | |
"Requirement already satisfied: absl-py~=1.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow-model-optimization==0.7.5) (1.4.0)\n", | |
"Requirement already satisfied: dm-tree~=0.1.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow-model-optimization==0.7.5) (0.1.8)\n", | |
"Requirement already satisfied: numpy~=1.23 in /usr/local/lib/python3.10/dist-packages (from tensorflow-model-optimization==0.7.5) (1.23.5)\n", | |
"Requirement already satisfied: six~=1.14 in /usr/local/lib/python3.10/dist-packages (from tensorflow-model-optimization==0.7.5) (1.16.0)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import tensorflow_model_optimization as tfmot" | |
], | |
"metadata": { | |
"id": "pnuOmbHWiG2g" | |
}, | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import tensorflow as tf" | |
], | |
"metadata": { | |
"id": "7uO09aMiiNak" | |
}, | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import keras" | |
], | |
"metadata": { | |
"id": "W7ckxnzUlGUE" | |
}, | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "2nKs2L9Ih6E7", | |
"outputId": "1cf60d74-e694-458d-ddf4-ed7dd5132fa0" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"1688/1688 [==============================] - 32s 18ms/step - loss: 0.2877 - accuracy: 0.9201 - val_loss: 0.1137 - val_accuracy: 0.9683\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"<keras.src.callbacks.History at 0x7d320199cdc0>" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 5 | |
} | |
], | |
"source": [ | |
"# Load MNIST dataset\n", | |
"mnist = keras.datasets.mnist\n", | |
"(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n", | |
"\n", | |
"# Normalize the input image so that each pixel value is between 0 to 1.\n", | |
"train_images = train_images / 255.0\n", | |
"test_images = test_images / 255.0\n", | |
"\n", | |
"# Define the model architecture.\n", | |
"model = keras.Sequential([\n", | |
" keras.layers.InputLayer(input_shape=(28, 28)),\n", | |
" keras.layers.Reshape(target_shape=(28, 28, 1)),\n", | |
" keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),\n", | |
" keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", | |
" keras.layers.Flatten(),\n", | |
" keras.layers.Dense(10)\n", | |
"])\n", | |
"\n", | |
"# Train the digit classification model\n", | |
"model.compile(optimizer='adam',\n", | |
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", | |
" metrics=['accuracy'])\n", | |
"\n", | |
"model.fit(\n", | |
" train_images,\n", | |
" train_labels,\n", | |
" epochs=1,\n", | |
" validation_split=0.1,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import tensorflow_model_optimization as tfmot\n", | |
"\n", | |
"quantize_model = tfmot.quantization.keras.quantize_model\n", | |
"\n", | |
"# q_aware stands for for quantization aware.\n", | |
"q_aware_model = quantize_model(model)\n", | |
"\n", | |
"# `quantize_model` requires a recompile.\n", | |
"q_aware_model.compile(optimizer='adam',\n", | |
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", | |
" metrics=['accuracy'])\n", | |
"\n", | |
"q_aware_model.summary()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "N0Y_i4q-iFTX", | |
"outputId": "92cf67f4-9169-4d9b-f6f9-ba3a27df0841" | |
}, | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Model: \"sequential\"\n", | |
"_________________________________________________________________\n", | |
" Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
" quantize_layer (QuantizeLa (None, 28, 28) 3 \n", | |
" yer) \n", | |
" \n", | |
" quant_reshape (QuantizeWra (None, 28, 28, 1) 1 \n", | |
" pperV2) \n", | |
" \n", | |
" quant_conv2d (QuantizeWrap (None, 26, 26, 12) 147 \n", | |
" perV2) \n", | |
" \n", | |
" quant_max_pooling2d (Quant (None, 13, 13, 12) 1 \n", | |
" izeWrapperV2) \n", | |
" \n", | |
" quant_flatten (QuantizeWra (None, 2028) 1 \n", | |
" pperV2) \n", | |
" \n", | |
" quant_dense (QuantizeWrapp (None, 10) 20295 \n", | |
" erV2) \n", | |
" \n", | |
"=================================================================\n", | |
"Total params: 20448 (79.88 KB)\n", | |
"Trainable params: 20410 (79.73 KB)\n", | |
"Non-trainable params: 38 (152.00 Byte)\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"train_images_subset = train_images[0:1000] # out of 60000\n", | |
"train_labels_subset = train_labels[0:1000]\n", | |
"\n", | |
"q_aware_model.fit(train_images_subset, train_labels_subset,\n", | |
" batch_size=500, epochs=1, validation_split=0.1)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "42IVHAbBlT8F", | |
"outputId": "968bfd4e-3dc3-4321-fa39-a52f905d18d3" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"2/2 [==============================] - 2s 375ms/step - loss: 0.1394 - accuracy: 0.9567 - val_loss: 0.1476 - val_accuracy: 0.9600\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"<keras.src.callbacks.History at 0x7d320164a2c0>" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 7 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)\n", | |
"\n", | |
"converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", | |
"\n", | |
"converter.inference_input_type = tf.int8\n", | |
"\n", | |
"converter.inference_output_type = tf.int8\n", | |
"\n", | |
"quantized_tflite_model = converter.convert()\n", | |
"\n", | |
"with open('output_model_path', 'wb') as f:\n", | |
" f.write(quantized_tflite_model)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "LvvVTkW2oDm-", | |
"outputId": "77de1061-c223-413a-a3fd-4540118d31b0" | |
}, | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/convert.py:953: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.\n", | |
" warnings.warn(\n" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment