Skip to content

Commit

Permalink
Add epsilon greedy support in from_state
Browse files Browse the repository at this point in the history
 ### Changes
 * Add epsilon and default_action to from_state methods in smab.py and cmab.py.
 * Updated state UTs.
  • Loading branch information
Shahar-Bar committed Dec 11, 2024
1 parent 9c15f78 commit 56436b6
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 37 deletions.
1 change: 1 addition & 0 deletions .github/workflows/release_draft.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- develop
- main

jobs:
draft_release:
Expand Down
20 changes: 17 additions & 3 deletions pybandits/cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,11 @@ def __init__(

@classmethod
def from_state(cls, state: dict) -> "CmabBernoulli":
return cls(actions=state["actions"])
return cls(
actions=state["actions"],
epsilon=state.get("epsilon", None),
default_action=state.get("default_action", None),
)

@validate_call(config=dict(arbitrary_types_allowed=True))
def update(self, context: ArrayLike, actions: List[ActionId], rewards: List[BinaryReward]):
Expand Down Expand Up @@ -271,7 +275,12 @@ def __init__(

@classmethod
def from_state(cls, state: dict) -> "CmabBernoulliBAI":
return cls(actions=state["actions"], exploit_p=state["strategy"].get("exploit_p", None))
return cls(
actions=state["actions"],
exploit_p=state["strategy"].get("exploit_p", None),
epsilon=state.get("epsilon", None),
default_action=state.get("default_action", None),
)

@validate_call(config=dict(arbitrary_types_allowed=True))
def update(self, context: ArrayLike, actions: List[ActionId], rewards: List[BinaryReward]):
Expand Down Expand Up @@ -324,7 +333,12 @@ def __init__(

@classmethod
def from_state(cls, state: dict) -> "CmabBernoulliCC":
return cls(actions=state["actions"], subsidy_factor=state["strategy"].get("subsidy_factor", None))
return cls(
actions=state["actions"],
subsidy_factor=state["strategy"].get("subsidy_factor", None),
epsilon=state.get("epsilon", None),
default_action=state.get("default_action", None),
)

@validate_call(config=dict(arbitrary_types_allowed=True))
def update(self, context: ArrayLike, actions: List[ActionId], rewards: List[BinaryReward]):
Expand Down
32 changes: 27 additions & 5 deletions pybandits/smab.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,11 @@ def __init__(

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulli":
return cls(actions=state["actions"])
return cls(
actions=state["actions"],
epsilon=state.get("epsilon", None),
default_action=state.get("default_action", None),
)

@validate_call
def update(self, actions: List[ActionId], rewards: List[BinaryReward]):
Expand Down Expand Up @@ -187,7 +191,12 @@ def __init__(

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulliBAI":
return cls(actions=state["actions"], exploit_p=state["strategy"].get("exploit_p", None))
return cls(
actions=state["actions"],
exploit_p=state["strategy"].get("exploit_p", None),
epsilon=state.get("epsilon", None),
default_action=state.get("default_action", None),
)

@validate_call
def update(self, actions: List[ActionId], rewards: List[BinaryReward]):
Expand Down Expand Up @@ -232,7 +241,12 @@ def __init__(

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulliCC":
return cls(actions=state["actions"], subsidy_factor=state["strategy"].get("subsidy_factor", None))
return cls(
actions=state["actions"],
subsidy_factor=state["strategy"].get("subsidy_factor", None),
epsilon=state.get("epsilon", None),
default_action=state.get("default_action", None),
)

@validate_call
def update(self, actions: List[ActionId], rewards: List[BinaryReward]):
Expand Down Expand Up @@ -303,7 +317,11 @@ def __init__(

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulliMO":
return cls(actions=state["actions"])
return cls(
actions=state["actions"],
epsilon=state.get("epsilon", None),
default_action=state.get("default_action", None),
)


class SmabBernoulliMOCC(BaseSmabBernoulliMO):
Expand Down Expand Up @@ -337,7 +355,11 @@ def __init__(

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulliMOCC":
return cls(actions=state["actions"])
return cls(
actions=state["actions"],
epsilon=state.get("epsilon", None),
default_action=state.get("default_action", None),
)


@validate_call
Expand Down
38 changes: 27 additions & 11 deletions tests/test_cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import get_args
from typing import Optional, get_args

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -381,21 +381,26 @@ def run_predict(mab):


@settings(deadline=500)
@given(st.integers(min_value=1), st.integers(min_value=1), st.integers(min_value=2, max_value=100))
def test_cmab_get_state(mu, sigma, n_features):
@given(
st.integers(min_value=1),
st.integers(min_value=1),
st.integers(min_value=2, max_value=100),
st.sampled_from([None, 0.1]),
)
def test_cmab_get_state(mu, sigma, n_features, epsilon):
actions: dict = {
"a1": BayesianLogisticRegression(alpha=StudentT(mu=mu, sigma=sigma), betas=n_features * [StudentT()]),
"a2": create_bayesian_logistic_regression_cold_start(n_betas=n_features),
}

cmab = CmabBernoulli(actions=actions)
cmab = CmabBernoulli(actions=actions, epsilon=epsilon)
expected_state = to_serializable_dict(
{
"actions": actions,
"strategy": {},
"predict_with_proba": False,
"predict_actions_randomly": False,
"epsilon": None,
"epsilon": epsilon,
"default_action": None,
}
)
Expand Down Expand Up @@ -438,6 +443,7 @@ def test_cmab_get_state(mu, sigma, n_features):
min_size=2,
),
"strategy": st.fixed_dictionaries({}),
"epsilon": st.sampled_from([None, 0.1]),
}
),
update_method=st.sampled_from(literal_update_methods),
Expand Down Expand Up @@ -613,21 +619,22 @@ def test_cmab_bai_update(n_samples, n_features, update_method):
st.integers(min_value=1),
st.integers(min_value=2, max_value=100),
st.floats(min_value=0, max_value=1),
st.sampled_from([None, 0.1]),
)
def test_cmab_bai_get_state(mu, sigma, n_features, exploit_p: Float01):
def test_cmab_bai_get_state(mu, sigma, n_features, exploit_p: Float01, epsilon: Optional[Float01]):
actions: dict = {
"a1": BayesianLogisticRegression(alpha=StudentT(mu=mu, sigma=sigma), betas=n_features * [StudentT()]),
"a2": create_bayesian_logistic_regression_cold_start(n_betas=n_features),
}

cmab = CmabBernoulliBAI(actions=actions, exploit_p=exploit_p)
cmab = CmabBernoulliBAI(actions=actions, exploit_p=exploit_p, epsilon=epsilon)
expected_state = to_serializable_dict(
{
"actions": actions,
"strategy": {"exploit_p": exploit_p},
"predict_with_proba": False,
"predict_actions_randomly": False,
"epsilon": None,
"epsilon": epsilon,
"default_action": None,
}
)
Expand Down Expand Up @@ -674,6 +681,7 @@ def test_cmab_bai_get_state(mu, sigma, n_features, exploit_p: Float01):
st.just({"exploit_p": None}),
st.builds(lambda x: {"exploit_p": x}, st.floats(min_value=0, max_value=1)),
),
"epsilon": st.sampled_from([None, 0.1]),
}
),
update_method=st.sampled_from(literal_update_methods),
Expand Down Expand Up @@ -864,9 +872,16 @@ def test_cmab_cc_update(n_samples, n_features, update_method):
st.floats(min_value=0),
st.floats(min_value=0),
st.floats(min_value=0, max_value=1),
st.sampled_from([None, 0.1]),
)
def test_cmab_cc_get_state(
mu, sigma, n_features, cost_1: NonNegativeFloat, cost_2: NonNegativeFloat, subsidy_factor: Float01
mu,
sigma,
n_features,
cost_1: NonNegativeFloat,
cost_2: NonNegativeFloat,
subsidy_factor: Float01,
epsilon: Optional[Float01],
):
actions: dict = {
"a1": BayesianLogisticRegressionCC(
Expand All @@ -875,14 +890,14 @@ def test_cmab_cc_get_state(
"a2": create_bayesian_logistic_regression_cc_cold_start(n_betas=n_features, cost=cost_2),
}

cmab = CmabBernoulliCC(actions=actions, subsidy_factor=subsidy_factor)
cmab = CmabBernoulliCC(actions=actions, subsidy_factor=subsidy_factor, epsilon=epsilon)
expected_state = to_serializable_dict(
{
"actions": actions,
"strategy": {"subsidy_factor": subsidy_factor},
"predict_with_proba": True,
"predict_actions_randomly": False,
"epsilon": None,
"epsilon": epsilon,
"default_action": None,
}
)
Expand Down Expand Up @@ -930,6 +945,7 @@ def test_cmab_cc_get_state(
st.just({"subsidy_factor": None}),
st.builds(lambda x: {"subsidy_factor": x}, st.floats(min_value=0, max_value=1)),
),
"epsilon": st.sampled_from([None, 0.1]),
}
),
update_method=st.sampled_from(literal_update_methods),
Expand Down
Loading

0 comments on commit 56436b6

Please sign in to comment.