Classifying ADHD from Healthy Controls using LSTMs with rs-fMRI Data

A hands-on tutorial: from extracting resting-state networks to calculating the classification significance

Gili Karni
Towards Data Science

--

image credit: American health imaging

This tutorial walks through the practical steps of using machine learning to analyze rs-fMRI data and more specifically to discriminate Attention Deficit Hyperactivity Disorder (ADHD) from healthy controls. I discuss (1) feature extraction using masks (2) The benefits and drawbacks of recurrent neural networks (RNN) and particularly long short term memory networks (LSTM) for classifying fMRI data (3) hypothesis testing and its use in model evaluation.

For more details on how to prepare fMRI images for analysis visit Identifying resting-state networks from fMRI data using ICAs for a preprocessing tutorial.

A note: if you have already read these previous posts, you may have noticed that I have swapped from a schizophrenia (SZ) dataset to an ADHD one. The main reason is that I lack the computational resources to preprocess the full schizophrenia dataset by myself. I could not find a preprocessed SZ dataset but I found an ADHD one.

The full code for this tutorial is available here. To be prepared, please download python and install the following packages- nilearn, sklearn, and keras.

Data preparation

As we mentioned previously, fMRI images are 4D matrices reflecting the activation level of each voxel at the three-dimensional space and time. However, oftentimes, the relevant information is portrayed by a subset of this data. For example, here- we are only curious about the resting-state networks. To omit the irrelevant data, we apply masks. Masks are simply filters that pass the desired subset of the data while dropping the rest. Practically, masks replace the activation value of unwanted voxels to 0.

There are many styles to mask fMRI data, which is mostly determined by the goal of the analysis. This tutorial focuses on classifying ADHD patients from controls via their resting-state networks. Thus, I apply Smith’s rs-fMRI components atlas (Smith et al.,2009). Smith atlas reflects seventy resting-state networks (RSN) acquired using an independent component analysis (ICA) on thousands of healthy patients. I prefer Smith atlas over the use of dataset-specific ICA components because it helps to avoid double-dipping (i.e using the data twice) which could lead to overfitting.

The Smith’s atlas is available via the Nilearn dataset:

The ADHD dataset is also available on Nilearn. From the header of the first image, we see that the image contains 73*61*61 voxels over 176 timestamps. Additionally, each voxel is about 3mm³. It is important to note that this information may not be uniform across the whole dataset.

OUTPUT: [ 4 61 73 61 176 1 1 1] // [-1. 3. 3. 3. 2. 0. 0. 0.]

To mask the data, we first need to generate the mask from the Smith’s atlas. Applying a standardization could contribute to the features’ robustness. In the case of masking, it can help in enhancing the signal by centering and normalizing the slices for each time-series. Considering the data confounds as part of the transformation process, also helps to enhance the signal by removing confounding noise.

OUTPUT: N control: 20 / N ADHD: 20

As mentioned before, a dataset might not hold a homogenous scanning length — these 40 subjects present quite a large variation. However, most machine learning algorithms (Keras included) require a uniform shape across all subjects. To optimize for keeping data, we can use padding; append each subject with zeros after the end of its scan to match the length of the longest scan. In addition to padding, I reshape the data such that it would fit the requirements posed by Keras; (40,261,10) implies that we have 40 subjects with 261 timestamps long sample over 10 regions.

OUTPUT: data shape: (40, 261, 70)

Data preparation

We use the train/test split paradigm to make sure the model is tested on completely new data. The function below randomly splits the data into train and test and reshapes each part according to the model’s requirements.

LSTM model

Long short-term memory (LSTM) models provide some benefits in learning fMRI data. The main reason is that, unlike most machine learning or deep learning methods, they manage to keep the contextual information of the inputs — thus incorporate details from previous parts of the input sequence while processing a current one. That said, being highly contextual isn’t always such a good thing. There are cases where LSTMs are not the best choice; being contextual could lead to an over-interpretation of the data. In addition, LSTMs could take longer to run than a simple NN and they might have far more parameters to tune. It is important to consider various options and finding the most suitable model from the relevant scenario (Culurciello, 2018).

In this case — I chose to demonstrate the ability of LSTMs in analyzing fMRIs due to their contextual nature. fMRI data represents dynamic brain activity over time, thus using LSTMs enables taking advantage of the temporal information (that otherwise would have been lost) when analyzing functional connectivity (Dvornek et al., 2017).

RSN LSTM classifier pipeline

Note- A common enhancement for LSTM is the use of convolutional neural networks (CNN) which supports analyzing the spatial structure. However, here, we extracted seventy discrete values reflecting independent components of the activation of whole networks — and thus gave up on the spatial property of the data. Therefore, CNN would unlikely to be useful.

An issue with LSTMs is that they can easily overfit training data, reducing their predictive capacity. A solution to this problem is via regularization, which hinders the model’s overfitting tendency. A conventional regularization in LSTM networks is the dropout, which probabilistically excludes units from the layer connection. There are two types of dropouts- input dropout and recurrent dropout. A dropout on the input means that for a given probability, the data on the input connection to each LSTM unit will be excluded from node activation and weight updates (see the dropout argument). The dropout on the recurrent input acts the same way but on the recurrent connection (see the reccurent_dropout argument). However, it is important not to over-regularize as it will hinder the model from learning at all (which could be detected by strict non-indicative prediction).

