Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat Sebulba recurrent IQL #1148

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
11 changes: 11 additions & 0 deletions mava/configs/default/rec_iql_sebulba.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
defaults:
- logger: logger
- arch: sebulba
- system: q_learning/rec_iql
- network: rnn # [rnn, rcnn]
- env: lbf_gym # [rware_gym, lbf_gym, smac_gym]
- _self_

hydra:
searchpath:
- file://mava/configs
10 changes: 7 additions & 3 deletions mava/configs/system/q_learning/rec_iql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ add_agent_id: True
min_buffer_size: 32
update_batch_size: 1 # Number of vectorised gradient updates per device.

rollout_length: 2 # Number of environment steps per vectorised environment.
rollout_length: 2 # Number of environment steps per vectorised enviro²nment.
Louay-Ben-nessir marked this conversation as resolved.
Show resolved Hide resolved
epochs: 2 # Number of learn epochs per training data batch.

# sizes
buffer_size: 5000 # size of the replay buffer. Note: total size is this * num_devices
buffer_size: 1000 # size of the replay buffer. Note: total size is this * num_devices
sample_batch_size: 32 # size of training data batch sampled from the buffer
sample_sequence_length: 20 # 20 transitions are sampled, giving 19 complete data points
sample_sequence_length: 32 # 20 transitions are sampled, giving 19 complete data points

# learning rates
q_lr: 3e-4 # the learning rate of the Q network network optimizer
Expand All @@ -31,3 +31,7 @@ gamma: 0.99 # discount factor

eps_min: 0.05
eps_decay: 1e5

# --- Sebulba parameters ---
data_sample_mean: 150 # Average number of times the learner should sample each item from the replay buffer.
Louay-Ben-nessir marked this conversation as resolved.
Show resolved Hide resolved
error_tolerance: 2 # Tolerance for how much the learner/actor can sample/insert before being blocked. Must be greater than 2 to avoid deadlocks.
3 changes: 1 addition & 2 deletions mava/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]:
if config.env.log_win_rate:
metrics["won_episode"] = timesteps.extras["won_episode"]

# find the first instance of done to get the metrics at that timestep, we don't
# care about subsequent steps because we only the results from the first episode
# Find the first instance of done to get the metrics at that timestep.
done_idx = np.argmax(timesteps.last(), axis=0)
metrics = tree.map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics)
del metrics["is_terminal_step"] # uneeded for logging
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/q_learning/anakin/rec_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
TrainState,
Transition,
)
from mava.types import MarlEnv, Observation
from mava.types import MarlEnv, MavaObservation
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
Expand Down Expand Up @@ -241,7 +241,7 @@ def make_update_fns(
) -> Callable[[LearnerState[QMIXParams]], Tuple[LearnerState[QMIXParams], Tuple[Metrics, Metrics]]]:
def select_eps_greedy_action(
action_selection_state: ActionSelectionState,
obs: Observation,
obs: MavaObservation,
term_or_trunc: Array,
) -> Tuple[ActionSelectionState, Array]:
"""Select action to take in eps-greedy way. Batch and agent dims are included."""
Expand Down Expand Up @@ -310,7 +310,7 @@ def action_step(action_state: ActionState, _: Any) -> Tuple[ActionState, Dict]:

return new_act_state, next_timestep.extras["episode_metrics"]

def prep_inputs_to_scannedrnn(obs: Observation, term_or_trunc: chex.Array) -> chex.Array:
def prep_inputs_to_scannedrnn(obs: MavaObservation, term_or_trunc: chex.Array) -> chex.Array:
"""Prepares the inputs to the RNN network for either getting q values or the
eps-greedy distribution.

Expand Down
Loading
Loading