The world’s leading publication for data science, AI, and ML professionals.

Which models to use for epidemic prediction?

Comparison of Data-Driven and Model-Based Methods for Prediction of Influenza and COVID-19 Cases

By Erin Stafford and Eli Shlizerman

source:frankundfrei, via pixabay (CC0)
source:frankundfrei, via pixabay (CC0)

In the time of COVID-19 the need for accurate predictions of both long-term and short-term evolution of epidemics has been made apparent. We propose to compare traditional model-based methods, such as Susceptible Infected Recovered model (SIR), with emerging data-driven models including recurrent neural networks (RNN) for time-series prediction. We compare these methods on influenza (flu) data which is more robust and then examine applications to COVID-19 data.

Our findings show that (i) Commonly used model-based methods (i.e. SIR) and data-driven RNN methods (i.e. vanilla-LSTM) do not provide accurate long-term predictions on flu data and require a constant update to be more accurate. (ii) We identify the data-driven Seq2Seq RNN model as the most promising data-driven approach for both short term and long term predictions. Since epidemics follow similar patterns, we propose that Seq2Seq trained on flu data could be used as a model for COVID-19. Such a model will require only a ‘few-shots’ retraining (several samples) to provide predictions.

right source:frankundfrei, via pixabay (CC0)
right source:frankundfrei, via pixabay (CC0)

Code files are available from: https://github.com/shlizee/Seq2SeqEpidemics


The Flu Data

The CDC has surveillance systems in place for tracking the seasonal spread of influenza. One such network is the US Outpatient Influenza-like Illness Surveillance Network (ILINet). Each week, outpatient healthcare providers in ILINET report the number of patients with influenza-like illness (ILI) by age group. ILINet provides data on a national, state, and regional level as well as percentages of visits due to ILI weighted by population and unweighted. We will be looking at the national level data for weekly case counts.

Plotting the data we can see that influenza epidemics follow a similar pattern. Furthermore, we can see that we are dealing with time-series which exhibits a yearly seasonal pattern. To make the data easier to work with, we change the format to be seasonal (i.e. looking at each yearly period with a center at week 14 – first week of April ).

Flu data from ILINet. Right: 2018–19 flu season number of cases used for evaluation of models
Flu data from ILINet. Right: 2018–19 flu season number of cases used for evaluation of models
> python3 fludata.py

Making predictions with a model-based SIR

The SIR model is a system of ordinary differential equations (ODEs) that have traditionally been used to study the dynamics of epidemics. In this model, the population is divided into three groups: Susceptible, Infected, and Recovered. A flow diagram of this model can be seen below.

Susceptible become infected with rate β, and infected recover with rate γ. In order to use the SIR model to make predictions, we must optimize the parameters of the model to fit the data. The parameters we optimize include: (i) % of the population that is initially susceptible (ii) β: the rate of infection (iii) γ: the rate of recovery.

First, we look at making predictions based on the previous season (Blue-Solid). We optimize the model parameters to fit the previous season and then use these parameters to run the model for the next season and compare the results with real data of the next season.

Next, we see how the SIR model does making predictions from only 10 data points (Blue-Dashed). We will do this in two ways: (i) fitting to the previous season and using these parameters as a starting point for optimization: Smart SIR. (ii) ** Using a general starting point: Naive SIR**.

SIR model prediction of last season flu cases
SIR model prediction of last season flu cases
> python3 SIRfludata.py

The results tell us that the shape of Previous Season fit SIR curve appears to be closest to real data, however, it does not predict the peak well, and the initial segement of the curve is still very different. Overall, we see that the SIR model even if updated on previous data does not capture important features of the epidemic evolution and indicates that data-driven approaches could improve the prediction.

Making Predictions using vanilla-LSTM

A commonly used prediction strategy for time-series data (e.g. the epidemic data of daily cases that we consider here) is recurrent neural networks (RNNs). Specific types of RNNs that could provide robust prediction are LSTMs (Long Short-Term Memory units). LSTMs are able to recognize temporal patterns in time series data that are then used in prediction. Our implementation follows the tutorial for time series prediction with LSTM in PyTorch [3].

We perform several preprocessing steps. (i) We normalize the data for LSTM (ii) Create input-output sequences to train the model in which we create tuples. These tuples will contain 52 weeks of data as the input and the following week as the output. There are 52 data points in each input sequence with one label, which is the next data point.

> python3 LSTMfludata.py

Now we construct the model. We train the model on the input/output sequences we created earlier, i.e., we slide a window of 52 data-points (one season) over the whole training data by a single step at a time and predict the next point. We then compute the L1 loss between the predicted value and the ground truth. It looks like the model converges (although notice that the loss (left plot) is pretty high) and we examine the predictions for the next season (right plot).

Left: Loss of LSTM prediction. Right: LSTM prediction (orange) compared with Real data (black).
Left: Loss of LSTM prediction. Right: LSTM prediction (orange) compared with Real data (black).

It turns out that vanilla-LSTM is doing worse than the SIR model! Furthermore, predictions from LSTM were not very reliable and tended to produce different results after training. However, in some cases, the results were as accurate as SIR, and notably, in many of the trials, LSTM tends to capture the uptick in new cases towards the end of the season.

The Seq2Seq Model

Seq2Seq models use RNN encoder and decoder that converts one sequence to another. Although this is typically used for language processing, it is also suited to time-series forecasting, see [4]. The implementation of this model is adapted from Time Series prediction Seq2Seq [5] and [1,2].

> python3 Seq2Seqfludata.py
Left: Loss of Seq2Seq prediction. Right: Comparison of Seq2Seq decoder prediction (red dashed) and the target flu Real data (green).
Left: Loss of Seq2Seq prediction. Right: Comparison of Seq2Seq decoder prediction (red dashed) and the target flu Real data (green).

