COVID-19 Detection from Chest X-rays Using Transfer Learning¶

Problem Statement¶

Rapid detection of COVID-19 cases is critical to containing the virus as well as decreasing pressure on an overflowing healthcare system. Detection of COVID-19 has a lengthy clinical testing time, thus imaging tools, such as Chest X-rays, are a key instrument in helping to detect COVID-19 and speeding up the identification process. In this notebook, we train a deep convolutional neural network using transfer learning on a novel dataset of 15,000 chest X-rays to aid in the detection of COVID-19. Further, we implement local interpretable model-agnostic explanations (LIME) to provide insight into the model's predictions and enabling healthcare providers to make both time-efficient and correct diagnosis.

In [8]:
'''Load PyTorch and necessary libraries.'''

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F

import torchvision
import torchvision.datasets as datasets
from torchvision import models
import torchvision.transforms as transforms

import os
from pathlib import Path
import Augmentor
from IPython.display import Image, display
import splitfolders
import matplotlib.pyplot as plt
import numpy as np
import time
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
import seaborn as sns
import pandas as pd
from PIL import Image
dir = os.getcwd()

Data Collection and Augmentation¶

We use chest X-ray images from 3 categories: normal, pneumonia, and COVID-19 cases. The dataset, containing 5000 normal images, 5000 pneumonia images, and 4420 COVID-19 images was collected from eleven different publicly available datasets by Badwini et. al. [1] Using data augmentation with the Augmentor package, we generate an additional 580 COVID-19 images to construct a more balanced dataset.

In [84]:
dir = os.getcwd()
p = Augmentor.Pipeline(os.path.join(dir, 'preXrayData/covid'))
p.rotate(probability=0.7, max_left_rotation=10, max_right_rotation=10) # Rotate 70% of the images between 10 and -10 degrees
p.random_distortion(probability=1, grid_width=4, grid_height=4, magnitude=8) # Randomly distort the images while maintaining their aspect ratio
p.flip_left_right(probability=1) # Mirror the images from left to right
p.process()
p.sample(580)
Initialised with 4999 image(s) found.
Processing <PIL.Image.Image image mode=L size=3050x2539 at 0x7F103893BF40>: 100%|██████████| 4999/4999 [06:11<00:00, 13.46 Samples/s]   
Processing <PIL.Image.Image image mode=L size=299x299 at 0x7F0FFC56D810>: 100%|██████████| 580/580 [00:49<00:00, 11.61 Samples/s]     
In [146]:
'''Visualize augmented image.'''

aug_img_path = 'preXrayData/covid/covid_original_COVID-597.png_b8c0f5bb-77e0-4749-9359-d746af324bcb.png'
display(Image.open(aug_img_path))
No description has been provided for this image
In [87]:
'''Move additional generated images into correct folder.'''

src_path = os.path.join(dir, 'preXrayData/covid/output')
for each_file in Path(src_path).glob('*.*'): # grabs all files
    tar_path = each_file.parent.parent
    each_file.rename(tar_path.joinpath(each_file.name))
os.remove(src_path)
In [94]:
'''Split the training and validation data.'''

data = os.path.join(dir, 'preXrayData')
splitfolders.ratio(data, output='postXrayData', seed=1337, ratio=(0.8,0.2))
Copying files: 20578 files [00:10, 1920.93 files/s]

The Model¶

InceptionV3 is an image recognition model that has been pre-trained on the ImageNet dataset, obtaining an accuracy of greater than 78.1%. Based on the paper Rethinking the Inception Architecture for Computer Vision by Szegedy et. al [2], the aim of the model was to demonstrate that although increased model size and computational cost tend to translate to immediate quality gains for most tasks - as long as enough labeled data is provided for training - computational efficiency and low parameter count are also valuable enabling factors. Thus, InceptionV3 focused on scaling up previous Inception iterations as efficiently as possible by using suitably factorized convolutions and aggressive regularization.

Why InceptionV3?¶
No description has been provided for this image
Why Transfer Learning?¶

Transfer learning is a method to overcome insufficient data and/or training resources by adding a head or final layer to a pre-trained network (replacing the original classifier). This involves taking a pre-trained model, extracting one of the layers, and using it as the input layer to a series of dense layers. A dense layer is a simple layer of neurons in which each neuron recieves input from all the neurons of the previous layer, thus caputuring complex patterns.

Model Architecture¶
No description has been provided for this image
In [6]:
'''Image transformation and normalization so the input data matches the model's expectations.'''

