Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Cox IPCW Implementation #28

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions examples/debaised_bs.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,22 +218,25 @@ def plot_censoring_survival_proba(
g_hat_marginal = estimator_marginal.compute_censoring_survival_proba(time_grid)

estimator_conditional = IPCWCoxEstimator().fit(y_censored, X=X)
g_hat_conditional = estimator_conditional.compute_censoring_survival_proba(
time_grid,
X=X,
)

sampler = IPCWSampler(
shape=shape_censoring,
scale=scale_censoring,
).fit(y_censored)

g_star = []
g_hat_conditional, g_star = [], []
for time_step in time_grid:
time_step = np.full(y_censored.shape[0], fill_value=time_step)

g_star_ = sampler.compute_censoring_survival_proba(time_step)
g_star.append(g_star_.mean())

g_hat_conditional_ = estimator_conditional.compute_censoring_survival_proba(
time_step,
X=X,
)
g_hat_conditional.append(g_hat_conditional_.mean())

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(time_grid, g_hat_marginal, label="$\hat{G}$ with KM")
ax.plot(time_grid, g_hat_conditional, label="$\hat{G}$ with Cox")
Expand Down Expand Up @@ -271,19 +274,25 @@ def plot_ipcw(X, y_uncensored, y_censored, shape_censoring, scale_censoring, kin
ipcw_pred_marginal = estimator_marginal.compute_ipcw_at(time_grid)

estimator_conditional = IPCWCoxEstimator().fit(y_censored, X=X)
ipcw_pred_conditional = estimator_conditional.compute_ipcw_at(time_grid, X=X)

sampler = IPCWSampler(
shape=shape_censoring,
scale=scale_censoring,
).fit(y_censored)

ipcw_sampled = []
ipcw_sampled, ipcw_pred_conditional = [], []
for time_step in time_grid:
time_step = np.full(y_censored.shape[0], fill_value=time_step)

ipcw_sampled_ = sampler.compute_ipcw_at(time_step)
ipcw_sampled.append(ipcw_sampled_.mean())

ipcw_pred_conditional_ = estimator_conditional.compute_ipcw_at(
time_step,
X=X,
)
ipcw_pred_conditional.append(ipcw_pred_conditional_.mean())

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(time_grid, ipcw_sampled, label="$1/G^*$")
ax.plot(time_grid, ipcw_pred_marginal, label="$1/\hat{G}$, using KM")
Expand Down
79 changes: 65 additions & 14 deletions hazardous/_ipcw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ class BaseIPCW(BaseEstimator):
def fit(self, y, X=None):
del X
event, duration = check_y_survival(y)
censoring = event == 0

km = KaplanMeierFitter()
censoring = event == 0
km.fit(
durations=duration,
event_observed=censoring,
Expand All @@ -43,13 +43,13 @@ def compute_ipcw_at(self, times, X=None):
X : pandas.DataFrame of shape (n_samples, n_features)
The input data for conditional estimators.

times : np.ndarray of shape (n_times,)
The input times for which to predict the IPCW.
times : np.ndarray of shape (n_samples,)
The input times for which to predict the IPCW for each sample.

Returns
-------
ipcw : np.ndarray of shape (n_times,)
The IPCW for times
ipcw : np.ndarray of shape (n_samples,)
The IPCW for each sample at each sample time.
"""
check_is_fitted(self, "min_censoring_prob_")

Expand Down Expand Up @@ -172,36 +172,87 @@ def compute_censoring_survival_proba(self, times, X=None):


class IPCWCoxEstimator(BaseIPCW):
def __init__(self, transformer=None, cox_estimator=None):
def __init__(
self,
transformer=None,
cox_estimator=None,
n_time_grid_steps=100,
):
self.transformer = transformer
self.cox_estimator = cox_estimator
self.n_time_grid_steps = n_time_grid_steps

def fit(self, y, X=None):
"""TODO"""
"""Fit a nonlinear transformer and a CoxPHFitter estimator.

Parameters
----------
y : pandas.DataFrame of shape (n_samples, 2)
The target, with 'event' and 'duration' columns.

X : pandas.DataFrame of shape (n_samples, n_features)
The covariates to be transformed and fitted.

Returns
-------
self : fitted instance of IPCWCoxEstimator
"""
super().fit(y)
self.check_transformer_estimator()

X_trans = self.transformer_.fit_transform(X)

frame = X_trans.copy()
frame["duration"] = y["duration"]
frame["event"] = y["event"] == 0
frame["censoring"] = y["event"] == 0

# XXX: This could be integrated in the pipeline by using the scikit-learn
# interface.
self.cox_estimator_.fit(frame, event_col="event", duration_col="duration")
self.cox_estimator_.fit(frame, event_col="censoring", duration_col="duration")

return self

def compute_censoring_survival_proba(self, times, X=None):
"""TODO"""
"""Compute the conditional censoring survival probability.

A time grid is used to reduce the size of the prediction matrix
returned by lifelines CoxPHFitter. Each sample matches a time of
evaluation in the array 'times', so for each sample we only return
the closest bin to the time of evaluation.

Parameters
----------
times : ndarray of shape (n_samples,)
Each time step corresponds to a single sample to be evaluated.

X : pandas.DataFrame of shape (n_samples, n_features)
Covariate used by the conditional estimator.

Returns
-------
cs_prob : ndarray of shape (n_samples,)
The censoring survival probability of each sample evaluated
at their corresponding time step.
"""
X_trans = self.transformer_.transform(X)

# shape (n_time_steps, n_samples)
cs_prob = self.cox_estimator_.predict_survival_function(X_trans, times=times)
unique_times = np.sort(np.unique(times))
if len(unique_times) > self.n_time_grid_steps:
t_min, t_max = min(self.unique_times_), max(self.unique_times_)
time_grid = np.linspace(t_min, t_max, self.n_time_grid_steps)
else:
time_grid = unique_times

# shape (n_unique_times, n_samples)
cs_prob_all_t = self.cox_estimator_.predict_survival_function(
X_trans,
times=time_grid,
)

indices = np.searchsorted(time_grid, times)
cs_prob = cs_prob_all_t.to_numpy()[indices, np.arange(X.shape[0])]

# shape (n_time_steps,)
return cs_prob.mean(axis=1).to_numpy()
return cs_prob

def check_transformer_estimator(self):
if self.transformer is None:
Expand Down