We set up Se2Seq to perform similar prediction as in ‘Previous Season fit SIR’ and in ‘LSTM’. We set the input into the encoder to be the previous season (‘Encoding Series’ of 52 weeks – black) and train the output of the decoder to predict 52 weeks of the current season (‘Target Series’ of 52 weeks – green). We train the model on all previous seasons and leave the last season for testing. The prediction of Seq2Seq is marked in red color.

Comparison of all methods in prediction of Real data (black). Seq2Seq (red) provides the most optimal prediction amongst the examined models.
Comparison of all methods in prediction of Real data (black). Seq2Seq (red) provides the most optimal prediction amongst the examined models.

As the comparison plot above shows, Seq2Seq predictions do significantly better than both SIR and LSTM on predicting the peak value of the epidemic, the peak week and the overall shape of the curve appears to be close to the real data curve, almost as a smoothed version of the real data curve.


Application of Seq2Seq to predict COVID-19 epidemic

Since the results from using Seq2Seq on the flu data are promising, we are motivated to see how the Seq2Seq model can predict the COVID-19. The data used for this application is from the COVID-19 Global Forcasting challenge on Kaggle[6]. The data includes confirmed cases and fatalities on the level of country, state, county, as well as the population of each region. We will focus on the US data and address two main questions with our experiments:

  1. Using Seq2Seq to predict the last 45 days of epidemic data (number of cases) for each state, training on other states. This will determine the accuracy of this data-driven prediction for COVID-19.
  2. We ask whether the Seq2Seq trained on the flu can be used as a ‘model-based’ approach for COVID-19. With this setup we explore whether training Seq2Seq on an epidemic similar to COVID-19 would be a hybrid data-model-based approach for accurate predictions.

1. Using Seq2Seq to predict the last 45 days of epidemic by state

For this application, we train a Seq2Seq neural network to predict the last 45 days of an epidemic using the beginning of the epidemic as the input. We predict 3 example states (and train on other 47 states). The daily prediction is shown in red and compared with target green.

> python3 Seq2SeqCovid19_45days.py
Seq2Seq prediction (red, best comapred with magenta: 7-day moving average) for the last 45 days of COVID-19 for 3 states: Alabama, New York, Texas.
Seq2Seq prediction (red, best comapred with magenta: 7-day moving average) for the last 45 days of COVID-19 for 3 states: Alabama, New York, Texas.

We are able to capture the overall trend pretty well! We do not capture the spikes in the data which many times are attributed to the days on which data was entered into the system. Notably, we do not expect to predict these. To further test the accuracy of the prediction, we look at moving averages, updated each day (compare prediction-red with magenta-7day moving average target) [7]. As can be seen, we are getting impressive accuracy for all three examined states.

2. Does flu data give us a better Seq2Seq model to use for COVID-19?

For this application, we first train Seq2Seq on the flu and then re-train it to predict COVID-19 (we label such model FC-Seq2Seq). In particular, we train for all states for the flu and then re-train only on 3 states on the COVID-19 data. By following such a protocol we aim to obtain a model-based approach for COVID-19, which provides accurate predictions with a small amount of re-training.

> python3 FCSeq2Seq.py
Loss of re-trained FC-Seq2Seq.
Loss of re-trained FC-Seq2Seq.

We observe that the loss of FC-Seq2Seq is of the smallest than all examined models. This indicates that this setup obtains highest accuracy amongst all setups that we examined and it uses minimal data for training (i.e. only data of 3 states). The accuracy of the prediction is evident from examination of the prediction for other states (red) out of 47 in the plots below, and comparing with 7 day moving average COVID-19 data (magenta). The referenced code can be used to generate prediction for the remaining 44 states.

FC-Seq2Seq prediction (red, best comapred with magenta: 7-day moving average) for COVID-19 cases.
FC-Seq2Seq prediction (red, best comapred with magenta: 7-day moving average) for COVID-19 cases.

We thereby propose that such practice would be useful to re-train the Seq2Seq flu model on earlier COVID outbreaks to predict later outbreaks.


Conclusion

Our analysis of Epidemic Prediction methods shows that Seq2Seq RNN is a reliable aproach for epidemic prediction. Seq2Seq outperforms SIR and LSTM for the prediction of flu data and shows promising results for COVID-19. Furthermore, the Seq2Seq model trained on flu was able to make accurate predictions of COVID-19 data with little re-training. We propose that Seq2Seq model will be used and compared with in **** epidemic predictions of number of cases, fatalities counts currently generated for COVID-19. Next, we plan to study Seq2Seq to clarify and predict the effect of mitigation strategies [8].

References

[1] Su, K., & Shlizerman, E. (2019). Dimension Reduction Approach for Interpretability of Sequence to Sequence Recurrent Neural Networks. arXiv preprint arXiv:1905.12176.

[2] Su, K., Liu, X., & Shlizerman, E. (2020). Predict & cluster: Unsupervised skeleton based action recognition. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 9631–9640).

[3] https://stackabuse.com/time-series-prediction-using-lstm-with-pytorch-in-python/

[4] https://blog.keras.io/a-ten-minute-introduction-to-sequence-to-sequence-learning-in-keras.html

[5]https://github.com/JEddy92/TimeSeries_Seq2Seq/blob/master/notebooks/TS_Seq2Seq_Intro.ipynb

[6] https://www.kaggle.com/c/covid19-global-forecasting-week-5/data?select=test.csv

[7] https://machinelearningmastery.com/moving-average-smoothing-for-time-series-forecasting-python/

[8] https://wallethub.com/edu/states-with-the-fewest-coronavirus-restrictions/73818/


Related Articles