Plant ID app (part 1): Data and model training

Plants species can be truly difficult to tell apart and this job often requires expert knowledge. However, when images are available computer vision methods can be used to guide us in the right direction. Deep learning methods are very useful for image analysis. Training convolutional neural networks have become the way to solve a wide range of image task including segmentation, classification, etc. Here, we will train a lightweight image classification model to identify 100 different plant species. The model will be served from a REST API which can be accessed from a simple landing page that we will create in part 2. First things first, we will get some data and train a classification model!

All code from parts 1 and 2 of this blog post can be found on GitHub.

Get plant images

To train our classification model, we need some data. Fortunately, the Global Biodiversity Information Facility (GBIF) hosts data on species occurrences which in many cases includes images. These images can be obtained by querying the GBIF database which returns species occurrences, and when present, URLs to the corresponding images. Accessing the data only requires a GBIF user but beware that licensing for images differs. Here, we narrow our plant species of interest to species that are used to determine stream plant index (“Dansk Vandplante Indeks” (in Danish)) or DVPI in short. This way the web application can be used to identify plant species and calculate the DVPI value by calling an external API.

From R, the GBIF request can be constructed and submitted:

library(tidyverse);library(rgbif)

#Read file containing species of interest
dvpi <- read_csv("data/dvpi_species_sc.csv") |> 
  select(long_edit, sc)

#Lookup GBIF names/ids/keys for each species
gbif_keys <- name_backbone_checklist(dvpi$long_edit)

#Bind to dataframe and write to file
dvpi_keys <- bind_cols(dvpi, gbif_keys)

write_csv(dvpi_keys, "data/dvpi_species_sc_gbif.csv")

#Create GBIF request for all species with images (requires GBIF user)
#The resulting Darwin Core Archive can be downloaded from gbif.org
res <- occ_download(
  pred_in("taxonKey", dvpi_keys$usageKey),
  pred_in("basisOfRecord", 
          c('HUMAN_OBSERVATION','OBSERVATION', 
            'MACHINE_OBSERVATION', 'LIVING_SPECIMEN')),
  pred("mediatype", "StillImage"),
  user = "",
  email = "",
  pwd = "",
  format = "DWCA"
)

After a short wait, the Darwin Core Archive can be downloaded from the GBIF website. After unpacking the file, the occurrence and media data can be combined and filtered:

library(data.table);library(rjson)

#Read file with species data
dvpi_gbif <- fread("data/dvpi_species_sc_gbif.csv")

#Read text files from the GBIF request
#One file has media data and the other species occurrence data
media <- fread("data/0364086-210914110416597/multimedia.txt")
occ <- fread("data/0364086-210914110416597/occurrence.txt")
occ_sub <- occ[, .(gbifID, taxonKey, datasetName)]

#Join files
media_taxon <- media[occ_sub, on="gbifID"]

#Define the licenses to keep - these are the most permissive ones
valid_licenses <- c("http://creativecommons.org/licenses/by/4.0/",
                    "http://creativecommons.org/publicdomain/zero/1.0/")

#Filter media data
#In addition to licenses, we also filter only jpeg and observations from 
#the iNaturalist dataset which are generally of high quality (good labels)
media_valid <- media_taxon[license %in% valid_licenses & 
                             taxonKey %in% dvpi_gbif$usageKey & 
                             format == "image/jpeg" & 
                             datasetName == "iNaturalist research-grade observations", ]

#Data has been requested for 194 species 
#but we select the 100 most common for our classification model
#Get the top 100 species with most media data
taxon_occur <- table(media_valid$taxonKey)
taxon_occur_top_100 <- names(sort(taxon_occur, decreasing = TRUE)[1:100])

#For each of the 100 species we sample 250 URLs for each species
url_list <- lapply(taxon_occur_top_100, 
                   \(x) sample(media_valid[taxonKey == x, identifier], size = 250))
names(url_list) <- taxon_occur_top_100

#Convert to JSON and write to file
url_list_json <- toJSON(url_list)
writeLines(url_list_json, "data/url_list_100.json")

In this case, we end up with a subset of the 100 most common plant species and 250 URLs with images. These images can be downloaded using a simple Python script where asynchronous requests and file writing are used to speed the download process:

import json 
import asyncio
from aiohttp import ClientSession
import aiofiles
import os

#Reading JSON file with 250 URLs for each of the 100 species
url_list_file = open("data/url_list_100.json")
url_list = json.load(url_list_file)

