Skip to content

Commit

Permalink
add example
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincent-Maladiere committed Jan 7, 2025
1 parent ef8f923 commit 5b59271
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 17 deletions.
186 changes: 186 additions & 0 deletions examples/plot_03_competing_risks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""
==============================
Exploring the accuracy in time
==============================
In this notebook, we showcase how the accuracy in time metric behaves, and how
to interpret it.
"""
# %%
# We begin by generating a linear, synthetic dataset. For each individual, we uniformly
# sample a shape and scale value, which we use to parameterize a Weibull distribution,
# from which we sample a duration.
from hazardous.data import make_synthetic_competing_weibull
from sklearn.model_selection import train_test_split


X, y = make_synthetic_competing_weibull(n_events=3, n_samples=10_000, return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)

X_train.shape, y_train.shape

# %%
# Next, we display the distribution of our target.
import seaborn as sns
from matplotlib import pyplot as plt


sns.histplot(
y_test,
x="duration",
hue="event",
multiple="stack",
palette="colorblind",
)

# %%
# We train a Survival Boost model and compute its accuracy in time.
import numpy as np
from hazardous import SurvivalBoost
from hazardous.metrics import accuracy_in_time


results = []

time_grid = np.arange(0, 4000, 100)
surv = SurvivalBoost(show_progressbar=False).fit(X_train, y_train)
y_pred = surv.predict_cumulative_incidence(X_test, times=time_grid)

quantiles = np.linspace(0.125, 1, 16)
accuracy, taus = accuracy_in_time(y_test, y_pred, time_grid, quantiles=quantiles)
results.append(dict(model_name="Survival Boost", accuracy=accuracy, taus=taus))

# %%
# We also compute the accuracy in time of the Aalen-Johansen estimator, which is
# a marginal model (it doesn't use covariates X), similar to the Kaplan-Meier estimator,
# except that it computes cumulative incidence functions of competing risks instead
# of a survival function.
from scipy.interpolate import interp1d
from lifelines import AalenJohansenFitter
from hazardous.utils import check_y_survival


def predict_aalen_johansen(y_train, time_grid, n_sample_test):
event, duration = check_y_survival(y_train)
event_ids = sorted(set(event) - set([0]))

y_pred = []
for event_id in event_ids:
aj = AalenJohansenFitter(calculate_variance=False).fit(
durations=duration,
event_observed=event,
event_of_interest=event_id,
)
cif = aj.cumulative_density_
y_pred_ = interp1d(
x=cif.index,
y=cif[cif.columns[0]],
kind="linear",
fill_value="extrapolate",
)(time_grid)

y_pred.append(
# shape: (n_sample_test, 1, n_time_steps)
np.tile(y_pred_, (n_sample_test, 1, 1))
)

y_survival = (1 - np.sum(np.concatenate(y_pred, axis=1), axis=1))[:, None, :]
y_pred.insert(0, y_survival)

return np.concatenate(y_pred, axis=1)


y_pred_aj = predict_aalen_johansen(y_train, time_grid, n_sample_test=X_test.shape[0])

accuracy, taus = accuracy_in_time(y_test, y_pred_aj, time_grid, quantiles=quantiles)
results.append(dict(model_name="Aalan-Johansen", accuracy=accuracy, taus=taus))

# %%
# We display the accuracy in time to compare Survival Boost with Aalen-Johansen.
# Higher is better. Note that the accuracy is high at very beginning (t < 1000), because
# both models predict that every individual survive.
# Then, beyond the time horizon 1000, the discriminative power of the conditional
# Survival Boost yields a better accuracy than the marginal, unbiased, Aalen-Johansen.
import pandas as pd


fig, ax = plt.subplots(figsize=(6, 3), dpi=300)

results = pd.DataFrame(results).explode(column=["accuracy", "taus"])

sns.lineplot(
results,
x="taus",
y="accuracy",
hue="model_name",
ax=ax,
legend=False,
)

sns.scatterplot(
results,
x="taus",
y="accuracy",
hue="model_name",
ax=ax,
s=50,
zorder=100,
style="model_name",
)


# %%
# We can drill into this metric by counting the observed events cumulatively across
# time, and compare that to predictions.
#
# We display below the distribution of ground truth labels. Each color bar group
# represents the event distribution at some given horizon.
# Almost no individual have experienced an event at the very beginning.
# Then, as time passes by, events occur and the number of censored individual at each
# time horizon shrinks. Therefore, the very last distribution represents the overall
# event distribution of the dataset.
def plot_event_in_time(y_in_time):
event_in_times = []
for event_id in range(4):
event_in_times.append(
dict(
event_count=(y_in_time == event_id).sum(axis=0),
time_grid=time_grid,
event=event_id,
)
)

event_in_times = pd.DataFrame(event_in_times).explode(["event_count", "time_grid"])

ax = sns.barplot(
event_in_times,
x="time_grid",
y="event_count",
hue="event",
palette="colorblind",
)

ax.set_xticks(ax.get_xticks()[::10])


time_grid_2d = np.tile(time_grid, (y_test.shape[0], 1))
y_test_class = (y_test["duration"].values[:, None] <= time_grid_2d) * y_test[
"event"
].values[:, None]
plot_event_in_time(y_test_class)
# %%
# Now, we compare this ground truth to the classes predicted by our Survival Boost
# model. Interestingly, it seems too confident about the censoring event at the
# beginning (t < 500), but then becomes underconfident in the middle (t > 1500) and
# very overconfident about the class 3 in the end (t > 3000).

y_pred_class = y_pred.argmax(axis=1)
plot_event_in_time(y_pred_class)

# %%
# Finally, we compare this to the classes predicted by the Aalen-Johansen model.
# They are constant in individuals because this model is marginal and we simply
# duplicated the global cumulative incidences for each individual.
y_pred_class_aj = y_pred_aj.argmax(axis=1)
plot_event_in_time(y_pred_class_aj)
# %%
9 changes: 4 additions & 5 deletions hazardous/_survival_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def fit(self, X):


class SurvivalBoost(BaseEstimator, ClassifierMixin):
r"""Cause-specific Cumulative Incidence Function (CIF) with GBDT [1]_.
r"""Cause-specific Cumulative Incidence Function (CIF) with GBDT [Alberge2024]_.
This model estimates the cause-specific Cumulative Incidence Function (CIF) for
each event of interest, as well as the survival function for any event, using a
Expand Down Expand Up @@ -297,10 +297,9 @@ class SurvivalBoost(BaseEstimator, ClassifierMixin):
References
----------
.. [1] J. Alberge, V. Maladière, O. Grisel, J. Abécassis, G. Varoquaux,
"Teaching Models To Survive: Proper Scoring Rule and Stochastic Optimization
with Competing Risks", 2024.
https://arxiv.org/pdf/2406.14085
.. [Alberge2024] J. Alberge, V. Maladiere, O. Grisel, J. Abécassis, G. Varoquaux,
"Survival Models: Proper Scoring Rule and Stochastic Optimization
with Competing Risks", 2024
Examples
--------
Expand Down
30 changes: 18 additions & 12 deletions hazardous/metrics/_accuracy_in_time.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import numpy as np

from hazardous.utils import check_y_survival
from ..utils import check_y_survival


def accuracy_in_time(y_test, y_pred, time_grid, quantiles=None, taus=None):
r"""Accuracy in time for prognostic models using competing risks.
.. math::
\mathrm{acc}(\zeta) = \frac{1}{n_{nc}} \sum_{i=1}^n I\{\hat{y}_i\=y_{i,\zeta}\}
\mathrm{acc}(\zeta) = \frac{1}{n_{nc}} \sum_{i=1}^n I\{\hat{y}_i=y_{i,\zeta}\}
\overline{I\{\delta_i = 0 \cap t_i \leq \zeta \}}
where:
Expand All @@ -18,24 +18,24 @@ def accuracy_in_time(y_test, y_pred, time_grid, quantiles=None, taus=None):
- :math:`\delta_i` is the event experienced by the individual :math:`i` at
:math:`t_i`
- :math:`\hat{y} = \text{arg}\max\limits_{k \in [1, K]} \hat{F}_k(\zeta|X=x_i)` is
the most probable event for individual :math:`i` at :math:`\zeta`
the most probable predicted event for individual :math:`i` at :math:`\zeta`
- :math:`y_{i,\zeta} = \delta_i I\{t_i \leq \zeta \}` is the observed event
for individual :math:`i` at :math:`\zeta`
The accuracy in time is a metrics introduced in [Alberge2024]_ which evaluates
whether observed events are predicted as the most likely at given times.
It is defined as the probability that the maximum predicted cumulative incidence
function (CIF) accross :math:`k` events corresponds to the observed event at a
fixed time horizon :math:`zeta`.
fixed time horizon :math:`\zeta`.
We remove individuals that were censored at times :math:`t \leq \zeta`, so the
accuracy in time essentially represents the accuracy of the estimator up to
:math:`zeta`.
accuracy in time essentially represents the accuracy of the estimator on
observed events up to :math:`\zeta`.
While the C-index can help clinicians to priorize treatment allocation by ranking
individuals by risk of a given event of interest, the accuracy in time answers
a different question: `what is the most likely event that this individual will
experience at some fixed time horizon?`. Conceptually, it helps clinicians choose
a different question: "`what is the most likely event that this individual will
experience at some fixed time horizon?`". Conceptually, it helps clinicians choose
the right treatment by priorizing the risk for a given individual.
Parameters
Expand Down Expand Up @@ -68,6 +68,12 @@ def accuracy_in_time(y_test, y_pred, time_grid, quantiles=None, taus=None):
taus : array of shape (n_quantiles or n_taus)
The fixed time horizons effectively used to compute the accuracy in time.
References
----------
.. [Alberge2024] J. Alberge, V. Maladiere, O. Grisel, J. Abécassis, G. Varoquaux,
"Survival Models: Proper Scoring Rule and Stochastic Optimization
with Competing Risks", 2024
"""
event_true, _ = check_y_survival(y_test)

Expand Down Expand Up @@ -109,17 +115,17 @@ def accuracy_in_time(y_test, y_pred, time_grid, quantiles=None, taus=None):

tau_idx = np.searchsorted(time_grid, tau)

# If tau is beyond the time_grid, we extrapolate its accuracy as the accuracy at
# max(time_grid).
# If tau is beyond the time_grid, we extrapolate its accuracy as
# the accuracy at max(time_grid).
if tau_idx == time_grid.shape[0]:
tau_idx = -1

y_pred_at_t = y_pred[:, :, tau_idx]
y_pred_class = y_pred_at_t[~mask_past_censored, :].argmax(axis=1)

y_test_class = y_test["event"] * (y_test["duration"] <= tau)
y_test_class = y_test_class.loc[~mask_past_censored]
y_test_class = y_test_class.loc[~mask_past_censored].values

acc_in_time.append((y_test_class.values == y_pred_class).mean())
acc_in_time.append((y_test_class == y_pred_class).mean())

return np.array(acc_in_time), np.asarray(taus)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ doc = [
"sphinx-design",
"sphinx-copybutton",
"matplotlib",
"seaborn",
"pillow", # to scrape images from the examples
"numpydoc",
"pycox",
Expand Down

0 comments on commit 5b59271

Please sign in to comment.