train_pathname = os.path.join(dir, 'postXrayData/train') # Contains 80% of the data
val_pathname = os.path.join(dir, 'postXrayData/val')     # Contains 20% of the data

# Transform pipeline to ensure input data is prepared in the same way as the original training data for the model
train_transform = transforms.Compose([
        transforms.RandomResizedCrop(299),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # Images are greyscale
])
val_transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
])

train_data = datasets.ImageFolder(train_pathname, train_transform)
val_data = datasets.ImageFolder(val_pathname, val_transform)

batch_size = 8 # Mini-batches of 8 samples at a time

# Dataloaders pass the data through the transform pipelines
train_loader = DataLoader(train_data, 
                          batch_size=batch_size, 
                          shuffle=True)
val_loader = DataLoader(val_data, 
                        batch_size=batch_size, 
                        shuffle=True)
In [95]:
'''Display labelled, transformed images.'''

def imshow(img, title):
    img = torchvision.utils.make_grid(img, normalize=True)
    npimg = img.numpy()
    fig = plt.figure(figsize=(10, 30))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title(title)
    plt.axis('off')
    plt.show()

dataiter = iter(train_loader)
images, labels = next(dataiter)

imshow(images, [train_data.classes[i] for i in labels])
No description has been provided for this image
In [9]:
'''Define the model and parameters.'''

model = models.inception_v3(pretrained=True)

for param in model.parameters():
    param.requires_grad = False

# Define the custom final layer
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Dropout(p=0.5),
    nn.Linear(64, 3),
    nn.Softmax()
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device: " + str(device))
model = model.to(device)

num_epochs = 15     # The number of complete passes through the training dataset
batch_size = 8      # The number of samples processed before the model is updated
learn_rate = 0.001  # Determines the step size while moving toward a min of a loss function
UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
Device: cuda:0
In [10]:
# Use the cross-entropy loss function to measure loss (the difference between the discovered and predicted distribution)
criterion = torch.nn.CrossEntropyLoss() 
No description has been provided for this image
In [11]:
# Adam optimization is used in place of classical stochastic gradient descent to update the network weights (speeds up gradient descent by considering the exponenentially weighted average of the gradients)
optimizer = torch.optim.Adam(model.parameters(), 
                             lr=learn_rate, 
                             weight_decay=learn_rate / num_epochs)
No description has been provided for this image
In [8]:
'''Function to plot the accuracy and loss while training.'''

def plot_accuracy_loss(train_losses, val_losses, train_acc, val_acc, num_epochs):
    sns.set_style('darkgrid')
    sns.set_palette("mako")
    fig, ax = plt.subplots(1, 2, figsize=(10, 5), tight_layout=True)
    data = {'Epochs': list(range(1, num_epochs + 1)), 'Train Loss': train_losses, 'Val Loss': val_losses,
            'Train Accuracy': train_acc, 'Val Accuracy': val_acc}
    df = pd.DataFrame(data)
    
    # Losses
    sns.lineplot(data=df, x="Epochs", y="Train Accuracy", color='C1', ax=ax[0])
    sns.lineplot(data=df, x="Epochs", y="Val Accuracy", color='C4', ax=ax[0])
    ax[0].set_ylabel("Accuracy")
    ax[0].set_xlabel("Epoch")
    ax[0].set_xlim(1, num_epochs)
    ax[0].set_xticks(range(1, num_epochs + 1))
    ax[0].legend(labels=["Train Acc", "Val Acc"])

    # Accuracies
    sns.lineplot(data=df, x="Epochs", y="Train Loss", color='C1', ax=ax[1])
    sns.lineplot(data=df, x="Epochs", y="Val Loss", color='C4', ax=ax[1])
    ax[1].set_ylabel("Loss")
    ax[1].set_xlabel("Epoch")
    ax[1].set_xlim(1, num_epochs)
    ax[1].set_xticks(range(1, num_epochs + 1))
    ax[1].legend(labels=["Train Loss", "Val Loss"])

    fig.suptitle('Inception-v3')
    plt.show()
In [9]:
'''Function to train the model.'''

def train_model(model, criterion, optimizer, num_epochs, train_loader, val_loader,
                device, dataset_sizes, class_names):
    since = time.time()

    train_losses = []
    val_losses = []
    train_acc = []
    val_acc = []

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                loader = train_loader
            else:
                model.eval()
                loader = val_loader

            running_loss = 0.0
            running_true = 0

            y_true, y_pred = [], []

            # Iterate over the data
            for inputs, labels in loader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad() # Sets the gradients of all optimized tensors to zero

                # Only track history when training
                with torch.set_grad_enabled(phase == 'train'):
                    if phase == 'train':
                        # Forward pass --> calculating loss from model outputs
                        outputs, aux_outputs = model(inputs)

                        # Inception has an auxiliary output --> while training we calculate loss by summing final & auxiliary output
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4 * loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)
                    y_true.append(labels)
                    y_pred.append(preds)

                    # Backward pass --> updating weights with Adam optimizer
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    # Some stats
                    running_loss += loss.item() * inputs.size(0)
                    running_true += torch.sum(preds == labels.data)

            # More stats
            y_true, y_pred = torch.cat(y_true), torch.cat(y_pred)
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_true.double() / dataset_sizes[phase]
            total_loss = np.sum(running_loss) / y_true.size(0)
            if phase == 'train':
                train_losses.append(epoch_loss)
                train_acc.append(epoch_acc.item())
            if phase == 'val':
                val_losses.append(epoch_loss)
                val_acc.append(epoch_acc.item())

            # Output stats each epoch
            print('{} loss: {:.4f}, {} accuracy: {:.4f}'.format(phase, total_loss, phase, epoch_acc))
            precision = precision_score(y_true.cpu(), y_pred.cpu(), average='macro')
            recall = recall_score(y_true.cpu(), y_pred.cpu(), average='macro')
            print('{} precision: {:.4f}, {} recall: {:.4f}'.format(phase, precision, phase, recall))
            f1 = f1_score(y_true.cpu(), y_pred.cpu(), average='macro')
            print('{} F1 score: {:.4f}'.format(phase, f1))
            print('Support: ')
            print(classification_report(y_true.cpu(), y_pred.cpu(), target_names=class_names))
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print()
    print('Train Lists')
    print(train_losses)
    print(train_acc)
    print('Val Lists')
    print(val_losses)
    print(val_acc)

    # Plot the accuracy and loss
    plot_accuracy_loss(train_losses, val_losses, train_acc, val_acc, num_epochs)

    return model
In [10]:
'''Train the model.'''

dataset_sizes = {'train': len(train_data), 'val': len(val_data)}
class_names = train_data.classes

model = train_model(model, criterion, optimizer, 
                    num_epochs, train_loader, val_loader, 
                    device, dataset_sizes, class_names)
Epoch 1/15
----------
UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.

  train loss: 4.2332, train accuracy: 0.6474
  train precision: 0.6456, train recall: 0.6474
  train F1 score: 0.6453
  Support: 
                precision    recall  f1-score   support
  
          covid       0.68      0.75      0.71      3999
         normal       0.63      0.57      0.60      4000
      pneumonia       0.63      0.62      0.62      4000
  
       accuracy                           0.65     11999
      macro avg       0.65      0.65      0.65     11999
   weighted avg       0.65      0.65      0.65     11999
  
  val loss: 0.7320, val accuracy: 0.8187
  val precision: 0.8379, val recall: 0.8187
  val F1 score: 0.8175
  Support: 
                precision    recall  f1-score   support
  
         covid       0.81      0.95      0.88      1000
        normal       0.72      0.81      0.77      1000
     pneumonia       0.97      0.69      0.81      1000
  
      accuracy                           0.82      3000
     macro avg       0.84      0.82      0.82      3000
  weighted avg       0.84      0.82      0.82      3000
  
  

Epoch 5/15
----------
UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.


train loss: 4.1726, train accuracy: 0.7157
train precision: 0.7149, train recall: 0.7157
train F1 score: 0.7144
Support: 
              precision    recall  f1-score   support

       covid       0.73      0.80      0.76      3999
      normal       0.70      0.65      0.68      4000
   pneumonia       0.71      0.69      0.70      4000

    accuracy                           0.72     11999
   macro avg       0.71      0.72      0.71     11999
weighted avg       0.71      0.72      0.71     11999

val loss: 0.7137, val accuracy: 0.8373
val precision: 0.8505, val recall: 0.8373
val F1 score: 0.8363
Support: 
              precision    recall  f1-score   support

       covid       0.76      0.97      0.86      1000
      normal       0.84      0.75      0.79      1000
   pneumonia       0.95      0.79      0.86      1000

    accuracy                           0.84      3000
   macro avg       0.85      0.84      0.84      3000
weighted avg       0.85      0.84      0.84      3000




