Skip to content

Commit

Permalink
Merge pull request #57 from BiomedSciAI/version-0.9.4
Browse files Browse the repository at this point in the history
* Revert numpy restriction due to mismatch with matplotlib

* Bump version: 0.9.4

* Add covariate specification to models in `WeightedStandardizedSurvival`
Allow specifying different covariate sets to the outcome model and
weight model in `WeightedStandardizedSurvival`.

* Allow no covariates (intercept-only model) in doubly robust models
Current implementation confused an empty list (`[]`) with `None`
(meaning all covariates)

* Fix onehot and augment in NHEFS survival data
Were ignored so far.

* Add tests for onehot, augment, and mutual index
  • Loading branch information
ehudkr authored May 2, 2023
2 parents 98e6a55 + 02a3cdb commit 942fce7
Show file tree
Hide file tree
Showing 8 changed files with 254 additions and 49 deletions.
2 changes: 1 addition & 1 deletion causallib/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.3"
__version__ = "0.9.4"
38 changes: 9 additions & 29 deletions causallib/datasets/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,36 +125,16 @@ def load_nhefs_survival(augment=True, onehot=True):
t (pd.Series): Followup duration, size (num_subjects,).
y (pd.Series): Observed outcome (1) or right censoring event (0), size (num_subjects,).
"""
nhefs_all = load_nhefs(raw=True)[0]
t = (nhefs_all["yrdth"] - 83) * 12 + nhefs_all["modth"]
t = t.fillna(120)
y = nhefs_all["death"]

nhefs_all = load_nhefs(raw=True, augment=augment, onehot=onehot)[0]

nhefs_all['longevity'] = (nhefs_all.yrdth - 83) * 12 + nhefs_all.modth - 1
nhefs_all['longevity'].fillna(120, inplace=True)

# Pre-process data
a = nhefs_all['qsmk']
t = nhefs_all['longevity']
y = nhefs_all['death']
X = nhefs_all[[
"sex", "race", "age",
"active", "education", "exercise",
"smokeintensity", "smokeyrs",
"wt71"
]]

# Add square terms and dummy variables
squares = {}
for col in ['age', 'wt71', 'smokeintensity', 'smokeyrs']:
squares[f'{col}^2'] = X[col] * X[col]
X = X.assign(**squares)
X = pd.get_dummies(
X, columns=["active", "education", "exercise"], drop_first=True
)

# Make timeline 1-index (to comply with some lifelines fitters that require strictly positive time steps)
t = t + 1

data = Bunch(X=X, a=a, t=t, y=y)
nhefs = load_nhefs(augment=augment, onehot=onehot, restrict=False)
a = nhefs.a
X = nhefs.X

data = Bunch(X=X, a=a, t=t, y=y, descriptors=nhefs.descriptors)
return data


Expand Down
10 changes: 8 additions & 2 deletions causallib/estimation/doubly_robust.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,18 @@ def fit(self, X, a, y, refit_weight_model=True, **kwargs):
raise NotImplementedError

def _extract_outcome_model_data(self, X):
outcome_covariates = self.outcome_covariates or X.columns
if self.outcome_covariates is None:
outcome_covariates = X.columns
else:
outcome_covariates = self.outcome_covariates
X_outcome = X[outcome_covariates]
return X_outcome

def _extract_weight_model_data(self, X):
weight_covariates = self.weight_covariates or X.columns
if self.weight_covariates is None:
weight_covariates = X.columns
else:
weight_covariates = self.weight_covariates
X_weight = X[weight_covariates]
return X_weight

Expand Down
73 changes: 58 additions & 15 deletions causallib/survival/weighted_standardized_survival.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,57 @@


class WeightedStandardizedSurvival(StandardizedSurvival):
"""
Combines WeightedSurvival and StandardizedSurvival:
1. Adjusts for treatment assignment by creating weighted pseudo-population (e.g., inverse propensity weighting).
2. Computes parametric curve by fitting a time-varying hazards model that includes baseline covariates.
"""

def __init__(self,
weight_model: WeightEstimator,
survival_model: Any,
stratify: bool = True):
def __init__(
self,
weight_model: WeightEstimator,
survival_model: Any,
stratify: bool = True,
outcome_covariates=None,
weight_covariates=None,
):
"""
Combines WeightedSurvival and StandardizedSurvival:
1. Adjusts for treatment assignment by creating weighted pseudo-population (e.g., inverse propensity weighting).
2. Computes parametric curve by fitting a time-varying hazards model that includes baseline covariates.
Args:
weight_model: causallib compatible weight model (e.g., IPW)
weight_model: causallib compatible weight model (e.g., IPW)
survival_model: Two alternatives:
1. Scikit-Learn estimator (needs to implement `predict_proba`) - compute parametric curve by fitting a
time-varying hazards model that includes baseline covariates. Note that the model is fitted on a
person-time table with all covariates, and might be computationally and memory expansive.
2. lifelines RegressionFitter - use lifelines fitter to compute survival curves from baseline covariates,
events and durations
stratify (bool): if True, fit a separate model per treatment group
outcome_covariates (array): Covariates to use for outcome model.
If None - all covariates passed will be used.
Either list of column names or boolean mask.
weight_covariates (array): Covariates to use for weight model.
If None - all covariates passed will be used.
Either list of column names or boolean mask.
"""
self.weight_model = weight_model
super().__init__(survival_model=survival_model, stratify=stratify)
self.outcome_covariates = outcome_covariates
self.weight_covariates = weight_covariates

def _prepare_data(self, X, *args, **kwargs):
"""
Extract the relevant parts for outcome model and weight model for the entire data matrix
Args:
X (pd.DataFrame): Covariate matrix of size (num_subjects, num_features).
a (pd.Series): Treatment assignment of size (num_subjects,).
Returns:
(pd.DataFrame, pd.DataFrame): X_outcome, X_weight
Data matrix for outcome model and data matrix weight model
"""
outcome_covariates = X.columns if self.outcome_covariates is None else self.outcome_covariates
X_outcome = X[outcome_covariates]
weight_covariates = X.columns if self.weight_covariates is None else self.weight_covariates
X_weight = X[weight_covariates]
return X_outcome, X_weight

def fit(self,
X: pd.DataFrame,
Expand All @@ -55,11 +80,29 @@ def fit(self,
self
"""
a, t, y, _, X = canonize_dtypes_and_names(a=a, t=t, y=y, w=None, X=X)
self.weight_model.fit(X=X, a=a, y=y)
iptw_weights = self.weight_model.compute_weights(X, a)
X_outcome, X_weight = self._prepare_data(X)

self.weight_model.fit(X=X_weight, a=a, y=y)
iptw_weights = self.weight_model.compute_weights(X_weight, a)

# Call fit from StandardizedSurvival, with added ipt weights
super().fit(X=X, a=a, t=t, y=y, w=iptw_weights, fit_kwargs=fit_kwargs)
super().fit(X=X_outcome, a=a, t=t, y=y, w=iptw_weights, fit_kwargs=fit_kwargs)
return self


def estimate_individual_outcome(
self,
X: pd.DataFrame,
a: pd.Series,
t: pd.Series,
y: Optional[Any] = None,
timeline_start: Optional[int] = None,
timeline_end: Optional[int] = None
) -> pd.DataFrame:
X_outcome, _ = self._prepare_data(X)
potential_outcomes = super().estimate_individual_outcome(
X_outcome,
a, t, y,
timeline_start=timeline_start,
timeline_end=timeline_end,
)
return potential_outcomes
90 changes: 89 additions & 1 deletion causallib/tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

import unittest

import pandas as pd
from pandas import DataFrame, Series

from causallib.datasets import load_nhefs, load_acic16
from causallib.datasets import load_nhefs, load_nhefs_survival, load_acic16


class BaseTestDatasets(unittest.TestCase):
Expand Down Expand Up @@ -72,6 +73,93 @@ def test_restrict_parameter(self):
data = load_nhefs(restrict=False)
self.assertTrue(data.y.isnull().any())

def test_augment(self):
data_aug = load_nhefs(augment=True, onehot=False).X
self.assertTrue("age" in data_aug.columns)
self.assertTrue("age^2" in data_aug.columns)
pd.testing.assert_series_equal(
data_aug["age"]**2, data_aug["age^2"],
check_names=False
)

data = load_nhefs(augment=False, onehot=False).X
self.assertGreater(data_aug.shape[1], data.shape[1])
self.assertEqual(data_aug.shape[0], data.shape[0])

def test_onehot(self):
data_aug = load_nhefs(augment=False, onehot=True).X
self.assertTrue("active_1" in data_aug.columns)
self.assertTrue("active_2" in data_aug.columns)
self.assertTrue("active_0" not in data_aug.columns)
self.assertTrue("active" not in data_aug.columns)

data = load_nhefs(augment=False, onehot=False).X
self.assertGreater(data_aug.shape[1], data.shape[1])
self.assertEqual(data_aug.shape[0], data.shape[0])

self.assertSetEqual(set(data_aug["active_1"]), {0, 1})
self.assertSetEqual(set(data_aug["active_2"]), {0, 1})
self.assertSetEqual(set(data["active"]), {0, 1, 2})

def test_index(self):
data = load_nhefs()
pd.testing.assert_index_equal(data.X.index, data.a.index)
pd.testing.assert_index_equal(data.X.index, data.y.index)

