From dfaab15567e97c4c280228697a53500238c3db33 Mon Sep 17 00:00:00 2001 From: Shahar Bar <33932594+shaharbar1@users.noreply.github.com> Date: Mon, 23 Sep 2024 17:11:46 +0300 Subject: [PATCH] cMAB Fast Update via Variational Inference (#48) ### Changes * Edited BaseBayesianLogisticRegression and inheritors on model.py to support variational inference by adding fast_inference control parameter on class attributes and adding control arguments on update method. * Edited BaseBayesianLogisticRegression to allow faster update via vectorization of PyMC operations. * Edited "update" UTs on test_cmab.py to support new inference mode. * Edited cMABs cold start function tto support new inference mode. * Removed redundant test_execution_time.py. * Edited version on pyproject.toml. --- pybandits/base.py | 8 + pybandits/cmab.py | 63 +++++-- pybandits/model.py | 191 ++++++++++++++------- pyproject.toml | 2 +- tests/test_base.py | 6 +- tests/test_cmab.py | 186 ++++++++++++++++----- tests/test_execution_time.py | 316 ----------------------------------- 7 files changed, 339 insertions(+), 433 deletions(-) delete mode 100644 tests/test_execution_time.py diff --git a/pybandits/base.py b/pybandits/base.py index 97e42d8..e6a94b2 100644 --- a/pybandits/base.py +++ b/pybandits/base.py @@ -104,8 +104,16 @@ class BaseMab(PyBanditsBaseModel, ABC): @field_validator("actions", mode="before") @classmethod def at_least_2_actions_are_defined(cls, v): + # validate that at least 2 actions are defined if len(v) < 2: raise AttributeError("At least 2 actions should be defined.") + # validate that all actions are of the same configuration + action_models = list(v.values()) + first_action = action_models[0] + first_action_type = type(first_action) + if any(not isinstance(action, first_action_type) for action in action_models[1:]): + raise AttributeError("All actions should follow the same type.") + return v @model_validator(mode="after") diff --git a/pybandits/cmab.py b/pybandits/cmab.py index d26a9b9..fae03c3 100644 --- a/pybandits/cmab.py +++ b/pybandits/cmab.py @@ -32,6 +32,7 @@ BaseBayesianLogisticRegression, BayesianLogisticRegression, BayesianLogisticRegressionCC, + UpdateMethods, create_bayesian_logistic_regression_cc_cold_start, create_bayesian_logistic_regression_cold_start, ) @@ -63,13 +64,21 @@ class BaseCmabBernoulli(BaseMab): predict_with_proba: bool predict_actions_randomly: bool - @field_validator("actions") - def check_bayesian_logistic_regression_models_len(cls, v): - blr_betas_len = [len(b.betas) for b in v.values()] - if not all(blr_betas_len[0] == x for x in blr_betas_len): - raise AttributeError( - f"All bayesian logistic regression models must have the same n_betas. Models betas_len={blr_betas_len}." - ) + @field_validator("actions", mode="after") + @classmethod + def check_bayesian_logistic_regression_models(cls, v): + action_models = list(v.values()) + first_action = action_models[0] + first_action_type = type(first_action) + for action in action_models[1:]: + if not isinstance(action, first_action_type): + raise AttributeError("All actions should follow the same type.") + if not len(action.betas) == len(first_action.betas): + raise AttributeError("All actions should have the same number of betas.") + if not action.update_method == first_action.update_method: + raise AttributeError("All actions should have the same update method.") + if not action.update_kwargs == first_action.update_kwargs: + raise AttributeError("All actions should have the same update kwargs.") return v @validate_call(config=dict(arbitrary_types_allowed=True)) @@ -329,6 +338,8 @@ def create_cmab_bernoulli_cold_start( n_features: PositiveInt, epsilon: Optional[Float01] = None, default_action: Optional[ActionId] = None, + update_method: UpdateMethods = "MCMC", + update_kwargs: Optional[dict] = None, ) -> CmabBernoulli: """ Utility function to create a Contextual Bernoulli Multi-Armed Bandit with Thompson Sampling, with default @@ -347,6 +358,12 @@ def create_cmab_bernoulli_cold_start( default_action: Optional[ActionId] The default action to select with a probability of epsilon when using the epsilon-greedy approach. If `default_action` is None, a random action from the action set will be selected with a probability of epsilon. + update_method: UpdateMethods, defaults to MCMC + The strategy for computing posterior quantities of the Bayesian models in the update function. Such as Markov + chain Monte Carlo ("MCMC") or Variational Inference ("VI"). Check UpdateMethods in pybandits.model for the + full list. + update_kwargs : Optional[dict], uses default values if not specified + Additional arguments to pass to the update method of each of the action models. Returns ------- @@ -354,8 +371,10 @@ def create_cmab_bernoulli_cold_start( Contextual Multi-Armed Bandit with strategy = ClassicBandit """ actions = {} - for a in set(action_ids): - actions[a] = create_bayesian_logistic_regression_cold_start(n_betas=n_features) + for action_id in set(action_ids): + actions[action_id] = create_bayesian_logistic_regression_cold_start( + n_betas=n_features, update_method=update_method, update_kwargs=update_kwargs + ) mab = CmabBernoulli(actions=actions, epsilon=epsilon, default_action=default_action) mab.predict_actions_randomly = True return mab @@ -368,6 +387,8 @@ def create_cmab_bernoulli_bai_cold_start( exploit_p: Optional[Float01] = None, epsilon: Optional[Float01] = None, default_action: Optional[ActionId] = None, + update_method: UpdateMethods = "MCMC", + update_kwargs: Optional[dict] = None, ) -> CmabBernoulliBAI: """ Utility function to create a Contextual Bernoulli Multi-Armed Bandit with Thompson Sampling, and Best Action @@ -395,6 +416,12 @@ def create_cmab_bernoulli_bai_cold_start( default_action: Optional[ActionId] The default action to select with a probability of epsilon when using the epsilon-greedy approach. If `default_action` is None, a random action from the action set will be selected with a probability of epsilon. + update_method: UpdateMethods, defaults to MCMC + The strategy for computing posterior quantities of the Bayesian models in the update function. Such as Markov + chain Monte Carlo ("MCMC") or Variational Inference ("VI"). Check UpdateMethods in pybandits.model for the + full list. + update_kwargs : Optional[dict], uses default values if not specified + Additional arguments to pass to the update method of each of the action models. Returns ------- @@ -403,7 +430,11 @@ def create_cmab_bernoulli_bai_cold_start( """ actions = {} for a in set(action_ids): - actions[a] = create_bayesian_logistic_regression_cold_start(n_betas=n_features) + actions[a] = create_bayesian_logistic_regression_cold_start( + n_betas=n_features, + update_method=update_method, + update_kwargs=update_kwargs, + ) mab = CmabBernoulliBAI(actions=actions, exploit_p=exploit_p, epsilon=epsilon, default_action=default_action) mab.predict_actions_randomly = True return mab @@ -416,6 +447,8 @@ def create_cmab_bernoulli_cc_cold_start( subsidy_factor: Optional[Float01] = None, epsilon: Optional[Float01] = None, default_action: Optional[ActionId] = None, + update_method: UpdateMethods = "MCMC", + update_kwargs: Optional[dict] = None, ) -> CmabBernoulliCC: """ Utility function to create a Stochastic Bernoulli Multi-Armed Bandit with Thompson Sampling, and Cost Control @@ -449,6 +482,12 @@ def create_cmab_bernoulli_cc_cold_start( default_action: Optional[ActionId] The default action to select with a probability of epsilon when using the epsilon-greedy approach. If `default_action` is None, a random action from the action set will be selected with a probability of epsilon. + update_method: UpdateMethods, defaults to MCMC + The strategy for computing posterior quantities of the Bayesian models in the update function. Such as Markov + chain Monte Carlo ("MCMC") or Variational Inference ("VI"). Check UpdateMethods in pybandits.model for the + full list. + update_kwargs : Optional[dict], uses default values if not specified + Additional arguments to pass to the update method. Returns ------- @@ -457,7 +496,9 @@ def create_cmab_bernoulli_cc_cold_start( """ actions = {} for a, cost in action_ids_cost.items(): - actions[a] = create_bayesian_logistic_regression_cc_cold_start(n_betas=n_features, cost=cost) + actions[a] = create_bayesian_logistic_regression_cc_cold_start( + n_betas=n_features, cost=cost, update_method=update_method, update_kwargs=update_kwargs + ) mab = CmabBernoulliCC( actions=actions, subsidy_factor=subsidy_factor, epsilon=epsilon, default_action=default_action ) diff --git a/pybandits/model.py b/pybandits/model.py index c94ba1f..ac09473 100644 --- a/pybandits/model.py +++ b/pybandits/model.py @@ -22,9 +22,11 @@ from random import betavariate -from typing import List, Tuple +from typing import List, Literal, Optional, Tuple, Union -from numpy import array, c_, exp, insert, mean, multiply, ones, sqrt, std +import numpy as np +import pymc.math as pmath +from numpy import array, c_, insert, mean, multiply, ones, sqrt, std from numpy.typing import ArrayLike from pydantic import ( Field, @@ -34,15 +36,16 @@ model_validator, validate_call, ) -from pymc import Bernoulli, Data, Deterministic, sample +from pymc import Bernoulli, Data, Deterministic, fit, sample from pymc import Model as PymcModel from pymc import StudentT as PymcStudentT -from pymc.math import sigmoid -from pytensor.tensor import dot +from pytensor.tensor import TensorVariable, dot from scipy.stats import t from pybandits.base import BinaryReward, Model, Probability, PyBanditsBaseModel +UpdateMethods = Literal["MCMC", "VI"] + class BaseBeta(Model): """ @@ -231,16 +234,66 @@ class BaseBayesianLogisticRegression(Model): Parameters ---------- - alpha: StudentT + alpha : StudentT Student's t-distribution of the alpha coefficient. - betas: StudentT + betas : StudentT Student's t-distributions of the betas coefficients. - params_sample: Dict - Parameters for the function pymc.sample() + update_method : UpdateMethods, defaults to "MCMC" + The strategy for computing posterior quantities of the Bayesian models in the update function. Such as Markov + chain Monte Carlo ("MCMC") or Variational Inference ("VI"). Check UpdateMethods in pybandits.model for the + full list. + update_kwargs : Optional[dict], uses default values if not specified + Additional arguments to pass to the update method. """ alpha: StudentT betas: List[StudentT] = Field(..., min_items=1) + update_method: UpdateMethods = "MCMC" + update_kwargs: Optional[dict] = None + _default_update_kwargs = dict(draws=1000, progressbar=False, return_inferencedata=False) + _default_mcmc_kwargs = dict( + tune=500, + draws=1000, + chains=2, + init="adapt_diag", + cores=1, + target_accept=0.95, + progressbar=False, + return_inferencedata=False, + ) + _default_variational_inference_kwargs = dict(method="advi") + + @model_validator(mode="after") + def arrange_update_kwargs(self): + if self.update_kwargs is None: + self.update_kwargs = self._default_update_kwargs + if self.update_method == "VI": + self.update_kwargs = {**self._default_variational_inference_kwargs, **self.update_kwargs} + elif self.update_method == "MCMC": + self.update_kwargs = {**self._default_mcmc_kwargs, **self.update_kwargs} + else: + raise ValueError("Invalid update method.") + return self + + @classmethod + def _stable_sigmoid(cls, x: Union[np.ndarray, TensorVariable]) -> Union[np.ndarray, TensorVariable]: + """ + Vectorized sigmoid function that avoids overflow and underflow. + Compatible with both numpy and PyMC3 tensors. + + Parameters + ---------- + x : Union[np.ndarray, TensorVariable] + Input values. + + Returns + ------- + prob : Union[np.ndarray, TensorVariable] + Sigmoid function applied to the input values. + """ + backend = np if isinstance(x, np.ndarray) else pmath + prob = backend.where(x >= 0, 1 / (1 + backend.exp(-x)), backend.exp(x) / (1 + backend.exp(x))) + return prob @validate_call(config=dict(arbitrary_types_allowed=True)) def check_context_matrix(self, context: ArrayLike): @@ -249,12 +302,12 @@ def check_context_matrix(self, context: ArrayLike): Parameters ---------- - context: ArrayLike of shape (n_samples, n_features) + context : ArrayLike of shape (n_samples, n_features) Matrix of contextual features. Returns ------- - context: pandas DataFrame of shape (n_samples, n_features) + context : pandas DataFrame of shape (n_samples, n_features) Matrix of contextual features. """ try: @@ -304,25 +357,12 @@ def sample_proba(self, context: ArrayLike) -> Tuple[Probability, float]: weighted_sum = multiply(context_ext, coeff.T).sum(axis=1) # compute the probability with the sigmoid function - prob = 1.0 / (1.0 + exp(-weighted_sum)) + prob = self._stable_sigmoid(weighted_sum) return prob, weighted_sum @validate_call(config=dict(arbitrary_types_allowed=True)) - def update( - self, - context: ArrayLike, - rewards: List[BinaryReward], - tune=500, - draws=1000, - chains=2, - init="adapt_diag", - cores=2, - target_accept=0.95, - progressbar=False, - return_inferencedata=False, - **kwargs, - ): + def update(self, context: ArrayLike, rewards: List[BinaryReward]): """ Update the model parameters. @@ -344,40 +384,41 @@ def update( # if model was never updated priors_parameters = default arguments # else priors_parameters are calculated from traces of the previous update alpha = PymcStudentT("alpha", mu=self.alpha.mu, sigma=self.alpha.sigma, nu=self.alpha.nu) - betas = [ - PymcStudentT("beta" + str(i), mu=self.betas[i].mu, sigma=self.betas[i].sigma, nu=self.betas[i].nu) - for i in range(len(self.betas)) - ] + beta_mu = [b.mu for b in self.betas] + beta_sigma = [b.sigma for b in self.betas] + beta_nu = [b.nu for b in self.betas] + betas = PymcStudentT("betas", mu=beta_mu, sigma=beta_sigma, nu=beta_nu, shape=len(self.betas)) - context = Data("context", context) - rewards = Data("rewards", rewards) + context = Data("context", context, mutable=False) + rewards = Data("rewards", rewards, mutable=False) # Likelihood (sampling distribution) of observations weighted_sum = Deterministic("weighted_sum", alpha + dot(betas, context.T)) - p = Deterministic("p", sigmoid(weighted_sum)) + p = Deterministic("p", self._stable_sigmoid(weighted_sum)) # Bernoulli random vector with probability of success given by sigmoid function and actual data as observed _ = Bernoulli("likelihood", p=p, observed=rewards) # update traces object by sampling from posterior distribution - trace = sample( - tune=tune, - draws=draws, - chains=chains, - init=init, - cores=cores, - target_accept=target_accept, - progressbar=progressbar, - return_inferencedata=return_inferencedata, - **kwargs, - ) + if self.update_method == "VI": + # variational inference + update_kwargs = self.update_kwargs.copy() + approx = fit(method=update_kwargs.pop("method")) + trace = approx.sample(**update_kwargs) + elif self.update_method == "MCMC": + # MCMC + trace = sample(**self.update_kwargs) + else: + raise ValueError("Invalid update method.") # compute mean and std of the coefficients distributions self.alpha.mu = mean(trace["alpha"]) self.alpha.sigma = std(trace["alpha"], ddof=1) - for i in range(len(self.betas)): - self.betas[i].mu = mean(trace["beta" + str(i)]) - self.betas[i].sigma = std(trace["beta" + str(i)], ddof=1) + betas_mu = mean(trace["betas"], axis=0) + betas_std = std(trace["betas"], axis=0, ddof=1) + self.betas = [ + StudentT(mu=mu, sigma=sigma, nu=beta.nu) for mu, sigma, beta in zip(betas_mu, betas_std, self.betas) + ] class BayesianLogisticRegression(BaseBayesianLogisticRegression): @@ -392,12 +433,16 @@ class BayesianLogisticRegression(BaseBayesianLogisticRegression): Parameters ---------- - alpha: StudentT + alpha : StudentT Student's t-distribution of the alpha coefficient. - betas: StudentT + betas : StudentT Student's t-distributions of the betas coefficients. - params_sample: Dict - Parameters for the function pymc.sample() + update_method : UpdateMethods, defaults to "MCMC" + The strategy for computing posterior quantities of the Bayesian models in the update function. Such as Markov + chain Monte Carlo ("MCMC") or Variational Inference ("VI"). Check UpdateMethods in pybandits.model for the + full list. + update_kwargs: Optional[dict], uses default values if not specified + Additional arguments to pass to the update method. """ @@ -417,8 +462,12 @@ class BayesianLogisticRegressionCC(BaseBayesianLogisticRegression): Student's t-distribution of the alpha coefficient. betas: StudentT Student's t-distributions of the betas coefficients. - params_sample: Dict - Parameters for the function pymc.sample() + update_method : UpdateMethods, defaults to "MCMC" + The strategy for computing posterior quantities of the Bayesian models in the update function. Such as Markov + chain Monte Carlo ("MCMC") or Variational Inference ("VI"). Check UpdateMethods in pybandits.model for the + full list. + update_kwargs : Optional[dict], uses default values if not specified + Additional arguments to pass to the update method. cost: NonNegativeFloat Cost associated to the Bayesian Logistic Regression model. """ @@ -426,7 +475,9 @@ class BayesianLogisticRegressionCC(BaseBayesianLogisticRegression): cost: NonNegativeFloat -def create_bayesian_logistic_regression_cold_start(n_betas: PositiveInt) -> BayesianLogisticRegression: +def create_bayesian_logistic_regression_cold_start( + n_betas: PositiveInt, update_method: UpdateMethods = "MCMC", update_kwargs: Optional[dict] = None +) -> BayesianLogisticRegression: """ Utility function to create a Bayesian Logistic Regression model, with default parameters. @@ -441,17 +492,31 @@ def create_bayesian_logistic_regression_cold_start(n_betas: PositiveInt) -> Baye n_betas : PositiveInt The number of betas of the Bayesian Logistic Regression model. This is also the number of features expected after in the context matrix. + update_method : UpdateMethods, defaults to "MCMC" + The strategy for computing posterior quantities of the Bayesian models in the update function. Such as Markov + chain Monte Carlo ("MCMC") or Variational Inference ("VI"). Check UpdateMethods in pybandits.model for the + full list. + update_kwargs : Optional[dict], uses default values if not specified + Additional arguments to pass to the update method. Returns ------- blr: BayesianLogisticRegression The Bayesian Logistic Regression model. """ - return BayesianLogisticRegression(alpha=StudentT(), betas=[StudentT() for _ in range(n_betas)]) + return BayesianLogisticRegression( + alpha=StudentT(), + betas=[StudentT() for _ in range(n_betas)], + update_method=update_method, + update_kwargs=update_kwargs, + ) def create_bayesian_logistic_regression_cc_cold_start( - n_betas: PositiveInt, cost: NonNegativeFloat + n_betas: PositiveInt, + cost: NonNegativeFloat, + update_method: UpdateMethods = "MCMC", + update_kwargs: Optional[dict] = None, ) -> BayesianLogisticRegressionCC: """ Utility function to create a Bayesian Logistic Regression model with cost control, with default parameters. @@ -469,10 +534,22 @@ def create_bayesian_logistic_regression_cc_cold_start( after in the context matrix. cost: NonNegativeFloat Cost associated to the Bayesian Logistic Regression model. + update_method : UpdateMethods, defaults to "MCMC" + The strategy for computing posterior quantities of the Bayesian models in the update function. Such as Markov + chain Monte Carlo ("MCMC") or Variational Inference ("VI"). Check UpdateMethods in pybandits.model for the + full list. + update_kwargs : Optional[dict], uses default values if not specified + Additional arguments to pass to the update method. Returns ------- blr: BayesianLogisticRegressionCC The Bayesian Logistic Regression model. """ - return BayesianLogisticRegressionCC(alpha=StudentT(), betas=[StudentT() for _ in range(n_betas)], cost=cost) + return BayesianLogisticRegressionCC( + alpha=StudentT(), + betas=[StudentT() for _ in range(n_betas)], + cost=cost, + update_method=update_method, + update_kwargs=update_kwargs, + ) diff --git a/pyproject.toml b/pyproject.toml index 565164e..009cbdc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pybandits" -version = "0.5.1" +version = "0.6.0" description = "Python Multi-Armed Bandit Library" authors = [ "Dario d'Andrea ", diff --git a/tests/test_base.py b/tests/test_base.py index dbec460..fa3db3b 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -30,7 +30,7 @@ from pytest_mock import MockerFixture from pybandits.base import ActionId, BaseMab, Float01, Probability -from pybandits.model import Beta +from pybandits.model import Beta, BetaCC from pybandits.strategy import ClassicBandit @@ -55,7 +55,7 @@ def get_state(self) -> (str, dict): return model_name, state -def test_base_mab_raise_on_less_than_2_actions(): +def test_base_mab_raise_on_less_than_2_actions(cost=0): with pytest.raises(ValidationError): DummyMab(actions={"a1": Beta(), "a2": Beta()}) with pytest.raises(ValidationError): @@ -68,6 +68,8 @@ def test_base_mab_raise_on_less_than_2_actions(): DummyMab(actions={"a1": None, "a2": None}, strategy=ClassicBandit()) with pytest.raises(AttributeError): DummyMab(actions={"a1": Beta()}, strategy=ClassicBandit()) + with pytest.raises(AttributeError): + DummyMab(actions={"a1": Beta(), "a2": BetaCC(cost=cost)}, strategy=ClassicBandit()) def test_base_mab_check_update_params(): diff --git a/tests/test_cmab.py b/tests/test_cmab.py index 5fe15e4..a992afa 100644 --- a/tests/test_cmab.py +++ b/tests/test_cmab.py @@ -19,6 +19,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +from typing import get_args import numpy as np import pandas as pd @@ -40,6 +41,7 @@ BayesianLogisticRegression, BayesianLogisticRegressionCC, StudentT, + UpdateMethods, create_bayesian_logistic_regression_cc_cold_start, create_bayesian_logistic_regression_cold_start, ) @@ -51,6 +53,14 @@ from pybandits.utils import to_serializable_dict from tests.test_utils import is_serializable +literal_update_methods = get_args(UpdateMethods) + + +def _apply_update_method_to_state(state, update_method): + for action in state["actions"]: + state["actions"][action]["update_method"] = update_method + + ######################################################################################################################## @@ -126,40 +136,78 @@ def test_cmab_can_instantiate(n_features): @settings(deadline=500) -@given(st.integers(min_value=1, max_value=10), st.integers(min_value=1, max_value=10)) -def test_cmab_init_with_wrong_blr_models(a, b): - # all blr models must have the same n_betas. If not raise a ValueError. - if a != b: - with pytest.raises(AttributeError): - CmabBernoulli( - actions={ - "a1": create_bayesian_logistic_regression_cold_start(n_betas=a), - "a2": create_bayesian_logistic_regression_cold_start(n_betas=a), - "a3": create_bayesian_logistic_regression_cold_start(n_betas=b), - } - ) - else: +@given( + st.integers(min_value=1, max_value=5), + st.integers(min_value=6, max_value=10), + st.integers(min_value=0, max_value=1), + st.just("draws"), + st.just(2), +) +def test_cmab_init_with_wrong_blr_models( + first_n_betas, second_n_betas, first_update_method_index, kwarg_to_alter, factor +): + with pytest.raises(AttributeError): CmabBernoulli( actions={ - "a1": create_bayesian_logistic_regression_cold_start(n_betas=a), - "a2": create_bayesian_logistic_regression_cold_start(n_betas=b), - "a3": create_bayesian_logistic_regression_cold_start(n_betas=b), + "a1": create_bayesian_logistic_regression_cold_start(n_betas=first_n_betas), + "a2": create_bayesian_logistic_regression_cold_start(n_betas=first_n_betas), + "a3": create_bayesian_logistic_regression_cold_start(n_betas=second_n_betas), + } + ) + first_update_method = literal_update_methods[first_update_method_index] + second_update_method = literal_update_methods[1 - first_update_method_index] + with pytest.raises(AttributeError): + CmabBernoulli( + actions={ + "a1": create_bayesian_logistic_regression_cold_start( + n_betas=first_n_betas, update_method=first_update_method + ), + "a2": create_bayesian_logistic_regression_cold_start( + n_betas=first_n_betas, update_method=second_update_method + ), + } + ) + first_model = create_bayesian_logistic_regression_cold_start( + n_betas=first_n_betas, update_method=first_update_method + ) + altered_kwarg = first_model.update_kwargs[kwarg_to_alter] // factor + with pytest.raises(AttributeError): + CmabBernoulli( + actions={ + "a1": first_model, + "a2": create_bayesian_logistic_regression_cold_start( + n_betas=first_n_betas, + update_method=first_update_method, + update_kwargs={kwarg_to_alter: altered_kwarg}, + ), } ) -def test_cmab_update(n_samples=100, n_features=3): +@settings(deadline=60000) +@given(st.just(100), st.just(3), st.sampled_from(literal_update_methods)) +def test_cmab_update(n_samples, n_features, update_method): actions = np.random.choice(["a1", "a2"], size=n_samples).tolist() rewards = np.random.choice([0, 1], size=n_samples).tolist() def run_update(context): - mab = create_cmab_bernoulli_cold_start(action_ids={"a1", "a2"}, n_features=n_features) + mab = create_cmab_bernoulli_cold_start( + action_ids={"a1", "a2"}, n_features=n_features, update_method=update_method + ) assert all( - [mab.actions[a] == create_bayesian_logistic_regression_cold_start(n_betas=n_features) for a in set(actions)] + [ + mab.actions[a] + == create_bayesian_logistic_regression_cold_start(n_betas=n_features, update_method=update_method) + for a in set(actions) + ] ) mab.update(context=context, actions=actions, rewards=rewards) assert all( - [mab.actions[a] != create_bayesian_logistic_regression_cold_start(n_betas=n_features) for a in set(actions)] + [ + mab.actions[a] + != create_bayesian_logistic_regression_cold_start(n_betas=n_features, update_method=update_method) + for a in set(actions) + ] ) assert not mab.predict_actions_randomly @@ -179,26 +227,42 @@ def run_update(context): run_update(context=context) -def test_cmab_update_not_all_actions(n_samples=100, n_feat=3): +@settings(deadline=10000) +@given(st.just(100), st.just(3), st.sampled_from(literal_update_methods)) +def test_cmab_update_not_all_actions(n_samples, n_feat, update_method): actions = np.random.choice(["a3", "a4"], size=n_samples).tolist() rewards = np.random.choice([0, 1], size=n_samples).tolist() context = np.random.uniform(low=-1.0, high=1.0, size=(n_samples, n_feat)) - mab = create_cmab_bernoulli_cold_start(action_ids={"a1", "a2", "a3", "a4"}, n_features=n_feat) + mab = create_cmab_bernoulli_cold_start( + action_ids={"a1", "a2", "a3", "a4"}, n_features=n_feat, update_method=update_method + ) mab.update(context=context, actions=actions, rewards=rewards) - assert mab.actions["a1"] == create_bayesian_logistic_regression_cold_start(n_betas=n_feat) - assert mab.actions["a2"] == create_bayesian_logistic_regression_cold_start(n_betas=n_feat) - assert mab.actions["a3"] != create_bayesian_logistic_regression_cold_start(n_betas=n_feat) - assert mab.actions["a4"] != create_bayesian_logistic_regression_cold_start(n_betas=n_feat) + assert mab.actions["a1"] == create_bayesian_logistic_regression_cold_start( + n_betas=n_feat, update_method=update_method + ) + assert mab.actions["a2"] == create_bayesian_logistic_regression_cold_start( + n_betas=n_feat, update_method=update_method + ) + assert mab.actions["a3"] != create_bayesian_logistic_regression_cold_start( + n_betas=n_feat, update_method=update_method + ) + assert mab.actions["a4"] != create_bayesian_logistic_regression_cold_start( + n_betas=n_feat, update_method=update_method + ) @settings(deadline=500) -@given(st.integers(min_value=1, max_value=1000), st.integers(min_value=1, max_value=100)) -def test_cmab_update_shape_mismatch(n_samples, n_features): +@given( + st.integers(min_value=1, max_value=1000), + st.integers(min_value=1, max_value=100), + st.sampled_from(literal_update_methods), +) +def test_cmab_update_shape_mismatch(n_samples, n_features, update_method): actions = np.random.choice(["a1", "a2"], size=n_samples).tolist() rewards = np.random.choice([0, 1], size=n_samples).tolist() context = np.random.uniform(low=-1.0, high=1.0, size=(n_samples, n_features)) - mab = create_cmab_bernoulli_cold_start(action_ids={"a1", "a2"}, n_features=n_features) + mab = create_cmab_bernoulli_cold_start(action_ids={"a1", "a2"}, n_features=n_features, update_method=update_method) with pytest.raises(AttributeError): # actions shape mismatch mab.update(context=context, actions=actions[1:], rewards=rewards) @@ -375,14 +439,16 @@ def test_cmab_get_state(mu, sigma, n_features): ), "strategy": st.fixed_dictionaries({}), } - ) + ), + update_method=st.sampled_from(literal_update_methods), ) -def test_cmab_from_state(state): +def test_cmab_from_state(state, update_method): + _apply_update_method_to_state(state, update_method) cmab = CmabBernoulli.from_state(state) assert isinstance(cmab, CmabBernoulli) - expected_actions = state["actions"] actual_actions = to_serializable_dict(cmab.actions) # Normalize the dict + expected_actions = {k: {**v, **state["actions"][k]} for k, v in actual_actions.items()} assert expected_actions == actual_actions # Ensure get_state and from_state compatibility @@ -513,18 +579,30 @@ def test_cmab_bai_predict(n_samples, n_features): assert len(selected_actions) == len(probs) == len(weighted_sums) == n_samples -def test_cmab_bai_update(n_samples=100, n_features=3): +@settings(deadline=10000) +@given(st.just(100), st.just(3), st.sampled_from(literal_update_methods)) +def test_cmab_bai_update(n_samples, n_features, update_method): actions = np.random.choice(["a1", "a2"], size=n_samples).tolist() rewards = np.random.choice([0, 1], size=n_samples).tolist() context = np.random.uniform(low=-1.0, high=1.0, size=(n_samples, n_features)) - mab = create_cmab_bernoulli_bai_cold_start(action_ids={"a1", "a2"}, n_features=n_features) + mab = create_cmab_bernoulli_bai_cold_start( + action_ids={"a1", "a2"}, n_features=n_features, update_method=update_method + ) assert mab.predict_actions_randomly assert all( - [mab.actions[a] == create_bayesian_logistic_regression_cold_start(n_betas=n_features) for a in set(actions)] + [ + mab.actions[a] + == create_bayesian_logistic_regression_cold_start(n_betas=n_features, update_method=update_method) + for a in set(actions) + ] ) mab.update(context=context, actions=actions, rewards=rewards) assert all( - [mab.actions[a] != create_bayesian_logistic_regression_cold_start(n_betas=n_features) for a in set(actions)] + [ + mab.actions[a] + != create_bayesian_logistic_regression_cold_start(n_betas=n_features, update_method=update_method) + for a in set(actions) + ] ) assert not mab.predict_actions_randomly @@ -597,15 +675,18 @@ def test_cmab_bai_get_state(mu, sigma, n_features, exploit_p: Float01): st.builds(lambda x: {"exploit_p": x}, st.floats(min_value=0, max_value=1)), ), } - ) + ), + update_method=st.sampled_from(literal_update_methods), ) -def test_cmab_bai_from_state(state): +def test_cmab_bai_from_state(state, update_method): + _apply_update_method_to_state(state, update_method) cmab = CmabBernoulliBAI.from_state(state) assert isinstance(cmab, CmabBernoulliBAI) - expected_actions = state["actions"] actual_actions = to_serializable_dict(cmab.actions) # Normalize the dict + expected_actions = {k: {**v, **state["actions"][k]} for k, v in actual_actions.items()} assert expected_actions == actual_actions + expected_exploit_p = ( state["strategy"].get("exploit_p", 0.5) if state["strategy"].get("exploit_p") is not None else 0.5 ) # Covers both not existing and existing + None @@ -743,22 +824,32 @@ def test_cmab_cc_predict(n_samples, n_features): assert len(selected_actions) == len(probs) == len(weighted_sums) == n_samples -def test_cmab_cc_update(n_samples=100, n_features=3): +@settings(deadline=10000) +@given(st.just(100), st.just(3), st.sampled_from(literal_update_methods)) +def test_cmab_cc_update(n_samples, n_features, update_method): actions = np.random.choice(["a1", "a2"], size=n_samples).tolist() rewards = np.random.choice([0, 1], size=n_samples).tolist() context = np.random.uniform(low=-1.0, high=1.0, size=(n_samples, n_features)) - mab = create_cmab_bernoulli_cc_cold_start(action_ids_cost={"a1": 10, "a2": 10}, n_features=n_features) + mab = create_cmab_bernoulli_cc_cold_start( + action_ids_cost={"a1": 10, "a2": 10}, n_features=n_features, update_method=update_method + ) assert mab.predict_actions_randomly assert all( [ - mab.actions[a] == create_bayesian_logistic_regression_cc_cold_start(n_betas=n_features, cost=10) + mab.actions[a] + == create_bayesian_logistic_regression_cc_cold_start( + n_betas=n_features, cost=10, update_method=update_method + ) for a in set(actions) ] ) mab.update(context=context, actions=actions, rewards=rewards) assert all( [ - mab.actions[a] != create_bayesian_logistic_regression_cc_cold_start(n_betas=n_features, cost=10) + mab.actions[a] + != create_bayesian_logistic_regression_cc_cold_start( + n_betas=n_features, cost=10, update_method=update_method + ) for a in set(actions) ] ) @@ -840,15 +931,18 @@ def test_cmab_cc_get_state( st.builds(lambda x: {"subsidy_factor": x}, st.floats(min_value=0, max_value=1)), ), } - ) + ), + update_method=st.sampled_from(literal_update_methods), ) -def test_cmab_cc_from_state(state): +def test_cmab_cc_from_state(state, update_method): + _apply_update_method_to_state(state, update_method) cmab = CmabBernoulliCC.from_state(state) assert isinstance(cmab, CmabBernoulliCC) - expected_actions = state["actions"] actual_actions = to_serializable_dict(cmab.actions) # Normalize the dict + expected_actions = {k: {**v, **state["actions"][k]} for k, v in actual_actions.items()} assert expected_actions == actual_actions + expected_subsidy_factor = ( state["strategy"].get("subsidy_factor", 0.5) if state["strategy"].get("subsidy_factor") is not None else 0.5 ) # Covers both not existing and existing + None diff --git a/tests/test_execution_time.py b/tests/test_execution_time.py deleted file mode 100644 index 647d6ac..0000000 --- a/tests/test_execution_time.py +++ /dev/null @@ -1,316 +0,0 @@ -# # MIT License -# # -# # Copyright (c) 2022 Playtika Ltd. -# # -# # Permission is hereby granted, free of charge, to any person obtaining a copy -# # of this software and associated documentation files (the "Software"), to deal -# # in the Software without restriction, including without limitation the rights -# # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# # copies of the Software, and to permit persons to whom the Software is -# # furnished to do so, subject to the following conditions: -# # -# # The above copyright notice and this permission notice shall be included in all -# # copies or substantial portions of the Software. -# # -# # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# # SOFTWARE. - -# import time - -# import numpy as np - -# from pybandits.core.cmab import Cmab - -# verbose = True - - -# def run_cmab_predict_streaming(n_actions, n_features, n_samples, n_iterations, n_jobs, sampling, params_sample, -# verbose=False): -# """ -# This function executes the following steps: -# - initialize cmab with input params -# - simulate first batch of users with actions and rewards -# - update cmab with first batch -# - predict 1 sample at time (i.e. in streaming) with sampling (sampling=True) or without (sampling=False) -# - return the mean and std of the prediction time. -# """ - -# # params -# size_first_batch = 1000 -# actions_ids = ['action' + str(i + 1) for i in range(n_actions)] - -# # init model -# cmab = Cmab(n_features=n_features, actions_ids=actions_ids, n_jobs=n_jobs, params_sample=params_sample) - -# # simulate first batch -# X = 2 * np.random.random_sample((size_first_batch, n_features)) - 1 # float in the interval (-1, 1) -# actions, _ = cmab.predict(X) -# rewards = np.random.randint(2, size=size_first_batch) - -# # update -# start = time.time() -# cmab.update(X=X, actions=actions, rewards=rewards) -# end = time.time() -# t = end - start -# if verbose: -# print('\nUpdate with n_actions = {}, n_features = {}, size_first_batch = {}. Time = {:.6f} sec.' -# .format(n_actions, n_features, size_first_batch, t)) - -# # predict 1 sample at time -# t = [] -# for i in range(n_iterations): -# x = 2 * np.random.random_sample((n_samples, n_features)) - 1 # floats in the interval (-1, 1) -# if sampling: -# start = time.time() -# _, _ = cmab.predict(x) -# end = time.time() -# else: -# start = time.time() -# _, _ = cmab.fast_predict(x) -# end = time.time() -# t.append(end-start) -# mu_t, simga_t = np.mean(t), np.std(t) - -# if verbose: -# print('Predict of n_actions={}, n_features={}, n_samples={}, n_iterations={}, sampling={}. ' -# '\nmean execution time = {:.6f} sec, std execution time = {:.6f} sec ' -# .format(n_actions, n_features, n_samples, n_iterations, sampling, mu_t, simga_t)) - -# return mu_t, simga_t - - -# def test_cmab_time_predict_before_update(): -# """ Test cmab.predict() in steaming before the first update(). """ -# # input -# n_iteration = 10000 -# n_actions = 1000 -# n_samples = 1 -# n_features = 1000 -# actions_ids = ['action' + str(i + 1) for i in range(n_actions)] -# params_sample = {'tune': 500, 'draws': 1000, 'chains': 2, 'init': 'adapt_diag', 'cores': 1, 'target_accept': 0.95, -# 'progressbar': False} - -# # init model -# cmab = Cmab(n_features=n_features, actions_ids=actions_ids, params_sample=params_sample) - -# # predict -# t = [] -# for i in range(n_iteration): -# x = 2 * np.random.random_sample((n_samples, n_features)) - 1 # float in the interval (-1, 1) -# start = time.time() -# _, _ = cmab.predict(x) -# end = time.time() -# t.append(end - start) -# mu_t, simga_t = np.mean(t), np.std(t) - -# if verbose: -# print('\nPredict before the first update of n_samples={}, n_actions={}, n_features={}, n_iteration={}' -# '\nmean execution time = {:.6f} sec, std execution time = {:.6f} sec ' -# .format(n_samples, n_iteration, n_actions, n_features, mu_t, simga_t)) - - -# # test with fast predict - -# def test_cmab_time_predict_2_2_1_fp(): -# """ Test cmab.fast_predict() in steaming after the first update(). """ -# # input -# n_actions = 2 -# n_features = 2 -# n_samples = 1 -# n_iterations = 10 -# n_jobs = n_actions -# sampling = False -# params_sample = {'tune': 500, 'draws': 1000, 'chains': 2, 'init': 'adapt_diag', 'cores': 1, 'target_accept': 0.95, -# 'progressbar': False} - -# # run test -# mu_t, simga_t = run_cmab_predict_streaming(n_actions=n_actions, n_features=n_features, n_samples=n_samples, -# n_iterations=n_iterations, n_jobs=n_jobs, sampling=sampling, -# params_sample=params_sample, verbose=verbose) - - -# def test_cmab_time_predict_2_2_10000_fp(): -# """ Test cmab.fast_predict() in steaming after the first update(). """ -# # input -# n_actions = 2 -# n_features = 2 -# n_samples = 10000 -# n_iterations = 10 -# n_jobs = n_actions -# sampling = False -# params_sample = {'tune': 500, 'draws': 1000, 'chains': 2, 'init': 'adapt_diag', 'cores': 1, 'target_accept': 0.95, -# 'progressbar': False} - -# # run test -# mu_t, simga_t = run_cmab_predict_streaming(n_actions=n_actions, n_features=n_features, n_samples=n_samples, -# n_iterations=n_iterations, n_jobs=n_jobs, sampling=sampling, -# params_sample=params_sample, verbose=verbose) - - -# def test_cmab_time_predict_2_5_10000_fp(): -# """ Test cmab.fast_predict() in steaming after the first update(). """ -# # input -# n_actions = 2 -# n_features = 5 -# n_samples = 10000 -# n_iterations = 10 -# n_jobs = n_actions -# sampling = False -# params_sample = {'tune': 500, 'draws': 1000, 'chains': 2, 'init': 'adapt_diag', 'cores': 1, 'target_accept': 0.95, -# 'progressbar': False} - -# # run test -# mu_t, simga_t = run_cmab_predict_streaming(n_actions=n_actions, n_features=n_features, n_samples=n_samples, -# n_iterations=n_iterations, n_jobs=n_jobs, sampling=sampling, -# params_sample=params_sample, verbose=verbose) - - -# def test_cmab_time_predict_2_100_10000_fp(): -# """ Test cmab.fast_predict() in steaming after the first update(). """ -# # input -# n_actions = 2 -# n_features = 100 -# n_samples = 10000 -# n_iterations = 10 -# n_jobs = n_actions -# sampling = False -# params_sample = {'tune': 500, 'draws': 1000, 'chains': 2, 'init': 'adapt_diag', 'cores': 1, 'target_accept': 0.95, -# 'progressbar': False} - -# # run test -# mu_t, simga_t = run_cmab_predict_streaming(n_actions=n_actions, n_features=n_features, n_samples=n_samples, -# n_iterations=n_iterations, n_jobs=n_jobs, sampling=sampling, -# params_sample=params_sample, verbose=verbose) - - -# def test_cmab_time_predict_5_2_10000_fp(): -# """ Test cmab.fast_predict() in steaming after the first update(). """ -# # input -# n_actions = 5 -# n_features = 2 -# n_samples = 10000 -# n_iterations = 10 -# n_jobs = n_actions -# sampling = False -# params_sample = {'tune': 500, 'draws': 1000, 'chains': 2, 'init': 'adapt_diag', 'cores': 1, 'target_accept': 0.95, -# 'progressbar': False} - -# # run test -# mu_t, simga_t = run_cmab_predict_streaming(n_actions=n_actions, n_features=n_features, n_samples=n_samples, -# n_iterations=n_iterations, n_jobs=n_jobs, sampling=sampling, -# params_sample=params_sample, verbose=verbose) - - -# def test_cmab_time_predict_20_2_1_fp(): -# """ Test cmab.fast_predict() in steaming after the first update(). """ -# # input -# n_actions = 20 -# n_features = 2 -# n_samples = 1 -# n_iterations = 10 -# n_jobs = n_actions -# sampling = False -# params_sample = {'tune': 500, 'draws': 1000, 'chains': 2, 'init': 'adapt_diag', 'cores': 1, 'target_accept': 0.95, -# 'progressbar': False} - -# # run test -# mu_t, simga_t = run_cmab_predict_streaming(n_actions=n_actions, n_features=n_features, n_samples=n_samples, -# n_iterations=n_iterations, n_jobs=n_jobs, sampling=sampling, -# params_sample=params_sample, verbose=verbose) - - -# def test_cmab_time_predict_20_100_1_fp(): -# """ Test cmab.fast_predict() in steaming after the first update(). """ -# # input -# n_actions = 20 -# n_features = 100 -# n_samples = 1 -# n_iterations = 10 -# n_jobs = n_actions -# sampling = False -# params_sample = {'tune': 500, 'draws': 1000, 'chains': 2, 'init': 'adapt_diag', 'cores': 1, 'target_accept': 0.95, -# 'progressbar': False} - -# # run test -# mu_t, simga_t = run_cmab_predict_streaming(n_actions=n_actions, n_features=n_features, n_samples=n_samples, -# n_iterations=n_iterations, n_jobs=n_jobs, sampling=sampling, -# params_sample=params_sample, verbose=verbose) - - -# # test with sampling - -# def test_cmab_time_predict_2_2_1_w_s(): -# """ Test cmab.predict() in steaming after the first update(). """ -# # input -# n_actions = 2 -# n_features = 2 -# n_samples = 1 -# n_iterations = 10 -# n_jobs = n_actions -# sampling = True -# params_sample = {'tune': 500, 'draws': 1000, 'chains': 2, 'init': 'adapt_diag', 'cores': 1, 'target_accept': 0.95, -# 'progressbar': False} - -# # run test -# mu_t, simga_t = run_cmab_predict_streaming(n_actions=n_actions, n_features=n_features, n_samples=n_samples, -# n_iterations=n_iterations, n_jobs=n_jobs, sampling=sampling, -# params_sample=params_sample, verbose=verbose) - - -# def test_cmab_time_predict_2_2_10000_w_s(): -# """ Test cmab.predict() in steaming after the first update(). """ -# # input -# n_actions = 2 -# n_features = 2 -# n_samples = 10000 -# n_iterations = 10 -# n_jobs = n_actions -# sampling = True -# params_sample = {'tune': 500, 'draws': 1000, 'chains': 2, 'init': 'adapt_diag', 'cores': 1, 'target_accept': 0.95, -# 'progressbar': False} - -# # run test -# mu_t, simga_t = run_cmab_predict_streaming(n_actions=n_actions, n_features=n_features, n_samples=n_samples, -# n_iterations=n_iterations, n_jobs=n_jobs, sampling=sampling, -# params_sample=params_sample, verbose=verbose) - - -# def test_cmab_time_predict_2_5_10000_w_s(): -# """ Test cmab.predict() in steaming after the first update(). """ -# # input -# n_actions = 2 -# n_features = 5 -# n_samples = 10000 -# n_iterations = 10 -# n_jobs = n_actions -# sampling = True -# params_sample = {'tune': 500, 'draws': 1000, 'chains': 2, 'init': 'adapt_diag', 'cores': 1, 'target_accept': 0.95, -# 'progressbar': False} - -# # run test -# mu_t, simga_t = run_cmab_predict_streaming(n_actions=n_actions, n_features=n_features, n_samples=n_samples, -# n_iterations=n_iterations, n_jobs=n_jobs, sampling=sampling, -# params_sample=params_sample, verbose=verbose) - - -# def test_cmab_time_predict_2_100_10000_w_s(): -# """ Test cmab.predict() in steaming after the first update(). """ -# # input -# n_actions = 2 -# n_features = 100 -# n_samples = 10000 -# n_iterations = 10 -# n_jobs = n_actions -# sampling = True -# params_sample = {'tune': 500, 'draws': 1000, 'chains': 2, 'init': 'adapt_diag', 'cores': 1, 'target_accept': 0.95, -# 'progressbar': False} - -# # run test -# mu_t, simga_t = run_cmab_predict_streaming(n_actions=n_actions, n_features=n_features, n_samples=n_samples, -# n_iterations=n_iterations, n_jobs=n_jobs, sampling=sampling, -# params_sample=params_sample, verbose=verbose)