For prostate cancer diagnosis using PyTorch and AWS SageMaker data parallelism
Introduction
This post consists of the following parts:
Part 1 is an overview on why AI is positioned to transform the healthcare industry.
Part 2 is an explanation of a machine learning technique called multiple instance learning and why it is suitable for pathology applications.
These serve as a build-up for Part 3 which outlines the implementation of an attention-based deep multiple instance learning model for prostate cancer diagnosis using PyTorch and AWS SageMaker’s data parallelism toolkit.
An abridged version of this article been published in College of American Pathologists Today‘s November 2021 issue: found here.
Part 1 – Why AI is positioned to transform the healthcare industry
Before diving into code, let’s step back and consider why Artificial Intelligence is positioned to transform healthcare.
Much of AI’s momentum today can be attributed to the success of deep neural networks which would not be possible without a perfect storm of the following four driving forces:
- The increasing availability of massive datasets such as ImageNet’s 15 million labeled images, Facebook’s library of billions of images, YouTube’s video library which grows by 300 hours of video per minute, and Tesla’s collection of driving data which adds 1 million miles of data per hour.¹
- The use of graphic processing units (GPUs), and later more AI-specialized hardware called tensor processing units (TPUs), which are optimized for training Deep Learning models. TPUs consist of many cores which enable them to process large amounts of data and perform multiple computations in parallel. A 2018 report by OpenAI proposed that prior to 2012, AI compute growth closely tracked Moore’s law, doubling every two years, and that post-2012, compute has been doubling every three to four months. Overall, since 2012, this compute metric has grown by a factor of more than 300,000 while a two-year doubling period would have only yielded a 16x increase.²
- The availability of cloud computing which has made the ability to store large datasets and use them to train models more accessible and economical.
- Open-source algorithmic development modules such as Facebook’s PyTorch, Google’s TensorFlow, Microsoft’s Cognitive Kit, and others.
This burgeoning of rich resources to fuel AI’s rapid advancement comes at a time when physicians are more overwhelmed than ever. In the United States, the average length of a clinic visit for a new patient clinic visit has dropped from 60 minutes in 1975 to 12 minutes today despite an increase in the number of healthcare employees from 4 million in 1975 to 16 million today.¹ Aside from facing a growing population, doctors have become increasingly inundated with electronic health records, managed care, health maintenance organizations, and relative value units which has diverted their attention from forming meaningful relationships with patients. Burnt-out doctors who are disconnected from their patients are more likely to make judgements ridden with cognitive biases. As a result, they reflexively order incorrect tests and subsequently misinterpret them leading to misdiagnoses. A 2014 review concluded that the United States faces approximately 12 million misdiagnoses per year.³
As Eric Topol, MD, founder and director of the Scripps Research Translational Institute, among others, asserts, the exciting promise of AI lies in using deep and comprehensive collections of relevant patient data in order to improve decision-making, reduce misdiagnosis and unnecessary procedures, guide in the selection and interpretation of tests, and to recommend the safest treatment regimen.¹ However, the extent to which health care becomes infused with AI needs to be tempered by the industry’s inherent need for empathy and connection between clinicians and patients. A physician affords patients a sense of ethics and core values, neither of which can be replicated by a computer. Assertions that AI will soon become sophisticated enough to lead to full automation in other industries has stoked fears among medical professionals about whether human involvement in medicine will become a thing of the past. However, the possibility of AI leading to full automation in medicine is still far off. Those involved in AI have time to strike the right balance between doctors and machines.
An appropriate balance might involve an AI system taking on a digital assistant role in which it alerts the physician with the most probable diagnoses and best courses of action, and leaves the physician responsible for making the final decision. A "human in the loop" approach is in line with Friedman’s Fundamental Theorem which states that a human working in partnership with a computer will always outperform a human working alone, and warrants the need for transparency into how an algorithm arrived at a particular prediction.⁴ Explainable AI __ is a set of processes or methods that provides this insight, and is crucial for building trust among those relying on AI-enabled systems while also ensuring accuracy, fairness, and compliance with regulatory standards. Explainability provides clinicians with quality control and checks and balances, and can help them become more confident about relying on algorithms to make final diagnoses.
Robust models will need to be developed with a satisfactory level of explainability. But if done correctly, the resulting increased workflow and efficiency can afford clinicians more time to connect with patients. Paradoxically, the rise of machines can restore the humanity in medicine, and allow medical professionals to get back in touch with their motivations for pursuing a medical career in the first place.
Part 2 – Artificial intelligence in pathology
One of the most effective applications of AI in healthcare has been in medical imaging. Radiology, pathology, and dermatology are specialties that rely on visual pattern analysis and are, therefore, positioned to undergo a rapid and dramatic transformation due to integration with AI.
Pathologists play a crucial role in diagnosing cancer, and their report helps dictate a patient’s treatment strategy. Typically, pathologists look at hematoxylin and eosin (H&E) stained tissue samples under a microscope, and describe the types of cells they see, how they’re arranged, whether they’re abnormal, and any other features that are important for diagnosis. The practice of using microscopes to examine glass slides containing tissue samples has been largely unchanged for a century. In recent years, however, slides increasingly are being digitized using digital slide scanners to produce whole slide images (WSIs) that can be examined on a computer. Yet, pathologists have been slow to adopt WSIs and other digital techniques, which has resulted in the encroachment of AI into pathology to be slower than expected. Nevertheless, WSIs have laid the groundwork for incorporating neural network image processing into pathology, thereby making a new AI-assisted era in the field imminent.¹

AI can be used to carry out routine workflows that are typically performed by pathologists, such as detection of tumor tissue in biopsy samples and determination of tumor subtype based on morphology, with greater efficiency and accuracy. A major milestone for AI in pathology was the CAMELYON16 challenge which set the goal of developing algorithms to detect metastatic breast cancer in WSIs of lymph node biopsies. **** The dataset provided consists of 400 WSIs in which pathologists manually delineated regions of metastatic cancer, and is one of the largest labeled pathology datasets. This enabled the team with the top submission, whose algorithm performed on par with pathologists, to make use of of supervised learning.⁶
In general, supervised learning is a machine learning approach in which an algorithm is shown a sequence of many input data and their corresponding output labels until it can detect the underlying pattern which reveals the relationship between these inputs and outputs. This technique allows it to accurately label data which it hasn’t seen before, and can be used for either classification (sorts an input into a given number of categories) or regression (given an input, predicts a target numeric value) tasks.
A major downside of supervised learning is that it typically requires the training dataset to be hand-labeled by domain experts. This is particularly the case when dealing with WSIs: for scale, roughly 470 pathology images contain approximately the same amount of pixels as the entire ImageNet dataset. Additionally, despite the CAMELYON16 dataset being one of the largest in pathology, 400 WSIs are not enough to capture the wide variety of cases seen on a regular basis in the clinic. It would therefore be highly expensive and time-consuming to obtain an appropriately large dataset whose billions of pixels would also be computationally demanding to process for training. As a result, designing supervised learning models for everyday use in pathology is highly impractical.⁷
Multiple instance learning (MIL) and its suitability for pathology applications
MIL is a variation of supervised learning that is more suitable to pathology applications. The technique involves assigning a single class label to a collection of inputs – in this context, referred to as a bag of instances. While it is assumed that labels exist for each instance within a bag, there is no access to those labels and they remain unknown during training. A bag is typically labeled as negative if all instances in the bag are negative, or positive if there is at least one positive instance (known as the standard MIL assumption). A simple example is shown in the figure below in which we only know whether a keychain contains the key that can open a given door. This allows us to infer that the green key can open the door.
![Simplified illustration of multiple instance learning using keychains (figure by author - inspired by [source])](https://towardsdatascience.com/wp-content/uploads/2021/06/12dHiAk7NnBNh-jC18Q1X6A.png)
The MIL formulation naturally fits the task of imaging-based patient diagnosis: analogous to the standard MIL assumption, diseased tissue samples have both abnormal and healthy regions, while healthy tissue samples only have healthy regions. As a result, it is possible to divide WSIs into tiles where each collection of tiles can be labeled as either "malignant" or "diseased." These weak labels are significantly easier to obtain than strong labels (i.e. outlines of diseased areas provided manually by an expert), and therefore save pathologists from the painstaking task of having to annotate WSIs themselves. MIL models can also be made to be highly interpretable which caters to the explainability requirement of AI systems in healthcare discussed earlier.⁸
Furthermore, a particularly exciting part about the use of MIL in pathology is that it can be integrated into a deep learning model which allows for the creation of a smooth end-to-end pipeline in which WSIs are fed in as inputs and diagnoses are returned as outputs. As a byproduct, deep MIL models can automatically uncover novel, abstract features from WSIs that perform better than traditional features at determining survival, treatment response, and genetic defects. It is remarkable that it is possible to derive these insights directly from H&E slides, which are easily accessible in pathology labs, rather than performing additional tests that may be costly.⁹ Deep MIL forms the basis of Paige.AI Inc’s prostate cancer software which was authorized for use by the FDA in September 2021.¹⁰
Part 3 – Implementation of an attention-based deep MIL model for prostate cancer diagnosis using PyTorch and AWS SageMaker’s data parallelism toolkit
In my previous post, I further discussed the merits of formulating MIL as a deep learning problem. I also outlined the mathematical basis for the model described in Attention-based Deep Multiple Instance Learning (Ilse et al.) which allows for WSI classification using deep MIL.⁸ This model makes use of a modified version of the attention mechanism as its aggregation operator which allows for a greater degree of interpretability than models relying on typical aggregation operators such as mean and max. In other words, this attention-based MIL pooling operator provides insight into the contribution of each instance to the predicted bag label.
Here, we use the dataset provided in the Prostate cANcer graDe Assessment (PANDA) Kaggle Challenge, containing 11,000 WSIs of digitized H&E-stained prostate biopsies, to train an attention-based deep MIL model to diagnose prostate cancer based on the approach in Ilse et al. The goal is to outline how this can be implemented using PyTorch and AWS SageMaker’s data parallelism toolkit.
Dataset
Each tissue sample in the PANDA Challenge dataset is classified into Gleason patterns based on the architectural growth patterns of the tumor, and a corresponding ISUP grade on a scale of 1–5. Gleason scores are determined based on the extent to which white branched cavities, or glandular tissue, persist throughout the tissue sample. An increased loss in glandular tissue implies greater severity and corresponds to a higher Gleason score. If multiple Gleason patterns are present in a single biopsy, they can be broken down into the most and second most frequently occurring patterns (majority and minority, respectively) as judged by a pathologist.
![Illustration of Gleason grading process for an example prostate cancer biopsy containing prostate (figure by author - inspired by [source])](https://towardsdatascience.com/wp-content/uploads/2021/06/1IO0xDIATnXj29Srtsq1KWA.png)
In order to make use of the dataset in a deep MIL model, I referred to the following Kaggle notebook in order to divide the WSIs into collections of 16x128x128 tiles each. As in Ilse et al., bags are labeled as either malignant or benign. Slides with ISUP grades of 1,2,3,4, or 5 are labelled "malignant", and those with ISUP grades of 0 are labeled "benign." ⁸

Model
In the code below, we implement a modified version of the model used in Ilse et al. which takes into account the dataset described above.
import torch
import torch.nn.functional as F
import torch.nn as nn
class Attention(nn.Module):
def __init__(self):
super(Attention, self).__init__()
self.L = 512 # 512 node fully connected layer
self.D = 128 # 128 node attention layer
self.K = 1
self.feature_extractor_part1 = nn.Sequential(
nn.Conv2d(3, 36, kernel_size=4),
nn.ReLU(),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(36, 48, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(2, stride=2)
)
self.feature_extractor_part2 = nn.Sequential(
nn.Linear(48 * 30 * 30, self.L),
nn.ReLU(),
nn.Dropout(),
nn.Linear(self.L, self.L),
nn.ReLU(),
nn.Dropout()
)
self.attention = nn.Sequential(
nn.Linear(self.L, self.D),
nn.Tanh(),
nn.Linear(self.D, self.K)
)
self.classifier = nn.Sequential(
nn.Linear(self.L * self.K, 1),
nn.Sigmoid()
)
def forward(self, x):
x = x.squeeze(0)
H = self.feature_extractor_part1(x)
H = H.view(-1, 48 * 30 * 30)
H = self.feature_extractor_part2(H)
A = self.attention(H) # NxK
A = torch.transpose(A, 1, 0) # KxN
A = F.softmax(A, dim=1) # softmax over N
M = torch.mm(A, H)
# The probability that a given bag is malignant or benign
Y_prob = self.classifier(M)
# The prediction given the probability (Y_prob >= 0.5 returns a Y_hat of 1 meaning malignant)
Y_hat = torch.ge(Y_prob, 0.5).float()
return Y_prob, Y_hat, A.byte()
Model training using AWS SageMaker data parallelism (SDP)
In general, neural networks are trained by systematically adjusting their parameters in a direction that reduces prediction error. A common technique is stochastic gradient descent in which these parameter changes occur iteratively using equally sized samples called mini-batches. It is possible to speed up training time by evenly distributing mini-batches across a collection of independent machines which each have its own copy of the model, optimizer, and other essentials. Here, we use AWS SageMaker’s data parallelism toolkit which has been shown to achieve superior performance over PyTorch DistributedDataParallel.¹¹
![Schematic illustration of data parallelism (figure by author - inspired by [source])](https://towardsdatascience.com/wp-content/uploads/2021/06/1uVmZpC1aoCx66ARj-RfqbA.png)
SageMaker notebook setup
To prepare for SDP training, we can upload the aforementioned data into an Amazon S3 bucket, and launch a Jupyter notebook instance using SageMaker’s pre-built PyTorch container. For this project, training is initialized by calling the PyTorch estimator from the Amazon SageMaker Python SDK. Notably, we pass the training script, specify the instance count and type, and enable the SDP distribution method, as shown below:
import sagemaker
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
from sagemaker.pytorch import PyTorch
estimator = PyTorch(base_job_name='pytorch-smdataparallel-Histopathology-mil',
source_dir='code',
entry_point='train.py',
role=role,
framework_version='1.8.1',
py_version='py36',
instance_count=2,
instance_type= 'ml.p3.16xlarge',
sagemaker_session=sagemaker_session,
distribution={'smdistributed':{
'dataparallel':{
'enabled': True
}
}
},
debugger_hook_config=False,
volume_size=40)
ml.p3.16xlarge is one of three instance types supported by the SageMaker data parallelism toolkit, and AWS recommends using at least 2 instances to get the best performance and most out of it.¹² One instance of this type contains 8 NVIDIA V100 GPUs, each with 16 GB of memory. Here, this amounts to running 16 independent copies of our model.
We can then fit our PyTorch estimator by passing in the data we uploaded to S3. This imports our data into the local filesystem of the training cluster so that our train.py script can simply read the data from disk.
channels = {
'training': 's3://sagemaker-us-east-1-318322629142/train/',
'testing': 's3://sagemaker-us-east-1-318322629142/test/'
}
estimator.fit(inputs=channels)
Entry point script
In our train.py entry point script, we define our train function as shown below:
def train(model, device, train_loader, optimizer, epoch):
model.train()
train_loss = 0.
train_error = 0.
predictions = []
labels = []
for batch_idx, (data, label) in enumerate(train_loader):
bag_label = label
data = torch.squeeze(data)
data, bag_label = Variable(data), Variable(bag_label)
data, bag_label = data.to(device), bag_label.to(device)
# reset gradients
optimizer.zero_grad()
# calculate error
bag_label = bag_label.float()
Y_prob, Y_hat, _ = model(data)
error = 1. - Y_hat.eq(bag_label).cpu().float().mean().data
train_error += error
# calculate loss
Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5)
loss = -1. * (bag_label * torch.log(Y_prob) + (1. - bag_label) * torch.log(1. - Y_prob))
train_loss += loss.data[0]
# Keep track of predictions and labels to calculate accuracy after each epoch
predictions.append(int(Y_hat))
labels.append(int(bag_label))
# backward pass
loss.backward()
# step
optimizer.step()
# calculate loss and error for epoch
train_loss /= len(train_loader)
train_error /= len(train_loader)
print('Train Set, Epoch: {}, Loss: {:.4f}, Error: {:.4f},
Accuracy: {:.2f}%'.format(epoch, train_loss.cpu().numpy()[0],
train_error, accuracy_score(labels, predictions)*100))
We also create a function to save our model once training has completed:
def save_model(model, model_dir):
with open(os.path.join(model_dir, 'model.pth'), 'wb') as f:
torch.save(model.module.state_dict(), f)
In the main guard, we load our dataset (see repository for details), train over 10 epochs, and save our model:
device = torch.device("cuda")
model = DDP(Attention().to(device))
optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0005)
print('Start Training')
for epoch in range(1, 10 + 1):
train(model, device, train_loader, optimizer, epoch)
save_model(model, args.model_dir)
Deployment, prediction, and evaluation
After training is complete, we can use the PyTorch estimator to deploy an endpoint which runs a SageMaker-provided PyTorch model server and hosts our trained model. In general, deployment is used to perform real-time predictions on a client application, but here we deploy for demonstrative purposes.
import sagemaker
role = sagemaker.get_execution_role()
from sagemaker.pytorch import PyTorchModel
model = PyTorchModel(model_data=model_data, source_dir='code',
entry_point='inference.py', role=role, framework_version='1.6.0', py_version='py3')
We can now use our predictor to predict labels for our test data and determine our accuracy score:
predictions = []
true_labels = []
for batch_idx, (data, label) in enumerate(test_loader):
_, Y_hat, _ = predictor.predict(data)
predictions.append(int(Y_hat))
true_labels.append(int(label))
from sklearn.metrics import accuracy_score
accuracy_score(true_labels, predictions)
In my experience of the above implementation, I achieved an accuracy of 67.2% which is approximately 7.5% lower than the reported accuracy in Ilse et al. This could be attributed to the choices I took to minimize AWS SageMaker training costs: here I only used 624 WSIs out of the 11,000 WSIs available in the dataset. Additionally, while the model in the literature was trained over 100 epochs, this model was only trained for 10 epochs. Given more financial investment, I anticipate that a larger training dataset and longer training time would lead to results closer to those seen in the paper.
Repository (Updated in October 2021 to include unit testing)
References
- Topol EJ. Deep Medicine: How Artificial Intelligence Can Make Healthcare Human Again. Basic Books; 2019.
- Amodei D, Hernandez D. AI and compute. OpenAI. May 16, 2018. https://openai.com/blog/ai-and-compute
- Singh H, Meyer AN, Thomas EJ. "The Frequency of Diagnostic Errors in Outpatient Care: Estimations from Three Large Observational Studies Involving US Adult Populations." BMJ Qual Saf. 2014. 23(9): 727–731.
- Friedman, C. P. (2009). A "fundamental theorem" of Biomedical Informatics. Journal of the American Medical Informatics Association, 16(2), 169–170. https://doi.org/10.1197/jamia.m3092
- Campanella G, Hanna MG, Geneslaw L, et al. Clinical-grade computational pathology using weakly supervised deep learning on whole slide images. Nat Med. 2019;25(8):1301–1309.
- Wang, D., Khosla, A., Gargeya, R., Irshad, H., Beck, A. H. (2016). Deep Learning for Identifying Metastatic Breast Cancer. arXiv preprint arXiv: 1606.05718.
- Campanella, G., Hanna, M. G., Geneslaw, L., Miraflor, A., Werneck Krauss Silva, V., Busam, K. J., Brogi, E., Reuter, V. E., Klimstra, D. S., & Fuchs, T. J. (2019). Clinical-grade computational pathology using weakly supervised deep learning on whole slide images. Nature Medicine, 25(8), 1301–1309. https://doi.org/10.1038/s41591-019-0508-1
- Ilse, M., Tomczak, J. M., Welling, M. (2018). Attention-based Deep Multiple Instance Learning. Proceedings of the 35th International Conference on Machine Learning, Stockholm, Sweden, PMLR 80. https://arxiv.org/abs/1802.04712.
- Rajpukar P, Saporta A. The AI Health Podcast. Pathology AI and entrepreneurship with PathAI’s Dr. Aditya Khosla. December 16, 2020. https://theaihealthpodcast.com/episodes/patholgy-ai-and-entrepreneurship-with-pathais-aditya-khosla.
- McCormick, J. (2021, September 27). FDA Authorizes AI Software Designed to Help Spot Prostate Cancer. Wall Street Journal.
- Webber, E., & Cruchant, O. (2020, December 9). Scale deep learning with 2 new libraries for distributed training on Amazon SageMaker [web log]. https://towardsdatascience.com/scale-neural-network-training-with-sagemaker-distributed-8cf3aefcff51.
- Aws. (n.d.). aws/amazon-sagemaker-examples. GitHub. https://github.com/aws/amazon-sagemaker-examples/blob/35e2faf7d1cc48ccedf0b2ede1da9987a18727a5/training/distributed_training/pytorch/data_parallel/mnist/pytorch_smdataparallel_mnist_demo.ipynb.