Skip to content

Instantly share code, notes, and snippets.

@jcausey-astate
Last active July 9, 2021 01:09
Show Gist options
  • Save jcausey-astate/9286ab113a24e7f219370041419b0fd3 to your computer and use it in GitHub Desktop.
Save jcausey-astate/9286ab113a24e7f219370041419b0fd3 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Rice Leaf Diseases: Deep Learning and Transfer Learning (Part 1)\n",
"This notebook will continue our exploration of the [Rice Leaf Diseases dataset from Kaggle](https://www.kaggle.com/vbookshelf/rice-leaf-diseases).\n",
"\n",
"## Setting up your environment\n",
"Before running this notebook, you will need to make sure you have [downloaded](https://www.kaggle.com/vbookshelf/rice-leaf-diseases) the dataset and extracted the files. This notebook assumes the image data is extracted in the same directory as this notebook, and that the top-level data directory is named \"rice_leaf_diseases\". You can edit the code if those assumptions do not hold on your own setup.\n",
"\n",
"## Python environment\n",
"You will need the following packages installed to execute the code shown in this notebook:\n",
"\n",
"* [Matplotlib](https://matplotlib.org/)\n",
"* [Numpy](https://numpy.org/)\n",
"* [Pandas](https://pandas.pydata.org/)\n",
"* [Pillow](https://pillow.readthedocs.io/en/stable/)\n",
"* [Scikit-Learn](https://scikit-learn.org/)\n",
"* [Tensorflow](https://www.tensorflow.org/)\n",
"\n",
"## Local directory structure\n",
"This notebook assumes you have the rice leaf images in the following directory structure:\n",
"\n",
" rice_leaf_diseases/\n",
" Bacterial leaf blight/\n",
" Brown spot/\n",
" Leaf smut/\n",
"\n",
"The \"ground truth\" file \"_rice_leaf_diseases_ground_truth.csv_\" was created in the previous (\"rice_image_EDA\") notebook. If you don't have this file, you can create it by running the first several cells in that notebook, which is available [by clicking here](https://gist.github.com/jcausey-astate/207ba4d65126abe0482b740b41117f9e)."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 1,
"source": [
"import sys, os\n",
"import tensorflow as tf, tensorflow.keras as keras\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.model_selection import StratifiedShuffleSplit\n",
"import sklearn.metrics as metrics"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"We want to train a deep learning CNN-based model that should be capable of looking at the images of rice leaves and focusing on features that will allow the model to discriminate between the different types of disease. \n",
"\n",
"To train such a model from scratch would require a large amount of computational time and (ideally) thousands of training images. The Kaggle dataset only has ~120 images, split into three classes (about 40 per class), so we cannot hope to train a cutting-edge model from scratch.\n",
"\n",
"To overcome this hurdle, we will make use of the fact that many of the \"features\" we see in natural images are common across images of different classes of \"things\" - color, shape, texture, etc. We will use a model that has already been trained on a different natural image dataset, and we will simply replace the layers responsible for performing the classification task. The already-trained parameters for the rest of the model will perform feature extraction for our new classifier, which we will have to train. Using features learned on a similar (but different) dataset to simplify application of a machine learning model on a new dataset is called *transfer learning*, and it is an important tool for researchers."
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"First, let's load up the ground truth data we created in the previos (\"rice_image_EDA\") notebook. If you don't have this file, you can create it by running the first several cells in that notebook, which is available [by clicking here](https://gist.github.com/jcausey-astate/207ba4d65126abe0482b740b41117f9e)."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"source": [
"ground_truth_df = pd.read_csv(\"rice_leaf_diseases_ground_truth.csv\")\n",
"ground_truth_df.head()"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" class image_path\n",
"0 leaf_blight rice_leaf_diseases/Bacterial leaf blight/DSC_0...\n",
"1 leaf_blight rice_leaf_diseases/Bacterial leaf blight/DSC_0...\n",
"2 leaf_blight rice_leaf_diseases/Bacterial leaf blight/DSC_0...\n",
"3 leaf_blight rice_leaf_diseases/Bacterial leaf blight/DSC_0...\n",
"4 leaf_blight rice_leaf_diseases/Bacterial leaf blight/DSC_0..."
],
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>class</th>\n",
" <th>image_path</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>leaf_blight</td>\n",
" <td>rice_leaf_diseases/Bacterial leaf blight/DSC_0...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>leaf_blight</td>\n",
" <td>rice_leaf_diseases/Bacterial leaf blight/DSC_0...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>leaf_blight</td>\n",
" <td>rice_leaf_diseases/Bacterial leaf blight/DSC_0...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>leaf_blight</td>\n",
" <td>rice_leaf_diseases/Bacterial leaf blight/DSC_0...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>leaf_blight</td>\n",
" <td>rice_leaf_diseases/Bacterial leaf blight/DSC_0...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
]
},
"metadata": {},
"execution_count": 2
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Just like we did in the EDA notebook previously, we will split the data into a training and testing set using `StratifiedShuffleSplit` from Scikit-Learn. It will randomize the data and make sure the test set's class distribution approximately matches the training set. We will use an 80% (train) / 20% (test) split."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"source": [
"sss = StratifiedShuffleSplit(1, test_size=0.20, random_state=2021)\n",
"train_indices, test_indices = list(sss.split(ground_truth_df.values, ground_truth_df['class'].values))[0]\n",
"print(f\"Training set has {train_indices.shape[0]} samples.\")\n",
"print(f\"Test set has {test_indices.shape[0]} samples.\")"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Training set has 96 samples.\n",
"Test set has 24 samples.\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Creating the model\n",
"\n",
"Now, let's create the model. We will use the Xception model, which is available in [Keras Applications](https://keras.io/api/applications/xception/). This model provides a nice tradeoff between overall number of parameters and accuracy, with a reasonable default input image size ($299 \\times 299$ pixels).\n",
"\n",
"Because we will be using transfer learning, we will initialize the model with pre-trained weights that were trained on the ImageNet dataset. ImageNet is a large dataset of natural images, so it should provide a good base model for transfer learning to our leaf images."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 4,
"source": [
"xception_base = keras.applications.Xception(\n",
" include_top = False, \n",
" weights=\"imagenet\", \n",
" input_shape=(299,299,3), \n",
" classes=3\n",
")"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"The `include_top = False` argument specifies that we do not want the classification layer from the original model (this is often referred to as the \"head\" or \"top\" of a deep learning model). The `weights` parameter specifies the pre-trained weights we want (we are using ImageNet). We explicitly set the `input_shape` to specify we will use $299 \\times 299$ pixel images, with 3 color channels. Finally, the `classes` parameter is set to 3 to reflect our three rice diseases: Leaf blight, Brown spot, and Leaf smut.\n",
"\n",
"We can see the structure of the model by using its `summary()` method. To see this, just un-comment the following line of code (it is commented here because the output is lengthy). We can see that the model has around 20 million trainable parameters."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"source": [
"# xception_base.summary()"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"### Adding a custom classification layer\n",
"\n",
"Since we set `include_top` to `False` when we created the model, our model does not have a classifer layer... We will add that now.\n",
"\n",
"We will also set the base model's parameters to be non-trainable so that the training process will not try to modify them. This is common in transfer learning -- if we allowed the base model's parameters to be trained, we would make its performance _worse_ because it would receive loss feedback from the new layer we are about to add that will be randomly initialized (and thus not very good at predicting the class at first). By locking the base weights as non-trainable, we protect them while we allow the new layer's weights to train in isolation. This is the \"trick\" that allows transfer learning to be so effective - we are training only a fraction of the total parameters in the model."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 6,
"source": [
"xception_base.trainable = False # mark the base model as non-trainable to prevent changing the pre-trained weights.and\n",
"\n",
"leaf_classifier_input = keras.layers.Input(shape=(299,299,3))\n",
"# We will pass the inputs to the base model, and mark it's `training` parameter as False so that it is running\n",
"# in \"inference\" mode. See (https://keras.io/guides/transfer_learning/) for details.\n",
"intermediate_values = xception_base(leaf_classifier_input, training=False)\n",
"# We need to reshape the features from the base model (which are 4-dimensional CNN outputs) into a vector that\n",
"# can easily be fed into a fully-connected (\"Dense\") layer. We will do this by using a `GlobalAveragePooling2D`\n",
"# layer which will average the values produced by each of the filters (CNN outputs) into single values, giving\n",
"# us a vector.\n",
"intermediate_values = keras.layers.GlobalAveragePooling2D()(intermediate_values)\n",
"# Now we will attach the output layer so that it receives the intermediate features from Xception:\n",
"leaf_classifier_output = keras.layers.Dense(3, activation=\"softmax\")(intermediate_values)\n",
"leaf_classifier_model = keras.Model(inputs=leaf_classifier_input, outputs=leaf_classifier_output)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"If you un-comment the line below and look at the summary, you'll see that there are just over 6,000\n",
"trainable parameters now (about .003% of the total parameters for the model). \n",
"\n",
"We should be able to\n",
"train those even with our small training set (96 images). In fact, we will start to overfit very quickly, so we will only use a small number of epochs for now."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 7,
"source": [
"# leaf_classifier_model.summary()"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Create a generator to handle loading the images\n",
"\n",
"Now we need to set up the training dataset so that we can pass it to a model for training. There is more than one way to do this... The most obvious way is just to read the images using the `PIL` library like we did in the EDA notebook. This has a couple of drawbacks though:\n",
"\n",
"1. We would be loading the entire training dataset into memory at once. Our dataset is small, so this probably isn't a problem - but for larger training sets it could exhaust the machine's memory.\n",
"2. Our images are not all the same size. If you examine the image data, you will see that they come in a range of sizes and aspect ratios. The mean dimensions are about $2362 \\times 717$ pixels, but there is quite a lot of variation. The model we will be using for our first transfer learning experiment (Xception) was designed to support images of shape $299 \\times 299$. We _can_ change that to make it larger, but it will be costly in terms of memory use (both on our computer and on the GPU when we start training). For now, let's choose to rescale the images to make them fit in $299 \\times 299$ pixels.\n",
"\n",
"The good news is that the Keras framework includes several nice tools to help with loading image datasets for training and testing. They have already tackled both of these issues. By using an object called a \"generator\", we will allow Keras to load only enough images at a time to keep the model training without delay, but without loading all of them at once. Keras can also take care of resizing the images automatically (and in a \"smart\" way to assure that we keep the aspect ratio while padding to make the image the correct size).\n",
"\n",
"The object we will use for this purpose is the [`ImageDataGenerator`](https://keras.io/api/preprocessing/image/#imagedatagenerator-class) class from the [`tf.keras.preprocessing`](https://keras.io/api/preprocessing) library. It has a method called [`flow_from_dataframe()`](https://keras.io/api/preprocessing/image/#flowfromdataframe-method) that will do exactly what we need, since our ground truth is already in a Pandas DataFrame.\n",
"\n",
"The `ImageDataGenerator` can also help us perform some simple data augmentation, but we will look at that later.\n"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 8,
"source": [
"# Declare the generator (we will take default initial paramters for now)\n",
"train_generator = keras.preprocessing.image.ImageDataGenerator(\n",
" rescale=1.0/255.0\n",
")\n",
"train_dataflow = train_generator.flow_from_dataframe(\n",
" ground_truth_df.iloc[train_indices,:],\n",
" x_col = \"image_path\",\n",
" y_col = \"class\",\n",
" target_size = (299,299),\n",
" seed = 2021\n",
")\n",
"image_classes = sorted(train_dataflow.class_indices.keys())"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Found 96 validated image filenames belonging to 3 classes.\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Let's train\n",
"\n",
"Now we can start to train the model on the rows of the dataframe that were selected as part of the training set (the indices are stored in `train_indices`). \n",
"\n",
"First we call `compile()` to initialize the (trainable) weights of the new model, and to set up the parameters for the optimization process. Then we will call the `fit()` method and use the `iloc[]` accessor for the ground truth DataFrame to create a filtered DataFrame containing only these training rows, and use the `flow_from_dataframe()` method of our `ImageDataGenerator` object to feed the training images to the training process."
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"One note: We need to use the `rescale` parameter to specify that values stored for each pixel must be rescaled. When we load the images, the color of each pixel is determined by a set of three numbers (one per color channel: red, green, and blue). These numbers are integers in the range $[0,255]$ when we load the images.\n",
"\n",
"But, the Xception model (and most models that take images as input) requires that the values be floating-point values in the range $[0,1]$. To accomplish this, we scale each value (for each channel) by dividing it by 255, so that we get a value in $[0,1]$. This is done by passing the scaling value $\\frac{1}{255}$ to the `rescale` parameter."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 9,
"source": [
"# We will set the seeds for both the Tensorflow and Numpy random number generators so that we get repeatable \n",
"# results from the training process. If you change the seed value, the results will be different.\n",
"tf.random.set_seed(2021)\n",
"np.random.seed(2021)\n",
"leaf_classifier_model.compile(\n",
" optimizer=keras.optimizers.Adam(),\n",
" loss=keras.losses.CategoricalCrossentropy(),\n",
" metrics=[keras.metrics.CategoricalAccuracy()]\n",
")\n",
"leaf_classifier_model.fit(\n",
" train_dataflow,\n",
" epochs=5\n",
")"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1/5\n",
"3/3 [==============================] - 8s 2s/step - loss: 1.1551 - categorical_accuracy: 0.2812\n",
"Epoch 2/5\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.9794 - categorical_accuracy: 0.6146\n",
"Epoch 3/5\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.8356 - categorical_accuracy: 0.7812\n",
"Epoch 4/5\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.7289 - categorical_accuracy: 0.8542\n",
"Epoch 5/5\n",
"3/3 [==============================] - 6s 2s/step - loss: 0.6375 - categorical_accuracy: 0.8958\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x194f45b80>"
]
},
"metadata": {},
"execution_count": 9
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Time to predict\n",
"\n",
"Let's see how we did... Just like in the EDA notebook, we will run inferrence on the 20% of images we set aside for the test set, then look at the performance by category. We can use another `ImageDataGenerator` to do this - this one will feed from the testing indices of our ground truth, and it won't bother shuffling the order of the images."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"source": [
"test_generator = keras.preprocessing.image.ImageDataGenerator(\n",
" rescale=1.0/255.0\n",
")\n",
"test_dataflow = test_generator.flow_from_dataframe(\n",
" ground_truth_df.iloc[test_indices,:],\n",
" x_col = \"image_path\",\n",
" y_col = \"class\",\n",
" target_size = (299,299),\n",
" shuffle = False\n",
")\n",
"predictions = leaf_classifier_model.predict(test_dataflow)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Found 24 validated image filenames belonging to 3 classes.\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Now, we have the predictions. We use a _confusion matrix_ to compare the predicted label versus the correct label for all classes. The number of correct predictions are shown along the diagonal; all other numbers are counting incorrect predictions. \n",
"\n",
"In the EDA notebook we use the Scikit_learn `plot_confusion_matrix()` function, but we saw that it takes the model, inputs, and labels as parameters. It then makes the predictions internally and compares them. This will not work in our case because the model we just trained actually predicts an analog to the _probability_ for each class for _every_ value. So, we need to use the predictions we just made to assign a class label to the most highly probably class.\n",
"\n",
"We can do this very easily by using the Numpy `argmax()` function to find the _index_ of the _column_ whose value is highest, and do that for each row that was predicted. We use the `axis=1` parameter to specify we want the _column_ index for every row. The result will be a vector of values in ${0,1,2}$."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 11,
"source": [
"predicted_classes = list(np.argmax(predictions, axis=1))\n",
"print(predicted_classes)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[2, 0, 2, 1, 2, 1, 1, 1, 2, 2, 1, 0, 1, 2, 0, 1, 2, 1, 2, 1, 0, 1, 0, 1]\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"These predictions correspond to our three classes in the order we specified them when training/predicting:\n",
"\n",
" 0: 'leaf_blight'\n",
" 1: 'brown_spot'\n",
" 2: 'leaf_smut'\n",
"\n",
"Let's create a numeric representation for the ground truth as well:"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 12,
"source": [
"class_to_int = test_dataflow.class_indices\n",
"gt_classes = test_dataflow.classes\n",
"print(gt_classes)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[2, 2, 2, 1, 0, 1, 1, 0, 2, 2, 1, 0, 0, 0, 1, 1, 2, 1, 2, 1, 0, 0, 2, 0]\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Now we can get our confusion matrix, although we have to do it a bit differently. We create a `ConfusionMatrixDisplay` object given the confusion matrix computed by the SciKit Learn metrics `confusion_matrix()` function, then tell it to `plot()` itself:"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 13,
"source": [
"metrics.ConfusionMatrixDisplay(\n",
" metrics.confusion_matrix(gt_classes, predicted_classes), \n",
" display_labels=image_classes\n",
").plot()"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x1955bc5e0>"
]
},
"metadata": {},
"execution_count": 13
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV4AAAELCAYAAACRaO5eAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAh1klEQVR4nO3de7xUZd338c93w0YOgoCgKWKomWZm6kMkZYZWmuZT1qOVHV6VlXprqT1h56eDlWndT3ealZKVlai3gcc8QKWEkshJRA4qJXjCEyCKgrDZ+3f/sRY67vbMns2eWWvtPd/367Vee2bNNdf6zXL8cc11XetaigjMzCw7TXkHYGbWaJx4zcwy5sRrZpYxJ14zs4w58ZqZZcyJ18wsY068ZmbdJGkfSQtLtuclnVW2vOfxmpnVjqQ+wOPAWyPi4Y7KuMVrZlZb7wL+VS7pAvTNMJiG0dx/UPQbNDzvMApryxD/yupM3+eVdwiFt2HtY6sjYuS2vv+owwfFmrWtVZWdv2jTEuClkl2TImJSmeIfBa6sVJ8Tbx30GzSc/Y8+K+8wCuupI1vyDqHwdp7enHcIhTdn8sSyLcpqrF7byt3TdquqbPMu/3opIsZ2Vk5SP+D9wNcrlXPiNbMGFbRGW60rPRpYEBFPVSrkxGtmDSmANmre7XUinXQzgBOvmTWwNmrX4pU0CHgPcEpnZZ14zawhBUFrDafTRsSLwI7VlHXiNbOGFEBLDVu8XeHEa2YNqw59vFVx4jWzhhRQ066GrnDiNbOGlU9HgxOvmTWoIGh1V4OZWXYioCWnq9edeM2sQYlW8lkTw4nXzBpSAG1u8ZqZZcstXjOzDAVOvGZmmWsLJ14zs8y0ITbTJ5djO/GaWcNyi9fMLEPu4zUzy5xojXzu9+vEa2YNKbkDhROvmVmm3NVgZpahCNESntVgZpaZZHDNXQ1mZhny4JqZWaY8uGZmloNWX0BhZpadQLRE7VKgpKHApcD+JA3qkyLiro7KOvGaWUOqw+DaBcCtEXG8pH7AwHIFnXjNrCEFqllXg6QdgMOATwNExGZgc7ny+fQsm5kVQBtNVW1V2AN4BvidpHskXSppULnCbvH2Yjvt8ALfPvF2hg/eQIS4fvYbuPrON+UdVvG0BaO/u4zWYf1Y9aXX5R1NofTm71AEXZlONkLSvJLnkyJiUsnzvsDBwBcj4m5JFwBfA/5fR5U58fZirW3iwhsP4cHHRzJwu8387qxrmLN8N1Y+NSzv0Apl6PSnadm1P00b2/IOpXB693dItFV/yfDqiBhb4fXHgMci4u70+RSSxNuhunU1SBojaXG96s+DpG/kHUNXrFk/iAcfHwnAhk39WPnUUEYOeTHnqIql79rNDLr3OZ47bETeoRRSb/4OBbA5+la1dVpXxJPAo5L2SXe9C1harnyufbyS8rlQetv1qMRb6jXD1vP6UWtY8shOeYdSKCOueJTVHxkFymc+Z0/S275DgWiL6rYqfRGYLGkRcCBwbrmC9U68fSVNlrRM0hRJAyWtlHS+pAXACZJOlHSfpMWSzgeQdIKkn6aPz5T0UPp4T0mz0scrJX1P0oL0/fuWC0LSOyUtTLd7JA2WNEHSTEk3SXpA0sWSmtLyHcV0HjAgrWNyfU9bbQ3o18KPPjWdn10/ng2b+uUdTmEMWriO1iHNbBpTdgzEUr31O9RKU1VbNSJiYUSMjYgDIuK4iHi2XNl69/HuA3w2ImZJ+i1wWrp/TUQcLGlXYDbwv4BngemSjgPuAL6Sln0HsEbSqPTxzJL6V6f1nAZMBD5XJo6JwOlpHNsDL6X7xwH7AQ8DtwIfkvQP4Pz2MUXE1yR9ISIO7OgAkk4GTgboN7A4/V99mlo591PTmbZgb/6+eM+8wymU/stfZNA96xh073OopY2ml1rZ+ZIVPHXKHnmHVii99TsUQFsvXavh0YiYlT6+HDgjffzf6d+3ADMi4hmAtCV5WERcJ2l7SYOB0cAVJHPk3gFcU1L/1sfzgQ9ViGMW8NO0/msi4jElPy3nRMTW1vSVwKFAS0cxAddV+qDpCOckgEE7jo5KZbMTfPPDf+fhp4Zy1cwD8g6mcNacMIo1J4wCYMCy9Qy79Skn3X/Tm79D6rXr8bZPQFufV9M7/w/gM8ADJC3gk4DxwJdLymxK/7ZS4bNExHmSbgKOAWZJOqqT+HqFA8Y8ydFjl/PPVcP5/ZemAHDxLeO46/7dc47Meore/B3qzS3e3SWNT69X/hhwJ3BQyetzgAsljSD5WX8i8PP0tTuAc9LtHuBwYGNEPNfVICTtFRH3AfdJeguwL7AOGCdpD5Kuho+QtFgrxdQiqTkiWroaQx4WrdyF8RNPyTuMHmHjGwaz8Q2D8w6jcHrzdyjPhdDrne4fAE6XtAwYBvyq9MWIeIJkrtvtwL3A/Ii4Pn35DpJuhpkR0Qo8SpK4t8VZ6UDZIpKuhFvS/XOBi4BlwArg2k5imgQs6mmDa2bWsdZoqmqrtbq1eCNiJUnLsr0x7cpdCVzZwfv/Ba90wETEke1eH1PyeB4woUIsX2y/L+3jfT4iju2gfLmYvgp8tdxxzKznSNbj7Z19vGZmBeU7UNSEpM8AZ7bbPSsiTm9fNiJmADMyCMvMCigZXHOLt9si4nfA7/KOw8yKL1kI3XcZNjPLlO+5ZmaWoWRZSHc1mJllyn28ZmYZSlYnc1eDmVmmeutaDWZmhRSILW2e1WBmlilfuWZmliHPajAzy4EH18zMMrT1nmt5cOI1s4blPl4zswwFeFaDmVmmunbr9ppy4jWzhlTrhdAlrQTWk9wDcktEjC1X1onXzBpWHVq8h0fE6s4KOfGaWUPyQuhmZhlLLhmu6TzeAKZLCuCSiJhUrqATr5k1rC708Y6QNK/k+aQOEuuhEfG4pJ2Av0i6PyJmdlSZE6+ZNaboUlfD6kqDZQAR8Xj692lJ1wLjgA4Tbz7Xy5mZ5WxrH281W2ckDZI0eOtj4EhgcbnybvGaWcOq4eDazsC1kiDJq1dExK3lCjvxmllDquVaDRHxEPDmass78ZpZw2r16mRmZtmJrg2u1ZQTr5k1rHDiNTPLkhfJMTPLnFu8vUjfja0MXfp83mEU1l3/eUXeIRTeUZ8+MO8Qej2v1WBmljXf7NLMLFuBuxrMzDLmwTUzs8xF5HNcJ14za1juajAzy1AEtNZ2IfSqOfGaWcNyV4OZWcbc1WBmlqFATrxmZlnLqafBidfMGlS4q8HMLHPR5sRrZpapws1qkPRzKnSBRMQZdYnIzCwDRV2rYV5mUZiZZS2AoiXeiPh96XNJAyNiQ/1DMjPLRl5dDZ1eLydpvKSlwP3p8zdL+mXdIzMzqysRbdVtVdco9ZF0j6Q/VypXzYXKPwOOAtYARMS9wGFVR2JmVlRR5Va9M4FlnRWqaoWIiHi03a7WLoViZlY06TzearZqSNoNeB9waWdlq5lO9qiktwEhqZkqM7qZWeHVto/3Z8BXgMGdFaymxXsqcDowClgFHJg+NzPr4VTlxghJ80q2k19Vi3Qs8HREzK/mqJ22eCNiNfDxLnwSM7OeofoW7+qIGFvh9bcD75d0DNAfGCLp8oj4REeFq5nVsKekGyU9I+lpSddL2rPqcM3MiiiANlW3dVZVxNcjYreIGAN8FLitXNKF6roargCuBnYBdgX+BFxZxfvMzAotorqt1qpJvAMj4o8RsSXdLidpSpuZ9Wy1n05GRMyIiGMrlam0VsPw9OEtkr4GXJWG8BHg5q6FYmZWQEW7ZBiYT5Jot0Z2SslrAXy9XkGZmWVBRVudLCL2yDIQM7NMbUM3Qq1UtR6vpP2B/Sjp242IP9QrKDOz+qtuxkI9dJp4JX0HmECSeG8GjgbuBJx4zaxnK+rqZMDxwLuAJyPiM8CbgR3qGpWZWRbqMKuhGtV0NWyMiDZJWyQNAZ4GRtc+FKu1L501m3HjVrFuXX/+47Rj8g6nkB7953ace+qYl58/+Ug/Pnn2k3zo88/kF1TBjJ3wPKd+fxV9moJbrhzO1RftnHdItVHEhdBLzJM0FPg1yUyHF4C76hmU1cZf/ronN9z4eiZ+eXbeoRTW6Ndt4ld/fQCA1lb4+MFv5O1Hr8s3qAJpagpOP/dxvv7RPVn9RDM/v3k5s6ftwCPLe8dU/rxmNXTa1RARp0XEuoi4GHgP8Km0y6HLJL2wLe9L33uCpGWSbi/z+qclXVTpuJJ2lTRlW+OUdJyk/boSd54WL96J9ev75R1Gj7HwjsHs8tpN7LxbS96hFMY+B21g1cp+PPnIdmxpaWLG9UMZf9RzeYdVO0XrapB0cKXXImJB7cOp6LPA5yPizm2tICJWkfRZb6vjgD8DS7tRhxXUjOuHMuG4dXmHUSg7vqaFZ1a98o/36iea2ffg3nMHsMLN4wX+f4XXAjiiOweWdDbwYWA74NqI+E66/zqSPuT+wAURMUnSt4FDgd9IuiEizi5T7WhJM0iWsLw8Ir7X7phjgD9HxP6SBgKXAfsDD5CsQ3F6RMxLy/4QOBbYCHwA2At4P/BOSd8C/k9E/Kuk7pOBkwH6N3vssadp2SxmT9+Bk77xRN6hWJaK1scbEYfX66CSjgT2BsaRXBl3g6TDImImcFJErJU0AJgraWpEnCPpCGDi1sRYxjiSRLohfe9NFcqfBjwbEful85QXlrw2CJgdEd+U9GOSlvYPJN1Akrj/rbsiIiYBkwB2GLhrTv+O2raae9tgXvemDQwbuSXvUAplzZPNjNx188vPR+zSwuonmnOMqIZyvICiqlv/1MGR6XYPsADYlyQRA5wh6V5gNknLd+8Oa+jYXyJiTURsBK4haSWXcyjJ+hNExGJgUclrm0m6FCAZUBzThRisB5px3TB3M3TggYUDGbXHZnYevYm+zW1M+MA6Zk/vRb/oitbHW2cCfhQRl7xqpzQBeDcwPiI2pN0GXRk+bX+KtvWUtUS8vBhcK/mdp2756ldmccABTzNkyCb++Ifr+OPlb2L69L3yDqtwXtrQxII7BnPmj9vfWtDaWsUvvjmKc694iKY+MP2q4Tz8YO+Y0QDF7OOtp2nA9yVNjogXJI0CWkguzHg2Tbr7Aod0sd73pKuqbSQZCDupQtlZJH3Mt6czFd5URf3rqeJ+SkVx/o/fnncIPUL/gW1MWbI47zAKa+5tQ5h725C8w6iPtnwOW80dKCTpE+kAF5J2lzSuOweNiOkkC6zfJek+YApJQrsV6CtpGXAeSXdDV8wBppJ0G0ztpD/4l8BISUuBHwBLgM7myVwFnC3pHkluOpr1YIrqt1qrpsX7S5J/F44AziFp9U0F3tLVg0XE9iWPLwAu6KDY0WXeO6GTui8jmaVQ9rgRsZJk8A3gJeATEfFSmkT/CjzcQZxTSP5hICJmkaxZYWa9QdFmNZR4a0QcLOkegIh4VlJvmJU/kKSboZmkz/m0iNjcyXvMrDcpcB9vi6Q+pCFKGkluPSMg6Sjg/Ha7V0TEB7tST0SsByrdNdTMerkiD65dCFwL7JReVHA88K26RlVBREwjGZwzM+ueoibeiJgsaT7J0pACjouIZXWPzMysngKU02/3ahZC353kSrAbS/dFxCP1DMzMrO6K2uIFbuKVm172B/YgWdvgjXWMy8ys7mrVxyupPzCTZO2ZvsCUrevPdKSaroZXXViQrlp2WjfjNDPrTTYBR6QXhDUDd0q6JSI6vBahy1euRcQCSW/tbpRmZrmrUYs3XWJg6zrezelWtvZq+nj/b8nTJuBgYFU3YjQzy1+NB9fSabfzgdcBv4iIu8uVrWZ1ssEl23Ykfb4fqEGcZmb5qn51shGS5pVsJ/9bVRGtEXEgsBswLl1utkMVW7xpBh8cERO35TOZmRWV6NLg2uqIqOqCq4hYl96i7L1Ah6svlW3xSuobEa2Al7gys96pRuvxShqZ3hSY9CYO7wHuL1e+Uot3Dkl/7sL0zgt/Al58Od6IazoPx8ysoGq78tguwO/TXoIm4OqI+HO5wtXMaugPrCFZnWzrfN4gucODmVnPVbtZDYuAg6otXynx7pTOaFjMKwn35eNsW3hmZsVRxEuG+wDb8+qEu5UTr5n1fAW8ZPiJiDgns0jMzLKU412GKyXefJZmNzPLSBHX431XZlGYmeWhaIk3ItZmGYiZWdaK2OI1M+u9gtxuYubEa2YNSeQ3kOXEa2aNy10NZmbZch+vmVnWnHjNzDJU5LsMm5n1Wm7xmplly328ZmZZc+LtPbYM6MO6/YbkHUZhHbXrgXmHUHgfv/+xvEMovL/u0/063OI1M8tSQVcnMzPrtYRnNZiZZc8tXjOzbCnyybxOvGbWmNzHa2aWPc9qMDPLWF6Da035HNbMrACiyq0TkkZLul3SUklLJJ1ZqbxbvGbWmKKmXQ1bgC9HxAJJg4H5kv4SEUs7KuwWr5k1rhq1eCPiiYhYkD5eDywDRpUr7xavmTUk0aUW7whJ80qeT4qISR3WK40BDgLuLleZE6+ZNa7q5/GujoixnRWStD0wFTgrIp4vV86J18waU40XQpfUTJJ0J0fENZXKOvGaWcOqVeKVJOA3wLKI+Gln5T24ZmaNq0aDa8DbgU8CR0hamG7HlCvsFq+ZNaxaTSeLiDtJxuuq4sRrZo0p6MrgWk058ZpZw/JaDWZmGfJC6GZmWYtwV4OZWdbc1WBmljUnXjOzbLnFa2aWpQBa3cdrZpYpt3jNzLLmWQ1mZtlyi9fMLEu+vbuZWbaSO1C4q8HMLFPyrAYzswy5q8HqYacdXuDbJ97O8MEbiBDXz34DV9/5przDKpSxE57n1O+vok9TcMuVw7n6op3zDqlwNj8vZn9rGM8tbwbBIT98lpEHbc47rBrwWg1WB61t4sIbD+HBx0cycLvN/O6sa5izfDdWPjUs79AKoakpOP3cx/n6R/dk9RPN/Pzm5cyetgOPLO+fd2iFMu+HQ9n1HS9x2IVrad0MrS9Vvd534eU1qyG3W/9IeqEb7z1B0jJJt9cypk6OOVTSaVkdrxbWrB/Eg4+PBGDDpn6sfGooI4e8mHNUxbHPQRtYtbIfTz6yHVtamphx/VDGH/Vc3mEVyub14ul527HX8RsA6NMP+g3JKVvVw9YVyjrbaqyn3nPts8DnI+LwDI85FOhRibfUa4at5/Wj1rDkkZ3yDqUwdnxNC8+s6vfy89VPNDNil5YcIyqeFx7rS//hbcz++jBu/uBOzP7WMLZs6CUt3vQuw9VstVaIxCvpbElzJS2S9L2S/ddJmi9piaST033fBg4FfiPpJ2Xqe6OkOekN5xZJ2lvSGEn3S7pM0oOSJkt6t6RZkpZLGpe+97uSJpbUtVjSGOA8YK+0zg6PW1QD+rXwo09N52fXj2fDpn6dv8EsFVtg7dJm9j7xRY659mn6Dmhjya8H5x1W7bRFdVuN5d7HK+lIYG9gHMnUuhskHRYRM4GTImKtpAHAXElTI+IcSUcAEyNiXplqTwUuiIjJkvoBfYCdgdcBJwAnAXOBj5Ek8fcD3wCOqxDq14D9I+LAMp/jZOBkgH4Di9OH2qeplXM/NZ1pC/bm74v3zDucQlnzZDMjd31lkGjELi2sfqI5x4iKZ+BrWhm4cysj3pycp92P2tirEm9e83iL0OI9Mt3uARYA+5IkYoAzJN0LzAZGl+zvzF3ANyR9FXhtRGxM96+IiPsiog1YAvwtIgK4DxjTnQ8REZMiYmxEjO3bf1B3qqqh4Jsf/jsPPzWUq2YekHcwhfPAwoGM2mMzO4/eRN/mNiZ8YB2zp++Qd1iFMmBkGwN3aeX5h5I22pN39WeHvbbkHFUN5dTHm3uLl6SV+6OIuORVO6UJwLuB8RGxQdIMoKrh5oi4QtLdwPuAmyWdAjwEbCop1lbyvI1XzsUWXv0PUo8d4j5gzJMcPXY5/1w1nN9/aQoAF98yjrvu3z3nyIqhrVX84pujOPeKh2jqA9OvGs7DD/bY/9x1M/Zb65h19nDaWmD70a0ccu7avEOqjSD5P78GJP0WOBZ4OiL276x8ERLvNOD7kiZHxAuSRgEtwA7As2nS3Rc4pNoKJe0JPBQRF0raHTiAJPFWYyXJCUTSwcAe6f71QI/6jbVo5S6Mn3hK3mEU2tzbhjD3tiF5h1Fow9/QwtFTn847jJoTUcuuhsuAi4A/VFM4966GiJgOXAHcJek+YApJgrsV6CtpGcnA1uwuVPthYLGkhcD+VHkyUlOB4ZKWAF8AHkzjXAPMSgfbetTgmpmV0dZW3daJdEyq6p8CubV4I2L7kscXABd0UOzoMu+d0End55Ek61JrSZLw1jKfLnm8cutraX/wkWXq/Vil45pZD9K1roYRkkoH8ydFxKRtPXQRuhrMzHLRha6G1RExtlbH7dGJV9JRwPntdq+IiA/mEY+Z9TBeq6HrImIayeCcmVkX5bdITu6Da2ZmuQhqNo9X0pUk1w/sI+kxSZ+tVL5Ht3jNzLqjVguhR8SJXSnvxGtmjct9vGZmGQrqsgBONZx4zaxB+Q4UZmbZc+I1M8uYE6+ZWYYioLU1l0M78ZpZ43KL18wsQ57VYGaWA7d4zcwy5sRrZpYhD66ZmeXALV4zs4w58ZqZZSk8q8HMLFMBETW6v3sXOfGaWeNyi9fMLEOe1WBmlgMPrpmZZSva3MdrZpYhL4RuZpYtL5JjZpatACKnwbWmXI5qZpa3CIi26rYqSHqvpAck/VPS1yqVdYvXzBpW1KirQVIf4BfAe4DHgLmSboiIpR2Vd4vXzBpX7Vq844B/RsRDEbEZuAr4QLnCipxG9XozSc8AD+cdR4kRwOq8gyg4n6PKinh+XhsRI7f1zZJuJflc1egPvFTyfFJETCqp63jgvRHxufT5J4G3RsQXOqrMXQ110J0vQz1ImhcRY/OOo8h8jirrjecnIt6b17Hd1WBm1n2PA6NLnu+W7uuQE6+ZWffNBfaWtIekfsBHgRvKFXZXQ2OY1HmRhudzVJnPTwURsUXSF4BpQB/gtxGxpFx5D66ZmWXMXQ1mZhlz4jUzy5gTr5lZxpx4cyJpjKTFecdRS5K+Uad6X+jGe0+QtEzS7WVe/7SkiyodV9KukqZsa5ySjpO0X1fi3hb1PE/1IGmopNOyOl6ROPEWWHr9d09Sl8TbTZ8FPh8Rh29rBRGxKiKO70YMxwF1T7zd1O3ztA2GAk68lrm+kianLY0pkgZKWinpfEkLgBMknSjpPkmLJZ0PL7dOfpo+PlPSQ+njPSXNSh+vlPQ9SQvS9+9bLghJ75S0MN3ukTRY0gRJMyXdlK64dLGkprR8RzGdBwxI65hcrxMm6WxJcyUtkvS9kv3XSZovaYmkk9N93wYOBX4j6ScVqh0taYak5ZK+08ExX/51kv43ulrSUknXSrpb0tiSsj+UdK+k2ZJ2lvQ24P3AT9Jzs1eNTkVFtT5Pkt4oaU76GRZJ2js9L/dLukzSg+l3+d2SZqXnclz63u9KmlhS12JJY4DzgL3SOiv99+l9IsJbDhswhmRJ0Lenz38LTARWAl9J9+0KPAKMJJlzfRtJ6+k1wNy0zBSSydujgE8BP0r3rwS+mD4+Dbi0Qiw3lsSxfXqsCSTXpu9JMi/xL8Dx5WJK3/tCnc7VC+nfI0nmk4qk0fBn4LD0teHp3wHAYmDH9PkMYGyFuj8NPAHsWPLese2OOwZYnD6eCFySPt4f2FJSPoD/nT7+MfCt9PFlwPEZfKfqeZ5+Dnw8fdwvff+Y9PO/KT3OfJLvsUgWiLkuLf9dYGJJXYvT9758Xhttc4s3X49GxKz08eUkrQ6A/07/vgWYERHPRMQWYDLJ/0BPAttLGkxymeIVwGHAO4A7Suq/Jv07n+RLXs4s4KeSzgCGpscCmBPJakutwJVpfB3GtA2ffVscmW73AAuAfYG909fOkHQvMJvknOzdYQ0d+0tErImIjSTn7NAKZQ8lWXmKiFgMLCp5bTNJkoPOz3k91eM83QV8Q9JXSRan2ZjuXxER90VEG7AE+Fsk2fU+8vv8hecr1/LV/uqVrc9frOK9/wA+AzxAkmxPAsYDXy4psyn920qF/9YRcZ6km4BjgFmSjuokvryIpEV/yat2ShOAdwPjI2KDpBkkq0lVq1afsyVNOtDJOa+zmp+niLhC0t3A+4CbJZ0CPMQr3zGAtpLnbbzy+bfw6m7Nrvy36ZXc4s3X7pLGp48/BtzZ7vU5wDsljUgH2k4E/p6+dgfJz96ZJC2bw4FNEfFcV4OQtFfaajmfpNtia3/wOCXXnjcBH0njqxRTi6Tmrh6/C6YBJ0naPo17lKSdgB2AZ9Nksi9wSBfrfY+k4ZIGkHTlzKpQdhbw4fT4+5H8zO7MemBwF2PqjpqfJ0l7Ag9FxIXA9cABXYhnJXBwWs/BwB7p/qzPS2E48ebrAeB0ScuAYcCvSl+MiCeArwG3A/cC8yPi+vTlO0h+Ks5MuwIe5d8Td7XOSgc8FgEtwC3p/rnARcAyYAVwbScxTQIW1WtwLSKmk3Sr3CXpPpL+7cHArSQDlctIBmxmd7HqOcBUkm6DqRExr0LZXwIjJS0FfkDy87qzf+yuAs5OBy7rPrhWp/P0YWCxpIUkfdt/6MJ7pwLDJS0BvgA8mMa5huQX1uJGG1zzWg3WofRn6cSIODbnUAolbeU3R8RLaRL9K7BPJHcdMKuK+3jNumYgcHvapSLgNCdd6yq3eBuIpM8AZ7bbPSsiTs8jniylA4bnt9u9IiI+mEc8ReXzlA0nXjOzjHlwzcwsY068ZmYZc+K1XEhqTa/RXyzpT5IGdqOuy5TcXhtJl6rCSmBK1qB42zYcY6Wkf7sVeLn97cp0adWw9msbWO/jxGt52RgRB0bE/iSX2p5a+qKkbZpxExGfi4ilFYpMALqceM1qyYnXiuAO4HVpa/QOSTcASyX1kfSTklW2TgFQ4iIlq6b9Fdhpa0VKVhkbmz5+r5LV2e6V9Ld0RaxTgS+lre13SBopaWp6jLmS3p6+d0dJ05Ws4nUpydSxitTByl8lr/1Xuv9vkkam+/aSdGv6njtUYQU56108j9dylbZsjya5qgqSS0v3j4gVafJ6LiLeImk7kqucpgMHAfuQrHG7M7CUZFWs0npHAr8mWVRohaThEbFW0sUkq3j9Z1ruCuC/IuJOSbuTXG77BuA7wJ0RcY6k95GsV9uZk9JjDADmSpqaXp01CJgXEV9SsgTjd0iu4JoEnBoRyyW9leSquCO24TRaD+PEa3kZkF5+CkmL9zckXQBzImJFuv9I4ICt/bckaw3sTbIa2pXppdKrJN3WQf2HkFxOvQIgItaWiePdwH7Syw3aIekaB4cBH0rfe5OkZ6v4TGdI2jrfdevKX2tIFozZuuLc5cA16THeBvyp5NjbVXEM6wWceC0vGyPiwNIdaQIqXZlNJGsKT2tX7pgaxtEEHBIRL3UQS9XUtZW/Ij3uuvbnwBqD+3ityKYB/5Fenouk10saRLIi20fSPuBdSFZma282cJikPdL3Dk/3t18Razrwxa1PJB2YPpxJsmIcko4mWcSokkorfzWRLCJPWuedEfE8sELSCekxJOnNnRzDegknXiuyS0n6bxcoufXOJSS/0q4Flqev/YFkke5XiYhngJNJftbfyys/9W8EPrh1cA04AxibDt4t5ZXZFd8jSdxLSLocHukk1korf71IssTmYpI+3HPS/R8HPpvGt4Tkrg3WAHzJsJlZxtziNTPLmBOvmVnGnHjNzDLmxGtmljEnXjOzjDnxmpllzInXzCxj/wOdMAfGuAhqWwAAAABJRU5ErkJggg=="
},
"metadata": {
"needs_background": "light"
}
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Let's also show the classification report:"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 14,
"source": [
"print(metrics.classification_report(gt_classes, predicted_classes, target_names=image_classes))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" precision recall f1-score support\n",
"\n",
" brown_spot 0.40 0.25 0.31 8\n",
" leaf_blight 0.64 0.88 0.74 8\n",
" leaf_smut 0.75 0.75 0.75 8\n",
"\n",
" accuracy 0.62 24\n",
" macro avg 0.60 0.62 0.60 24\n",
"weighted avg 0.60 0.62 0.60 24\n",
"\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Summing it up and thinking about what's next\n",
"\n",
"These results are better than the histogram-based model we used in the EDA notebook, but they still aren't outstanding... We have plenty of room for improvement.\n",
"\n",
"Look around at techniques that might be able to improve the performance of the model. Some things to consider:\n",
"\n",
"* We only trained for 5 epochs because overfitting could be an issue if we trained longer. Look for ways to mitigate overfitting.\n",
"* We used the `GlobalAveragePooling` layer to connect the Xception base model to our new classification layer. This led to only ~6,000 trainable parameters. Using a `Flatten` layer instead would have given us about 600,000 parameters! The downside is that all those additional parameters could make the overfitting issue even worse.\n",
"* Xception is only one of several pre-trained models available from the Keras Applications library - we could try others. Take a look, and notice how the parameter counts vary between different models there."
],
"metadata": {}
}
],
"metadata": {
"interpreter": {
"hash": "5162b779d3627a0aa18120aa507efd6d3bc07a2fac0ffddef9ceec1244a53b8c"
},
"kernelspec": {
"display_name": "Python 3.8.9 ('.venv': poetry)",
"name": "python3"
},
"language_info": {
"name": "python",
"version": ""
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment