Skip to content

Commit

Permalink
Adaptive Windowing for Multi-Armed Bandits
Browse files Browse the repository at this point in the history
 ### Changes:
 * Added adaptive windowing mechanism to detect and handle concept drift in MAB models.
 * Introduced ActionsManager class to handle action memory and updates with configurable window sizes.
 * Refactored Model class hierarchy to support model resetting and memory management.
 * Added support for infinite and fixed-size windows with change detection via delta parameter.
 * Enhanced test coverage for adaptive windowing functionality across MAB variants.
  • Loading branch information
Shahar-Bar committed Jan 1, 2025
1 parent 64913ef commit a2781d7
Show file tree
Hide file tree
Showing 15 changed files with 2,185 additions and 542 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ MANIFEST

# poetry
poetry.lock
.qodo
626 changes: 626 additions & 0 deletions pybandits/actions_manager.py

Large diffs are not rendered by default.

15 changes: 14 additions & 1 deletion pybandits/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# SOFTWARE.


from typing import Any, Dict, List, NewType, Tuple, Union
from typing import Any, Dict, List, NewType, Tuple, Union, _GenericAlias, get_args, get_origin

from pybandits.pydantic_version_compatibility import (
PYDANTIC_VERSION_1,
Expand Down Expand Up @@ -52,6 +52,7 @@
Union[Dict[ActionId, float], Dict[ActionId, Probability], Dict[ActionId, List[Probability]]],
)
ACTION_IDS_PREFIX = "action_ids_"
ACTIONS = "actions"


class _classproperty(property):
Expand Down Expand Up @@ -96,6 +97,18 @@ def _apply_version_adjusted_method(self, v2_method_name: str, v1_method_name: st
def _get_value_with_default(cls, key: str, values: Dict[str, Any]) -> Any:
return values.get(key, cls.model_fields[key].default)

@classmethod
def _get_field_type(cls, key: str) -> Any:
if pydantic_version == PYDANTIC_VERSION_1:
annotation = cls.model_fields[key].type_
elif pydantic_version == PYDANTIC_VERSION_2:
annotation = cls.model_fields[key].annotation
if isinstance(annotation, _GenericAlias) and get_origin(annotation) is dict:
annotation = get_args(annotation)[1] # refer to the type of the Dict values
else:
raise ValueError(f"Unsupported pydantic version: {pydantic_version}")
return annotation

if pydantic_version == PYDANTIC_VERSION_1:

@_classproperty
Expand Down
48 changes: 10 additions & 38 deletions pybandits/cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,26 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Dict, List, Optional, Set, Union
from abc import ABC
from typing import List, Optional, Set, Union

from numpy import array
from numpy.random import choice
from numpy.typing import ArrayLike

from pybandits.actions_manager import CmabActionsManager
from pybandits.base import ActionId, BinaryReward, CmabPredictions
from pybandits.mab import BaseMab
from pybandits.model import BayesianLogisticRegression, BayesianLogisticRegressionCC
from pybandits.pydantic_version_compatibility import field_validator, validate_call
from pybandits.pydantic_version_compatibility import validate_call
from pybandits.strategy import (
BestActionIdentificationBandit,
ClassicBandit,
CostControlBandit,
)


class BaseCmabBernoulli(BaseMab):
class BaseCmabBernoulli(BaseMab, ABC):
"""
Base model for a Contextual Multi-Armed Bandit for Bernoulli bandits with Thompson Sampling.
Expand All @@ -54,27 +56,10 @@ class BaseCmabBernoulli(BaseMab):
bandit strategy.
"""

actions: Dict[ActionId, BayesianLogisticRegression]
actions_manager: CmabActionsManager[BayesianLogisticRegression]
predict_with_proba: bool
predict_actions_randomly: bool

@field_validator("actions", mode="after")
@classmethod
def check_bayesian_logistic_regression_models(cls, v):
action_models = list(v.values())
first_action = action_models[0]
first_action_type = type(first_action)
for action in action_models[1:]:
if not isinstance(action, first_action_type):
raise AttributeError("All actions should follow the same type.")
if not len(action.betas) == len(first_action.betas):
raise AttributeError("All actions should have the same number of betas.")
if not action.update_method == first_action.update_method:
raise AttributeError("All actions should have the same update method.")
if not action.update_kwargs == first_action.update_kwargs:
raise AttributeError("All actions should have the same update kwargs.")
return v

@validate_call(config=dict(arbitrary_types_allowed=True))
def predict(
self,
Expand Down Expand Up @@ -169,20 +154,7 @@ def update(
If strategy is MultiObjectiveBandit, rewards should be a list of list, e.g. (with n_objectives=2):
rewards = [[1, 1], [1, 0], [1, 1], [1, 0], [1, 1], ...]
"""
self._validate_update_params(actions=actions, rewards=rewards)
if len(context) != len(rewards):
raise AttributeError(f"Shape mismatch: actions and rewards should have the same length {len(actions)}.")

# cast inputs to numpy arrays to facilitate their manipulation
context, actions, rewards = array(context), array(actions), array(rewards)

for a in set(actions):
# get context and rewards of the samples associated to action a
context_of_a = context[actions == a]
rewards_of_a = rewards[actions == a].tolist()

# update model associated to action a
self.actions[a].update(context=context_of_a, rewards=rewards_of_a)
super().update(actions=actions, rewards=rewards, context=context)

# always set predict_actions_randomly after update
self.predict_actions_randomly = False
Expand All @@ -208,7 +180,7 @@ class CmabBernoulli(BaseCmabBernoulli):
bandit strategy.
"""

actions: Dict[ActionId, BayesianLogisticRegression]
actions_manager: CmabActionsManager[BayesianLogisticRegression]
strategy: ClassicBandit
predict_with_proba: bool = False
predict_actions_randomly: bool = False
Expand All @@ -234,7 +206,7 @@ class CmabBernoulliBAI(BaseCmabBernoulli):
bandit strategy.
"""

actions: Dict[ActionId, BayesianLogisticRegression]
actions_manager: CmabActionsManager[BayesianLogisticRegression]
strategy: BestActionIdentificationBandit
predict_with_proba: bool = False
predict_actions_randomly: bool = False
Expand Down Expand Up @@ -268,7 +240,7 @@ class CmabBernoulliCC(BaseCmabBernoulli):
bandit strategy.
"""

actions: Dict[ActionId, BayesianLogisticRegressionCC]
actions_manager: CmabActionsManager[BayesianLogisticRegressionCC]
strategy: CostControlBandit
predict_with_proba: bool = True
predict_actions_randomly: bool = False
Loading

0 comments on commit a2781d7

Please sign in to comment.