Epoch 10/15
----------
UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
train loss: 4.1572, train accuracy: 0.7371
train precision: 0.7363, train recall: 0.7372
train F1 score: 0.7365
Support: 
              precision    recall  f1-score   support

       covid       0.77      0.80      0.79      3999
      normal       0.71      0.68      0.70      4000
   pneumonia       0.72      0.73      0.73      4000

    accuracy                           0.74     11999
   macro avg       0.74      0.74      0.74     11999
weighted avg       0.74      0.74      0.74     11999

val loss: 0.7836, val accuracy: 0.7620
val precision: 0.8357, val recall: 0.7620
val F1 score: 0.7617
Support: 
              precision    recall  f1-score   support

       covid       0.92      0.80      0.85      1000
      normal       0.60      0.95      0.74      1000
   pneumonia       0.99      0.54      0.70      1000

    accuracy                           0.76      3000
   macro avg       0.84      0.76      0.76      3000
weighted avg       0.84      0.76      0.76      3000




Epoch 15/15
----------
UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
train loss: 4.1443, train accuracy: 0.7475
train precision: 0.7470, train recall: 0.7475
train F1 score: 0.7469
Support: 
              precision    recall  f1-score   support

       covid       0.77      0.81      0.79      3999
      normal       0.72      0.71      0.72      4000
   pneumonia       0.75      0.72      0.73      4000

    accuracy                           0.75     11999
   macro avg       0.75      0.75      0.75     11999
weighted avg       0.75      0.75      0.75     11999

val loss: 0.6864, val accuracy: 0.8590
val precision: 0.8616, val recall: 0.8590
val F1 score: 0.8595
Support: 
              precision    recall  f1-score   support

       covid       0.92      0.84      0.88      1000
      normal       0.80      0.84      0.82      1000
   pneumonia       0.87      0.90      0.88      1000

    accuracy                           0.86      3000
   macro avg       0.86      0.86      0.86      3000
weighted avg       0.86      0.86      0.86      3000


Training complete in 129m 10s

Train Lists
[4.233225082772922, 4.184903535199907, 4.167400298351864, 4.1570234450909265, 4.1726095845912115, 4.1639086500546645, 4.159437492555634, 4.164028998881303, 4.155271156558057, 4.157233690180295, 4.146512436125614, 4.153732038953025, 4.144367789832718, 4.150911634311506, 4.144309521059223]
[0.6473872822735227, 0.7044753729477456, 0.7263105258771564, 0.7295607967330611, 0.7157263105258771, 0.7283940328360696, 0.7318943245270438, 0.7298108175681306, 0.7398116509709142, 0.7371447620635052, 0.7368947412284357, 0.7402283523626968, 0.7461455121260104, 0.739311609300775, 0.747478956579715]


Val Lists
[0.7319899916648864, 0.7481962094306945, 0.7327691181500753, 0.7089627049763998, 0.713724932829539, 0.6956825722058614, 0.6898001863161722, 0.7157928222020468, 0.6971481850941976, 0.7836368376413981, 0.6745693442026774, 0.7573272180557251, 0.7451883797645569, 0.7283474817276001, 0.6863970088958741]
[0.8186666666666667, 0.7986666666666666, 0.8136666666666666, 0.8383333333333333, 0.8373333333333333, 0.852, 0.8583333333333333, 0.8316666666666667, 0.851, 0.762, 0.873, 0.7889999999999999, 0.7999999999999999, 0.8173333333333334, 0.859]


No description has been provided for this image

Results¶

