The world’s leading publication for data science, AI, and ML professionals.

Using Multi-Task and Ensemble Learning to Predict Alzheimer’s Cognitive Functioning

Realizing the impact I can make in ML coming from cognitive science and the publication of my first scientific paper

Personal Tales Into Data Science

In one of my previous articles, I detailed my experience of transitioning into machine learning from Cognitive Science and the imposter syndrome that taunted me. In that article, I mentioned:

"An idea began to slowly unravel – perhaps, my background [in cognitive science] provided a much more solid foundation than I had initially anticipated".

In this article, I’ll share a concrete example of when my cognitive science background enabled me to 1) develop innovative modeling approaches for a disorder that holds personal significance for me in the field of neuroscience and 2) forge unique connections that are often overlooked in conventional discussions.

Through this experience, it became evident to me that the field of deep learning, with all its potential, is still in its formative stages, serving as a reminder of the inclusive opportunities it offers to individuals from both traditional and non-traditional backgrounds.

The Brain Networks Laboratory

A lingering feeling that haunted me after completing my undergraduate degree was a sense of having a decent theoretical foundation but lacking the practical understanding to apply those tools effectively. I envisioned the ideal scenario where I could apply these tools within the space of neuroscience or mental health.

Consequently, I was thrilled when the opportunity came up to apply for the Wicklow AI in Medicine Research Initiative, which meant I’d commit to work for a research-oriented practicum in my graduate program focused on "utilizing artificial intelligence models to advance […] medical research across areas including oncology, cardiology and neurology".

During the program, I was accepted 🎉 and ultimately ended up working under the Brain Networks Lab at UCSF. The lab centers on:

"Understanding the mechanisms of healthy and diseased brains by applying computational tools to neuroimaging data."

I eagerly waited for further details on my practicum; excitement bubbled within me to start bridging that gap between theory and practice. Finally, the long-awaited moment arrived when I received my task.

Predict a Cognitive Score for Alzheimer’s Patients

The Task: Predict a cognitive score for Alzheimer’s patients The Problem: I had absolutely no experience with computer vision.

My **** _first impression was, "How can I possibly contribute to this_?" Naturally, the voices of fear and imposter syndrome wanted to make their presence known. Additionally, my experience with deep learning was limited – more or less negligible – and on top of that, the position was at risk of being discontinued due to the complexity of the task.

However, what I did have was the fuel to understand Alzheimer’s holistically, heartfelt empathy for those affected, and an eagerness to contribute using the computational tools we have. Additionally, the disconnect between theory and practice fueled my determination to contribute and grow.

Lifestyle Correlates to Alzheimer’s Disease

My second impression was that relying solely on MRI data to forecast a cognitive score seemed peculiar, considering the known demographic, genetic, and lifestyle correlations with Alzheimer’s disease.

In the study, "The Links Between the Gut Microbiome, Aging, Modern Lifestyle and Alzheimer’s Disease," the authors highlight that "significant […] changes of gut microbiome have been reported in patients with Alzheimer’s disease […] gut microbiome is highly sensitive to negative external lifestyle aspects, such as diet, sleep deprivation, circadian rhythm disturbance, chronic noise, and sedentary behavior, which are also considered as important risk factors for the development of sporadic Alzheimer’s disease".

However, I knew I needed to focus on my assignment; nonetheless, curiosity consumed me and encouraged me to inquire about the possibility of obtaining access to demographic and clinical data just in case the opportunity arose… rubs hands together

We’ll get back to this later.

Measuring Cognitive Functioning

The ADAS Cog-11 Score

In the meantime, I continued to brainstorm modeling approaches using MRI data alone. However, you may be wondering, how are we defining ‘cognitive functioning’? In other words, what are we aiming to predict? What is our target label?

We’re measuring cognitive function using ADAS-Cog-11 – a metric employed to evaluate the deterioration of memory, language, and praxis in individuals affected by Alzheimer’s disease. According to Wikipedia, "it is one of the most widely used cognitive scales […] and considered to be the ‘gold standard’ for assessing anti-dementia treatments."

The ADAS-Cog-11 score is derived from the following eleven cognitive tasks:

