Skip to content

Instantly share code, notes, and snippets.

@allenday
Created February 28, 2022 09:32
Show Gist options
  • Save allenday/24bbdbc72b6019b919bc37d571665da5 to your computer and use it in GitHub Desktop.
Save allenday/24bbdbc72b6019b919bc37d571665da5 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "771857ee",
"metadata": {},
"outputs": [],
"source": [
"import datetime\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n",
"from classification_models.keras import Classifiers\n",
"from PIL import Image"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f030ab29",
"metadata": {},
"outputs": [],
"source": [
"ds, info = tfds.load(\n",
" \"mnist\",\n",
" split=[\"train\",\"test\"],\n",
" shuffle_files=True,\n",
" as_supervised=True,\n",
" with_info=True,\n",
" )\n",
"NUM_CLASSES = info.features[\"label\"].num_classes\n",
"SIZE = (28,28)\n",
"BATCH_SIZE=400\n",
"EPOCHS=5\n",
"model_base='mnist-model'\n",
"thresh=0\n",
"labels = info.features['label'].names\n",
"ds_train, ds_validation = ds[0], ds[1]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "179be129",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tfds.core.DatasetInfo(\n",
" name='mnist',\n",
" full_name='mnist/3.0.1',\n",
" description=\"\"\"\n",
" The MNIST database of handwritten digits.\n",
" \"\"\",\n",
" homepage='http://yann.lecun.com/exdb/mnist/',\n",
" data_path='/home/allenday/tensorflow_datasets/mnist/3.0.1',\n",
" download_size=11.06 MiB,\n",
" dataset_size=21.00 MiB,\n",
" features=FeaturesDict({\n",
" 'image': Image(shape=(28, 28, 1), dtype=tf.uint8),\n",
" 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),\n",
" }),\n",
" supervised_keys=('image', 'label'),\n",
" disable_shuffling=False,\n",
" splits={\n",
" 'test': <SplitInfo num_examples=10000, num_shards=1>,\n",
" 'train': <SplitInfo num_examples=60000, num_shards=1>,\n",
" },\n",
" citation=\"\"\"@article{lecun2010mnist,\n",
" title={MNIST handwritten digit database},\n",
" author={LeCun, Yann and Cortes, Corinna and Burges, CJ},\n",
" journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},\n",
" volume={2},\n",
" year={2010}\n",
" }\"\"\",\n",
")\n",
"Number of classes: 10\n",
"Number of training samples: 60000\n",
"Number of validation samples: 10000\n"
]
}
],
"source": [
"print(info)\n",
"print(\"Number of classes: %d\" % NUM_CLASSES)\n",
"print(\"Number of training samples: %d\" % tf.data.experimental.cardinality(ds_train))\n",
"print(\"Number of validation samples: %d\" % tf.data.experimental.cardinality(ds_validation))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "5e65f447",
"metadata": {},
"outputs": [],
"source": [
"ds_train = ds_train.map(lambda x, y: (tf.image.resize(tf.image.grayscale_to_rgb(x), SIZE), y))\n",
"ds_validation = ds_validation.map(lambda x, y: (tf.image.resize(tf.image.grayscale_to_rgb(x), SIZE), y))\n",
"\n",
"# As you fit the dataset in memory, cache it before shuffling for a better performance.\n",
"# Note: Random transformations should be applied after caching.\n",
"ds_train = ds_train.cache()\n",
"# For true randomness, set the shuffle buffer to the full dataset size.\n",
"# Note: For large datasets that can't fit in memory, use buffer_size=1000 if your system allows it.\n",
"ds_train = ds_train.shuffle(1000)\n",
"# Batch elements of the dataset after shuffling to get unique batches at each epoch.\n",
"ds_train = ds_train.batch(BATCH_SIZE)\n",
"# It is good practice to end the pipeline by prefetching for performance.\n",
"ds_train = ds_train.prefetch(tf.data.AUTOTUNE)\n",
"\n",
"ds_validation = ds_validation.cache()\n",
"ds_validation = ds_validation.batch(BATCH_SIZE)\n",
"ds_validation = ds_validation.prefetch(tf.data.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "cc86a2ed",
"metadata": {},
"outputs": [],
"source": [
"ResNet18, preprocess_input = Classifiers.get('resnet18')\n",
"\n",
"scale_layer = tf.keras.layers.Rescaling(scale=1 / 127.5, offset=-1)\n",
"\n",
"\n",
"base_model = ResNet18(\n",
" input_shape=(28,28,3),\n",
" weights='imagenet',\n",
"# weights=None,\n",
" include_top=False\n",
" )\n",
"base_model.trainable = True\n",
"#base_model.trainable = False\n",
"\n",
"inputs = tf.keras.Input(shape=(28,28,3))\n",
"x = scale_layer(inputs)\n",
"\n",
"x = base_model(x, training=True)\n",
"x = tf.keras.layers.GlobalAveragePooling2D()(x)\n",
"x = tf.keras.layers.Dense(NUM_CLASSES, activation=None)(x)\n",
"outputs = tf.keras.layers.Activation(activation=\"softmax\", name=\"activation\")(x)\n",
"loss_function = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "02002aee",
"metadata": {},
"outputs": [],
"source": [
"# Define the per-epoch callbacks\n",
"logdir = \"logs/position/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
"cbTensorBoard = tf.keras.callbacks.TensorBoard(log_dir = logdir, histogram_freq = 1)\n",
"cbEarlyStop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=100)\n",
"cbCheckPoint = tf.keras.callbacks.ModelCheckpoint(\n",
" filepath=model_base,\n",
" monitor = \"val_sparse_categorical_accuracy\",\n",
" verbose=1,\n",
" save_best_only=True,\n",
" mode='max',\n",
" initial_value_threshold=thresh\n",
" )\n",
"opt = tf.keras.optimizers.SGD(learning_rate=0.001,momentum=0.9)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "06ea64b9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model_1\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"input_1 (InputLayer) [(None, 28, 28, 3)] 0 \n",
"_________________________________________________________________\n",
"rescaling (Rescaling) (None, 28, 28, 3) 0 \n",
"_________________________________________________________________\n",
"model (Functional) (None, 1, 1, 512) 11186889 \n",
"_________________________________________________________________\n",
"global_average_pooling2d (Gl (None, 512) 0 \n",
"_________________________________________________________________\n",
"dense (Dense) (None, 10) 5130 \n",
"_________________________________________________________________\n",
"activation (Activation) (None, 10) 0 \n",
"=================================================================\n",
"Total params: 11,192,019\n",
"Trainable params: 11,184,077\n",
"Non-trainable params: 7,942\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model = tf.keras.Model(inputs, outputs)\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "558affcc",
"metadata": {},
"outputs": [],
"source": [
"model.compile(\n",
" optimizer=opt,\n",
" loss=loss_function,\n",
" metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "ccc6d981",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/allenday/venv/tf/lib/python3.6/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n",
" category=CustomMaskWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"150/150 [==============================] - 8s 31ms/step - loss: 0.4073 - sparse_categorical_accuracy: 0.8792 - val_loss: 0.1005 - val_sparse_categorical_accuracy: 0.9683\n",
"\n",
"Epoch 00001: val_sparse_categorical_accuracy improved from -inf to 0.96830, saving model to mnist-model\n",
"INFO:tensorflow:Assets written to: mnist-model/assets\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: mnist-model/assets\n",
"/home/allenday/venv/tf/lib/python3.6/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n",
" category=CustomMaskWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/5\n",
"150/150 [==============================] - 4s 24ms/step - loss: 0.0747 - sparse_categorical_accuracy: 0.9776 - val_loss: 0.0720 - val_sparse_categorical_accuracy: 0.9770\n",
"\n",
"Epoch 00002: val_sparse_categorical_accuracy improved from 0.96830 to 0.97700, saving model to mnist-model\n",
"INFO:tensorflow:Assets written to: mnist-model/assets\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: mnist-model/assets\n",
"/home/allenday/venv/tf/lib/python3.6/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n",
" category=CustomMaskWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/5\n",
"150/150 [==============================] - 4s 24ms/step - loss: 0.0469 - sparse_categorical_accuracy: 0.9859 - val_loss: 0.0608 - val_sparse_categorical_accuracy: 0.9806\n",
"\n",
"Epoch 00003: val_sparse_categorical_accuracy improved from 0.97700 to 0.98060, saving model to mnist-model\n",
"INFO:tensorflow:Assets written to: mnist-model/assets\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: mnist-model/assets\n",
"/home/allenday/venv/tf/lib/python3.6/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n",
" category=CustomMaskWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/5\n",
"150/150 [==============================] - 4s 24ms/step - loss: 0.0342 - sparse_categorical_accuracy: 0.9901 - val_loss: 0.0546 - val_sparse_categorical_accuracy: 0.9817\n",
"\n",
"Epoch 00004: val_sparse_categorical_accuracy improved from 0.98060 to 0.98170, saving model to mnist-model\n",
"INFO:tensorflow:Assets written to: mnist-model/assets\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: mnist-model/assets\n",
"/home/allenday/venv/tf/lib/python3.6/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n",
" category=CustomMaskWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/5\n",
"150/150 [==============================] - 4s 24ms/step - loss: 0.0259 - sparse_categorical_accuracy: 0.9929 - val_loss: 0.0509 - val_sparse_categorical_accuracy: 0.9832\n",
"\n",
"Epoch 00005: val_sparse_categorical_accuracy improved from 0.98170 to 0.98320, saving model to mnist-model\n",
"INFO:tensorflow:Assets written to: mnist-model/assets\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: mnist-model/assets\n",
"/home/allenday/venv/tf/lib/python3.6/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n",
" category=CustomMaskWarning)\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7ffa206281d0>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(\n",
" ds_train,\n",
" epochs=EPOCHS,\n",
" validation_data=ds_validation,\n",
" callbacks=[cbCheckPoint, cbEarlyStop, cbTensorBoard],\n",
" verbose=1,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "90dc0b84",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"true=3\n",
"prob=tf.Tensor(\n",
"[0.1004887 0.09967535 0.10066849 0.09976515 0.09914842 0.10214555\n",
" 0.09826583 0.09943706 0.09950258 0.10090292], shape=(10,), dtype=float32)\n",
"pred=5\t0.102145545\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAIAAAD9b0jDAAAB6UlEQVR4nO2Uva8hURyGT8b4nEaDAoVK4aNRShQKiU5rIjT+gGkoRhQSjaChQD9RINGoFEKtnIRIRMZHQggxEjKT+N3ZQlZu7iI79la7+5TvyXnOez5yEPrPX4JCofD5fKVSSRCEj59IkuRwON40er3eTqcDj6Bp+tks/NmAUqnMZDIURanVaoTQ9Xqt1WqiKDqdzkAggBDa7/eyO8ZisVujzWZTKBSMRuMtTyQSt/yd7VMUNRwOSZK02Wz3MBQK8TwPAI1GgyAI2VIMw3Q63eeEpunz+QwAHMdptVrZxi8QBBEOhwVBAIDBYOByuf7UqFKpUqnU/dLX63Uul7NYLG/qcByPx+Pz+fzX98SyrMFgkG1UKBT5fP5u2W63vV4vEolks9nVagUAfr//aZsXXkmSEEIsy3a73Wq1OpvN7uslk0m3293r9d4pq9frNRrN59Dj8ZxOp9fv9FVTADgej1/CYDBIEATP85fLRXbNh5hMJo7jAKBcLn+PESHEMAwALJdLHH+1RRk4HI7FYgEA6XT6e4wIoel0CgCtVgvDMNmT7XZ7u90WRVEUxXq9XiwWGYYRRREALpcLSZLvNKpUKg9/5clkEo1Gf8fw4Lz7/b7VajWbzbvdbjwe38JmszkajQ6Hwzs1/zF+AAjUMpuD885eAAAAAElFTkSuQmCC\n",
"text/plain": [
"<PIL.Image.Image image mode=RGB size=28x28 at 0x7FF9AC0C8A58>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"OFF=88\n",
"# QA on training data\n",
"a = list(ds_train.take(1))[0] \n",
"b = a[0][OFF].numpy()\n",
"image = Image.fromarray(b.astype(np.uint8))\n",
"image_np = np.array(image)/ 127.5\n",
"input_tensor = tf.convert_to_tensor(image_np)\n",
"input_tensor = image_np[tf.newaxis, ...]\n",
"input_tensor.shape\n",
"detections = model.predict(input_tensor)[0]\n",
"detections = tf.nn.softmax(detections)\n",
"maxval = tf.math.argmax(detections)\n",
"maxlab = labels[maxval]\n",
"print(\"true=\" + labels[a[1][OFF]])\n",
"print(\"prob=\" + str(detections))\n",
"print(\"pred=\" + maxlab + \"\\t\" + str(detections[maxval].numpy()))\n",
"image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e08e6ad9",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment