Skip to content

Instantly share code, notes, and snippets.

@kiransair
Created February 9, 2024 08:12
Show Gist options
  • Save kiransair/2a25968f47e48f430cc5c16886d01096 to your computer and use it in GitHub Desktop.
Save kiransair/2a25968f47e48f430cc5c16886d01096 to your computer and use it in GitHub Desktop.
TF_Forum_22518.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyMlUNT5AedoVtw8sytPOgNm",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/kiransair/2a25968f47e48f430cc5c16886d01096/tf_forum_22518.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"pip install tensorflow-model-optimization==0.7.5"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "h3kD5LCQh-wQ",
"outputId": "d7483297-9522-43af-bf0d-24ed59a80305"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: tensorflow-model-optimization==0.7.5 in /usr/local/lib/python3.10/dist-packages (0.7.5)\n",
"Requirement already satisfied: absl-py~=1.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow-model-optimization==0.7.5) (1.4.0)\n",
"Requirement already satisfied: dm-tree~=0.1.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow-model-optimization==0.7.5) (0.1.8)\n",
"Requirement already satisfied: numpy~=1.23 in /usr/local/lib/python3.10/dist-packages (from tensorflow-model-optimization==0.7.5) (1.23.5)\n",
"Requirement already satisfied: six~=1.14 in /usr/local/lib/python3.10/dist-packages (from tensorflow-model-optimization==0.7.5) (1.16.0)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import tensorflow_model_optimization as tfmot"
],
"metadata": {
"id": "pnuOmbHWiG2g"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import tensorflow as tf"
],
"metadata": {
"id": "7uO09aMiiNak"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import keras"
],
"metadata": {
"id": "W7ckxnzUlGUE"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2nKs2L9Ih6E7",
"outputId": "1cf60d74-e694-458d-ddf4-ed7dd5132fa0"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"1688/1688 [==============================] - 32s 18ms/step - loss: 0.2877 - accuracy: 0.9201 - val_loss: 0.1137 - val_accuracy: 0.9683\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<keras.src.callbacks.History at 0x7d320199cdc0>"
]
},
"metadata": {},
"execution_count": 5
}
],
"source": [
"# Load MNIST dataset\n",
"mnist = keras.datasets.mnist\n",
"(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
"\n",
"# Normalize the input image so that each pixel value is between 0 to 1.\n",
"train_images = train_images / 255.0\n",
"test_images = test_images / 255.0\n",
"\n",
"# Define the model architecture.\n",
"model = keras.Sequential([\n",
" keras.layers.InputLayer(input_shape=(28, 28)),\n",
" keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
" keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),\n",
" keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
" keras.layers.Flatten(),\n",
" keras.layers.Dense(10)\n",
"])\n",
"\n",
"# Train the digit classification model\n",
"model.compile(optimizer='adam',\n",
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" metrics=['accuracy'])\n",
"\n",
"model.fit(\n",
" train_images,\n",
" train_labels,\n",
" epochs=1,\n",
" validation_split=0.1,\n",
")"
]
},
{
"cell_type": "code",
"source": [
"import tensorflow_model_optimization as tfmot\n",
"\n",
"quantize_model = tfmot.quantization.keras.quantize_model\n",
"\n",
"# q_aware stands for for quantization aware.\n",
"q_aware_model = quantize_model(model)\n",
"\n",
"# `quantize_model` requires a recompile.\n",
"q_aware_model.compile(optimizer='adam',\n",
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" metrics=['accuracy'])\n",
"\n",
"q_aware_model.summary()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "N0Y_i4q-iFTX",
"outputId": "92cf67f4-9169-4d9b-f6f9-ba3a27df0841"
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" quantize_layer (QuantizeLa (None, 28, 28) 3 \n",
" yer) \n",
" \n",
" quant_reshape (QuantizeWra (None, 28, 28, 1) 1 \n",
" pperV2) \n",
" \n",
" quant_conv2d (QuantizeWrap (None, 26, 26, 12) 147 \n",
" perV2) \n",
" \n",
" quant_max_pooling2d (Quant (None, 13, 13, 12) 1 \n",
" izeWrapperV2) \n",
" \n",
" quant_flatten (QuantizeWra (None, 2028) 1 \n",
" pperV2) \n",
" \n",
" quant_dense (QuantizeWrapp (None, 10) 20295 \n",
" erV2) \n",
" \n",
"=================================================================\n",
"Total params: 20448 (79.88 KB)\n",
"Trainable params: 20410 (79.73 KB)\n",
"Non-trainable params: 38 (152.00 Byte)\n",
"_________________________________________________________________\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"train_images_subset = train_images[0:1000] # out of 60000\n",
"train_labels_subset = train_labels[0:1000]\n",
"\n",
"q_aware_model.fit(train_images_subset, train_labels_subset,\n",
" batch_size=500, epochs=1, validation_split=0.1)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "42IVHAbBlT8F",
"outputId": "968bfd4e-3dc3-4321-fa39-a52f905d18d3"
},
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"2/2 [==============================] - 2s 375ms/step - loss: 0.1394 - accuracy: 0.9567 - val_loss: 0.1476 - val_accuracy: 0.9600\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<keras.src.callbacks.History at 0x7d320164a2c0>"
]
},
"metadata": {},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"source": [
"converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)\n",
"\n",
"converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
"\n",
"converter.inference_input_type = tf.int8\n",
"\n",
"converter.inference_output_type = tf.int8\n",
"\n",
"quantized_tflite_model = converter.convert()\n",
"\n",
"with open('output_model_path', 'wb') as f:\n",
" f.write(quantized_tflite_model)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LvvVTkW2oDm-",
"outputId": "77de1061-c223-413a-a3fd-4540118d31b0"
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/convert.py:953: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.\n",
" warnings.warn(\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment