COVID-19 Detection from Chest X-rays Using Transfer Learning¶
Problem Statement¶
Rapid detection of COVID-19 cases is critical to containing the virus as well as decreasing pressure on an overflowing healthcare system. Detection of COVID-19 has a lengthy clinical testing time, thus imaging tools, such as Chest X-rays, are a key instrument in helping to detect COVID-19 and speeding up the identification process. In this notebook, we train a deep convolutional neural network using transfer learning on a novel dataset of 15,000 chest X-rays to aid in the detection of COVID-19. Further, we implement local interpretable model-agnostic explanations (LIME) to provide insight into the model's predictions and enabling healthcare providers to make both time-efficient and correct diagnosis.
'''Load PyTorch and necessary libraries.'''
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
from torchvision import models
import torchvision.transforms as transforms
import os
from pathlib import Path
import Augmentor
from IPython.display import Image, display
import splitfolders
import matplotlib.pyplot as plt
import numpy as np
import time
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
import seaborn as sns
import pandas as pd
from PIL import Image
dir = os.getcwd()
Data Collection and Augmentation¶
We use chest X-ray images from 3 categories: normal, pneumonia, and COVID-19 cases. The dataset, containing 5000 normal images, 5000 pneumonia images, and 4420 COVID-19 images was collected from eleven different publicly available datasets by Badwini et. al. [1] Using data augmentation with the Augmentor package, we generate an additional 580 COVID-19 images to construct a more balanced dataset.
dir = os.getcwd()
p = Augmentor.Pipeline(os.path.join(dir, 'preXrayData/covid'))
p.rotate(probability=0.7, max_left_rotation=10, max_right_rotation=10) # Rotate 70% of the images between 10 and -10 degrees
p.random_distortion(probability=1, grid_width=4, grid_height=4, magnitude=8) # Randomly distort the images while maintaining their aspect ratio
p.flip_left_right(probability=1) # Mirror the images from left to right
p.process()
p.sample(580)
Initialised with 4999 image(s) found.
Processing <PIL.Image.Image image mode=L size=3050x2539 at 0x7F103893BF40>: 100%|██████████| 4999/4999 [06:11<00:00, 13.46 Samples/s] Processing <PIL.Image.Image image mode=L size=299x299 at 0x7F0FFC56D810>: 100%|██████████| 580/580 [00:49<00:00, 11.61 Samples/s]
'''Visualize augmented image.'''
aug_img_path = 'preXrayData/covid/covid_original_COVID-597.png_b8c0f5bb-77e0-4749-9359-d746af324bcb.png'
display(Image.open(aug_img_path))
'''Move additional generated images into correct folder.'''
src_path = os.path.join(dir, 'preXrayData/covid/output')
for each_file in Path(src_path).glob('*.*'): # grabs all files
tar_path = each_file.parent.parent
each_file.rename(tar_path.joinpath(each_file.name))
os.remove(src_path)
'''Split the training and validation data.'''
data = os.path.join(dir, 'preXrayData')
splitfolders.ratio(data, output='postXrayData', seed=1337, ratio=(0.8,0.2))
Copying files: 20578 files [00:10, 1920.93 files/s]
The Model¶
InceptionV3 is an image recognition model that has been pre-trained on the ImageNet dataset, obtaining an accuracy of greater than 78.1%. Based on the paper Rethinking the Inception Architecture for Computer Vision by Szegedy et. al [2], the aim of the model was to demonstrate that although increased model size and computational cost tend to translate to immediate quality gains for most tasks - as long as enough labeled data is provided for training - computational efficiency and low parameter count are also valuable enabling factors. Thus, InceptionV3 focused on scaling up previous Inception iterations as efficiently as possible by using suitably factorized convolutions and aggressive regularization.
Why InceptionV3?¶
Why Transfer Learning?¶
Transfer learning is a method to overcome insufficient data and/or training resources by adding a head or final layer to a pre-trained network (replacing the original classifier). This involves taking a pre-trained model, extracting one of the layers, and using it as the input layer to a series of dense layers. A dense layer is a simple layer of neurons in which each neuron recieves input from all the neurons of the previous layer, thus caputuring complex patterns.
Model Architecture¶
'''Image transformation and normalization so the input data matches the model's expectations.'''
train_pathname = os.path.join(dir, 'postXrayData/train') # Contains 80% of the data
val_pathname = os.path.join(dir, 'postXrayData/val') # Contains 20% of the data
# Transform pipeline to ensure input data is prepared in the same way as the original training data for the model
train_transform = transforms.Compose([
transforms.RandomResizedCrop(299),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # Images are greyscale
])
val_transform = transforms.Compose([
transforms.Resize((299, 299)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_data = datasets.ImageFolder(train_pathname, train_transform)
val_data = datasets.ImageFolder(val_pathname, val_transform)
batch_size = 8 # Mini-batches of 8 samples at a time
# Dataloaders pass the data through the transform pipelines
train_loader = DataLoader(train_data,
batch_size=batch_size,
shuffle=True)
val_loader = DataLoader(val_data,
batch_size=batch_size,
shuffle=True)
'''Display labelled, transformed images.'''
def imshow(img, title):
img = torchvision.utils.make_grid(img, normalize=True)
npimg = img.numpy()
fig = plt.figure(figsize=(10, 30))
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.title(title)
plt.axis('off')
plt.show()
dataiter = iter(train_loader)
images, labels = next(dataiter)
imshow(images, [train_data.classes[i] for i in labels])
'''Define the model and parameters.'''
model = models.inception_v3(pretrained=True)
for param in model.parameters():
param.requires_grad = False
# Define the custom final layer
model.fc = nn.Sequential(
nn.Linear(model.fc.in_features, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(64, 3),
nn.Softmax()
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device: " + str(device))
model = model.to(device)
num_epochs = 15 # The number of complete passes through the training dataset
batch_size = 8 # The number of samples processed before the model is updated
learn_rate = 0.001 # Determines the step size while moving toward a min of a loss function
UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
Device: cuda:0
# Use the cross-entropy loss function to measure loss (the difference between the discovered and predicted distribution)
criterion = torch.nn.CrossEntropyLoss()
# Adam optimization is used in place of classical stochastic gradient descent to update the network weights (speeds up gradient descent by considering the exponenentially weighted average of the gradients)
optimizer = torch.optim.Adam(model.parameters(),
lr=learn_rate,
weight_decay=learn_rate / num_epochs)
'''Function to plot the accuracy and loss while training.'''
def plot_accuracy_loss(train_losses, val_losses, train_acc, val_acc, num_epochs):
sns.set_style('darkgrid')
sns.set_palette("mako")
fig, ax = plt.subplots(1, 2, figsize=(10, 5), tight_layout=True)
data = {'Epochs': list(range(1, num_epochs + 1)), 'Train Loss': train_losses, 'Val Loss': val_losses,
'Train Accuracy': train_acc, 'Val Accuracy': val_acc}
df = pd.DataFrame(data)
# Losses
sns.lineplot(data=df, x="Epochs", y="Train Accuracy", color='C1', ax=ax[0])
sns.lineplot(data=df, x="Epochs", y="Val Accuracy", color='C4', ax=ax[0])
ax[0].set_ylabel("Accuracy")
ax[0].set_xlabel("Epoch")
ax[0].set_xlim(1, num_epochs)
ax[0].set_xticks(range(1, num_epochs + 1))
ax[0].legend(labels=["Train Acc", "Val Acc"])
# Accuracies
sns.lineplot(data=df, x="Epochs", y="Train Loss", color='C1', ax=ax[1])
sns.lineplot(data=df, x="Epochs", y="Val Loss", color='C4', ax=ax[1])
ax[1].set_ylabel("Loss")
ax[1].set_xlabel("Epoch")
ax[1].set_xlim(1, num_epochs)
ax[1].set_xticks(range(1, num_epochs + 1))
ax[1].legend(labels=["Train Loss", "Val Loss"])
fig.suptitle('Inception-v3')
plt.show()
'''Function to train the model.'''
def train_model(model, criterion, optimizer, num_epochs, train_loader, val_loader,
device, dataset_sizes, class_names):
since = time.time()
train_losses = []
val_losses = []
train_acc = []
val_acc = []
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch + 1, num_epochs))
print('-' * 10)
for phase in ['train', 'val']:
if phase == 'train':
model.train()
loader = train_loader
else:
model.eval()
loader = val_loader
running_loss = 0.0
running_true = 0
y_true, y_pred = [], []
# Iterate over the data
for inputs, labels in loader:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad() # Sets the gradients of all optimized tensors to zero
# Only track history when training
with torch.set_grad_enabled(phase == 'train'):
if phase == 'train':
# Forward pass --> calculating loss from model outputs
outputs, aux_outputs = model(inputs)
# Inception has an auxiliary output --> while training we calculate loss by summing final & auxiliary output
loss1 = criterion(outputs, labels)
loss2 = criterion(aux_outputs, labels)
loss = loss1 + 0.4 * loss2
else:
outputs = model(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
y_true.append(labels)
y_pred.append(preds)
# Backward pass --> updating weights with Adam optimizer
if phase == 'train':
loss.backward()
optimizer.step()
# Some stats
running_loss += loss.item() * inputs.size(0)
running_true += torch.sum(preds == labels.data)
# More stats
y_true, y_pred = torch.cat(y_true), torch.cat(y_pred)
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_true.double() / dataset_sizes[phase]
total_loss = np.sum(running_loss) / y_true.size(0)
if phase == 'train':
train_losses.append(epoch_loss)
train_acc.append(epoch_acc.item())
if phase == 'val':
val_losses.append(epoch_loss)
val_acc.append(epoch_acc.item())
# Output stats each epoch
print('{} loss: {:.4f}, {} accuracy: {:.4f}'.format(phase, total_loss, phase, epoch_acc))
precision = precision_score(y_true.cpu(), y_pred.cpu(), average='macro')
recall = recall_score(y_true.cpu(), y_pred.cpu(), average='macro')
print('{} precision: {:.4f}, {} recall: {:.4f}'.format(phase, precision, phase, recall))
f1 = f1_score(y_true.cpu(), y_pred.cpu(), average='macro')
print('{} F1 score: {:.4f}'.format(phase, f1))
print('Support: ')
print(classification_report(y_true.cpu(), y_pred.cpu(), target_names=class_names))
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print()
print('Train Lists')
print(train_losses)
print(train_acc)
print('Val Lists')
print(val_losses)
print(val_acc)
# Plot the accuracy and loss
plot_accuracy_loss(train_losses, val_losses, train_acc, val_acc, num_epochs)
return model
'''Train the model.'''
dataset_sizes = {'train': len(train_data), 'val': len(val_data)}
class_names = train_data.classes
model = train_model(model, criterion, optimizer,
num_epochs, train_loader, val_loader,
device, dataset_sizes, class_names)
Epoch 1/15 ----------
UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
train loss: 4.2332, train accuracy: 0.6474 train precision: 0.6456, train recall: 0.6474 train F1 score: 0.6453 Support: precision recall f1-score support covid 0.68 0.75 0.71 3999 normal 0.63 0.57 0.60 4000 pneumonia 0.63 0.62 0.62 4000 accuracy 0.65 11999 macro avg 0.65 0.65 0.65 11999 weighted avg 0.65 0.65 0.65 11999
val loss: 0.7320, val accuracy: 0.8187 val precision: 0.8379, val recall: 0.8187 val F1 score: 0.8175 Support: precision recall f1-score support covid 0.81 0.95 0.88 1000 normal 0.72 0.81 0.77 1000 pneumonia 0.97 0.69 0.81 1000 accuracy 0.82 3000 macro avg 0.84 0.82 0.82 3000 weighted avg 0.84 0.82 0.82 3000
Epoch 5/15 ----------
UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
train loss: 4.1726, train accuracy: 0.7157 train precision: 0.7149, train recall: 0.7157 train F1 score: 0.7144 Support: precision recall f1-score support covid 0.73 0.80 0.76 3999 normal 0.70 0.65 0.68 4000 pneumonia 0.71 0.69 0.70 4000 accuracy 0.72 11999 macro avg 0.71 0.72 0.71 11999 weighted avg 0.71 0.72 0.71 11999
val loss: 0.7137, val accuracy: 0.8373 val precision: 0.8505, val recall: 0.8373 val F1 score: 0.8363 Support: precision recall f1-score support covid 0.76 0.97 0.86 1000 normal 0.84 0.75 0.79 1000 pneumonia 0.95 0.79 0.86 1000 accuracy 0.84 3000 macro avg 0.85 0.84 0.84 3000 weighted avg 0.85 0.84 0.84 3000
Epoch 10/15 ----------
UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
train loss: 4.1572, train accuracy: 0.7371 train precision: 0.7363, train recall: 0.7372 train F1 score: 0.7365 Support: precision recall f1-score support covid 0.77 0.80 0.79 3999 normal 0.71 0.68 0.70 4000 pneumonia 0.72 0.73 0.73 4000 accuracy 0.74 11999 macro avg 0.74 0.74 0.74 11999 weighted avg 0.74 0.74 0.74 11999
val loss: 0.7836, val accuracy: 0.7620 val precision: 0.8357, val recall: 0.7620 val F1 score: 0.7617 Support: precision recall f1-score support covid 0.92 0.80 0.85 1000 normal 0.60 0.95 0.74 1000 pneumonia 0.99 0.54 0.70 1000 accuracy 0.76 3000 macro avg 0.84 0.76 0.76 3000 weighted avg 0.84 0.76 0.76 3000
Epoch 15/15 ----------
UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
train loss: 4.1443, train accuracy: 0.7475 train precision: 0.7470, train recall: 0.7475 train F1 score: 0.7469 Support: precision recall f1-score support covid 0.77 0.81 0.79 3999 normal 0.72 0.71 0.72 4000 pneumonia 0.75 0.72 0.73 4000 accuracy 0.75 11999 macro avg 0.75 0.75 0.75 11999 weighted avg 0.75 0.75 0.75 11999
val loss: 0.6864, val accuracy: 0.8590 val precision: 0.8616, val recall: 0.8590 val F1 score: 0.8595 Support: precision recall f1-score support covid 0.92 0.84 0.88 1000 normal 0.80 0.84 0.82 1000 pneumonia 0.87 0.90 0.88 1000 accuracy 0.86 3000 macro avg 0.86 0.86 0.86 3000 weighted avg 0.86 0.86 0.86 3000 Training complete in 129m 10s Train Lists [4.233225082772922, 4.184903535199907, 4.167400298351864, 4.1570234450909265, 4.1726095845912115, 4.1639086500546645, 4.159437492555634, 4.164028998881303, 4.155271156558057, 4.157233690180295, 4.146512436125614, 4.153732038953025, 4.144367789832718, 4.150911634311506, 4.144309521059223] [0.6473872822735227, 0.7044753729477456, 0.7263105258771564, 0.7295607967330611, 0.7157263105258771, 0.7283940328360696, 0.7318943245270438, 0.7298108175681306, 0.7398116509709142, 0.7371447620635052, 0.7368947412284357, 0.7402283523626968, 0.7461455121260104, 0.739311609300775, 0.747478956579715] Val Lists [0.7319899916648864, 0.7481962094306945, 0.7327691181500753, 0.7089627049763998, 0.713724932829539, 0.6956825722058614, 0.6898001863161722, 0.7157928222020468, 0.6971481850941976, 0.7836368376413981, 0.6745693442026774, 0.7573272180557251, 0.7451883797645569, 0.7283474817276001, 0.6863970088958741] [0.8186666666666667, 0.7986666666666666, 0.8136666666666666, 0.8383333333333333, 0.8373333333333333, 0.852, 0.8583333333333333, 0.8316666666666667, 0.851, 0.762, 0.873, 0.7889999999999999, 0.7999999999999999, 0.8173333333333334, 0.859]
Results¶
After running for approximately 15 epochs and 2 hours and 15 minutes on an NVIDIA GeForce GTX 1050 Ti GPU, the model achieved the following validation scores:
- Accuracy (# correct labels): 85.90%
- Loss (cross-entropy): 68.64%
- Precision (TP / TP + FP): 86.16%
- Recall (TP / TP + FN): 85.90%
- F1-Score (avg of precision & recall): 85.95%
'''Visualize some of the model predictions.'''
def get_image(path):
'''Display images.'''
with open(os.path.abspath(path), 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
classes = ["Covid", "Normal", "Pneumonia"]
pathname = os.path.join(dir, 'LIMEImages')
input_imgs = datasets.ImageFolder(pathname, val_transform)
val_loader = DataLoader(input_imgs,
batch_size=batch_size,
shuffle=False)
dataiter = iter(val_loader)
images, labels = next(dataiter)
outputs = model(images.cuda())
_, predicted = torch.max(outputs.data, 1)
probs = F.softmax(outputs, dim=1).cpu().detach().numpy()
img_pred_labels = []
for j in range(3):
img_pred_labels.append(classes[predicted[j]])
img_paths = ['LIMEImages/covid/COVID-12.png',
'LIMEImages/normal/IM-0031-0001.jpeg',
'LIMEImages/pneumonia/person1_virus_12.jpeg']
plt.figure()
f, axarr = plt.subplots(1,3,figsize=(15,15))
f.subplots_adjust(wspace=0.5)
for n in range(3):
axarr[n].imshow(get_image(os.path.join(dir, img_paths[n])))
axarr[n].xaxis.set_tick_params(labelbottom=False)
axarr[n].yaxis.set_tick_params(labelleft=False)
axarr[n].set_xticks([])
axarr[n].set_yticks([])
axarr[n].set_title("Predicted label:{}\nTrue label:{}".format(classes[predicted[n]], classes[n]))
Model Impact¶
Explainable AI helps characterize model accuracy, fairness, transparency, and outcomes, and is crucial in building trust and confidence when putting AI models into production. This becomes exceedingly important as the model is expected to predict an outcome in sensitive life-changing situations.
Deployment¶
AI-enabled automation is often portrayed as a binary on-or-off process that is either automated or not. In the real world, however, automation is a spectrum, and the team deploying the model must choose where on the spectrum to operate. In the specific case of diagnosing patients with COVID-19 from X-rays, the following are the deployment options:
- Human only: No AI involved
- Shadow mode: A human doctor reads an X-ray and decides on a diagnosis, but an AI system shadows the docto with its own attempt.
- AI assistance: A humane doctor is responsible for the diagnosis, but the AI system may supply suggestions. For example, it can highlight areas of an X-ray for the doctor to focus on.
- Partial automation: An AI system looks at an X-ray image and, if it has high confidence in its decision, renders a diagnosis. In cases where it's not confident, it asks a human to make the final decision.
- Full automation: AI makes the diagnosis
Model Interpretation¶
Local Interpretable Model-Agnostic Explanations (LIME) can help us to understand the reasons behind predictions, and can help place the trained model in shadow mode or AI assistance on the automation spectrum.
def get_image(path):
with open(os.path.abspath(path), 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
img = get_image(os.path.join(dir, img_paths[0]))
plt.imshow(img)
def get_pil_transform():
transf = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224)
])
return transf
pill_transf = get_pil_transform()
def get_preprocess_transform():
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
transf = transforms.Compose([
transforms.ToTensor(),
normalize
])
return transf
preprocess_transform = get_preprocess_transform()
def batch_predict(images):
model.eval()
batch = torch.stack(tuple(preprocess_transform(i) for i in images), dim=0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
batch = batch.to(device)
logits = model(batch)
probs = F.softmax(logits, dim=1)
return probs.detach().cpu().numpy()
from lime import lime_image
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(np.array(pill_transf(img)),
batch_predict,
top_labels=3,
hide_color=0,
num_samples=1000)
from skimage.segmentation import mark_boundaries
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=3, hide_rest=False)
img_boundry1 = mark_boundaries(temp/255.0, mask)
plt.imshow(img_boundry1)
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=5, hide_rest=False)
img_boundry2 = mark_boundaries(temp/255.0, mask)
plt.imshow(img_boundry2)
ind = explanation.top_labels[0]
dict_heatmap = dict(explanation.local_exp[ind])
heatmap = np.vectorize(dict_heatmap.get)(explanation.segments)
plt.imshow(heatmap, cmap = 'RdBu', vmin = -heatmap.max(), vmax = heatmap.max())
Future Work¶
- Finetune the model
- Experiment with adding adding multiple layers
- Remove text from X-Ray images (to reduce bias)
Citations¶
- Badawi, A.; Elgazzar, K. Detecting Coronavirus from Chest X-rays Using Transfer Learning. COVID 2021, 1, 403-415. https://doi.org/10.3390/covid1010034
- Szegedy, Christian, et al. "Rethinking the inception architecture for computer vision." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. https://arxiv.org/abs/1512.00567