Skip to content

Commit

Permalink
Merge pull request #355 from pymc-labs/summary
Browse files Browse the repository at this point in the history
Enable `summary` method for all currently implemented frequentist experiments
  • Loading branch information
drbenvincent authored Jun 19, 2024
2 parents 4af4af6 + 9fc0798 commit 2916688
Show file tree
Hide file tree
Showing 11 changed files with 338 additions and 86 deletions.
3 changes: 0 additions & 3 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,9 @@ def __init__(
self.post_X = np.asarray(new_x)
self.post_y = np.asarray(new_y)

# DEVIATION FROM SKL EXPERIMENT CODE =============================
# fit the model to the observed (pre-intervention) data
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.pre_X.shape[0])}
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
# ================================================================

# score the goodness of fit to the pre-intervention data
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
Expand Down Expand Up @@ -347,7 +345,6 @@ def summary(self, round_to=None) -> None:

print(f"{self.expt_type:=^80}")
print(f"Formula: {self.formula}")
# TODO: extra experiment specific outputs here
self.print_coefficients(round_to)


Expand Down
72 changes: 55 additions & 17 deletions causalpy/skl_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class ExperimentalDesign:
"""Base class for experiment designs"""

model = None
expt_type = None
outcome_variable_name = None

def __init__(self, model=None, **kwargs):
Expand All @@ -53,6 +54,24 @@ def __init__(self, model=None, **kwargs):
if self.model is None:
raise ValueError("fitting_model not set or passed.")

