Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loss factor to CostControlBandit and update logic for dynamic subsidy factor #65

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/continuous_delivery.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
poetry run pre-commit run --all-files
- name: Run tests
run: |
poetry run pytest -vv -k 'not time and not update_parallel'
poetry run pytest -vv -k 'not time and not update_parallel' --cov=pybandits
- name: Extract version from pyproject.toml
id: extract_version
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/continuous_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ jobs:
poetry run pre-commit run --all-files
- name: Run tests
run: |
poetry run pytest -vv -k 'not time and not update_parallel'
poetry run pytest -vv -k 'not time and not update_parallel' --cov=pybandits
17 changes: 7 additions & 10 deletions pybandits/cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Dict, List, Optional, Set, Union
from typing import Dict, List, Set, Union

from numpy import array
from numpy.random import choice
Expand Down Expand Up @@ -76,22 +76,20 @@ def check_bayesian_logistic_regression_models(cls, v):
return v

@validate_call(config=dict(arbitrary_types_allowed=True))
def predict(
def _predict(
self,
context: ArrayLike,
forbidden_actions: Optional[Set[ActionId]] = None,
valid_actions: Set[ActionId],
) -> CmabPredictions:
"""
Predict actions.

Parameters
----------
context: ArrayLike of shape (n_samples, n_features)
context : ArrayLike of shape (n_samples, n_features)
Matrix of contextual features.
forbidden_actions : Optional[Set[ActionId]], default=None
Set of forbidden actions. If specified, the model will discard the forbidden_actions and it will only
consider the remaining allowed_actions. By default, the model considers all actions as allowed_actions.
Note that: actions = allowed_actions U forbidden_actions.
valid_actions : Set[ActionId]
The set of valid actions to consider.

Returns
-------
Expand All @@ -102,7 +100,6 @@ def predict(
ws : List[Dict[ActionId, float]]
The weighted sum of logistic regression logits.
"""
valid_actions = self._get_valid_actions(forbidden_actions)

# cast inputs to numpy arrays to facilitate their manipulation
context = array(context)
Expand Down Expand Up @@ -149,7 +146,7 @@ def predict(
return selected_actions, probs, weighted_sums

@validate_call(config=dict(arbitrary_types_allowed=True))
def update(
def _update(
self, context: ArrayLike, actions: List[ActionId], rewards: List[Union[BinaryReward, List[BinaryReward]]]
):
"""
Expand Down
47 changes: 44 additions & 3 deletions pybandits/mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def _validate_update_params(

####################################################################################################################

@abstractmethod
@validate_call
def update(
self, actions: List[ActionId], rewards: Union[List[BinaryReward], List[List[BinaryReward]]], *args, **kwargs
Expand All @@ -182,10 +181,27 @@ def update(
rewards: List[Union[BinaryReward, List[BinaryReward]]]
The reward for each sample.
"""
self._validate_update_params(actions=actions, rewards=rewards)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This side effect should be a function decorator

self._update(actions=actions, rewards=rewards, *args, **kwargs)
if hasattr(self.strategy, "update"):
self.strategy.update(rewards=rewards)

@abstractmethod
def _update(
self, actions: List[ActionId], rewards: Union[List[BinaryReward], List[List[BinaryReward]]], *args, **kwargs
):
"""
Update the multi-armed bandit model.

actions : List[ActionId]
The selected action for each sample.
rewards : List[Union[BinaryReward, List[BinaryReward]]]
The reward for each sample.
"""
pass

@validate_call
def predict(self, forbidden_actions: Optional[Set[ActionId]] = None) -> Predictions:
def predict(self, forbidden_actions: Optional[Set[ActionId]] = None, **kwargs) -> Predictions:
"""
Predict actions.

Expand All @@ -196,15 +212,40 @@ def predict(self, forbidden_actions: Optional[Set[ActionId]] = None) -> Predicti
consider the remaining allowed_actions. By default, the model considers all actions as allowed_actions.
Note that: actions = allowed_actions U forbidden_actions.

Returns
-------
actions : List[ActionId] of shape (n_samples,)
The actions selected by the multi-armed bandit model.
probs : List[Dict[ActionId, Probability]] of shape (n_samples,)
The probabilities of getting a positive reward for each action
ws : List[Dict[ActionId, float]], only relevant for some of the MABs
The weighted sum of logistic regression logits.
"""
if hasattr(self.strategy, "reset"):
self.strategy.reset()
valid_actions = self._get_valid_actions(forbidden_actions)
return self._predict(valid_actions=valid_actions, **kwargs)

@abstractmethod
def _predict(self, valid_actions: Set[ActionId], **kwargs) -> Predictions:
"""
Predict actions.

Parameters
----------
valid_actions : Set[ActionId]
The set of valid actions.

Returns
-------
actions: List[ActionId] of shape (n_samples,)
The actions selected by the multi-armed bandit model.
probs: List[Dict[ActionId, Probability]] of shape (n_samples,)
The probabilities of getting a positive reward for each action
ws : List[Dict[ActionId, float]], only relevant for some of the MABs
The weighted sum of logistic regression logits..
The weighted sum of logistic regression logits.
"""
pass

def get_state(self) -> (str, dict):
"""
Expand Down
21 changes: 6 additions & 15 deletions pybandits/smab.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


from collections import defaultdict
from typing import Dict, List, Optional, Set, Union
from typing import Dict, List, Set, Union

from pydantic import PositiveInt, field_validator, validate_call

Expand Down Expand Up @@ -59,22 +59,16 @@ class BaseSmabBernoulli(BaseMab):
actions: Dict[ActionId, BaseBeta]

@validate_call
def predict(
self,
n_samples: PositiveInt = 1,
forbidden_actions: Optional[Set[ActionId]] = None,
) -> SmabPredictions:
def _predict(self, n_samples: PositiveInt, valid_actions: Set[ActionId]) -> SmabPredictions:
"""
Predict actions.

Parameters
----------
n_samples : int > 0, default=1
n_samples : PositiveInt
Number of samples to predict.
forbidden_actions : Optional[Set[ActionId]], default=None
Set of forbidden actions. If specified, the model will discard the forbidden_actions and it will only
consider the remaining allowed_actions. By default, the model considers all actions as allowed_actions.
Note that: actions = allowed_actions U forbidden_actions.
valid_actions : Set[ActionId]
The set of valid actions.

Returns
-------
Expand All @@ -83,7 +77,6 @@ def predict(
probs: List[Dict[ActionId, Probability]] of shape (n_samples,)
The probabilities of getting a positive reward for each action.
"""
valid_actions = self._get_valid_actions(forbidden_actions)

selected_actions: List[ActionId] = []
probs: List[Dict[ActionId, Probability]] = []
Expand All @@ -96,7 +89,7 @@ def predict(
return selected_actions, probs

@validate_call
def update(self, actions: List[ActionId], rewards: Union[List[BinaryReward], List[List[BinaryReward]]]):
def _update(self, actions: List[ActionId], rewards: Union[List[BinaryReward], List[List[BinaryReward]]]):
"""
Update the stochastic Bernoulli bandit given the list of selected actions and their corresponding binary
rewards.
Expand All @@ -113,8 +106,6 @@ def update(self, actions: List[ActionId], rewards: Union[List[BinaryReward], Lis
rewards = [[1, 1], [1, 0], [1, 1], [1, 0], [1, 1], ...]
"""

self._validate_update_params(actions=actions, rewards=rewards)

rewards_dict = defaultdict(list)

for a, r in zip(actions, rewards):
Expand Down
Loading