You can find a detailed summary of the tasks here.

Simple Convolutional Neural Network

To begin the experiments, we trained a baseline Convolutional Neural Network (CNN) model on MRI data to predict the cognitive score.

This simple CNN consisted of the following layers: the first convolutional layer, pooling layer, the second convolutional layer, pooling layer, a 2-layer fully connected neural network, and a regression layer.

The results were unremarkable with the cross-validation R2 ranging from 0.33 to 0.52 and a testing R2 of 0.47. Nevertheless, considering the modest expectations for this simple model, we were content to establish a starting baseline that we can work on improving.

Multi-task Learning & Capturing Brain Structural Context

The next step was to research structural predictors of dementia. This led to a body of literature that highlighted the independent relationships between gray and white matter volumes and dementia severity.

According to Stout et al,

"Quantitative magnetic resonance methods provided strong evidence that cortical gray matter volume, which may reflect atrophy, and abnormal white matter volume are independently related to dementia severity in probable Alzheimer disease: lower gray matter and higher abnormal white matter volumes are associated with more severe dementia".

From this, a theory surfaced: If the model could capture information about gray matter and white matter volumes, it should boost predictive power.

But how can we do this?

The short answer: ✨ multi-task learning ✨

"Multi-task learning is a subfield of machine learning in which multiple learning tasks are solved at the same time, while exploiting commonalities and differences across tasks […] using the domain information contained in the training signals of related tasks […] what is learned for each task can help other tasks be learned better"

  • Multi-task learning, 2021, July 6. In _Wikipedia_

The intuition: If our model is predicting a cognitive score and simultaneously learning to segment the input MRI scan into white matter, gray matter and cerebrospinal fluid, these interrelated tasks would capitalize off of the shared domain information and boost performance for each separate task.

The U-Net Architecture

For this model, we used a U-Net architecture:

"An architecture [that] relies on the strong use of data augmentation, […] consists of a contracting path to capture context and a symmetric expanding path that enables precise localization […] such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network)" – Ronneberger et al., 2015

This was an attractive architecture to experiment with since we’re able to use very few images to achieve a better performance than previous methods, as long as we apply data augmentation techniques.

What are Data Augmentation Techniques for Medical Imaging?

Data augmentation is the process of applying randomized alterations (e.g. translating, rotating, flipping, stretching, shearing, etc.) to the input images in order to increase the variability. This process allows our model to better generalize by introducing minor shifts to the input data, as long as the modified images remain within the realm of possible inputs.

These techniques also serve as a valuable solution to address the challenge of labeled medical imaging scarcity. Acquiring sufficiently large medical imaging datasets is a prominent issue for two main reasons: 1) the manual annotation of medical scans is exceptionally time-consuming and 2) the sharing of clinical data is encumbered by increasingly stringent patient privacy laws.

Performance Across Different Organs & Modalities

It’s common practice to experiment with various data augmentation techniques to see which ones work best for your particular task. However, a way to steer the exploration process, particularly in medical imaging, is to understand the suitable – and most effective – augmentation techniques for different organ / structure, modality, and task combinations. Here, "suitable" refers to ensuring that the augmented data comprises valid examples within the input space.

For instance, elastic deformations are generally suitable for organs that possess inherent elasticity or deformability. Brain tissue is an excellent example of this due to the brain’s remarkable ability to undergo structural and functional changes in response to experiences, learning and environmental changes, also known as neuroplasticity.

Santiago Ramón y Cajal, a revered figure known as the father of modern neuroscience once proclaimed:

"Any man could, if he were so inclined, be the sculptor of his own brain"

  • Santiago Ramón y Cajal

However, bones have limited deformability and blood vessels are rigid structures, therefore, applying elastic deformations may not accurately represent realistic variations or preserve anatomical integrity.

In addition, zooming (scaling the image to emphasize a specific region) is often well-suited for X-ray images, as they generally encompass a broader field of view. However, for MRIs that already have a narrower field of view and concentrate on specific areas of interest, zooming may inadvertently exclude important contextual information, rendering it less suitable as an augmentation technique for MRIs.