def print_coefficients(self, round_to=None) -> None:
"""
Prints the model coefficients
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
print("Model coefficients:")
# Determine the width of the longest label
max_label_length = max(len(name) for name in self.labels)
# Print each coefficient with formatted alignment
for name, val in zip(self.labels, self.model.coef_[0]):
# Left-align the name
formatted_name = f"{name:<{max_label_length}}"
# Right-align the value with width 10
formatted_val = f"{round_num(val, round_to):>10}"
print(f" {formatted_name}\t{formatted_val}")


class PrePostFit(ExperimentalDesign, PrePostFitDataValidator):
"""
Expand Down Expand Up @@ -95,6 +114,8 @@ def __init__(
super().__init__(model=model, **kwargs)
self._input_validation(data, treatment_time)
self.treatment_time = treatment_time
# set experiment type - usually done in subclasses
self.expt_type = "Pre-Post Fit"
# split data in to pre and post intervention
self.datapre = data[data.index < self.treatment_time]
self.datapost = data[data.index >= self.treatment_time]
Expand All @@ -103,10 +124,10 @@ def __init__(

# set things up with pre-intervention data
y, X = dmatrices(formula, self.datapre)
self.outcome_variable_name = y.design_info.column_names[0]
self._y_design_info = y.design_info
self._x_design_info = X.design_info
self.labels = X.design_info.column_names
self.outcome_variable_name = y.design_info.column_names[0]
self.pre_y, self.pre_X = np.asarray(y), np.asarray(X)
# process post-intervention data
(new_y, new_x) = build_design_matrices(
Expand Down Expand Up @@ -222,6 +243,18 @@ def plot_coeffs(self):
palette=sns.color_palette("husl"),
)

def summary(self, round_to=None) -> None:
"""
Print text output summarising the results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""

print(f"{self.expt_type:=^80}")
print(f"Formula: {self.formula}")
self.print_coefficients(round_to)


class InterruptedTimeSeries(PrePostFit):
"""
Expand Down Expand Up @@ -253,7 +286,6 @@ class InterruptedTimeSeries(PrePostFit):
... formula="y ~ 1 + t + C(month)",
... model = LinearRegression()
... )
"""

expt_type = "Interrupted Time Series"
Expand Down Expand Up @@ -351,6 +383,7 @@ def __init__(
):
super().__init__(model=model, **kwargs)
self.data = data
self.expt_type = "Difference in Differences"
self.formula = formula
self.time_variable_name = time_variable_name
self.group_variable_name = group_variable_name
Expand Down Expand Up @@ -509,6 +542,20 @@ def plot(self, round_to=None):
ax.legend(fontsize=LEGEND_FONT_SIZE)
return (fig, ax)

def summary(self, round_to=None) -> None:
"""
Print text output summarising the results.
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""

print(f"{self.expt_type:=^80}")
print(f"Formula: {self.formula}")
print("\nResults:")
print(f"Causal impact = {round_num(self.causal_impact[0], round_to)}")
self.print_coefficients(round_to)


class RegressionDiscontinuity(ExperimentalDesign, RDDataValidator):
"""
Expand Down Expand Up @@ -542,17 +589,6 @@ class RegressionDiscontinuity(ExperimentalDesign, RDDataValidator):
... model=LinearRegression(),
... treatment_threshold=0.5,
... )
>>> result.summary() # doctest: +NORMALIZE_WHITESPACE,+NUMBER
Difference in Differences experiment
Formula: y ~ 1 + x + treated
Running variable: x
Threshold on running variable: 0.5
Results:
Discontinuity at threshold = 0.19
Model coefficients:
Intercept 0.0
treated[T.True] 0.19
x 1.23
"""

def __init__(
Expand Down Expand Up @@ -687,16 +723,18 @@ def plot(self, round_to=None):
ax.legend(fontsize=LEGEND_FONT_SIZE)
return (fig, ax)

def summary(self):
def summary(self, round_to=None) -> None:
"""
Print text output summarising the results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
print("Difference in Differences experiment")
print(f"Formula: {self.formula}")
print(f"Running variable: {self.running_variable_name}")
print(f"Threshold on running variable: {self.treatment_threshold}")
print("\nResults:")
print(f"Discontinuity at threshold = {self.discontinuity_at_threshold:.2f}")
print("Model coefficients:")
for name, val in zip(self.labels, self.model.coef_[0]):
print(f"\t{name}\t\t{val}")
print("\n")
self.print_coefficients(round_to)
7 changes: 7 additions & 0 deletions causalpy/tests/test_integration_skl_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_did():
)
assert isinstance(data, pd.DataFrame)
assert isinstance(result, cp.skl_experiments.DifferenceInDifferences)
result.summary()


@pytest.mark.integration
Expand All @@ -68,6 +69,7 @@ def test_rd_drinking():
)
assert isinstance(df, pd.DataFrame)
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
result.summary()


@pytest.mark.integration
Expand All @@ -94,6 +96,7 @@ def test_its():
)
assert isinstance(df, pd.DataFrame)
assert isinstance(result, cp.skl_experiments.SyntheticControl)
result.summary()


@pytest.mark.integration
Expand All @@ -115,6 +118,7 @@ def test_sc():
)
assert isinstance(df, pd.DataFrame)
assert isinstance(result, cp.skl_experiments.SyntheticControl)
result.summary()


@pytest.mark.integration
Expand All @@ -136,6 +140,7 @@ def test_rd_linear_main_effects():
)
assert isinstance(data, pd.DataFrame)
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
result.summary()


@pytest.mark.integration
Expand All @@ -159,6 +164,7 @@ def test_rd_linear_main_effects_bandwidth():
)
assert isinstance(data, pd.DataFrame)
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
result.summary()


@pytest.mark.integration
Expand All @@ -180,6 +186,7 @@ def test_rd_linear_with_interaction():
)
assert isinstance(data, pd.DataFrame)
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
result.summary()


@pytest.mark.integration
Expand Down
Binary file modified docs/source/_static/classes.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions docs/source/_static/interrogate_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/source/_static/packages.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
32 changes: 29 additions & 3 deletions docs/source/notebooks/did_skl.ipynb

Large diffs are not rendered by default.

45 changes: 31 additions & 14 deletions docs/source/notebooks/its_skl.ipynb

Large diffs are not rendered by default.

134 changes: 110 additions & 24 deletions docs/source/notebooks/rd_skl.ipynb

Large diffs are not rendered by default.

22 changes: 8 additions & 14 deletions docs/source/notebooks/rd_skl_drinking.ipynb

Large diffs are not rendered by default.

103 changes: 95 additions & 8 deletions docs/source/notebooks/sc_skl.ipynb

Large diffs are not rendered by default.

0 comments on commit 2916688

Please sign in to comment.