After running for approximately 15 epochs and 2 hours and 15 minutes on an NVIDIA GeForce GTX 1050 Ti GPU, the model achieved the following validation scores:

  1. Accuracy (# correct labels): 85.90%
  2. Loss (cross-entropy): 68.64%
  3. Precision (TP / TP + FP): 86.16%
  4. Recall (TP / TP + FN): 85.90%
  5. F1-Score (avg of precision & recall): 85.95%
No description has been provided for this image
In [140]:
'''Visualize some of the model predictions.'''

def get_image(path):
    '''Display images.'''
    with open(os.path.abspath(path), 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

classes = ["Covid", "Normal", "Pneumonia"]
pathname = os.path.join(dir, 'LIMEImages')
input_imgs = datasets.ImageFolder(pathname, val_transform)

val_loader = DataLoader(input_imgs, 
                        batch_size=batch_size, 
                        shuffle=False)

dataiter = iter(val_loader)
images, labels = next(dataiter)

outputs = model(images.cuda())
_, predicted = torch.max(outputs.data, 1)
probs = F.softmax(outputs, dim=1).cpu().detach().numpy()

img_pred_labels = []
for j in range(3):
    img_pred_labels.append(classes[predicted[j]])

img_paths = ['LIMEImages/covid/COVID-12.png', 
             'LIMEImages/normal/IM-0031-0001.jpeg', 
             'LIMEImages/pneumonia/person1_virus_12.jpeg']

plt.figure()
f, axarr = plt.subplots(1,3,figsize=(15,15))
f.subplots_adjust(wspace=0.5)

for n in range(3):
    axarr[n].imshow(get_image(os.path.join(dir, img_paths[n])))
    axarr[n].xaxis.set_tick_params(labelbottom=False)
    axarr[n].yaxis.set_tick_params(labelleft=False)
    axarr[n].set_xticks([])
    axarr[n].set_yticks([])
    axarr[n].set_title("Predicted label:{}\nTrue label:{}".format(classes[predicted[n]], classes[n]))
    
No description has been provided for this image

Model Impact¶

Explainable AI helps characterize model accuracy, fairness, transparency, and outcomes, and is crucial in building trust and confidence when putting AI models into production. This becomes exceedingly important as the model is expected to predict an outcome in sensitive life-changing situations.

Deployment¶

AI-enabled automation is often portrayed as a binary on-or-off process that is either automated or not. In the real world, however, automation is a spectrum, and the team deploying the model must choose where on the spectrum to operate. In the specific case of diagnosing patients with COVID-19 from X-rays, the following are the deployment options:

  1. Human only: No AI involved
  2. Shadow mode: A human doctor reads an X-ray and decides on a diagnosis, but an AI system shadows the docto with its own attempt.
  3. AI assistance: A humane doctor is responsible for the diagnosis, but the AI system may supply suggestions. For example, it can highlight areas of an X-ray for the doctor to focus on.
  4. Partial automation: An AI system looks at an X-ray image and, if it has high confidence in its decision, renders a diagnosis. In cases where it's not confident, it asks a human to make the final decision.
  5. Full automation: AI makes the diagnosis
No description has been provided for this image
Model Interpretation¶

Local Interpretable Model-Agnostic Explanations (LIME) can help us to understand the reasons behind predictions, and can help place the trained model in shadow mode or AI assistance on the automation spectrum.

No description has been provided for this image
In [142]:
def get_image(path):
    with open(os.path.abspath(path), 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB') 
        
img = get_image(os.path.join(dir, img_paths[0]))
plt.imshow(img)
Out[142]:
No description has been provided for this image
In [134]:
def get_pil_transform(): 
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224)
    ])    

    return transf   

pill_transf = get_pil_transform()
In [131]:
def get_preprocess_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])     
    transf = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])    

    return transf 

preprocess_transform = get_preprocess_transform()

def batch_predict(images):
    model.eval()
    batch = torch.stack(tuple(preprocess_transform(i) for i in images), dim=0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    batch = batch.to(device)
    
    logits = model(batch)
    probs = F.softmax(logits, dim=1)
    return probs.detach().cpu().numpy()
In [136]:
from lime import lime_image

explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(np.array(pill_transf(img)), 
                                         batch_predict,
                                         top_labels=3,
                                         hide_color=0,
                                         num_samples=1000)
In [137]:
from skimage.segmentation import mark_boundaries

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=3, hide_rest=False)
img_boundry1 = mark_boundaries(temp/255.0, mask)
plt.imshow(img_boundry1)
Out[137]:
No description has been provided for this image
In [148]:
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=5, hide_rest=False)
img_boundry2 = mark_boundaries(temp/255.0, mask)
plt.imshow(img_boundry2)
Out[148]:
No description has been provided for this image
In [139]:
ind = explanation.top_labels[0]
dict_heatmap = dict(explanation.local_exp[ind])
heatmap = np.vectorize(dict_heatmap.get)(explanation.segments)

plt.imshow(heatmap, cmap = 'RdBu', vmin = -heatmap.max(), vmax = heatmap.max())
Out[139]:
No description has been provided for this image

Future Work¶

  1. Finetune the model
  2. Experiment with adding adding multiple layers
  3. Remove text from X-Ray images (to reduce bias)

Citations¶

  1. Badawi, A.; Elgazzar, K. Detecting Coronavirus from Chest X-rays Using Transfer Learning. COVID 2021, 1, 403-415. https://doi.org/10.3390/covid1010034
  2. Szegedy, Christian, et al. "Rethinking the inception architecture for computer vision." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. https://arxiv.org/abs/1512.00567