Skip to content

Instantly share code, notes, and snippets.

@jcausey-astate
Last active July 9, 2021 01:03
Show Gist options
  • Save jcausey-astate/f28a767885feae930d369b003e64edd9 to your computer and use it in GitHub Desktop.
Save jcausey-astate/f28a767885feae930d369b003e64edd9 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Rice Leaf Diseases: Implementing a Custom Model\n",
"\n",
"This notebook will continue our exploration of the [Rice Leaf Diseases dataset from Kaggle](https://www.kaggle.com/vbookshelf/rice-leaf-diseases). \n",
"\n",
"In this notebook, we will do the following:\n",
"\n",
"1. Create a CNN-based classifier network similar to [1].\n",
"3. Train the network directly on our rice image dataset (training partition), using lots of image augmentation to prevent overfitting.\n",
"4. Predict on the testing partition, and display results with a confusion matrix and classification report.\n",
"\n",
"[1]: Wick, C. & Puppe, F. Leaf Identification Using a Deep Convolutional Neural Network. arXiv:1712.00967 [cs] (2017).\n",
"\n",
"## Setting up your environment\n",
"Before running this notebook, you will need to make sure you have [downloaded](https://www.kaggle.com/vbookshelf/rice-leaf-diseases) the dataset and extracted the files. This notebook assumes the image data is extracted in the same directory as this notebook, and that the top-level data directory is named \"rice_leaf_diseases\". You can edit the code if those assumptions do not hold on your own setup.\n",
"\n",
"## Python environment\n",
"You will need the following packages installed to execute the code shown in this notebook:\n",
"\n",
"* [Matplotlib](https://matplotlib.org/)\n",
"* [Numpy](https://numpy.org/)\n",
"* [Pandas](https://pandas.pydata.org/)\n",
"* [Scikit-Learn](https://scikit-learn.org/)\n",
"* [Tensorflow](https://www.tensorflow.org/)\n",
"\n",
"## Local directory structure\n",
"This notebook assumes you have the rice leaf images in the following directory structure:\n",
"\n",
" rice_leaf_diseases/\n",
" Bacterial leaf blight/\n",
" Brown spot/\n",
" Leaf smut/\n",
"\n",
"The \"ground truth\" file \"_rice_leaf_diseases_ground_truth.csv_\" was created in the previous (\"rice_image_EDA\") notebook. If you don't have this file, you can create it by running the first several cells in that notebook, which is available [by clicking here](https://gist.github.com/jcausey-astate/207ba4d65126abe0482b740b41117f9e)."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 1,
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.model_selection import StratifiedShuffleSplit\n",
"import sklearn.metrics as metrics\n",
"import tensorflow as tf\n",
"import tensorflow.keras as keras\n",
"import matplotlib.pyplot as plt"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"First, we will create a function that will actually define the structure of the network. We use Wick et al. (2017) as a guide, and fill in the missing details with common choices (activations, amount of dropout, etc.)."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"source": [
"def build_leaf_classifier_model(input_shape=(256,256,3), n_classes=3):\n",
" '''\n",
" This function builds a Keras CNN similar to the one described in Wick et al. (2017).\n",
" '''\n",
" lrelu = keras.layers.LeakyReLU()\n",
" input = keras.layers.Input(shape=input_shape)\n",
" x = keras.layers.Conv2D(40, (3,3), activation=lrelu, padding='same', name=\"conv1\")(input)\n",
" x = keras.layers.MaxPool2D((2,2), name='max_pool1')(x)\n",
" x = keras.layers.Conv2D(40, (4,4), activation=lrelu, padding='same', name='conv2')(x)\n",
" x = keras.layers.MaxPool2D((2,2), name='max_pool2')(x)\n",
" x = keras.layers.Conv2D(80, (4,4), activation=lrelu, padding='same', name='conv3')(x)\n",
" x = keras.layers.MaxPool2D((2,2), name='max_pool3')(x)\n",
" x = keras.layers.Conv2D(160, (4,4), activation=lrelu, padding='same', name='conv4')(x)\n",
" x = keras.layers.MaxPool2D((2,2), name='max_pool4')(x)\n",
" x = keras.layers.Flatten()(x)\n",
" x = keras.layers.Dense(500, activation=lrelu)(x)\n",
" x = keras.layers.Dropout(0.20)(x)\n",
" x = keras.layers.Dense(n_classes, activation='softmax')(x)\n",
" return keras.Model(inputs=input, outputs=x)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Let's create an instance of the model and see the summary:"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"source": [
"leaf_model = build_leaf_classifier_model()\n",
"leaf_model.summary()"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model: \"model\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"input_1 (InputLayer) [(None, 256, 256, 3)] 0 \n",
"_________________________________________________________________\n",
"conv1 (Conv2D) (None, 256, 256, 40) 1120 \n",
"_________________________________________________________________\n",
"max_pool1 (MaxPooling2D) (None, 128, 128, 40) 0 \n",
"_________________________________________________________________\n",
"conv2 (Conv2D) (None, 128, 128, 40) 25640 \n",
"_________________________________________________________________\n",
"max_pool2 (MaxPooling2D) (None, 64, 64, 40) 0 \n",
"_________________________________________________________________\n",
"conv3 (Conv2D) (None, 64, 64, 80) 51280 \n",
"_________________________________________________________________\n",
"max_pool3 (MaxPooling2D) (None, 32, 32, 80) 0 \n",
"_________________________________________________________________\n",
"conv4 (Conv2D) (None, 32, 32, 160) 204960 \n",
"_________________________________________________________________\n",
"max_pool4 (MaxPooling2D) (None, 16, 16, 160) 0 \n",
"_________________________________________________________________\n",
"flatten (Flatten) (None, 40960) 0 \n",
"_________________________________________________________________\n",
"dense (Dense) (None, 500) 20480500 \n",
"_________________________________________________________________\n",
"dropout (Dropout) (None, 500) 0 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 3) 1503 \n",
"=================================================================\n",
"Total params: 20,765,003\n",
"Trainable params: 20,765,003\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"This model has almost 21 million parameters, which is a bit high for training directly with such a small amount of training data available, but it will give us a good baseline for how well the model can do.\n",
"\n",
"We will train the model from scratch using more or less the same strategy we used with the previous transfer learning example (except for the transfer learning part, of course).\n",
"\n",
"Let's load the dataset and use `StratifiedShuffleSplit` to create an 80/20 train/test split:"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 4,
"source": [
"ground_truth_df = pd.read_csv(\"rice_leaf_diseases_ground_truth.csv\")\n",
"ground_truth_df.head()\n",
"sss = StratifiedShuffleSplit(1, test_size=0.20, random_state=2021)\n",
"train_indices, test_indices = list(sss.split(ground_truth_df.values, ground_truth_df['class'].values))[0]\n",
"print(f\"Training set has {train_indices.shape[0]} samples.\")\n",
"print(f\"Test set has {test_indices.shape[0]} samples.\")"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Training set has 96 samples.\n",
"Test set has 24 samples.\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Now we will set up our image data generators for training and testing.\n",
"\n",
"The training generator will not do any augmentation. It will only scale the image to contain values in the $[0,1]$ numeric range.\n",
"\n",
"The test generator will perform several kinds of augmentation, in addition to rescaling the image to $[0,1]$: \n",
"\n"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"source": [
"# Declare the generator (we will take default initial paramters for now)\n",
"test_generator = keras.preprocessing.image.ImageDataGenerator(\n",
" rescale=1.0/255.0\n",
")\n",
"train_generator = keras.preprocessing.image.ImageDataGenerator(\n",
" rescale=1.0/255.0,\n",
" horizontal_flip=True,\n",
" vertical_flip=True,\n",
" rotation_range=60,\n",
" width_shift_range=0.20,\n",
" height_shift_range=0.10,\n",
" zoom_range=0.2,\n",
" brightness_range=(0.5, 1.25),\n",
")"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"The following cell will create a \"flow\" object that will actually become the sequence of values we feed into the model for training and testing. \n",
"\n",
"The main difference between the two \"flows\" is that the \"train\" version must use only the _training_ rows from the ground truth, and the \"test\" version uses only the _test_ rows."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 6,
"source": [
"BATCH_SIZE = 32\n",
"train_dataflow = train_generator.flow_from_dataframe(\n",
" ground_truth_df.iloc[train_indices,:],\n",
" x_col = \"image_path\",\n",
" y_col = \"class\",\n",
" target_size = (256,256),\n",
" shuffle = True,\n",
" seed = 2021,\n",
" batch_size=BATCH_SIZE,\n",
")\n",
"test_dataflow = test_generator.flow_from_dataframe(\n",
" ground_truth_df.iloc[test_indices,:],\n",
" x_col = \"image_path\",\n",
" y_col = \"class\",\n",
" target_size = (256,256),\n",
" shuffle = False,\n",
" seed = 2021,\n",
" batch_size=BATCH_SIZE,\n",
")\n",
"class_to_int = test_dataflow.class_indices\n",
"image_classes = sorted(class_to_int.keys())"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Found 96 validated image filenames belonging to 3 classes.\n",
"Found 24 validated image filenames belonging to 3 classes.\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Now we can compile and train. We will train for 30 epochs here."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 7,
"source": [
"leaf_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])\n",
"leaf_model.fit(\n",
" train_dataflow,\n",
" validation_data=test_dataflow,\n",
" epochs=30,\n",
" batch_size=BATCH_SIZE\n",
")"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1/30\n",
"3/3 [==============================] - 8s 3s/step - loss: 3.7391 - acc: 0.3854 - val_loss: 3.2986 - val_acc: 0.3333\n",
"Epoch 2/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 2.7144 - acc: 0.2812 - val_loss: 1.4589 - val_acc: 0.3333\n",
"Epoch 3/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 1.2251 - acc: 0.3542 - val_loss: 1.1923 - val_acc: 0.4167\n",
"Epoch 4/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 1.1523 - acc: 0.3854 - val_loss: 1.0745 - val_acc: 0.3750\n",
"Epoch 5/30\n",
"3/3 [==============================] - 7s 2s/step - loss: 1.1059 - acc: 0.4375 - val_loss: 1.0091 - val_acc: 0.5000\n",
"Epoch 6/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 1.0757 - acc: 0.4062 - val_loss: 1.1515 - val_acc: 0.4167\n",
"Epoch 7/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 1.1265 - acc: 0.3958 - val_loss: 1.0421 - val_acc: 0.3750\n",
"Epoch 8/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 1.1556 - acc: 0.3438 - val_loss: 1.0186 - val_acc: 0.4583\n",
"Epoch 9/30\n",
"3/3 [==============================] - 7s 2s/step - loss: 1.0556 - acc: 0.3646 - val_loss: 1.0765 - val_acc: 0.4167\n",
"Epoch 10/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 1.0648 - acc: 0.4375 - val_loss: 1.0254 - val_acc: 0.4167\n",
"Epoch 11/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 1.0499 - acc: 0.4167 - val_loss: 1.0139 - val_acc: 0.5417\n",
"Epoch 12/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 1.0388 - acc: 0.4688 - val_loss: 0.9831 - val_acc: 0.4583\n",
"Epoch 13/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.9707 - acc: 0.4792 - val_loss: 0.9808 - val_acc: 0.3750\n",
"Epoch 14/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.9742 - acc: 0.5625 - val_loss: 0.9954 - val_acc: 0.3750\n",
"Epoch 15/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.9603 - acc: 0.5938 - val_loss: 0.9246 - val_acc: 0.4583\n",
"Epoch 16/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.9592 - acc: 0.5729 - val_loss: 0.9061 - val_acc: 0.5833\n",
"Epoch 17/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.9333 - acc: 0.5208 - val_loss: 1.0532 - val_acc: 0.3750\n",
"Epoch 18/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.9671 - acc: 0.5312 - val_loss: 0.9136 - val_acc: 0.5833\n",
"Epoch 19/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.9243 - acc: 0.4896 - val_loss: 0.8898 - val_acc: 0.6250\n",
"Epoch 20/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.9488 - acc: 0.4688 - val_loss: 0.8645 - val_acc: 0.7083\n",
"Epoch 21/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.9294 - acc: 0.5208 - val_loss: 0.8928 - val_acc: 0.4583\n",
"Epoch 22/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.9029 - acc: 0.5625 - val_loss: 0.8931 - val_acc: 0.4167\n",
"Epoch 23/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.9208 - acc: 0.5521 - val_loss: 0.8971 - val_acc: 0.6250\n",
"Epoch 24/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.8931 - acc: 0.5729 - val_loss: 0.8689 - val_acc: 0.4583\n",
"Epoch 25/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.8407 - acc: 0.6667 - val_loss: 0.8673 - val_acc: 0.6667\n",
"Epoch 26/30\n",
"3/3 [==============================] - 7s 2s/step - loss: 0.8480 - acc: 0.6562 - val_loss: 0.8648 - val_acc: 0.5417\n",
"Epoch 27/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.8545 - acc: 0.6146 - val_loss: 0.8436 - val_acc: 0.5833\n",
"Epoch 28/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.8443 - acc: 0.6250 - val_loss: 0.8668 - val_acc: 0.5000\n",
"Epoch 29/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.8514 - acc: 0.5104 - val_loss: 0.8505 - val_acc: 0.5417\n",
"Epoch 30/30\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.8074 - acc: 0.6042 - val_loss: 0.8556 - val_acc: 0.5417\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x1976f2d00>"
]
},
"metadata": {},
"execution_count": 7
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Now, we have trained the model for a bit - let's see how the accuracy actually looks with our previous metrics:"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 8,
"source": [
"gt_classes = test_dataflow.classes\n",
"print(gt_classes)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[2, 2, 2, 1, 0, 1, 1, 0, 2, 2, 1, 0, 0, 0, 1, 1, 2, 1, 2, 1, 0, 0, 2, 0]\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 9,
"source": [
"predictions = leaf_model.predict(test_dataflow)\n",
"predicted_classes = list(np.argmax(predictions, axis=1))\n",
"print(predicted_classes)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[1, 2, 1, 1, 0, 1, 1, 1, 2, 1, 1, 0, 1, 2, 2, 1, 2, 1, 2, 1, 1, 1, 0, 1]\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"source": [
"metrics.ConfusionMatrixDisplay(\n",
" metrics.confusion_matrix(gt_classes, predicted_classes), \n",
" display_labels=image_classes\n",
").plot()"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x197f460d0>"
]
},
"metadata": {},
"execution_count": 10
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV4AAAELCAYAAACRaO5eAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAhcklEQVR4nO3deZhcVbnv8e+vMxASEkJIQAhwQgBBBESeGEUQAyoITugBFYer4hEwKMM1OB8H9CjRc1UUByIqAgGPhlkUAkIM5BDISOgkAkoCSJiSEAgkZOh+7x97NRRtV3V1umrvStfv8zz76apdq9Z+a6fz9qq11l5bEYGZmeWnpegAzMyajROvmVnOnHjNzHLmxGtmljMnXjOznDnxmpnlzInXzKyXJO0raWHJ9qyks8qW9zxeM7PakdQPeBR4fUQ81FUZt3jNzGrrLcA/yiVdgP45BtM0BgwaEgOHjCg6jIbVf9TGokNofA/3KzqChvfs+sdWRsSoLX3/MUcOiVWr26oqO2/RhsXACyW7pkTElDLFPwhcUak+J946GDhkBAcce1bRYTSsHU8p2xCwpO2sHYoOoeFNX3Bur36RVq5u466bdquq7IBd/vFCRIzrrpykgcC7gS9VKufEa2ZNKmiL9lpXeiwwPyKeqFTIidfMmlIA7dR8csFJdNPNAE68ZtbE2qldi1fSEOBtwKndlXXiNbOmFARtNZxOGxHPAztWU9aJ18yaUgCbatji7QknXjNrWnXo462KE6+ZNaWAmnY19IQTr5k1rWI6Gpx4zaxJBUGbuxrMzPITAZsKWiPMidfMmpRoQ4Uc2YnXzJpSAO1u8ZqZ5cstXjOzHAVOvGZmuWsPJ14zs9y0IzZSzILzTrxm1rTc4jUzy5H7eM3Mcifaopj7/TrxmllTyu5A4cRrZpYrdzWYmeUoQmwKz2owM8tNNrjmrgYzsxx5cM3MLFceXDMzK0CbL6AwM8tPIDZF7VKgpOHARcABZA3qkyPizq7KOvGaWVOqw+Da+cCNEXGCpIHA4HIFnXjNrCkFqllXg6TtgSOAjwNExEZgY7nyxfQsm5k1gHZaqtqqsCfwFPAbSQskXSRpSLnCbvH2YTtt/xxfO+k2RgxdR4S4dvar+P0dBxYdVsPZ9IEnYbCgBdRP9J8ysuiQGsbZZ81m/PgVrFkziE9PPK7ocGoqgp5MJxspaW7J8ykRMaXkeX/gEOCzEXGXpPOBLwL/2VVlTrx9WFu7+PH1b+D+R0cxeJuN/Oasq7j7gd1Y/sQORYfWcPr/cEc03F8AO7v5lrFcd/0rmfS52UWHUgeivfpLhldGxLgKr/8T+GdE3JWeTyNLvF2q22+apDGSWutVfxEkfbnoGHpi1doh3P/oKADWbRjI8ieGM2rY8wVHZVuT1tadWLt2YNFh1EUAG6N/VVu3dUU8Djwiad+06y3AknLlC23xSuoXEW1FxtBDXwa+U3QQW+IVO6zllaNXsfjhnYoOpfEINp+zCgT93jWElneVHYy2PiRQrRdC/ywwNc1oeBD4RLmC9f5u1V/SVElLJU2TNFjSckmTJc0HTpR0kqR7JbVKmgwg6URJP0iPz5T0YHo8VtKs9Hi5pG9Kmp/ev1+5ICS9WdLCtC2QNFTSBEkzJd0g6T5Jv5DUksp3FdN5wLapjqn1PW21te3ATXz3Y9P50bWHsm5D32y99Eb/n+zIgF+Oov/kEbRd8zzt92woOiTLSRstVW3ViIiFETEuIg6KiOMj4ulyZeudePcFfhYRrwKeBSam/asi4hBgJjAZOAo4GHidpOOB24E3pbJvAlZJGp0ezyypf2Wq5+fApApxTAJOj4iDUx3r0/7xZH+l9gf2At4nadeuYoqILwLrI+LgiPhw5wNIOkXSXElzN7/QOF/n+7W08Z2PTeem+fvw19axRYfTkDQqW6FKO/Sj5fBBxNJNBUdkeQigPVqq2mqt3on3kYiYlR5fBhyeHv9P+vk6YEZEPBURm4GpwBGpv2Q7SUOB3YHLyebIvYksKXe4Kv2cB4ypEMcs4AeSzgCGp2MB3B0RD6bujitSfF3G1N0HjYgp6a/duP6Dys4iyVnwlff/lYeeGM7vZh5UdDANKda3E+vaX3o8dwPa02POzUG0VbnVWr1/w6LM82qahP9L1kdyH1myPRk4FPhcSZmO74RtVPgsEXGepBuA44BZko7pJr4+4aAxj3PsuAf4+4oR/PbsaQD84s/jufNvexQcWQN5up3N/5m+EbZBy1sG0fL6QcXG1EC+8PlZHHTQkwwbtoFLL7mGSy87kOnT9yo6rJroaPEWod6Jdw9Jh6brlT8E3AG8tuT1u4EfSxoJPA2cBPwkvXY7cG7aFgBHkn3Vf6anQUjaKyLuBe6V9DpgP2ANMF7SnsBDwAeAKd3EtEnSgIjYKr6LLlq+C4dOOrXoMBqadu3PgF+NKjqMhjX5e4cVHULdFLkQer3T/X3A6ZKWAjuQ9cW+KCIeI5vrdhtwDzAvIq5NL99O1s0wM3UFPEKWuLfEWWmgbBGwCfhz2j8HuABYCiwDru4mpinAoq1tcM3MutYWLVVttVa3Fm9ELCdrWXY2plO5K8j6Vzu//x/wUudKRBzd6fUxJY/nAhMqxPLZzvskATwbEe/sony5mL4AfKHcccxs65Gtx+tlIc3McuQ7UNSEpE8AZ3baPSsiTu9cNiJmADNyCMvMGlA2uOYWb69FxG+A3xQdh5k1vmwhdN9l2MwsV77nmplZjrJlId3VYGaWK/fxmpnlKFudzF0NZma5qsc6DNVw4jWzphSIze2e1WBmlitfuWZmliPPajAzK4AH18zMclSHe65VzYnXzJqW+3jNzHIU4FkNZma5Cnc1mJnlqtYLoUtaDqwluwfk5ogYV66sE6+ZNa06tHiPjIiV3RVy4jWzpuSF0M3McpZdMlzTebwBTJcUwIURMaVcQSdeM2taPejjHSlpbsnzKV0k1sMj4lFJOwE3S/pbRMzsqjInXjNrTtGjroaVlQbLACLi0fTzSUlXA+OBLhNvMdfLmZkVrKOPt5qtO5KGSBra8Rg4GmgtV94tXjNrWjUcXNsZuFoSZHn18oi4sVxhJ14za0q1XKshIh4EXlNteSdeM2tabV6dzMwsP9GzwbWacuI1s6YVTrxmZnnyIjlmZrlzi7cP6bf6eYZdPrvoMBrWH/97YdEhNLxj+VDRIfR5XqvBzCxvvtmlmVm+Anc1mJnlzINrZma5iyjmuE68Zta03NVgZpajCGir7ULoVXPiNbOm5a4GM7OcuavBzCxHgZx4zczyVlBPgxOvmTWpcFeDmVnuot2J18wsVw03q0HST6jQBRIRZ9QlIjOzHDTqWg1zc4vCzCxvATRa4o2I35Y+lzQ4ItbVPyQzs3wU1dXQ7fVykg6VtAT4W3r+Gkk/q3tkZmZ1JaK9uq3qGqV+khZI+mOlctVcqPwj4BhgFUBE3AMcUXUkZmaNKqrcqncmsLS7QlWtEBERj3Ta1dajUMzMGk2ax1vNVg1JuwHvAC7qrmw108kekfRGICQNoMqMbmbW8Grbx/sj4PPA0O4KVtPiPQ04HRgNrAAOTs/NzLZyqnJjpKS5JdspL6tFeifwZETMq+ao3bZ4I2Il8OEefBIzs61D9S3elRExrsLrhwHvlnQcMAgYJumyiPhIV4WrmdUwVtL1kp6S9KSkayWNrTpcM7NGFEC7qtu6qyriSxGxW0SMAT4I3Fou6UJ1XQ2XA78HdgF2Bf4AXFHF+8zMGlpEdVutVZN4B0fEpRGxOW2XkTWlzcy2brWfTkZEzIiId1YqU2mthhHp4Z8lfRH4XQrhA8CfehaKmVkDarRLhoF5ZIm2I7JTS14L4Ev1CsrMLA9qtNXJImLPPAMxM8vVFnQj1EpV6/FKOgDYn5K+3Yi4pF5BmZnVX3UzFuqh28Qr6evABLLE+yfgWOAOwInXzLZujbo6GXAC8Bbg8Yj4BPAaYPu6RmVmloc6zGqoRjVdDesjol3SZknDgCeB3WsfitXDuAnPctq3VtCvJfjzFSP4/QU7Fx1SQ3nk79vwndPGvPj88YcH8tFzHud9n3qquKAayNlnzWb8+BWsWTOIT088ruhwaqsRF0IvMVfScOCXZDMdngPurGdQVhstLcHp33mUL31wLCsfG8BP/vQAs2/anocf8DTsDrvvvYGf33IfAG1t8OFDXs1hx64pNqgGcvMtY7nu+lcy6XOziw6lLoqa1dBtV0NETIyINRHxC+BtwMdSl0OPSXpuS96X3nuipKWSbivz+sclXVDpuJJ2lTRtS+OUdLyk/XsSd5H2fe06ViwfyOMPb8PmTS3MuHY4hx7zTNFhNayFtw9ll3/bwM67bSo6lIbR2roTa9cOLDqM+mm0rgZJh1R6LSLm1z6cij4JfCoi7tjSCiJiBVmf9ZY6HvgjsKQXdeRmx1ds4qkVL/2nWfnYAPY7xHdvKmfGtcOZcPyaosOwHDXcPF7g/1V4LYCjenNgSecA7we2Aa6OiK+n/deQ9SEPAs6PiCmSvgYcDvxK0nURcU6ZaneXNINsCcvLIuKbnY45BvhjRBwgaTBwMXAAcB/ZOhSnR8TcVPa/gHcC64H3AHsB7wbeLOmrwL9HxD9K6j4FOAVgEIN7cWasCJs2itnTt+fkLz9WdCiWp0br442II+t1UElHA/sA48mujLtO0hERMRM4OSJWS9oWmCPpyog4V9JRwKSOxFjGeLJEui6994YK5ScCT0fE/mme8sKS14YAsyPiK5K+R9bS/rak68gS9790V0TEFGAKwDCNKOjv6MutenwAo3bd+OLzkbtsYuVjAwqMqHHNuXUoex+4jh1GbS46FMtLgRdQVHXrnzo4Om0LgPnAfmSJGOAMSfcAs8lavvt0WUPXbo6IVRGxHriKrJVczuFk608QEa3AopLXNpJ1KUA2oDimBzE0jPsWDmb0nhvZefcN9B/QzoT3rGH2dM8E7MqMa3ZwN0MzarQ+3joT8N2IuPBlO6UJwFuBQyNiXeo26MkQfOdTtKWnbFPEi4vBtVHceeqV9jbx06+M5juXP0hLP5j+uxE8dL9nNHT2wroW5t8+lDO/1/nWgvaFz8/ioIOeZNiwDVx6yTVcetmBTJ++V9Fh1Uwj9vHW003AtyRNjYjnJI0GNpFdmPF0Srr7AW/oYb1vS6uqrScbCDu5QtlZZH3Mt6WZCgdWUf9aqrifUiOZc+sw5tw6rOgwGtqgwe1MW9xadBgNafL3Dis6hPpqL+aw1dyBQpI+kga4kLSHpPG9OWhETCdbYP1OSfcC08gS2o1Af0lLgfPIuht64m7gSrJugyu76Q/+GTBK0hLg28BioLu5Vr8DzpG0QFLf+bNv1oQU1W+1Vk2L92dkfxeOAs4la/VdCbyupweLiO1KHp8PnN9FsWPLvHdCN3VfTDZLoexxI2I52eAbwAvARyLihZREbwEe6iLOaWR/GIiIWWRrVphZX9BosxpKvD4iDpG0ACAinpbUF2ZUDybrZhhA1uc8MSI2dvMeM+tLGriPd5OkfqQQJY2isJ4RkHQMMLnT7mUR8d6e1BMRa4FKdw01sz6ukQfXfgxcDeyULio4AfhqXaOqICJuIhucMzPrnUZNvBExVdI8sqUhBRwfEUvrHpmZWT0FqKDv7tUshL4H2ZVg15fui4iH6xmYmVndNWqLF7iBl256OQjYk2xtg1fXMS4zs7qrVR+vpEHATLK1Z/oD0zrWn+lKNV0NL7uwIK1aNrGXcZqZ9SUbgKPSBWEDgDsk/TkiurwWocdXrkXEfEmv722UZmaFq1GLNy0x0LGO94C0la29mj7e/1vytAU4BFjRixjNzIpX48G1NO12HrA38NOIuKtc2WpWJxtasm1D1uf7nhrEaWZWrOpXJxspaW7Jdsq/VBXRFhEHA7sB49Nys12q2OJNGXxoREzaks9kZtaoRI8G11ZGRFUXXEXEmnSLsrcDXa6+VLbFK6l/RLQBfXx5IjNrWjVaj1fSqHRTYNJNHN4G/K1c+Uot3rvJ+nMXpjsv/AF4/sV4I67qPhwzswZV25XHdgF+m3oJWoDfR8QfyxWuZlbDIGAV2epkHfN5g+wOD2ZmW6/azWpYBLy22vKVEu9OaUZDKy8l3BePs2XhmZk1jka8ZLgfsB0vT7gdnHjNbOvXgJcMPxYR5+YWiZlZngq8y3ClxFvM0uxmZjlpxPV435JbFGZmRWi0xBsRq/MMxMwsb43Y4jUz67uCwm5i5sRrZk1JFDeQ5cRrZs3LXQ1mZvlyH6+ZWd6ceM3MctTIdxk2M+uz3OI1M8uX+3jNzPLmxNt3aNtBtOy3f9FhNKz9f/7GokNoeC+ctaHoEBrfx3pfhVu8ZmZ5atDVyczM+izhWQ1mZvlzi9fMLF+KYjKvE6+ZNSf38ZqZ5c+zGszMclbU4FpLMYc1M2sAUeXWDUm7S7pN0hJJiyWdWam8W7xm1pyipl0Nm4HPRcR8SUOBeZJujoglXRV2i9fMmleNWrwR8VhEzE+P1wJLgdHlyrvFa2ZNSfSoxTtS0tyS51MiYkqX9UpjgNcCd5WrzInXzJpX9fN4V0bEuO4KSdoOuBI4KyKeLVfOidfMmlONF0KXNIAs6U6NiKsqlXXiNbOmVavEK0nAr4ClEfGD7sp7cM3MmleNBteAw4CPAkdJWpi248oVdovXzJpWraaTRcQdZON1VXHiNbPmFPRkcK2mnHjNrGl5rQYzsxx5IXQzs7xFuKvBzCxv7mowM8ubE6+ZWb7c4jUzy1MAbe7jNTPLlVu8ZmZ586wGM7N8ucVrZpYn397dzCxf2R0o3NVgZpYreVaDmVmO3NVg9XD2WbMZP34Fa9YM4tMTy67J3NQG9tvMJe+5loH92ujf0s70B8dywZzxRYfVeNqD3b+xlLYdBrLi7L2LjqZGvFaD1cHNt4zluutfyaTPzS46lIa1sa0fJ1/3btZtHkD/ljYuO/4aZj68B4ueeEXRoTWU4dOfZNOug2hZX9ByXnVS1KyGwm79I+m5Xrz3RElLJd1Wy5i6OeZwSRPzOl4ttLbuxNq1A4sOo8GJdZsHANC/pZ3+Le0QVd9IoCn0X72RIfc8wzNHjCw6lNrrWKGsu63GttYW7yeBT6XbbeRlODAR+FmOx7QctKidaSdMY4/tn+Hy1gNY9OTORYfUUEZe/ggrPzC6z7V2a32X4Z5oiJtdSjpH0hxJiyR9s2T/NZLmSVos6ZS072vA4cCvJH2/TH2vlnR3uuHcIkn7SBoj6W+SLpZ0v6Spkt4qaZakBySNT+/9hqRJJXW1ShoDnAfslers8ri2dWqPFt73h/dz5CX/hwN3epK9R6wqOqSGMWThGtqGDWDDmCFFh1If7VHdVmOFt3glHQ3sA4wnm1p3naQjImImcHJErJa0LTBH0pURca6ko4BJETG3TLWnAedHxFRJA4F+wM7A3sCJwMnAHOBDZEn83cCXgeMrhPpF4ICIOLjM5zgFOAVg0IDtq/781jjWbtyGux8dzZt2f4S/r96x6HAawqAHnmfIgjUMuecZtKmdlhfa2PnCZTxx6p5Fh1YTzTyP9+i0LUjPtyNLxDOBMyS9N+3fPe2vpjlyJ/AVSbsBV0XEA9lt71kWEfcCSFoM/CUiQtK9wJjefIiImAJMAdh+8K4FddlbT+0waD2b21tYu3Ebtum3mTfu/ggXLXht0WE1jFUnjmbViaMB2HbpWna48Yk+k3SBpp7VIOC7EXHhy3ZKE4C3AodGxDpJM4BB1VQYEZdLugt4B/AnSacCDwIbSoq1lzxv56VzsZmXd8FUdcxG9IXPz+Kgg55k2LANXHrJNVx62YFMn75X0WE1lFGD1/Hdo26lpaWdFgU3/n1v/vrQmKLDsjwE2f/8GpD0a+CdwJMRcUB35Rsh8d4EfEvS1Ih4TtJoYBOwPfB0Srr7AW+otkJJY4EHI+LHkvYADiJLvNVYTnYCkXQI0PHnfS0wtNoYGsHk7x1WdAgN7/7VO/Lv004sOoytwvpXDWX9q7aq/wIViahlV8PFwAXAJdUULnxwLSKmA5cDd6av/NPIEtyNQH9JS8kGtnoyGfX9QKukhcABVHkykiuBEakr4jPA/SnOVcCsNNjmwTWzvqC9vbqtG2lManW1hy2sxRsR25U8Ph84v4tix5Z574Ru6j6PLFmXWk2WhDvKfLzk8fKO1yJiPVmfc1f1fqjScc1sK9KzroaRkkoH86ekcZ0t0ghdDWZmhehBV8PKiBhXq+Nu1YlX0jHA5E67l0XEe7sqb2b2Mk08q2GLRcRNZINzZmY9VNwiOYUPrpmZFSKo2VoNkq4gu35gX0n/lPTJSuW36havmVlv1Goh9Ig4qSflnXjNrHm5j9fMLEdBXRbAqYYTr5k1Kd+Bwswsf068ZmY5c+I1M8tRBLS1FXJoJ14za15u8ZqZ5cizGszMCuAWr5lZzpx4zcxy5ME1M7MCuMVrZpYzJ14zszyFZzWYmeUqIKJG93fvISdeM2tebvGameXIsxrMzArgwTUzs3xFu/t4zcxy5IXQzczy5UVyzMzyFUAUNLjWUshRzcyKFgHRXt1WBUlvl3SfpL9L+mKlsm7xmlnTihp1NUjqB/wUeBvwT2COpOsiYklX5d3iNbPmVbsW73jg7xHxYERsBH4HvKdcYUVBo3p9maSngIeKjqPESGBl0UE0OJ+jyhrx/PxbRIza0jdLupHsc1VjEPBCyfMpETGlpK4TgLdHxH+k5x8FXh8Rn+mqMnc11EFvfhnqQdLciBhXdByNzOeosr54fiLi7UUd210NZma99yiwe8nz3dK+Ljnxmpn13hxgH0l7ShoIfBC4rlxhdzU0hyndF2l6PkeV+fxUEBGbJX0GuAnoB/w6IhaXK+/BNTOznLmrwcwsZ068ZmY5c+I1M8uZE29BJI2R1Fp0HLUk6ct1qve5Xrz3RElLJd1W5vWPS7qg0nEl7Spp2pbGKel4Sfv3JO4tUc/zVA+ShkuamNfxGokTbwNL139vTeqSeHvpk8CnIuLILa0gIlZExAm9iOF4oO6Jt5d6fZ62wHDAiddy11/S1NTSmCZpsKTlkiZLmg+cKOkkSfdKapU0GV5snfwgPT5T0oPp8VhJs9Lj5ZK+KWl+ev9+5YKQ9GZJC9O2QNJQSRMkzZR0Q1px6ReSWlL5rmI6D9g21TG1XidM0jmS5khaJOmbJfuvkTRP0mJJp6R9XwMOB34l6fsVqt1d0gxJD0j6ehfHfPHbSfo3+r2kJZKulnSXpHElZf9L0j2SZkvaWdIbgXcD30/nZq8anYqKan2eJL1a0t3pMyyStE86L3+TdLGk+9Pv8lslzUrncnx67zckTSqpq1XSGOA8YK9UZ6V/n74nIrwVsAFjyJYEPSw9/zUwCVgOfD7t2xV4GBhFNuf6VrLW0yuAOanMNLLJ26OBjwHfTfuXA59NjycCF1WI5fqSOLZLx5pAdm36WLJ5iTcDJ5SLKb33uTqdq+fSz6PJ5pOKrNHwR+CI9NqI9HNboBXYMT2fAYyrUPfHgceAHUveO67TcccArenxJODC9PgAYHNJ+QDelR5/D/hqenwxcEIOv1P1PE8/AT6cHg9M7x+TPv+B6TjzyH6PRbZAzDWp/DeASSV1tab3vnhem21zi7dYj0TErPT4MrJWB8D/pJ+vA2ZExFMRsRmYSvYf6HFgO0lDyS5TvBw4AngTcHtJ/Veln/PIfsnLmQX8QNIZwPB0LIC7I1ttqQ24IsXXZUxb8Nm3xNFpWwDMB/YD9kmvnSHpHmA22TnZp8saunZzRKyKiPVk5+zwCmUPJ1t5iohoBRaVvLaRLMlB9+e8nupxnu4EvizpC2SL06xP+5dFxL0R0Q4sBv4SWXa9l+I+f8PzlWvF6nz1Ssfz56t47/8CnwDuI0u2JwOHAp8rKbMh/Wyjwr91RJwn6QbgOGCWpGO6ia8oImvRX/iyndIE4K3AoRGxTtIMstWkqlWrz7kpJR3o5pzXWc3PU0RcLuku4B3AnySdCjzIS79jAO0lz9t56fNv5uXdmj35t+mT3OIt1h6SDk2PPwTc0en1u4E3SxqZBtpOAv6aXrud7GvvTLKWzZHAhoh4pqdBSNortVomk3VbdPQHj1d27XkL8IEUX6WYNkka0NPj98BNwMmStktxj5a0E7A98HRKJvsBb+hhvW+TNELStmRdObMqlJ0FvD8df3+yr9ndWQsM7WFMvVHz8yRpLPBgRPwYuBY4qAfxLAcOSfUcAuyZ9ud9XhqGE2+x7gNOl7QU2AH4eemLEfEY8EXgNuAeYF5EXJtevp3sq+LM1BXwCP+auKt1VhrwWARsAv6c9s8BLgCWAsuAq7uJaQqwqF6DaxExnaxb5U5J95L1bw8FbiQbqFxKNmAzu4dV3w1cSdZtcGVEzK1Q9mfAKElLgG+Tfb3u7o/d74Bz0sBl3QfX6nSe3g+0SlpI1rd9SQ/eeyUwQtJi4DPA/SnOVWTfsFqbbXDNazVYl9LX0kkR8c6CQ2koqZU/ICJeSEn0FmDfyO46YFYV9/Ga9cxg4LbUpSJgopOu9ZRbvE1E0ieAMzvtnhURpxcRT57SgOHkTruXRcR7i4inUfk85cOJ18wsZx5cMzPLmROvmVnOnHitEJLa0jX6rZL+IGlwL+q6WNnttZF0kSqsBKZsDYo3bsExlkv6l1uBl9vfqUyPVg3rvLaB9T1OvFaU9RFxcEQcQHap7WmlL0raohk3EfEfEbGkQpEJQI8Tr1ktOfFaI7gd2Du1Rm+XdB2wRFI/Sd8vWWXrVABlLlC2atotwE4dFSlbZWxcevx2Zauz3SPpL2lFrNOAs1Nr+02SRkm6Mh1jjqTD0nt3lDRd2SpeF5FNHatIXaz8VfLaD9P+v0galfbtJenG9J7bVWEFOetbPI/XCpVatseSXVUF2aWlB0TEspS8nomI10nahuwqp+nAa4F9yda43RlYQrYqVmm9o4Bfki0qtEzSiIhYLekXZKt4/Xcqdznww4i4Q9IeZJfbvgr4OnBHRJwr6R1k69V25+R0jG2BOZKuTFdnDQHmRsTZypZg/DrZFVxTgNMi4gFJrye7Ku6oLTiNtpVx4rWibJsuP4Wsxfsrsi6AuyNiWdp/NHBQR/8t2VoD+5CthnZFulR6haRbu6j/DWSXUy8DiIjVZeJ4K7C/9GKDdlha4+AI4H3pvTdIerqKz3SGpI75rh0rf60iWzCmY8W5y4Cr0jHeCPyh5NjbVHEM6wOceK0o6yPi4NIdKQGVrswmsjWFb+pU7rgaxtECvCEiXugilqqpZyt/RTrums7nwJqD+3itkd0EfDpdnoukV0oaQrYi2wdSH/AuZCuzdTYbOELSnum9I9L+zitiTQc+2/FE0sHp4UyyFeOQdCzZIkaVVFr5q4VsEXlSnXdExLPAMkknpmNI0mu6OYb1EU681sguIuu/na/s1jsXkn1Luxp4IL12Cdki3S8TEU8Bp5B9rb+Hl77qXw+8t2NwDTgDGJcG75bw0uyKb5Il7sVkXQ4PdxNrpZW/nidbYrOVrA/33LT/w8AnU3yLye7aYE3AlwybmeXMLV4zs5w58ZqZ5cyJ18wsZ068ZmY5c+I1M8uZE6+ZWc6ceM3Mcvb/AVXZnUURepQwAAAAAElFTkSuQmCC"
},
"metadata": {
"needs_background": "light"
}
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 11,
"source": [
"print(metrics.classification_report(gt_classes, predicted_classes, target_names=image_classes))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" precision recall f1-score support\n",
"\n",
" brown_spot 0.67 0.25 0.36 8\n",
" leaf_blight 0.47 0.88 0.61 8\n",
" leaf_smut 0.67 0.50 0.57 8\n",
"\n",
" accuracy 0.54 24\n",
" macro avg 0.60 0.54 0.51 24\n",
"weighted avg 0.60 0.54 0.51 24\n",
"\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"The performance still isn't great, but we can see that it is in the same range as our previous attempt at transfer learning with a very large model. We might be able to squeeze more performance out of this one by using transfer learning from a similar data set...\n",
"\n",
"What do you think?\n",
"\n",
"---\n",
"\n",
"_Note:_ You may also notice that there was one epoch where the validation performance reached ~70%! If we had saved weights at that moment, we would have had a better performing model -- on this test set... Does that mean it would perform better on new rice images? Probably not. Be careful about selecting a model based on one epoch where the validation performance is high. It could have just been a \"lucky\" set of weights on that epoch. To _really_ solve this problem, we need to do two things. 1) Use a third partition for the validation set during training. Only evaluate the test set after we've chosen what we think is our best model. 2) Repeat this process using cross-validation to see how well it generalized. We acknowledge that cross-validation is time and cost-prohibitive, but it is best-practice when you can afford to do it."
],
"metadata": {}
}
],
"metadata": {
"orig_nbformat": 4,
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment