diff --git a/examples/debaised_bs.py b/examples/debaised_bs.py index 5b3b19a..5ebe881 100644 --- a/examples/debaised_bs.py +++ b/examples/debaised_bs.py @@ -218,22 +218,25 @@ def plot_censoring_survival_proba( g_hat_marginal = estimator_marginal.compute_censoring_survival_proba(time_grid) estimator_conditional = IPCWCoxEstimator().fit(y_censored, X=X) - g_hat_conditional = estimator_conditional.compute_censoring_survival_proba( - time_grid, - X=X, - ) sampler = IPCWSampler( shape=shape_censoring, scale=scale_censoring, ).fit(y_censored) - g_star = [] + g_hat_conditional, g_star = [], [] for time_step in time_grid: time_step = np.full(y_censored.shape[0], fill_value=time_step) + g_star_ = sampler.compute_censoring_survival_proba(time_step) g_star.append(g_star_.mean()) + g_hat_conditional_ = estimator_conditional.compute_censoring_survival_proba( + time_step, + X=X, + ) + g_hat_conditional.append(g_hat_conditional_.mean()) + fig, ax = plt.subplots(figsize=(8, 4)) ax.plot(time_grid, g_hat_marginal, label="$\hat{G}$ with KM") ax.plot(time_grid, g_hat_conditional, label="$\hat{G}$ with Cox") @@ -271,19 +274,25 @@ def plot_ipcw(X, y_uncensored, y_censored, shape_censoring, scale_censoring, kin ipcw_pred_marginal = estimator_marginal.compute_ipcw_at(time_grid) estimator_conditional = IPCWCoxEstimator().fit(y_censored, X=X) - ipcw_pred_conditional = estimator_conditional.compute_ipcw_at(time_grid, X=X) sampler = IPCWSampler( shape=shape_censoring, scale=scale_censoring, ).fit(y_censored) - ipcw_sampled = [] + ipcw_sampled, ipcw_pred_conditional = [], [] for time_step in time_grid: time_step = np.full(y_censored.shape[0], fill_value=time_step) + ipcw_sampled_ = sampler.compute_ipcw_at(time_step) ipcw_sampled.append(ipcw_sampled_.mean()) + ipcw_pred_conditional_ = estimator_conditional.compute_ipcw_at( + time_step, + X=X, + ) + ipcw_pred_conditional.append(ipcw_pred_conditional_.mean()) + fig, ax = plt.subplots(figsize=(8, 4)) ax.plot(time_grid, ipcw_sampled, label="$1/G^*$") ax.plot(time_grid, ipcw_pred_marginal, label="$1/\hat{G}$, using KM") diff --git a/hazardous/_ipcw.py b/hazardous/_ipcw.py index ca6c92d..c455782 100644 --- a/hazardous/_ipcw.py +++ b/hazardous/_ipcw.py @@ -15,9 +15,9 @@ class BaseIPCW(BaseEstimator): def fit(self, y, X=None): del X event, duration = check_y_survival(y) + censoring = event == 0 km = KaplanMeierFitter() - censoring = event == 0 km.fit( durations=duration, event_observed=censoring, @@ -43,13 +43,13 @@ def compute_ipcw_at(self, times, X=None): X : pandas.DataFrame of shape (n_samples, n_features) The input data for conditional estimators. - times : np.ndarray of shape (n_times,) - The input times for which to predict the IPCW. + times : np.ndarray of shape (n_samples,) + The input times for which to predict the IPCW for each sample. Returns ------- - ipcw : np.ndarray of shape (n_times,) - The IPCW for times + ipcw : np.ndarray of shape (n_samples,) + The IPCW for each sample at each sample time. """ check_is_fitted(self, "min_censoring_prob_") @@ -172,12 +172,31 @@ def compute_censoring_survival_proba(self, times, X=None): class IPCWCoxEstimator(BaseIPCW): - def __init__(self, transformer=None, cox_estimator=None): + def __init__( + self, + transformer=None, + cox_estimator=None, + n_time_grid_steps=100, + ): self.transformer = transformer self.cox_estimator = cox_estimator + self.n_time_grid_steps = n_time_grid_steps def fit(self, y, X=None): - """TODO""" + """Fit a nonlinear transformer and a CoxPHFitter estimator. + + Parameters + ---------- + y : pandas.DataFrame of shape (n_samples, 2) + The target, with 'event' and 'duration' columns. + + X : pandas.DataFrame of shape (n_samples, n_features) + The covariates to be transformed and fitted. + + Returns + ------- + self : fitted instance of IPCWCoxEstimator + """ super().fit(y) self.check_transformer_estimator() @@ -185,23 +204,55 @@ def fit(self, y, X=None): frame = X_trans.copy() frame["duration"] = y["duration"] - frame["event"] = y["event"] == 0 + frame["censoring"] = y["event"] == 0 # XXX: This could be integrated in the pipeline by using the scikit-learn # interface. - self.cox_estimator_.fit(frame, event_col="event", duration_col="duration") + self.cox_estimator_.fit(frame, event_col="censoring", duration_col="duration") return self def compute_censoring_survival_proba(self, times, X=None): - """TODO""" + """Compute the conditional censoring survival probability. + + A time grid is used to reduce the size of the prediction matrix + returned by lifelines CoxPHFitter. Each sample matches a time of + evaluation in the array 'times', so for each sample we only return + the closest bin to the time of evaluation. + + Parameters + ---------- + times : ndarray of shape (n_samples,) + Each time step corresponds to a single sample to be evaluated. + + X : pandas.DataFrame of shape (n_samples, n_features) + Covariate used by the conditional estimator. + + Returns + ------- + cs_prob : ndarray of shape (n_samples,) + The censoring survival probability of each sample evaluated + at their corresponding time step. + """ X_trans = self.transformer_.transform(X) - # shape (n_time_steps, n_samples) - cs_prob = self.cox_estimator_.predict_survival_function(X_trans, times=times) + unique_times = np.sort(np.unique(times)) + if len(unique_times) > self.n_time_grid_steps: + t_min, t_max = min(self.unique_times_), max(self.unique_times_) + time_grid = np.linspace(t_min, t_max, self.n_time_grid_steps) + else: + time_grid = unique_times + + # shape (n_unique_times, n_samples) + cs_prob_all_t = self.cox_estimator_.predict_survival_function( + X_trans, + times=time_grid, + ) + + indices = np.searchsorted(time_grid, times) + cs_prob = cs_prob_all_t.to_numpy()[indices, np.arange(X.shape[0])] - # shape (n_time_steps,) - return cs_prob.mean(axis=1).to_numpy() + return cs_prob def check_transformer_estimator(self): if self.transformer is None: