diff --git a/CHANGELOG.md b/CHANGELOG.md index 3aebf867f..fbc6a95a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ### Changelogs +#### 0.17.4 + - Fix bug in `plot_covariate_groups` that wasn't allowing for strata to be used. + - change name of `multicenter_aids_cohort_study` to `load_multicenter_aids_cohort_study` + #### 0.17.3 - Fix in `compute_residuals` when using `schoenfeld` and the minumum duration has only censored subjects. diff --git a/docs/conf.py b/docs/conf.py index f6b361343..a92ed6e8f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -60,7 +60,7 @@ # # The short X.Y version. -version = "0.17.3" +version = "0.17.4" # The full version, including dev info release = version diff --git a/lifelines/datasets/__init__.py b/lifelines/datasets/__init__.py index f41f39379..b37e47234 100644 --- a/lifelines/datasets/__init__.py +++ b/lifelines/datasets/__init__.py @@ -50,7 +50,7 @@ def load_recur(**kwargs): return load_dataset("recur.csv", **kwargs) -def multicenter_aids_cohort_study(**kwargs): +def load_multicenter_aids_cohort_study(**kwargs): """ Originally in [1] @@ -62,7 +62,7 @@ def multicenter_aids_cohort_study(**kwargs): D: indicator of death during follow up - i AIDSY W T D + i AIDSY W T D 1 1990.425 4.575 7.575 0 2 1991.250 3.750 6.750 0 3 1992.014 2.986 5.986 0 diff --git a/lifelines/fitters/coxph_fitter.py b/lifelines/fitters/coxph_fitter.py index 7eb790cd3..5da2a352d 100644 --- a/lifelines/fitters/coxph_fitter.py +++ b/lifelines/fitters/coxph_fitter.py @@ -31,6 +31,7 @@ from lifelines.utils import ( _get_index, _to_list, + _to_array, survival_table_from_events, inv_normal_cdf, normalize, @@ -1537,7 +1538,7 @@ def plot(self, columns=None, display_significance_code=True, **errorbar_kwargs): return ax - def plot_covariate_groups(self, covariate, groups, **kwargs): + def plot_covariate_groups(self, covariate, values, **kwargs): """ Produces a visual representation comparing the baseline survival curve of the model versus what happens when a covariate is varied over values in a group. This is useful to compare @@ -1548,14 +1549,14 @@ def plot_covariate_groups(self, covariate, groups, **kwargs): ---------- covariate: string a string of the covariate in the original dataset that we wish to vary. - groups: iterable + values: iterable an iterable of the values we wish the covariate to take on. kwargs: pass in additional plotting commands Returns ------- - ax: matplotlib axis + ax: matplotlib axis, or list of axis' the matplotlib axis that be edited. """ from matplotlib import pyplot as plt @@ -1563,15 +1564,34 @@ def plot_covariate_groups(self, covariate, groups, **kwargs): if covariate not in self.hazards_.columns: raise KeyError("covariate `%s` is not present in the original dataset" % covariate) - ax = kwargs.get("ax", None) or plt.figure().add_subplot(111) - x_bar = self._norm_mean.to_frame().T - X = pd.concat([x_bar] * len(groups)) - X.index = ["%s=%s" % (covariate, g) for g in groups] - X[covariate] = groups + if self.strata is None: + axes = kwargs.get("ax", None) or plt.figure().add_subplot(111) + x_bar = self._norm_mean.to_frame().T + X = pd.concat([x_bar] * len(values)) + X.index = ["%s=%s" % (covariate, g) for g in values] + X[covariate] = values - self.predict_survival_function(X).plot(ax=ax) - self.baseline_survival_.plot(ax=ax, ls="--") - return ax + self.predict_survival_function(X).plot(ax=axes) + self.baseline_survival_.plot(ax=axes, ls="--") + + else: + axes = [] + for stratum, baseline_survival_ in self.baseline_survival_.iteritems(): + ax = plt.figure().add_subplot(1, 1, 1) + x_bar = self._norm_mean.to_frame().T + + for name, value in zip(self.strata, _to_array(stratum)): + x_bar[name] = value + + X = pd.concat([x_bar] * len(values)) + X.index = ["%s=%s" % (covariate, g) for g in values] + X[covariate] = values + + self.predict_survival_function(X).plot(ax=ax) + baseline_survival_.plot(ax=ax, ls="--", label="stratum %s baseline survival" % str(stratum)) + plt.legend() + axes.append(ax) + return axes def check_assumptions( self, training_df, advice=True, show_plots=True, p_value_threshold=0.05, plot_n_bootstraps=10 diff --git a/lifelines/plotting.py b/lifelines/plotting.py index 9bd4588de..751d44aec 100644 --- a/lifelines/plotting.py +++ b/lifelines/plotting.py @@ -171,11 +171,11 @@ def add_at_risk_counts(*fitters, **kwargs): def plot_lifetimes( - duration, + durations, event_observed=None, entry=None, left_truncated=False, - sort_by_duration=False, + sort_by_duration=True, event_observed_color="#A60628", event_censored_color="#348ABD", **kwargs @@ -185,7 +185,7 @@ def plot_lifetimes( Parameters ----------- - duration: (n,) numpy array or pd.Series + durations: (n,) numpy array or pd.Series duration subject was observed for. event_observed: (n,) numpy array or pd.Series array of booleans: True if event observed, else False. @@ -209,7 +209,7 @@ def plot_lifetimes( set_kwargs_ax(kwargs) ax = kwargs.pop("ax") - N = duration.shape[0] + N = durations.shape[0] if N > 80: warnings.warn("For less visual clutter, you may want to subsample to less than 80 individuals.") @@ -219,20 +219,23 @@ def plot_lifetimes( if entry is None: entry = np.zeros(N) + assert durations.shape == (N,) + assert event_observed.shape == (N,) + if sort_by_duration: - # order by length of lifetimes; probably not very informative. - ix = np.argsort(duration, 0) - duration = duration[ix] + # order by length of lifetimes; + ix = np.argsort(entry + durations, 0) + durations = durations[ix] event_observed = event_observed[ix] entry = entry[ix] for i in range(N): c = event_observed_color if event_observed[i] else event_censored_color - ax.hlines(N - 1 - i, entry[i], entry[i] + duration[i], color=c, lw=1.5) + ax.hlines(N - 1 - i, entry[i], entry[i] + durations[i], color=c, lw=1.5) if left_truncated: ax.hlines(N - 1 - i, 0, entry[i], color=c, lw=1.0, linestyle="--") m = "" if not event_observed[i] else "o" - ax.scatter(entry[i] + duration[i], N - 1 - i, color=c, marker=m, s=10) + ax.scatter(entry[i] + durations[i], N - 1 - i, color=c, marker=m, s=10) ax.set_ylim(-0.5, N) return ax diff --git a/lifelines/version.py b/lifelines/version.py index f9a4af8fd..90daa1f8c 100644 --- a/lifelines/version.py +++ b/lifelines/version.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals -__version__ = "0.17.3" +__version__ = "0.17.4" diff --git a/tests/test_estimation.py b/tests/test_estimation.py index 7cc9b34e6..4d0ab6496 100644 --- a/tests/test_estimation.py +++ b/tests/test_estimation.py @@ -57,7 +57,7 @@ load_holly_molly_polly, load_regression_dataset, load_stanford_heart_transplants, - multicenter_aids_cohort_study, + load_multicenter_aids_cohort_study, ) from lifelines.generate_datasets import generate_hazard_rates, generate_random_lifetimes @@ -435,7 +435,7 @@ def kaplan_meier(self, lifetimes, observed=None): return km.reshape(len(km), 1) def test_left_truncation_against_Cole_and_Hudgens(self): - df = multicenter_aids_cohort_study() + df = load_multicenter_aids_cohort_study() kmf = KaplanMeierFitter() kmf.fit(df["T"], event_observed=df["D"], entry=df["W"]) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index c1d146639..261f9890e 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -17,6 +17,8 @@ load_lcd, load_panel_test, load_stanford_heart_transplants, + load_rossi, + load_multicenter_aids_cohort_study, ) from lifelines.generate_datasets import cumulative_integral @@ -171,6 +173,22 @@ def test_plot_lifetimes_left_truncation(self, block): self.plt.title("test_plot_lifetimes_left_truncation") self.plt.show(block=block) + def test_MACS_data_with_plot_lifetimes(self, block): + df = load_multicenter_aids_cohort_study() + + plot_lifetimes( + df["T"] - df["W"], + event_observed=df["D"], + entry=df["W"], + event_observed_color="#383838", + event_censored_color="#383838", + left_truncated=True, + ) + self.plt.ylabel("Patient Number") + self.plt.xlabel("Years from AIDS diagnosis") + self.plt.title("test_MACS_data_with_plot_lifetimes") + self.plt.show(block=block) + def test_plot_lifetimes_relative(self, block): t = np.linspace(0, 20, 1000) hz, coef, covrt = generate_hazard_rates(1, 5, t) @@ -279,12 +297,28 @@ def test_coxph_plotting_with_subset_of_columns(self, block): self.plt.title("test_coxph_plotting_with_subset_of_columns") self.plt.show(block=block) + def test_coxph_plot_covariate_groups(self, block): + df = load_rossi() + cp = CoxPHFitter() + cp.fit(df, "week", "arrest") + cp.plot_covariate_groups("age", [10, 50, 80]) + self.plt.title("test_coxph_plot_covariate_groups") + self.plt.show(block=block) + + def test_coxph_plot_covariate_groups_with_strata(self, block): + df = load_rossi() + cp = CoxPHFitter() + cp.fit(df, "week", "arrest", strata=["paro"]) + cp.plot_covariate_groups("age", [10, 50, 80]) + self.plt.title("test_coxph_plot_covariate_groups_with_strata") + self.plt.show(block=block) + def test_coxtv_plotting_with_subset_of_columns(self, block): df = load_stanford_heart_transplants() ctv = CoxTimeVaryingFitter() ctv.fit(df, id_col="id", event_col="event") ctv.plot(columns=["age", "year"]) - self.plt.title("test_coxtv_plotting_with_subset_of_columns_and_standardized") + self.plt.title("test_coxtv_plotting_with_subset_of_columns") self.plt.show(block=block) def test_coxtv_plotting(self, block):