Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BetaGeoBetaBinomModel #1031

Merged
merged 12 commits into from
Sep 13, 2024
6,498 changes: 6,072 additions & 426 deletions docs/source/notebooks/clv/dev/beta_geo_beta_binom.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pymc_marketing/clv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""CLV models and utilities."""

from pymc_marketing.clv.models import (
BetaGeoBetaBinomModel,
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
BetaGeoModel,
GammaGammaModel,
GammaGammaModelIndividual,
Expand All @@ -34,6 +35,7 @@

__all__ = (
"BetaGeoModel",
"BetaGeoBetaBinomModel",
"ParetoNBDModel",
"GammaGammaModel",
"GammaGammaModelIndividual",
Expand Down
8 changes: 1 addition & 7 deletions pymc_marketing/clv/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,23 +601,17 @@ def logp(value, alpha, beta, gamma, delta, T):
"""Log-likelihood of the distribution."""
t_x = pt.atleast_1d(value[..., 0])
x = pt.atleast_1d(value[..., 1])
scalar_case = t_x.type.broadcastable == (True,)

for param in (t_x, x, alpha, beta, gamma, delta, T):
if param.type.ndim > 1:
raise NotImplementedError(
f"BetaGeoBetaBinom logp only implemented for vector parameters, got ndim={param.type.ndim}"
)
if scalar_case:
if param.type.broadcastable == (False,):
raise NotImplementedError(
f"Parameter {param} cannot be larger than scalar value"
)

# Broadcast all the parameters so they are sequences.
# Potentially inefficient, but otherwise ugly logic needed to unpack arguments in the scan function,
# since sequences always precede non-sequences.
_, alpha, beta, gamma, delta, T = pt.broadcast_arrays(
t_x, alpha, beta, gamma, delta, T = pt.broadcast_arrays(
t_x, alpha, beta, gamma, delta, T
)

Expand Down
2 changes: 2 additions & 0 deletions pymc_marketing/clv/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from pymc_marketing.clv.models.basic import CLVModel
from pymc_marketing.clv.models.beta_geo import BetaGeoModel
from pymc_marketing.clv.models.beta_geo_beta_binom import BetaGeoBetaBinomModel
from pymc_marketing.clv.models.gamma_gamma import (
GammaGammaModel,
GammaGammaModelIndividual,
Expand All @@ -25,6 +26,7 @@

__all__ = (
"CLVModel",
"BetaGeoBetaBinomModel",
"GammaGammaModel",
"GammaGammaModelIndividual",
"BetaGeoModel",
Expand Down
6 changes: 6 additions & 0 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def _validate_cols(
data: pd.DataFrame,
required_cols: Sequence[str],
must_be_unique: Sequence[str] = (),
must_be_homogenous: Sequence[str] = (),
):
existing_columns = set(data.columns)
n = data.shape[0]
Expand All @@ -71,6 +72,11 @@ def _validate_cols(
if required_col in must_be_unique:
if data[required_col].nunique() != n:
raise ValueError(f"Column {required_col} has duplicate entries")
if required_col in must_be_homogenous:
if data[required_col].nunique() != 1:
raise ValueError(
f"Column {required_col} has non-homogeneous entries"
)

def __repr__(self) -> str:
"""Representation of the model."""
Expand Down
Loading
Loading