Skip to content

Commit

Permalink
Vectorization of cMAB Predict (#61)
Browse files Browse the repository at this point in the history
### Changes
 * Changed predict method of BaseCmabBernoulli on cmab.py to vectorized version, rather than for loop.
  • Loading branch information
shaharbar1 authored Sep 24, 2024
1 parent dfaab15 commit 9c15f78
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 28 deletions.
53 changes: 26 additions & 27 deletions pybandits/cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,33 +123,32 @@ def predict(
probs = len(context) * [{k: 0.5 for k in valid_actions}] # all probs are set to 0.5
weighted_sums = len(context) * [{k: 0 for k in valid_actions}] # all weighted sum are set to 1
else:
selected_actions: List[ActionId] = []
probs: List[Dict[ActionId, Probability]] = []
weighted_sums: List[Dict[ActionId, float]] = []

# sample_proba() and select_action() each row of the context
for i in range(len(context)):
# p is a dict of the sampled probability "prob" and weighted_sum "ws", e.g.
#
# p = {'a1': ([0.5], [200]), 'a2': ([0.4], [100]), ...}
# | | | |
# prob ws prob ws
p = {
action: model.sample_proba(context=context[i].reshape(1, -1)) # reshape row i-th to (1, n_features)
for action, model in self.actions.items()
if action in valid_actions
}

prob = {a: x[0][0] for a, x in p.items()} # e.g. prob = {'a1': 0.5, 'a2': 0.4, ...}
ws = {a: x[1][0] for a, x in p.items()} # e.g. ws = {'a1': 200, 'a2': 100, ...}

# select either "prob" or "ws" to use as input argument in select_actions()
p_to_select_action = prob if self.predict_with_proba else ws

# predict actions, probs, weighted_sums
selected_actions.append(self._select_epsilon_greedy_action(p=p_to_select_action, actions=self.actions))
probs.append(prob)
weighted_sums.append(ws)
# p is a dict of the sampled probability "prob" and weighted_sum "ws", e.g.
#
# p = {'a1': ([0.5, 0.2, 0.3], [200, 100, 130]), 'a2': ([0.4, 0.5, 0.6], [180, 200, 230]), ...}
# | | | |
# prob ws prob ws
p = {
action: model.sample_proba(context=context) # sample probabilities for the entire context matrix
for action, model in self.actions.items()
if action in valid_actions
}

prob = {a: x[0] for a, x in p.items()} # e.g. prob = {'a1': [0.5, 0.4, ...], 'a2': [0.4, 0.3, ...], ...}
ws = {a: x[1] for a, x in p.items()} # e.g. ws = {'a1': [200, 100, ...], 'a2': [100, 50, ...], ...}

# select either "prob" or "ws" to use as input argument in select_actions()
p_to_select_action = prob if self.predict_with_proba else ws

# predict actions, probs, weighted_sums
selected_actions = [
self._select_epsilon_greedy_action(
p={a: p_to_select_action[a][i] for a in p_to_select_action}, actions=self.actions
)
for i in range(len(context))
]
probs = [{a: prob[a][i] for a in prob} for i in range(len(context))]
weighted_sums = [{a: ws[a][i] for a in ws} for i in range(len(context))]

return selected_actions, probs, weighted_sums

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pybandits"
version = "0.6.0"
version = "0.6.1"
description = "Python Multi-Armed Bandit Library"
authors = [
"Dario d'Andrea <[email protected]>",
Expand Down

0 comments on commit 9c15f78

Please sign in to comment.