Deep Learning for Survival Analysis

Sundar V
Towards Data Science
7 min readJul 13, 2019

--

Recently I got an opportunity to work on Survival Analysis. Like any other project, I got excited and started to explore more about Survival Analysis. As per wiki,

“Survival analysis is a branch of statistics for analyzing the expected duration of time until one or more events happen, such as death in biological organisms and failure in mechanical systems.”

In short, it is a time to event analysis that focuses on the time at which the event of interest occurs. The event can be death, sensor failure, or occurrence of a disease, etc. Survival analysis is a popular field having a wide range of use cases in Medicine, Epidemiology, Engineering, etc.

Do you wonder how such a significant field is transformed by Deep Learning? If yes, then you came to the right place.

When I started to take a look at how deep learning is being used in survival analysis, to my surprise, I couldn’t find any good article. All the materials, tutorials that I came across used only statistical methods. But when I dig little deeper, I discovered that significant study for survival analysis involving neural networks was initially published in the year 1995 by Faraggi and Simon [5].

Intro to Survival Analysis

Will start with basics by understanding the critical definitions in survival analysis.

Let T be the random variable representing the waiting time until the occurrence of an event. Besides the usual probability functions, we can define some essential functions related to survival analysis like Survival function, Hazard function, and so on.

Survival function S(t) is the probability that the event of interest has not occurred by time t. Properties include S(0) = 1, S(∞) = 0 and S(t) is non-increasing.

And the hazard function λ(t) can be viewed as the probability of the event occurring instantaneously at time t given that the event did not happen before the time t.

Survival Analysis Dataset

A pivotal point to be noted is that we are not dealing with a usually labeled dataset. Data for survival analysis is different from standard regression or classification problem. In survival analysis, we are dealing with censored data. Censoring is a form of missing data problem.

Whenever we deal with survival analysis, we will get the training data within a study period. What if the object under study didn’t experience the event within the study period?. Then it is called as right-censored data. Similarly, we can also have left-censored and interval-censored samples in our dataset. Commonly, we will be dealing with datasets having right-censored samples.

Example dataset — lifelines.datasets.load_static_test

In the above example dataset, column id identifies the individual object under study. Column E will tell us whether the data is censored or not. In other words, If the value of column E is 1, then the corresponding object experienced the event of interest during the study. Here all the objects have experienced the event except the object with id 3. Object id 3 is right-censored. Column t is the time taken for the event to happen, and in the case of censored data, it is the study time. Column var1, var2 here are the features or covariates of the object. For example, covariates can be gender, age, etc. while predicting the probability of organ failure or heart attack for a patient.

Standard Methods

We will now briefly look into conventional statistical methods before using neural networks to solve the survival analysis problem. We can broadly classify the standard methods into non-parametric, semiparametric, and parametric approaches.

The most popular non-parametric approach used to estimate the survival function is Kaplan–Meier estimator. The biggest limitation is that it cannot estimate survival considering covariates.

The basic model for survival data that considers covariates is the Proportional Hazards model. This model was proposed by Prof. D. Cox (1972) and has also come to be known as the Cox regression model.

The Cox proportional hazards model (CoxPH) assumes that the hazard function is composed of two non-negative functions: a baseline hazard function, λ₀(t), and a hazard ratio or risk score, r(x) = exp{h(x)}. The risk score is defined as the effect of an object’s observed covariates on the baseline hazard. The Cox model estimates the log-risk function h(x) using a linear function and the underlying baseline hazard function without assuming any particular form. This semiparametric approach gave more flexibility and became a more popular choice than parametric methods.

This proportional hazards assumption can also be expressed in terms of the survival function, as shown below. Here S₀(t) is the baseline survival function.

Adaptation of CoxPH to meet Neural Networks

Many people have tried using neural networks for survival analysis. Results are promising in recent times with the advancement in the deep learning field.

We have two major comprehensive approaches when it comes to deep learning for survival analysis. One approach can be seen as an adaptation of the Cox Proportional Hazard assumption, and the alternative method uses a fully parametric survival model. DeepSurv [1], Cox-nnet [2] will come under CoxPH adaptation approach whereas Nnet-survival [3], RNN-SURV [4] will come under the second approach. We will focus more on CoxPH adaptation approach since Cox model has been proven to be very useful and is familiar to most medical researchers [3].

In the Cox model, apart from the proportional hazards assumption, we have another assumption while estimating the log-risk function h(x). We estimate h(x) using a linear function. In many cases, for example, modeling non-linear gene interactions, we cannot assume the data satisfies the linear proportional hazards condition [1]. We need a more complex non-linear model to predict log-risk function h(x), and neural networks can fulfill our requirement.

DeepSurv

In this section, we will discuss DeepSurv in detail to understand how deep learning is used exactly for survival analysis. DeepSurv is a deep feed-forward neural network which predicts the log-risk function h(x) parameterized by the weights of the network θ.

DeepSurv Architecture

The object’s observed covariates or otherwise called as features is given as input to the network. The hidden layers of the network consist of a fully connected layer of nodes, followed by a dropout layer. The output layer has one node with linear activation, which estimates the log-risk function in the Cox model. The loss function for the network is shown below.

Here is an intuition behind the loss function. In order to minimize the loss function, we need to maximize the part inside the red box for each object with an observable event (E = 1). To maximize that part, we need to increase the risk factor of object i with an observable event and decrease the risk factor of objects j that haven’t experienced the event until time Ti where Ti is the time when object i experiences the event.

Modern deep learning techniques such as Scaled Exponential Units, Adam optimizer, etc. are used in DeepSurv for better performance. Cox-nnet is very similar to DeepSurv, but both Cox-nnet and DeepSurv only output the risk factor. They don’t estimate the baseline functions. We can use methods like Breslow estimator to generate the baseline functions [3].

As I said earlier, people are trying to use deep learning methods for survival analysis from the mid-1990s. Given the recent advancements in deep learning, results are promising for using deep learning in survival analysis. Deep learning techniques perform as well as or better than other state-of-the-art survival models when there is a complex relationship between an object’s covariates and their hazard [1]. There is still a lot of scope for improvement and research. This topic is something very interesting for someone like me who believes AI can reshape many essential fields positively.

This is my first article, and I hope this will be informative. Thank you :)

References

[1] Katzman JL, Shaham U, Cloninger A, Bates J, Jiang T, and Kluger Y, DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural network (2018), BMC Medical Research Methodology

[2] Ching T, Zhu X, and Garmire LX, Cox-nnet: an artificial neural network method for prognosis prediction of high-throughput omics data (2018), PLOS Computational Biology

[3] Michael F. Gensheimer, and Balasubramanian Narasimhan, A Scalable Discrete-Time Survival Model for Neural Networks (2019), arXiv:1805.00917

[4] Eleonora Giunchiglia, Anton Nemchenko, and Mihaela van der Schaar, RNN-SURV: A Deep Recurrent Model for Survival Analysis (2018), LNCS 11139–11141

[5] Faraggi D, and Simon R, A neural network model for survival data (1995), Statistics in Medicine

--

--