joined = pd.concat(
[data.X, data.a, data.y],
axis="columns", join="outer",
)
pd.testing.assert_index_equal(data.y.index, joined.index)


class TestSmokingSurvival(BaseTestDatasets):
def test_return_types(self):
self.ensure_return_types(load_nhefs_survival)
data = load_nhefs_survival()
self.assertTrue(hasattr(data, "y"))
self.assertIsInstance(data.t, pd.Series)

def test_augment(self):
data_aug = load_nhefs_survival(augment=True, onehot=False).X
self.assertTrue("age" in data_aug.columns)
self.assertTrue("age^2" in data_aug.columns)
pd.testing.assert_series_equal(
data_aug["age"]**2, data_aug["age^2"],
check_names=False
)

data = load_nhefs_survival(augment=False, onehot=False).X
self.assertGreater(data_aug.shape[1], data.shape[1])
self.assertEqual(data_aug.shape[0], data.shape[0])

def test_onehot(self):
data_aug = load_nhefs_survival(augment=False, onehot=True).X
self.assertTrue("active_1" in data_aug.columns)
self.assertTrue("active_2" in data_aug.columns)
self.assertTrue("active_0" not in data_aug.columns)
self.assertTrue("active" not in data_aug.columns)

data = load_nhefs_survival(augment=False, onehot=False).X
self.assertGreater(data_aug.shape[1], data.shape[1])
self.assertEqual(data_aug.shape[0], data.shape[0])

self.assertSetEqual(set(data_aug["active_1"]), {0, 1})
self.assertSetEqual(set(data_aug["active_2"]), {0, 1})
self.assertSetEqual(set(data["active"]), {0, 1, 2})

def test_index(self):
data = load_nhefs_survival()
pd.testing.assert_index_equal(data.X.index, data.a.index)
pd.testing.assert_index_equal(data.X.index, data.t.index)
pd.testing.assert_index_equal(data.X.index, data.y.index)

joined = pd.concat(
[data.X, data.a, data.t, data.y],
axis="columns", join="outer",
)
pd.testing.assert_index_equal(data.y.index, joined.index)


class TestACIC16(BaseTestDatasets):
def test_return_types(self):
Expand Down
25 changes: 25 additions & 0 deletions causallib/tests/test_doublyrobust.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,21 @@ def ensure_data_is_separated_between_models(self, estimator, n_added_outcome_mod
len(estimator.outcome_covariates) + n_added_outcome_model_features)
self.assertEqual(estimator.weight_model.learner.coef_.size, len(estimator.weight_covariates))

def ensure_fit_with_no_outcome_covariates(self, estimator, n_added_outcome_model_features):
data = self.create_uninformative_ox_dataset()
# Reinitialize estimator:
estimator = estimator.__class__(estimator.outcome_model, estimator.weight_model,
outcome_covariates=[], weight_covariates=None)
estimator.fit(data["X"], data["a"], data["y"])
self.assertEqual(
estimator.outcome_model.learner.coef_.size,
n_added_outcome_model_features
)
self.assertEqual(
estimator.weight_model.learner.coef_.size,
data["X"].shape[1]
)

def ensure_weight_refitting_refits(self, estimator):
data = self.create_uninformative_ox_dataset()
with self.subTest("Test first fit of weight_model did fit the model"):
Expand Down Expand Up @@ -243,6 +258,9 @@ def test_is_fitted(self):
def test_data_is_separated_between_models(self):
self.ensure_data_is_separated_between_models(self.estimator, 1) # 1 treatment assignment feature

def test_fit_with_no_outcome_covariates(self):
self.ensure_fit_with_no_outcome_covariates(self.estimator, 1)

def test_weight_refitting_refits(self):
self.ensure_weight_refitting_refits(self.estimator)

Expand Down Expand Up @@ -343,6 +361,10 @@ def test_is_fitted(self):
def test_data_is_separated_between_models(self):
self.ensure_data_is_separated_between_models(self.estimator, 1) # 1 treatment assignment feature

def test_fit_with_no_outcome_covariates(self):
# Basically an Marginal Structural Model (MSM)
self.ensure_fit_with_no_outcome_covariates(self.estimator, 1)

def test_weight_refitting_refits(self):
self.ensure_weight_refitting_refits(self.estimator)

Expand Down Expand Up @@ -453,6 +475,9 @@ def test_is_fitted(self):
def test_data_is_separated_between_models(self):
self.ensure_data_is_separated_between_models(self.estimator, 1 + 1) # 1 ip-feature + 1 treatment assignment

def test_fit_with_no_outcome_covariates(self):
self.ensure_fit_with_no_outcome_covariates(self.estimator, 1 + 1)

def test_weight_refitting_refits(self):
self.ensure_weight_refitting_refits(self.estimator)

Expand Down
Loading

0 comments on commit 942fce7

Please sign in to comment.