Fish species classification using deep learning and the fastai library
Deep learning is everywhere. The surge of new methods for analyzing all kinds of data is astonishing. Especially image analysis has been impacted by deep learning with new methods and rapid improvements in model performance for many different tasks. Convolutional neural networks (CNN) can be used to classify images with high accuracy and new libraries have made it easier than ever to build and train such networks. The best thing is that you do not need large amounts of data or specialized GPU hardware to experiment with techniques such as transfer learning, where we only need to fine-tune the last part of a pre-trained network. Here, we download image data from the web, clean it up and train an accurate CNN to classify five freshwater fish species using the fastai Python library.
Getting some data
We start by downloaded images from https://duckduckgo.com/. Using the following JavaScript snippet and steps:
urls = Array.from(document.querySelectorAll('.tile--img__img')).map(el=> el.hasAttribute('data-src')? "https:"+el.getAttribute('data-src'):"https:"+el.getAttribute('src'));
window.open('data:text/csv;charset=utf-8,' + escape(urls.join('\n')));
- Enter search query at https://duckduckgo.com/
- Scroll down past many images
- Execute snippet in console (right click webpage -> inspect -> console)
- Save downloaded file with image urls in working folder with name of category
The result is a text file with URLs that we can use to download the images (200 images for five categories = 1000 images).
#Import fastai library and set up a path for the images (version 1.0)
from fastai.vision import *
from fastai.metrics import error_rate
from fastai.widgets import *
#Name of categories
#The five different freshwater fish species we searched images for
categories = ["tinca_tinca", "anguilla_anguilla", "esox_lucius", "perca_fluviatilis", "rutilus_rutilus"]
#Define working folder
working_folder = %pwd
working_folder = Path(working_folder)
#Download images
#Get urls in .csv files, make dirs and download images
for i in categories:
folder = i
file = i+'.csv'
dest = working_folder/i
dest.mkdir(parents=True, exist_ok=True)
download_images(working_folder/file, dest, max_pics=200)
Approaches using other search engines such as Google images or Bing are also available. However, I found this to be the easiest way. For more on this see this post on the fast.ai forum.
Cleaning the image data
The downloaded images may vary a lot in quality and some fraction of the images are most likely not what you searched for. We build an initial network and use a fastai widget to clean up the images.
#Define databunch, use all data as we are interested in filtering out nonsens images
rawdata = ImageDataBunch.from_folder(working_folder, train=".", valid_pct=0,
ds_tfms=get_transforms(), bs = 32, size=224, num_workers=4).normalize(imagenet_stats)
#Define learner and fit 5 epochs
learn = cnn_learner(rawdata, models.resnet34, metrics=error_rate)
learn.fit_one_cycle(5)
#Use image cleaner widget to delete or re-categorise images.
#Paths and labels of the "cleaned" dataset are available in the cleaned.csv file.
ds, idxs = DatasetFormatter().from_toplosses(learn)
ImageCleaner(ds, idxs, working_folder)
Training the CNN fish classifier
With the cleaned image data (664 images in total) we will fine-tune the last part of a pre-trained network (ResNet34 architecture). First we define a function to load the cleaned image data and resample it to a given resolution. We will use images with differing resolution for some experiments.
#Function return databunch from an input image size
def data_cleaned(image_size):
data = ImageDataBunch.from_csv(working_folder, folder=".", valid_pct=0.2, csv_labels='cleaned.csv',
ds_tfms=get_transforms(), size=image_size, num_workers=4, bs = 32).normalize(imagenet_stats)
return(data)
First approach (without resizing) - We use the ‘high-resolution’ data to train a model:
#Regular training network on "high" resolution images
#Further improvement could maybe be obtained by searching for more appropriate learning rates.
learn_without = cnn_learner(data_cleaned(224), models.resnet34, metrics=error_rate)
learn_without.fit_one_cycle(6) #~02:30 per epoch, error rate = 0.036
We train for 6 epochs each taking approximately 02:30 minutes and obtain an error rate of 3.6%. Pretty good considering the noisy image data and the small effort put into training the model. We could have trained much faster if a GPU had been available but this goes to show that decent models can be created from smaller data sets on regular desktop computers.
Second approach (with resizing) - We train the last part of the network on small images, then use these weights to train on medium images and finally use the medium resolution weights to train network on high resolution images:
learn_low = cnn_learner(data_cleaned(64), models.resnet34, metrics=error_rate)
learn_low.fit_one_cycle(2)
learn_low.save("low") #~00:30 per epoch
learn_med = cnn_learner(data_cleaned(112), models.resnet34, metrics=error_rate).load("low")
learn_med.fit_one_cycle(2)
learn_med.save("med") #~01:00 per epoch
learn_high = cnn_learner(data_cleaned(224), models.resnet34, metrics=error_rate).load("med")
learn_high.fit_one_cycle(2)
learn_high.save("high") #~02:30 per epoch
We train the model in three steps (total 6 epochs) increasing the resolution and re-using the weights each time. This is also known as progressive resizing but in this case it is only performed with the last part of the network. We obtain an error rate of 5.4%. Notice the drastic drop in time (~45% reduction) required to train the model while obtaining similar error rates. The error rates could likely be further improved by searching for more appropriate learning rates, increasing the number of epochs and finally fine-tuning by unfreezing all weights in the network. Other architectures could be tried out as well but would increase the computing required.
Using the model
We can use the model to perform classification for new images.
new_image = open_image(working_folder/'pike_small.jpg')
pred_class, pred_idx, outputs = learn_high.predict(new_image)
print("%.3f probability of %s" % (outputs[pred_idx], pred_class))
The trained CNN model correctly predicts that the fish below is very likely (probability of 98.4%) a pike (Esox lucius)
Conclusions
Finally, a few take-aways from this post:
- Image data from the web can be used to build CNN classification models
- Accurate classification models can be built rapidly using the fastai library
- While the quality of the images are not consistent, we can use an initial network to guide the process of filtering bad images.
- Progressive resizing can reduce the time spent fine-tuning the network.
- Specialised hardware is not needed for deep learning on smaller datasets