The model I present here is a sequential model with three stacked LSTM layers and one dense layer with sigmoid activation. To be frank, there is no one right answer to how one chooses their hyperparameters. There is a lot of trial and error and rules of thumb. I present a few that I have collected —

  • Generally, we would like to have at least two hidden layers (not including the last layer) since the power of NNs stems from their depth (i.e., a zero-layer can only represent linear functions..). Yet there is a trade-off between accuracy and training time — find the number of layers that add significant value to the accuracy without costing too much time (more details here).
  • We would like to start with a number of units smaller or equal to the input size and decrease it (by about a half at the time) until reaching the final layer. (see the full discussion here)
  • In the case of classification, the number of units at the output layer should be equal to the number of categories. Normally, a binary classification problem has one output.
  • The output layer activation is mostly sigmoid for binary classification, softmax for a multi-class classifier, and linear for regression.
  • The classification loss function should be matched to the number of labels and their type. The binary_crossentropy loss function is best for binary classification, the categorical_crossentropy is great for multiclass with or one-hot-encoded data and sparse_categorical_crossentropy for integer-like labels.
  • The accuracy metric is great to show the percent of correct classifications. You can use more than one!
  • The more epochs the better, start with 30 and follow your validation set for a decreasing loss and improvement in the accuracy. If you see improvement, increase the number of epochs, otherwise — go back to the drawing board.
  • Choose the optimizer based on your knowledge of the properties of your data and the depth of the network. For example — ada optimizers work well with sparse data but perform more poorly as the network deepens. Adam optimizer is a safe choice to start with as it is an efficient optimizer that does not require a lot of tuning (see more details here).
Image Credit: CS231n

And here is the model-

Let’s try the model and see if it manages to learn. Thus, its accuracy increases and its loss decreases;

We see some increase in the model’s accuracy and some decrease in the model’s loss as the number of epochs increases. Additionally, both the training and validation sets showing similar trends. Thus, we are (1) having a working model and (2) probably not overfitting — thus, we shall continue.

How do we actually use the model?

Hypothesis testing

I use a hypothesis testing framework to evaluate the model since it helps obtain both the accuracy of the model and its significance. To calculate the model’s significance, I use bootstrapping. Bootstrapping, shortly, is a powerful, computer-based method for statistical inference without relying on too many assumptions. It helps us approximate the properties of the population by estimates from small data samples.

How does it work (briefly)?

Our null hypothesis, to be tested, states that there are no differences between ADHD patients to controls, thus the accuracy of comparing the two using our model should be about 0.5 (meaning, chance).

To evaluate this hypothesis, we will repeatedly resample with replacement from the data; splitting it into a train and test parts, fitting the model, predicting the labels of test data, and calculating the accuracy. Such repetition will create a distribution of accuracies. Such resampling (following the central limit theorem) is likely to approach a Gaussian shape the larger the number of iterations we use. Thus, we could employ statistical tests that require normal distribution (like t-test).

If 0.5 placed at the tails of the distribution (i.e at the last 5%), we can conclude that it is very unlikely that the distribution is centered around 0.5 and reject the null hypothesis. Otherwise, we would fail to reject it.

The next function runs the bootstrapped experiment:

And these two calculate and plot the p-value associated with the accuracy of the bootstrapped experiment.

Lastly, let us present the results using the Receiver operating characteristic (ROC) curve. The ROC curve reflects the sensitivity and specificity of the model by plotting the true-positive rate against the false-positive rate.

I chose to present the median (rather than the mean) value of the ROC curve since it presents a more robust measurement in case the data is skewed (and should not present any drawbacks if the data is not skewed).

Interpretation

The ROC curve seen above shows a significant above-average classification capacity. We see that the median value, as well as one SD around it, is fully placed above the chance diagonal. The 95% confidence intervals (presented by 2 SD around the mean ROC curve expands below the chance diagonal mostly at the bottom left point. This could mean that the model is more likely to express higher sensitivity but lower specificity.

The model’s sensitivity is the proportion of patients that were identified correctly to have the disease (i.e. True Positive) upon the total number of patients who actually have the disease. The model’s specificity describes the proportion of patients that were identified correctly to not have the disease (i.e. True Negative) upon the total number of patients who do not have the disease. Usually, these two display an inverse relation.

Sensitivity or specificity?

It depends. However, it is common in diagnostic experiments to favor sensitivity- thus not missing any actually ill patients (and risking falsely diagnosing healthy patients). However, such a decision is highly related to the nature of the hypothesis (Parikh et al., 2008).

Summary

In this tutorial, we have built an LSTM model to classify ADHD patients from healthy controls via rs-fMRI. We have discussed hypothesis testing and demonstrated its benefits in a diagnostic experiment.

References

Borlase, N., Melzer, T. R., Eggleston, M. J., Darling, K. A., & Rucklidge, J. J. (2019). Resting-state networks and neurometabolites in children with ADHD after 10 weeks of treatment with micronutrients: results of a randomised placebo-controlled trial. Nutritional neuroscience, 1–11.

Culurciello, E. (2018). The fall of RNN / LSTM. Retrieved 14 December 2019, from https://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

Dvornek, N. C., Ventola, P., Pelphrey, K. A., & Duncan, J. S. (2017, September). Identifying autism from resting-state fMRI using long short-term memory networks. In International Workshop on Machine Learning in Medical Imaging (pp. 362–370). Springer, Cham.

Parikh, R., Mathai, A., Parikh, S., Sekhar, G. C., & Thomas, R. (2008). Understanding and using sensitivity, specificity and predictive values. Indian journal of ophthalmology, 56(1), 45.

Smith, S. M., Fox, P. T., Miller, K. L., Glahn, D. C., Fox, P. M., Mackay, C. E., … & Beckmann, C. F. (2009). Correspondence of the brain’s functional architecture during activation and rest. Proceedings of the National Academy of Sciences, 106(31), 13040–13045.

--

--