Skip to content

Commit

Permalink
V0.17.4 (#616)
Browse files Browse the repository at this point in the history
* v0.17.4

* lint

* bump version

* lint
  • Loading branch information
CamDavidsonPilon authored Jan 25, 2019
1 parent 47f5ede commit 89f1264
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 27 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
### Changelogs

#### 0.17.4
- Fix bug in `plot_covariate_groups` that wasn't allowing for strata to be used.
- change name of `multicenter_aids_cohort_study` to `load_multicenter_aids_cohort_study`

#### 0.17.3
- Fix in `compute_residuals` when using `schoenfeld` and the minumum duration has only censored subjects.

Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
#
# The short X.Y version.

version = "0.17.3"
version = "0.17.4"
# The full version, including dev info
release = version

Expand Down
4 changes: 2 additions & 2 deletions lifelines/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def load_recur(**kwargs):
return load_dataset("recur.csv", **kwargs)


def multicenter_aids_cohort_study(**kwargs):
def load_multicenter_aids_cohort_study(**kwargs):
"""
Originally in [1]
Expand All @@ -62,7 +62,7 @@ def multicenter_aids_cohort_study(**kwargs):
D: indicator of death during follow up
i AIDSY W T D
i AIDSY W T D
1 1990.425 4.575 7.575 0
2 1991.250 3.750 6.750 0
3 1992.014 2.986 5.986 0
Expand Down
42 changes: 31 additions & 11 deletions lifelines/fitters/coxph_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from lifelines.utils import (
_get_index,
_to_list,
_to_array,
survival_table_from_events,
inv_normal_cdf,
normalize,
Expand Down Expand Up @@ -1537,7 +1538,7 @@ def plot(self, columns=None, display_significance_code=True, **errorbar_kwargs):

return ax

def plot_covariate_groups(self, covariate, groups, **kwargs):
def plot_covariate_groups(self, covariate, values, **kwargs):
"""
Produces a visual representation comparing the baseline survival curve of the model versus
what happens when a covariate is varied over values in a group. This is useful to compare
Expand All @@ -1548,30 +1549,49 @@ def plot_covariate_groups(self, covariate, groups, **kwargs):
----------
covariate: string
a string of the covariate in the original dataset that we wish to vary.
groups: iterable
values: iterable
an iterable of the values we wish the covariate to take on.
kwargs:
pass in additional plotting commands
Returns
-------
ax: matplotlib axis
ax: matplotlib axis, or list of axis'
the matplotlib axis that be edited.
"""
from matplotlib import pyplot as plt

if covariate not in self.hazards_.columns:
raise KeyError("covariate `%s` is not present in the original dataset" % covariate)

ax = kwargs.get("ax", None) or plt.figure().add_subplot(111)
x_bar = self._norm_mean.to_frame().T
X = pd.concat([x_bar] * len(groups))
X.index = ["%s=%s" % (covariate, g) for g in groups]
X[covariate] = groups
if self.strata is None:
axes = kwargs.get("ax", None) or plt.figure().add_subplot(111)
x_bar = self._norm_mean.to_frame().T
X = pd.concat([x_bar] * len(values))
X.index = ["%s=%s" % (covariate, g) for g in values]
X[covariate] = values

self.predict_survival_function(X).plot(ax=ax)
self.baseline_survival_.plot(ax=ax, ls="--")
return ax
self.predict_survival_function(X).plot(ax=axes)
self.baseline_survival_.plot(ax=axes, ls="--")

else:
axes = []
for stratum, baseline_survival_ in self.baseline_survival_.iteritems():
ax = plt.figure().add_subplot(1, 1, 1)
x_bar = self._norm_mean.to_frame().T

for name, value in zip(self.strata, _to_array(stratum)):
x_bar[name] = value

X = pd.concat([x_bar] * len(values))
X.index = ["%s=%s" % (covariate, g) for g in values]
X[covariate] = values

self.predict_survival_function(X).plot(ax=ax)
baseline_survival_.plot(ax=ax, ls="--", label="stratum %s baseline survival" % str(stratum))
plt.legend()
axes.append(ax)
return axes

def check_assumptions(
self, training_df, advice=True, show_plots=True, p_value_threshold=0.05, plot_n_bootstraps=10
Expand Down
21 changes: 12 additions & 9 deletions lifelines/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,11 @@ def add_at_risk_counts(*fitters, **kwargs):


def plot_lifetimes(
duration,
durations,
event_observed=None,
entry=None,
left_truncated=False,
sort_by_duration=False,
sort_by_duration=True,
event_observed_color="#A60628",
event_censored_color="#348ABD",
**kwargs
Expand All @@ -185,7 +185,7 @@ def plot_lifetimes(
Parameters
-----------
duration: (n,) numpy array or pd.Series
durations: (n,) numpy array or pd.Series
duration subject was observed for.
event_observed: (n,) numpy array or pd.Series
array of booleans: True if event observed, else False.
Expand All @@ -209,7 +209,7 @@ def plot_lifetimes(
set_kwargs_ax(kwargs)
ax = kwargs.pop("ax")

N = duration.shape[0]
N = durations.shape[0]
if N > 80:
warnings.warn("For less visual clutter, you may want to subsample to less than 80 individuals.")

Expand All @@ -219,20 +219,23 @@ def plot_lifetimes(
if entry is None:
entry = np.zeros(N)

assert durations.shape == (N,)
assert event_observed.shape == (N,)

if sort_by_duration:
# order by length of lifetimes; probably not very informative.
ix = np.argsort(duration, 0)
duration = duration[ix]
# order by length of lifetimes;
ix = np.argsort(entry + durations, 0)
durations = durations[ix]
event_observed = event_observed[ix]
entry = entry[ix]

for i in range(N):
c = event_observed_color if event_observed[i] else event_censored_color
ax.hlines(N - 1 - i, entry[i], entry[i] + duration[i], color=c, lw=1.5)
ax.hlines(N - 1 - i, entry[i], entry[i] + durations[i], color=c, lw=1.5)
if left_truncated:
ax.hlines(N - 1 - i, 0, entry[i], color=c, lw=1.0, linestyle="--")
m = "" if not event_observed[i] else "o"
ax.scatter(entry[i] + duration[i], N - 1 - i, color=c, marker=m, s=10)
ax.scatter(entry[i] + durations[i], N - 1 - i, color=c, marker=m, s=10)

ax.set_ylim(-0.5, N)
return ax
Expand Down
2 changes: 1 addition & 1 deletion lifelines/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals

__version__ = "0.17.3"
__version__ = "0.17.4"
4 changes: 2 additions & 2 deletions tests/test_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
load_holly_molly_polly,
load_regression_dataset,
load_stanford_heart_transplants,
multicenter_aids_cohort_study,
load_multicenter_aids_cohort_study,
)
from lifelines.generate_datasets import generate_hazard_rates, generate_random_lifetimes

Expand Down Expand Up @@ -435,7 +435,7 @@ def kaplan_meier(self, lifetimes, observed=None):
return km.reshape(len(km), 1)

def test_left_truncation_against_Cole_and_Hudgens(self):
df = multicenter_aids_cohort_study()
df = load_multicenter_aids_cohort_study()
kmf = KaplanMeierFitter()
kmf.fit(df["T"], event_observed=df["D"], entry=df["W"])

Expand Down
36 changes: 35 additions & 1 deletion tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
load_lcd,
load_panel_test,
load_stanford_heart_transplants,
load_rossi,
load_multicenter_aids_cohort_study,
)
from lifelines.generate_datasets import cumulative_integral

Expand Down Expand Up @@ -171,6 +173,22 @@ def test_plot_lifetimes_left_truncation(self, block):
self.plt.title("test_plot_lifetimes_left_truncation")
self.plt.show(block=block)

def test_MACS_data_with_plot_lifetimes(self, block):
df = load_multicenter_aids_cohort_study()

plot_lifetimes(
df["T"] - df["W"],
event_observed=df["D"],
entry=df["W"],
event_observed_color="#383838",
event_censored_color="#383838",
left_truncated=True,
)
self.plt.ylabel("Patient Number")
self.plt.xlabel("Years from AIDS diagnosis")
self.plt.title("test_MACS_data_with_plot_lifetimes")
self.plt.show(block=block)

def test_plot_lifetimes_relative(self, block):
t = np.linspace(0, 20, 1000)
hz, coef, covrt = generate_hazard_rates(1, 5, t)
Expand Down Expand Up @@ -279,12 +297,28 @@ def test_coxph_plotting_with_subset_of_columns(self, block):
self.plt.title("test_coxph_plotting_with_subset_of_columns")
self.plt.show(block=block)

def test_coxph_plot_covariate_groups(self, block):
df = load_rossi()
cp = CoxPHFitter()
cp.fit(df, "week", "arrest")
cp.plot_covariate_groups("age", [10, 50, 80])
self.plt.title("test_coxph_plot_covariate_groups")
self.plt.show(block=block)

def test_coxph_plot_covariate_groups_with_strata(self, block):
df = load_rossi()
cp = CoxPHFitter()
cp.fit(df, "week", "arrest", strata=["paro"])
cp.plot_covariate_groups("age", [10, 50, 80])
self.plt.title("test_coxph_plot_covariate_groups_with_strata")
self.plt.show(block=block)

def test_coxtv_plotting_with_subset_of_columns(self, block):
df = load_stanford_heart_transplants()
ctv = CoxTimeVaryingFitter()
ctv.fit(df, id_col="id", event_col="event")
ctv.plot(columns=["age", "year"])
self.plt.title("test_coxtv_plotting_with_subset_of_columns_and_standardized")
self.plt.title("test_coxtv_plotting_with_subset_of_columns")
self.plt.show(block=block)

def test_coxtv_plotting(self, block):
Expand Down

0 comments on commit 89f1264

Please sign in to comment.