#Define functions for async HTTP GET request to download images
#This is much faster than the corresponding synchronous version
#https://progerhub.com/tutorial/downloading-random-images-using-python-requests-and-asyncio-aiohttp
async def make_request(session, url, taxon, counter):

    img_dir = "data/images"
    spec_dir = os.path.join(img_dir, taxon)

    if not os.path.exists(spec_dir):
        os.mkdir(spec_dir)

    img_path = os.path.join(spec_dir,  "{}_{}.jpeg".format(taxon, str(counter)))

    if os.path.exists(img_path):
        return

    try:
        resp = await session.request(method="GET", url=url)
    except Exception as ex:
        print(ex)
        return

    if resp.status == 200:
        async with aiofiles.open(img_path, 'wb') as f:
            await f.write(await resp.read())

async def bulk_request(url_list, taxon):
    async with ClientSession() as session:
        tasks = [make_request(session, url, taxon, count) for count, url in enumerate(url_list)]
        await asyncio.gather(*tasks)

def download_images(url_list, taxon):
    asyncio.run(bulk_request(url_list, taxon))

#Download images for each species
for taxon in url_list.keys():
    download_images(url_list[taxon], taxon)

After the approximately 25.000 images (100 species times 250 images) have been downloaded, the images are resized using another Python script to speed up subsequent model training:

from PIL import Image 
from pathlib import Path
import os
import pandas as pd
import pickle

#Read species data and write taxonkey-name dictionary for later
df = pd.read_csv("data/dvpi_species_sc_gbif.csv")
taxon_key_dict = {str(k): v for k, v in zip(df["usageKey"], df["long_edit"])}

with open("data/taxon_key_dict.p", "wb") as output_file:
    pickle.dump(taxon_key_dict, output_file)

#Preprocess images by resizing to ease later model training
data_dir = Path("data")
image_dir = data_dir/"images"
preproc_dir = data_dir/"images_preproc"
size = (640, 640) 

for dir in image_dir.glob('*'):

    for img in dir.glob('*.jpeg'):
        
        _, _, taxon, file = img.parts

        dir_out = preproc_dir/taxon
        image_out = dir_out/file

        if not os.path.exists(dir_out):
            os.mkdir(dir_out)

        try:
            image=Image.open(img)
        except IOError as er:
            print(er)
            continue

        if os.path.exists(image_out):
            continue

        image.thumbnail(size) 
        image.save(dir_out/file) 

Train classification model

The fastai Python library is used to create the plant species classification model as it provides a quick way to train the model with good defaults. Additionally, a pre-trained model from the timm model library is used. It is reasonable to train a ResNet model for this task, which would often act as a good baseline. However, using some of the newer neural network architectures such as the EfficientNet type models may yield similar performance using a smaller model. Reducing model size and complexity can be advantageous when deploying the service online. As is often the case, the actual training of the model is straightforward and most of the time/code is spent on data wrangling (above). Although spending some additional time on playing with the augmentation, model, and training settings should likely squeeze out a bit more performance.

Libraries are imported, and paths and dataloaders are defined:

from fastai.vision.all import *
import timm

#Define directories
data_dir = Path("../data")
img_dir = data_dir/"images_preproc"
model_dir = data_dir/"model"

#Define function to get dataloader with different image and batch size
def get_dls(bs, size):
    dl = ImageDataLoaders.from_folder(
        path = img_dir, 
        valid_pct = 0.2, 
        bs = bs, 
        item_tfms = Resize(460),
        batch_tfms = aug_transforms(size=size, min_scale=0.85))
    return(dl)

We choose a simple resizing strategy to speed up model training by doing two rounds of training on small (112x122 pixels) and larger (224x224 pixels) images. The learner is a pre-trained small EfficientNet which showed to perform similarly to a ResNet18 model but only half the size.

The model is trained in two rounds and saved:

#Get dataloader with lower resolution images for first round of training
#Second round of training will use higher resolution images (progressive resizing)
dls = get_dls(128, 112)

#Use model from the timm library
#Here we use a pretrained light-weight efficientnet
learn = vision_learner(dls, 
                       "efficientnet_b0",
                       pretrained=True,
                       loss_func=CrossEntropyLossFlat(),
                       metrics=[accuracy, top_k_accuracy]).to_fp16()
                       
#First round of training
learn.fit_one_cycle(20, 1e-3)

#Replace dataloader and loss function for second round of training
learn.dls = get_dls(64, 224)
learn.loss_func = LabelSmoothingCrossEntropy()

#Second round of training
learn.fit_one_cycle(10, slice(1e-6, 1e-4))

#Save model
learn.path = model_dir
learn.export("effnet_b0.export")

The model ends up with an accuracy of ~62% and a top-5 accuracy of ~86%. This may not seem very high but results in a good performance on unseen data. Downloading an increased number of images would improve model performance. Even though the images and labels are generally of good quality it does vary a lot which adds some noise during model training e.g. they may range from close-ups to images taken at distance containing multiple species.

Next steps

In part 2 of this blog post, the REST API serving the model and the landing page will be created…

Avatar
Kenneth Thorø Martinsen
Biologist (PhD)

Research interests in data science and carbon cycling.

Related