Crop disease classification¶
By Emile Hazard
This project uses the dataset provided at https://www.kaggle.com/datasets/mexwell/crop-diseases-classification/data. This is a set of pictures of cassava (manioc) plants with labels indicating wether they are healthy or afflicted with one of 4 diseases. My objective is to classify pictures of plants to guess their possible disease. This will make use of a convolutional neural network (CNN), and require some work on the initial dataset to make it more amenable to training.
Data exploration and cleaning¶
The first step is to import the necessary libraries.
# Data
import numpy as np
import pandas as pd
import json
# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
# Loading bar
from tqdm import tqdm
# Random sampling
import random
RANDOM_STATE = 42 # Used when consistency is needed
# Images
from PIL import Image
#import piexif # (metadata access)
# Path manipulation with os
import os
# Machine learning
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.regularizers import L1, L2
from tensorflow.keras.utils import Sequence
from tensorflow.keras.metrics import Accuracy
This next part is only here to check the availability of GPU for computation. This speeds up the learning process significantly.
list_GPUs = tf.config.list_physical_devices('GPU')
print("Num GPUs Available: ", len(list_GPUs))
for gpu in list_GPUs:
tf.config.experimental.set_memory_growth(gpu, True)
Num GPUs Available: 1
Loading the dataframe containing labels:
labels_path = "data/train.csv"
images_directory = "data/train_images"
labels = pd.read_csv(labels_path)
labels.tail(2)
image_id | label | |
---|---|---|
21395 | 999616605.jpg | 4 |
21396 | 999998473.jpg | 4 |
Here is how these images look:
img = Image.open(os.path.join(images_directory, labels.loc[:, 'image_id'].iloc[-1]))
plt.imshow(img)
plt.axis('off')
plt.show()
However there seems to be missing pictures in the data:
try:
img = Image.open(os.path.join(images_directory, labels.loc[:, 'image_id'].iloc[0]))
except:
print(f"Could not load image {labels.loc[0, 'image_id']}: not found in directory")
Could not load image 1000015157.jpg: not found in directory
Here are the pictures that are actually there:
print(f"{len(labels)} pictures according to the labels dataframe")
pictures_set = set(os.listdir(images_directory))
labels = labels.loc[labels['image_id'].isin(pictures_set)]
print(f"{len(labels)} pictures in the directory")
21397 pictures according to the labels dataframe 17938 pictures in the directory
The names of the various diseases we are studying are as follows:
labels_map_path = "data/label_num_to_disease_map.json"
labels_map_file = open(labels_map_path)
labels_map = json.load(labels_map_file)
labels_map
{'0': 'Cassava Bacterial Blight (CBB)', '1': 'Cassava Brown Streak Disease (CBSD)', '2': 'Cassava Green Mottle (CGM)', '3': 'Cassava Mosaic Disease (CMD)', '4': 'Healthy'}
Here are examples for each of these from the dataset (run again for different plots with random):
fig, ax = plt.subplots(2,3, figsize = (16,9))
for i,j in [(0,0), (0,1), (0,2), (1,0), (1,1)]:
image_name = labels.loc[labels['label'] == 3*i+j, 'image_id'].iloc[random.randint(0,100)]
img = Image.open(os.path.join(images_directory, image_name))
ax[i,j].imshow(img)
ax[i,j].set_axis_off()
ax[i,j].set_title(labels_map[str(3*i+j)])
ax[1,2].set_axis_off()
plt.show()
And here is how the data is split between them:
plt.hist(labels['label'], bins = 5, range =(-.5, 4.5))
plt.show()
We may want to balance this before training to avoid a bias toward label 3 (Cassava Mosaic Disease).
A quick workaround is to keep only 2000 of label 3, but this drastically reduces the training data. I tried this first as it is faster to implement and to run, but the results were disappointing. Instead, we will augment the data using cropping and mirroring at the next step, with larger increases for less populated labels.
Choosing the size of the pictures¶
The pictures are all 800x600 pixels:
sizes = set()
for image_name in labels.loc[:,'image_id']:
img = Image.open(os.path.join(images_directory, image_name))
sizes.add(img.size)
print(f"Set of all picture sizes in the database: {sizes}")
Set of all picture sizes in the database: {(800, 600)}
Let us reduce them to 200x150. We check on an example that features are still visible to the human eye.
new_size = (200, 150)
i = random.randint(0,100)
image_name = labels.loc[:, 'image_id'].iloc[i]
img = Image.open(os.path.join(images_directory, image_name))
fig, ax = plt.subplots(1,2, figsize = (12,4))
fig.suptitle(f"Example labelled {labels_map[str(labels.loc[:,'label'].iloc[i])]}")
ax[0].imshow(img)
ax[0].set_axis_off()
ax[0].set_title("Before size reduction", fontsize = 10)
img = img.resize(new_size)
ax[1].imshow(img)
ax[1].set_axis_off()
ax[1].set_title("After size reduction", fontsize = 10)
plt.show()
Creating the training set¶
First, we define a function to help with data augmentation. This will create up to 10 different variation for a given picture.
def transform_img(img, i, new_size = new_size, init_size = (800, 600)):
"""
Applies a different transform on the image for each value of i from 0 to 9
"""
imgout = img
width, height = init_size
if i%2 == 0:
# Symetry
imgout = img.transpose(Image.FLIP_LEFT_RIGHT)
if i//2 == 0:
# Leave as is
pass
elif i//2 == 1:
# Crop to upper left
imgout = imgout.crop((0, 0, width * 3 // 4, height * 3 // 4))
elif i//2 == 2:
# Crop to upper right
imgout = imgout.crop((width // 4 , 0, width, height * 3 // 4))
elif i//2 == 3:
# Crop to lower left
imgout = imgout.crop((0, height // 4, width * 3 // 4, height))
elif i//2 == 4:
# Crop to lower right
imgout = imgout.crop((width // 4 , height // 4, width, height))
return imgout.resize(new_size)
Now the training set can be created by augmenting the data on the go. The list duplicate
contains the augmentation factor for each label: we want to augment 1,2 and 4, and 0 even more.
duplicate = [10,5,5,1,5]
total_length = sum(duplicate[i] * len(labels.loc[labels['label'] == i]) for i in range(5))
X = np.empty((total_length, new_size[1], new_size[0], 3))
y = np.empty((total_length))
j = 0
for i in tqdm(range(len(labels))):
image_name = labels.loc[:, 'image_id'].iloc[i]
image_label = labels.loc[:, 'label'].iloc[i]
img = Image.open(os.path.join(images_directory, image_name))
for k in range(duplicate[image_label]):
X[j] = np.array(transform_img(img, k)) / 255
y[j] = image_label
j += 1
100%|███████████████████████████████████████████████████████████████████████████| 17938/17938 [02:56<00:00, 101.42it/s]
This yields a more balanced dataset:
plt.hist(y, bins = 5, range =(-.5, 4.5))
plt.show()
array([3., 3., 3., ..., 4., 4., 4.])
print(X.shape, y.shape)
(50187, 150, 200, 3) (50187,)
Here is an example of all images that were derived from a single one:
fig, ax = plt.subplots(2, 5, figsize = (15,4))
for i in range(5):
for j in range(2):
ax[j][i].imshow(X[3+i+5*j])
ax[j][i].axis('off')
plt.show()
Splitting the set into training (75%), cross validation (15%) and test (10%):
X_train, X_, y_train, y_ = train_test_split(X, y, train_size = 0.75, random_state = RANDOM_STATE)
del X
X_cv, X_test, y_cv, y_test = train_test_split(X_, y_, test_size = 0.4, random_state = RANDOM_STATE)
print(f"X_train: {X_train.shape}, y_train: {y_train.shape}, X_test: {X_test.shape}, y_test: {y_test.shape}")
X_train: (37640, 150, 200, 3), y_train: (37640,), X_test: (5019, 150, 200, 3), y_test: (5019,)
Using generators to store data, as Tensorflow works better this way (otherwise it will try loading everything to GPU memory).
class DataGenerator(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return batch_x, batch_y
train_gen = DataGenerator(X_train, y_train, 32)
cv_gen = DataGenerator(X_cv, y_cv, 32)
test_gen = DataGenerator(X_test, y_test, 32)
lambda_reg = 0.001
model = Sequential(
[
Conv2D(16, (3,3), 1, activation = 'relu', input_shape = (new_size[1], new_size[0], 3),
kernel_regularizer=L2(lambda_reg), bias_regularizer=L2(lambda_reg)),
MaxPooling2D(),
Conv2D(32, (3,3), activation = 'relu', kernel_regularizer=L2(lambda_reg),
bias_regularizer=L2(lambda_reg)),
MaxPooling2D(),
Conv2D(32, (3,3), activation = 'relu', kernel_regularizer=L2(lambda_reg),
bias_regularizer=L2(lambda_reg)),
MaxPooling2D(),
Flatten(),
Dense(128, activation = 'relu', kernel_regularizer=L2(lambda_reg), bias_regularizer=L2(lambda_reg)),
Dense(5, activation = 'linear', kernel_regularizer=L2(lambda_reg), bias_regularizer=L2(lambda_reg))
]
)
model.compile(
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
optimizer = 'adam',
metrics = ['accuracy']
)
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 148, 198, 16) 448 max_pooling2d (MaxPooling2D (None, 74, 99, 16) 0 ) conv2d_1 (Conv2D) (None, 72, 97, 32) 4640 max_pooling2d_1 (MaxPooling (None, 36, 48, 32) 0 2D) conv2d_2 (Conv2D) (None, 34, 46, 32) 9248 max_pooling2d_2 (MaxPooling (None, 17, 23, 32) 0 2D) flatten (Flatten) (None, 12512) 0 dense (Dense) (None, 128) 1601664 dense_1 (Dense) (None, 5) 645 ================================================================= Total params: 1,616,645 Trainable params: 1,616,645 Non-trainable params: 0 _________________________________________________________________
Training the model¶
We will keep the logs from training here.
log_dir = "logs"
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir = log_dir)
hist = model.fit(
train_gen,
epochs = 20,
validation_data = cv_gen,
callbacks = [tensorboard_callback]
)
Epoch 1/20 1177/1177 [==============================] - 118s 93ms/step - loss: 1.4800 - accuracy: 0.4010 - val_loss: 1.3920 - val_accuracy: 0.4320 Epoch 2/20 1177/1177 [==============================] - 72s 60ms/step - loss: 1.3434 - accuracy: 0.4698 - val_loss: 1.3268 - val_accuracy: 0.4849 Epoch 3/20 1177/1177 [==============================] - 70s 59ms/step - loss: 1.3169 - accuracy: 0.4813 - val_loss: 1.2900 - val_accuracy: 0.4884 Epoch 4/20 1177/1177 [==============================] - 79s 67ms/step - loss: 1.2997 - accuracy: 0.4880 - val_loss: 1.2838 - val_accuracy: 0.4975 Epoch 5/20 1177/1177 [==============================] - 71s 59ms/step - loss: 1.2835 - accuracy: 0.4973 - val_loss: 1.2730 - val_accuracy: 0.4983 Epoch 6/20 1177/1177 [==============================] - 69s 58ms/step - loss: 1.2574 - accuracy: 0.5163 - val_loss: 1.3036 - val_accuracy: 0.4997 Epoch 7/20 1177/1177 [==============================] - 59s 49ms/step - loss: 1.2280 - accuracy: 0.5382 - val_loss: 1.2059 - val_accuracy: 0.5522 Epoch 8/20 1177/1177 [==============================] - 73s 61ms/step - loss: 1.2127 - accuracy: 0.5478 - val_loss: 1.2146 - val_accuracy: 0.5387 Epoch 9/20 1177/1177 [==============================] - 68s 57ms/step - loss: 1.1972 - accuracy: 0.5612 - val_loss: 1.1752 - val_accuracy: 0.5687 Epoch 10/20 1177/1177 [==============================] - 69s 58ms/step - loss: 1.1826 - accuracy: 0.5694 - val_loss: 1.1723 - val_accuracy: 0.5709 Epoch 11/20 1177/1177 [==============================] - 76s 64ms/step - loss: 1.1693 - accuracy: 0.5766 - val_loss: 1.1583 - val_accuracy: 0.5770 Epoch 12/20 1177/1177 [==============================] - 55s 46ms/step - loss: 1.1596 - accuracy: 0.5815 - val_loss: 1.1574 - val_accuracy: 0.5677 Epoch 13/20 1177/1177 [==============================] - 71s 60ms/step - loss: 1.1511 - accuracy: 0.5889 - val_loss: 1.1796 - val_accuracy: 0.5728 Epoch 14/20 1177/1177 [==============================] - 76s 64ms/step - loss: 1.1501 - accuracy: 0.5891 - val_loss: 1.2306 - val_accuracy: 0.5481 Epoch 15/20 1177/1177 [==============================] - 77s 65ms/step - loss: 1.1405 - accuracy: 0.5952 - val_loss: 1.1177 - val_accuracy: 0.6071 Epoch 16/20 1177/1177 [==============================] - 57s 47ms/step - loss: 1.1395 - accuracy: 0.5976 - val_loss: 1.1521 - val_accuracy: 0.5883 Epoch 17/20 1177/1177 [==============================] - 53s 44ms/step - loss: 1.1290 - accuracy: 0.6058 - val_loss: 1.1822 - val_accuracy: 0.5902 Epoch 18/20 1177/1177 [==============================] - 70s 59ms/step - loss: 1.1303 - accuracy: 0.6065 - val_loss: 1.1214 - val_accuracy: 0.6099 Epoch 19/20 1177/1177 [==============================] - 76s 63ms/step - loss: 1.1243 - accuracy: 0.6085 - val_loss: 1.1695 - val_accuracy: 0.5933 Epoch 20/20 1177/1177 [==============================] - 49s 41ms/step - loss: 1.1161 - accuracy: 0.6154 - val_loss: 1.1564 - val_accuracy: 0.5964
Evaluating the model¶
Here is the progress of loss and accuracy during training:
fig, ax = plt.subplots(1,2, figsize = (12,5))
ax[0].plot(hist.history['loss'], color = 'blue', label = 'Train')
ax[0].plot(hist.history['val_loss'], color = 'red', label = 'Validation')
ax[0].set_title('Loss')
ax[0].legend(loc = "upper right")
ax[1].plot(hist.history['accuracy'], color = 'blue', label = 'Train')
ax[1].plot(hist.history['val_accuracy'], color = 'red', label = 'Validation')
ax[1].set_title('Accuracy')
ax[1].legend(loc = "upper left")
plt.show()
The model can now be evaluated:
y_test, yhat_test = [], []
for i in range(len(test_gen)):
X, y = test_gen.__getitem__(i)
yhat = np.argmax(model.predict(X, verbose = 0), axis = 1)
y_test += list(y)
yhat_test += list(yhat)
print(f"Accuracy: {sum(np.equal(yhat_test, y_test)) / len(y_test)*100:.2f}%")
Accuracy: 59.47%
list_bars = [[],[],[],[],[],[]]
for i in range(len(y_test)):
if yhat_test[i] == y_test[i]:
list_bars[0].append(yhat_test[i])
else:
list_bars[int(y_test[i])+1].append(yhat_test[i])
# Plotting this into a seaborn histogram
h = sns.histplot(list_bars, multiple = "stack", discrete = True, hue_order = np.array(range(5, -1,-1)), palette = ['#320a28', '#511730', '#702e37', '#8e443d', '#cb9173', '#32e875'])
h.set_xlabel('Actual condition')
h.set_ylabel('Number of pictures')
h.legend(['Correct', 'CBB', 'CBSD', 'CGM', 'CMD', 'Healthy'])
h.set_title('Summary of model classification')
h.set_xticks([0,1,2,3,4])
h.set_xticklabels(['CBB', 'CBSD', 'CGM', 'CMD', 'Healthy'])
plt.show()
This accuracy score of ~60% may seem low, but this is a significant leap from random (20%). Furthermore, looking at the pictures will show that a human may struggle to differentiate between these diseases.
One note is that this version of the model may actually be problematic, as it is prone to false negatives: labelling sick plants as healthy. One way to solve this would be to modify the training set to reduce the quantity of healthy examples.