class: center, middle, inverse, title-slide # A General Machine Learning
Framework for Survival Analysis ## ECML PKDD 2020
###
Andreas Bender (
@adibender
)
,
David Rügamer, Fabian Scheipl, Bernd Bischl ###
Department of Statistics, LMU Munich --- layout: true background-image: url(mcml-background-slide.svg) background-size: cover <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css"> <div class="my-footer"><span><i class="fa fa-twitter">@adibender</i>                            </span></div> --- ### The framework is *general* in the sense that .font120[ 1. It supports different Survival Tasks - right-censoring, left-truncation - time-varying effects, time-varying features - competing risks, multi-state models 2. Does not require specialized Software, can be applied across programming languages and using any algorithm that supports optimization of the Poisson Likelihood ] --- <br> .center[ <img src="ml-for-survival-graph.svg", width = "900px"> ] ??? `$$\usepackage{amsmath,amssymb,bm} \newcommand{\ra}{\rightarrow} \newcommand{\bs}[1]{\boldsymbol{#1}} \newcommand{\tn}[1]{\textnormal{#1}} \newcommand{\mbf}[1]{\mathbf{#1}} \newcommand{\E}{\mathbb{E}} \newcommand{\bfx}{\mathbf{x}} \newcommand{\bfX}{\mathbf{X}} \newcommand{\bfB}{\mathbf{B}} \newcommand{\bff}{\mathbf{f}} \newcommand{\bsbeta}{\boldsymbol{\beta}}$$` --- layout:false class: inverse, middle, center <h1 style="color:#005500;">Survival Task as Poisson Task</h1> <html> <div style='float:left'></div> <hr color='#005500' size=1px width=720px> </html> --- layout: true background-image: url(mcml-background-slide.svg) background-size: cover <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css"> <div class="my-footer"><span><i class="fa fa-twitter">@adibender</i>                            </span></div> --- <!-- General Idea goes back to e.g. [Friedman (1982)](http://www.jstor.org/stable/2240502): "Piecewise Exponential Model for Survival Data with Covariates" --> Consider setting with right-censored data: - we observe `\((t_i, \delta_i), i = 1,\ldots, n\)`, where - `\(t_i = \min(T_i, C_i)\)`; `\(T_i \sim F \perp C_i \sim G; T_i,C_i > 0\)` - `\(\delta_i = I(T_i \leq C_i) \in \{0,1\}\)` To approximate `$$\lambda(t; \bfx_i) = \exp(g(\bfx_i(t), t)) \stackrel{PH}{=}\lambda_0(t)\exp(\bfx_i'\bsbeta)$$` -- - split the follow-up in `\(J\)` intervals `\((\kappa_{j-1}, \kappa_j], j = 1,\ldots, J\)` -- - assume piece-wise constant hazards: `$$\begin{align} \lambda(t| \bfx_i(t)) & \equiv \exp(g(\bfx_{ij}, t_j)):=\lambda_{ij},\ \ \forall t \in (\kappa_{j-1}, \kappa_j],\\ \end{align}$$` -- - Estimation using Piece-wise Exponential Model (e.g. [Friedman (1982)](http://www.jstor.org/stable/2240502)) <br> ( `\(\ra\)` Poisson regression with transformed data) --- <div class="row"> <div class = "column", align = "center"> Data in "standard" time-to-event format <br> </div> <div class = "column", align = "center"> Data in PED format <br> </div> </div> <div class = "row" align = "middle"> <div class = "column", align = "middle"> .middle[ <img src="tab-standard.svg", width = "200px" align="middle"><br> `\(\ra\)` transform to PED using `\(\kappa_0=0, \kappa_1 = 1, \kappa_2=1.5, \kappa_3=3\)` ] </div> <div class = "column" align ="middle"> <img src="tab-ped.svg", width = "300px" align="middle" > </div> </div> --- count: false <div class="row"> <div class = "column", align = "center"> Data in "standard" time-to-event format <br> </div> <div class = "column", align = "center"> Data in PED format <br> </div> </div> <div class = "row" align = "middle"> <div class = "column", align = "middle"> .middle[ <img src="tab-standard.svg", width = "200px" align="middle"><br> `\(\ra\)` transform to PED using `\(\kappa_0=0, \kappa_1 = 1, \kappa_2=1.5, \kappa_3=3\)` ] </div> <div class = "column" align ="middle"> <img src="tab-ped1.svg", width = "300px" align="middle" > </div> </div> --- count: false <div class="row"> <div class = "column", align = "center"> Data in "standard" time-to-event format <br> </div> <div class = "column", align = "center"> Data in PED format <br> </div> </div> <div class = "row" align = "middle"> <div class = "column", align = "middle"> .middle[ <img src="tab-standard.svg", width = "200px" align="middle"><br> `\(\ra\)` transform to PED using `\(\kappa_0=0, \kappa_1 = 1, \kappa_2=1.5, \kappa_3=3\)` ] </div> <div class = "column" align ="middle"> <img src="tab-ped2.svg", width = "300px" align="middle" > </div> </div> - define: `\(\delta_{ij} = \begin{cases}1 & t_i \in (\kappa_{j-1}, \kappa_j] \wedge \delta_i = 1\\0 & \text{else}\end{cases}\)` --- count: false <div class="row"> <div class = "column", align = "center"> Data in "standard" time-to-event format <br> </div> <div class = "column", align = "center"> Data in PED format <br> </div> </div> <div class = "row" align = "middle"> <div class = "column", align = "middle"> .middle[ <img src="tab-standard.svg", width = "200px" align="middle"><br> `\(\ra\)` transform to PED using `\(\kappa_0=0, \kappa_1 = 1, \kappa_2=1.5, \kappa_3=3\)` ] </div> <div class = "column" align ="middle"> <img src="tab-ped3.svg", width = "300px" align="middle" > </div> </div> - define: `\(\delta_{ij} = \begin{cases}1 & t_i \in (\kappa_{j-1}, \kappa_j] \wedge \delta_i = 1\\0 & \text{else}\end{cases}\)`, `\(t_{ij} = \begin{cases}t_{i}-\kappa_{j-1} & \delta_{ij}=1\\ \kappa_{j}-\kappa_{j-1}& \text{else}\end{cases}\)` --- count: false <div class="row"> <div class = "column", align = "center"> Data in "standard" time-to-event format <br> </div> <div class = "column", align = "center"> Data in PED format <br> </div> </div> <div class = "row" align = "middle"> <div class = "column", align = "middle"> .middle[ <img src="tab-standard.svg", width = "200px" align="middle"><br> `\(\ra\)` transform to PED using `\(\kappa_0=0, \kappa_1 = 1, \kappa_2=1.5, \kappa_3=3\)` ] </div> <div class = "column" align ="middle"> <img src="tab-ped4.svg", width = "300px" align="middle" > </div> </div> - define: `\(\delta_{ij} = \begin{cases}1 & t_i \in (\kappa_{j-1}, \kappa_j] \wedge \delta_i = 1\\0 & \text{else}\end{cases}\)`, `\(t_{ij} = \begin{cases}t_{i}-\kappa_{j-1} & \delta_{ij}=1\\ \kappa_{j}-\kappa_{j-1}& \text{else}\end{cases}\)`, `\(t_j := \kappa_j\)` --- count: false <div class="row"> <div class = "column", align = "center"> Data in "standard" time-to-event format <br> </div> <div class = "column", align = "center"> Data in PED format <br> </div> </div> <div class = "row" align = "middle"> <div class = "column", align = "middle"> .middle[ <img src="tab-standard.svg", width = "200px" align="middle"><br> `\(\ra\)` transform to PED using `\(\kappa_0=0, \kappa_1 = 1, \kappa_2=1.5, \kappa_3=3\)` ] </div> <div class = "column" align ="middle"> <img src="tab-ped.svg", width = "300px" align="middle" > </div> </div> - define: `\(\delta_{ij} = \begin{cases}1 & t_i \in (\kappa_{j-1}, \kappa_j] \wedge \delta_i = 1\\0 & \text{else}\end{cases}\)`, `\(t_{ij} = \begin{cases}t_{i}-\kappa_{j-1} & \delta_{ij}=1\\ \kappa_{j}-\kappa_{j-1}& \text{else}\end{cases}\)`, `\(t_j := \kappa_j\)` .font80[ .pull-left[ .boxed_grey[ General log-likelihood contribution: `$$\begin{align}\ell_i & = \log(\lambda(t_i;\bfx_i)^{\delta_i}S(t_i;\bfx_i))\\ % & = \delta_i\log(\lambda_{iJ_i}) - \sum_{j=1}^{J_i}\lambda_{ij}t_{ij}\\ & = \sum_{j=1}^{J_i}\left(\delta_{ij}\log\lambda_{ij} - \lambda_{ij}t_{ij}\right) \end{align}$$` ] ] .pull-right[ .boxed_grey[ Working Assumption `\(\delta_{ij}\stackrel{iid}{\sim} Po(\mu_{ij} = \lambda_{ij}t_{ij})\)`: `$$\begin{align} \ell_i & = \log\left(\prod_{j=1}^{J_i} f(\delta_{ij})\right)\\ % & = \sum_{j=1}^{J_i} \delta_{ij}\log(\mu_{ij}) - \mu_{ij}\nn\\ & = \sum_{j=1}^{J_i} \delta_{ij}\log(\lambda_{ij}) + \delta_{ij}\log(t_{ij}) - \lambda_{ij}t_{ij} \end{align}$$` ] ] ] --- Consider 3 subjects in competing risks setting with event types `\(k \in \{1,2\}\)` - `\(i= 1\)`: `\((t_1 = 1.3, \delta_1 = 2)\)` - `\(i = 2\)`: `\((t_2 = 0.5, \delta_2 = 0)\)` - `\(i = 3\)`: `\((t_3 = 2.7, \delta_3 = 1)\)` .center[ Data in PED format <br> <img src="tab-ped-cr.svg", width = "410px"> ] `\(\ra\)` estimate `\(\lambda(t|\bfx, k) = \exp(f(\bfx(t),t,k)),\ k \in \{1,2\}\)` --- count: false Consider 3 subjects in competing risks setting with event types `\(k \in \{1,2\}\)` - `\(i= 1\)`: `\((t_1 = 1.3, \delta_1 = 2)\)` - `\(i = 2\)`: `\((t_2 = 0.5, \delta_2 = 0)\)` - `\(i = 3\)`: `\((t_3 = 2.7, \delta_3 = 1)\)` .center[ Data in PED format <br> <img src="tab-cr1.svg", width = "410px"> ] `\(\ra\)` estimate `\(\lambda(t|\bfx, k) = \exp(f(\bfx(t),t,k)),\ k \in \{1,2\}\)` --- count: false Consider 3 subjects in competing risks setting with event types `\(k \in \{1,2\}\)` - `\(i= 1\)`: `\((t_1 = 1.3, \delta_1 = 2)\)` - `\(i = 2\)`: `\((t_2 = 0.5, \delta_2 = 0)\)` - `\(i = 3\)`: `\((t_3 = 2.7, \delta_3 = 1)\)` .center[ Data in PED format <br> <img src="tab-cr2.svg", width = "410px"> ] `\(\ra\)` estimate `\(\lambda(t|\bfx, k) = \exp(f(\bfx(t),t,k)),\ k \in \{1,2\}\)` --- .column[Time-varying effects] .column[Shared vs. cause-specific effects (in CR)] .center[ <img src="splits-grey.svg", width = "800px"> ] --- layout:false class: inverse, middle, center <h1 style="color:#005500;">Experimental Results</h1> <html> <div style='float:left'></div> <hr color='#005500' size=1px width=720px> </html> --- layout: true background-image: url(mcml-background-slide.svg) background-size: cover <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css"> <div class="my-footer"><span><i class="fa fa-twitter">@adibender</i>                            </span></div> --- class: middle We use gradient boosted trees (**GBT**) as computing engine for PEMs (more specifically XGBoost ([Chen and Guestrin, 2016](https://arxiv.org/abs/1603.02754))) and compare them to + Oblique Random Survival Forest (**ORSF**; [Jaeger, Long, Long, et al. (2019)](https://doi.org/10.1214/19-AOAS1261)) + Deep Neural Net based Algorithm **DeepHit** ([Lee, Zame, Yoon, et al., 2018](https://www.aaai.org/ocs/index.php/AAAI/AAAI18/paper/view/16160)) Single Event and competing risks data sets - Standard data sets (directly available) - Synthetic data with time-varying effects (TVE) For each data set - 20 subsamples, each split into train (70%) and test (30%) data - tuning on the training data (random search with fixed budget) - evaluation on test, performance measured by Brier Score at different time-points (25%, 50% and 75% quantiles of event times in the test data) --- **Comparison with ORSF** (single-event, right-censoring) <br> Evaluation w.r.t. Integrated Brier Score .center[ <img src="tab-orsf1.svg", width = "600px"> ] --- **Comparison with DeepHit** (single-event and competing risks, right-censoring) <br> Evaluation w.r.t. weighted Brier Score .center[ <img src="tab-deephit.svg", width = "600px"> ] --- **Choice of interval split points** - Number and placement of interval split points could potentially be a tuning parameter - In our experience setting split points at observed event times results in good performance `\(\ra\)` many split points where many events observed - For large data sets select subset of unique event times for split points .center[ <img src="tab-scalability.svg", width = "400px"> ] --- layout:false class: inverse, middle, center <h1 style="color:#005500;">Concluding Remarks</h1> <html> <div style='float:left'></div> <hr color='#005500' size=1px width=720px> </html> --- layout: true background-image: url(mcml-background-slide.svg) background-size: cover <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css"> <div class="my-footer"><span><i class="fa fa-twitter">@adibender</i>                            </span></div> --- class: middle - **General ML Framework for Survival Analysis** ([Bender, Rügamer, Scheipl, et al., 2020](https://arxiv.org/abs/2006.15442)) + supports many survival task (TVE, TVF, CR, MSM) + does not require specialized software/algorithms - **No assumptions w.r.t. distribution of event times** (Poisson assumption just a computational vehicle) - Framework for **continuous time** survival analysis (exact time enters via offset, prediction of hazards and survival probabilities possible for any time `\(t\)`) --- layout: false ## References <a name=bib-bender_general_2020></a>[Bender, A, D. Rügamer, F. Scheipl, et al.](#cite-bender_general_2020) (2020). "A General Machine Learning Framework for Survival Analysis". In: _arXiv:2006.15442 [cs, stat]_. arXiv: [2006.15442](https://arxiv.org/abs/2006.15442). <a name=bib-chen_xgboost_2016></a>[Chen, T. and C. Guestrin](#cite-chen_xgboost_2016) (2016). "XGBoost: A Scalable Tree Boosting System". In: _Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining - KDD '16_, pp. 785-794. DOI: [10.1145/2939672.2939785](https://doi.org/10.1145%2F2939672.2939785). arXiv: [1603.02754](https://arxiv.org/abs/1603.02754). <a name=bib-friedman_piecewise_1982></a>[Friedman, M.](#cite-friedman_piecewise_1982) (1982). "Piecewise Exponential Models for Survival Data with Covariates". In: _The Annals of Statistics_ 10.1, pp. 101-113. ISSN: 00905364. URL: [http://www.jstor.org/stable/2240502](http://www.jstor.org/stable/2240502). <a name=bib-jaeger_oblique_2019></a>[Jaeger, B. C, D. L. Long, D. M. Long, et al.](#cite-jaeger_oblique_2019) (2019). "Oblique random survival forests". In: _The Annals of Applied Statistics_ 13.3, pp. 1847-1883. ISSN: 1932-6157, 1941-7330. DOI: [10.1214/19-AOAS1261](https://doi.org/10.1214%2F19-AOAS1261). <a name=bib-lee_deephit_2018></a>[Lee, C., W. R. Zame, J. Yoon, et al.](#cite-lee_deephit_2018) (2018). "DeepHit: A Deep Learning Approach to Survival Analysis With Competing Risks". In: _Thirty-Second AAAI Conference on Artificial Intelligence_. Thirty-Second AAAI Conference on Artificial Intelligence. URL: [https://www.aaai.org/ocs/index.php/AAAI/AAAI18/paper/view/16160](https://www.aaai.org/ocs/index.php/AAAI/AAAI18/paper/view/16160).