There‘s an incredibly comprehensive literature review on data augmentation techniques for medical imaging here. The authors highlight this point:

"Depending on the nature of the input and the visual task, different data augmentation strategies are likely to perform differently. For this reason, it is conceivable that medical imaging requires specific augmentation strategies that generate plausible data samples and enable effective regularization of deep neural networks"

Multitask Learning Architecture Visualized

I’ve provided a visual representation of the multi-task model below. At the lowermost part of the U-Net architecture, we incorporated a regression block. In this block, all the pixels are flattened, then passed through a linear layer, with the resulting output representing the cognitive score.

U-Net & Multitask Learning Performance

To recap, initially, we employed a simple CNN baseline, yielding cross-validation R2 values ranging from 0.33 to 0.52, with a testing R2 of 0.47.

However, with the implementation of the multi-task model, we witnessed notable improvements:

  • Cross-validation R2: Enhanced to the range of 0.41 to 0.69
  • Testing R2: Improved to 0.57

It’s also worth mentioning that we developed a baseline model for the segmentation task using the U-Net architecture, which demonstrated an accuracy of 93.7%.

Notably, the segmentation task also experienced a performance boost, elevating its accuracy from 93.7% to 97.27% 🎉

The multi-task model appeared to outperform the current state of the art at the time for brain segmentation, as demonstrated below:

What about the demographic and genetic features?

But what about the demographic and genetic features that I’d promised to get back to?

Well, we got access to the data! cheers

The first step was to train another baseline model on the tabular features alone. We ultimately decided on a Histogram-based Gradient Boosting regressor (HGB Regressor) with a Poisson loss function, which yielded surprisingly decent results.

  • Cross-validation R2: Ranging from 0.56 to 0.63
  • Testing R2: 0.51

Integrating Tabular Data into an Ensemble Multi-task Model

In the final experiment, we brainstormed ways to integrate these features into our model in an effective manner.

The integration of tabular data and imaging data was a challenge due to the substantial difference in the number of inputs between imaging data (where each voxel is considered an input) and the tabular dataset (where each feature is represented as a single value). Treating all features equally risked diminishing the significance of the demographic and genetic risk factors in the overall model.

To tackle this, we developed a separate model for the tabular data using the HGB regressor. We then applied a weighted average ensemble approach to combine the predictions from the HGB regressor and the Multi-task model. The weights assigned to each model are based on their performance and reliability, giving higher weights to the more accurate or confident model. This ensemble technique effectively optimized the contributions of each model by assigning appropriate weights.

Below, you can see a visualization of this ensemble approach.

Final Model Performance

So how did this ensemble, multi-task approach compare to the previous experiments?

drum roll

Regression Task Performance:

  • Cross-validation R2: 0.73–0.78
  • Testing R2: 0.67

Segmentation Task Performance:

  • Accuracy: 98.12%

Much like the notable performance improvement witnessed in multi-task learning, the incorporation of demographic features and genetic risk factors via an ensemble method, not only substantially enhanced the performance of the regression task but also further bolstered the performance of the segmentation task. This clearly demonstrated the power of leveraging multiple data sources and capitalizing on their synergistic potential.

Concluding Remarks

Having the opportunity to delve into the crossroads of neuroscience and Machine Learning, particularly in relation to Alzheimer’s disease, and realizing how my background enabled me to connect concepts between various fields was transformative. Since experiencing the substantial improvements achieved from aligning current neuroscience research with ML concepts, I’ve fostered greater appreciation for the power of integrating diverse data sources and applying model architectures in a domain-driven manner.

I hope this research has been a source of inspiration for individuals who:

  • Have transitioned from cognitive science into ML
  • Are fueled to acquire and apply techniques for neurological disorders
  • Have an interest in computer vision within the medical domain

Please feel free to reach out with any questions and I hope this has been as thrilling for you as it had been for me!

Articles Used for Figure II:

  1. Combination of healthy lifestyle traits may substantially reduce Alzheimer’s
  2. The Links Between the Gut Microbiome, Aging, Modern Lifestyle and Alzheimer’s Disease
  3. Cognitive reserve and lifestyle: moving towards preclinical Alzheimer’s disease

Related Articles