Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat Sebulba recurrent IQL #1148

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
11 changes: 11 additions & 0 deletions mava/configs/default/rec_iql_sebulba.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
defaults:
- logger: logger
- arch: sebulba
- system: q_learning/rec_iql
- network: rnn # [rnn, rcnn]
- env: lbf_gym # [rware_gym, lbf_gym, smac_gym]
- _self_

hydra:
searchpath:
- file://mava/configs
8 changes: 6 additions & 2 deletions mava/configs/system/q_learning/rec_iql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ rollout_length: 2 # Number of environment steps per vectorised environment.
epochs: 2 # Number of learn epochs per training data batch.

# sizes
buffer_size: 5000 # size of the replay buffer. Note: total size is this * num_devices
buffer_size: 1000 # size of the replay buffer. Note: total size is this * num_devices
sample_batch_size: 32 # size of training data batch sampled from the buffer
sample_sequence_length: 20 # 20 transitions are sampled, giving 19 complete data points
sample_sequence_length: 32 # 20 transitions are sampled, giving 19 complete data points

# learning rates
q_lr: 3e-4 # the learning rate of the Q network network optimizer
Expand All @@ -31,3 +31,7 @@ gamma: 0.99 # discount factor

eps_min: 0.05
eps_decay: 1e5

# --- Sebulba parameters ---
replay_ratio : 4 # Average number of times the learner should sample each item from the replay buffer.
error_tolerance: ~ # Tolerance for how much the learner/actor can sample/insert before being blocked. Must be greater than 2 to avoid deadlocks.
3 changes: 1 addition & 2 deletions mava/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]:
if config.env.log_win_rate:
metrics["won_episode"] = timesteps.extras["won_episode"]

# find the first instance of done to get the metrics at that timestep, we don't
# care about subsequent steps because we only the results from the first episode
# Find the first instance of done to get the metrics at that timestep.
done_idx = np.argmax(timesteps.last(), axis=0)
metrics = tree.map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics)
del metrics["is_terminal_step"] # uneeded for logging
Expand Down
26 changes: 13 additions & 13 deletions mava/systems/ppo/sebulba/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@
)
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_sebulba_config, check_total_timesteps
from mava.utils.config import check_total_timesteps
from mava.utils.config import ppo_sebulba_checks as check_sebulba_config
from mava.utils.jax_utils import merge_leading_dims, switch_leading_axes
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.sebulba import ParamsSource, Pipeline, RecordTimeTo, ThreadLifetime
from mava.utils.sebulba.pipelines import Pipeline
from mava.utils.sebulba.utils import ParamsSource, RecordTimeTo
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
from mava.wrappers.gym import GymToJumanji
Expand All @@ -70,7 +72,7 @@ def rollout(
apply_fns: Tuple[ActorApply, CriticApply],
actor_device: int,
seeds: List[int],
thread_lifetime: ThreadLifetime,
stop_event: threading.Event,
) -> None:
"""Runs rollouts to collect trajectories from the environment.

Expand Down Expand Up @@ -110,7 +112,7 @@ def act_fn(
dones = np.repeat(timestep.last(), num_agents).reshape(num_envs, -1)

# Loop till the desired num_updates is reached.
while not thread_lifetime.should_stop():
while not stop_event.is_set():
# Rollout
traj: List[PPOTransition] = []
episode_metrics: List[Dict] = []
Expand Down Expand Up @@ -596,20 +598,18 @@ def run_experiment(_config: DictConfig) -> float:
inital_params = jax.device_put(learner_state.params, actor_devices[0]) # unreplicate

# the rollout queue/ the pipe between actor and learner
pipe_lifetime = ThreadLifetime()
pipe = Pipeline(config.arch.rollout_queue_size, learner_sharding, pipe_lifetime)
pipe = Pipeline(config.arch.rollout_queue_size, learner_sharding)
pipe.start()

params_sources: List[ParamsSource] = []
actor_threads: List[threading.Thread] = []
actor_lifetime = ThreadLifetime()
params_sources_lifetime = ThreadLifetime()
actors_stop_event = threading.Event()

# Create the actor threads
print(f"{Fore.BLUE}{Style.BRIGHT}Starting up actor threads...{Style.RESET_ALL}")
for actor_device in actor_devices:
# Create 1 params source per device
params_source = ParamsSource(inital_params, actor_device, params_sources_lifetime)
params_source = ParamsSource(inital_params, actor_device)
params_source.start()
params_sources.append(params_source)
# Create multiple rollout threads per actor device
Expand All @@ -630,7 +630,7 @@ def run_experiment(_config: DictConfig) -> float:
apply_fns,
actor_device,
seeds,
actor_lifetime,
actors_stop_event,
),
name=f"Actor-{actor_device}-{thread_id}",
)
Expand Down Expand Up @@ -708,19 +708,19 @@ def run_experiment(_config: DictConfig) -> float:

# Stop all the threads.
logger.stop()
actor_lifetime.stop()
actors_stop_event.set()
pipe.clear() # We clear the pipeline before stopping the actor threads to avoid deadlock
print(f"{Fore.RED}{Style.BRIGHT}Pipe cleared{Style.RESET_ALL}")
print(f"{Fore.RED}{Style.BRIGHT}Stopping actor threads...{Style.RESET_ALL}")
for actor in actor_threads:
actor.join()
print(f"{Fore.RED}{Style.BRIGHT}{actor.name} stopped{Style.RESET_ALL}")
print(f"{Fore.RED}{Style.BRIGHT}Stopping pipeline...{Style.RESET_ALL}")
pipe_lifetime.stop()
pipe.stop()
pipe.join()
print(f"{Fore.RED}{Style.BRIGHT}Stopping params sources...{Style.RESET_ALL}")
params_sources_lifetime.stop()
for params_source in params_sources:
params_source.stop()
params_source.join()
print(f"{Fore.RED}{Style.BRIGHT}All threads stopped...{Style.RESET_ALL}")

Expand Down
6 changes: 3 additions & 3 deletions mava/systems/q_learning/anakin/rec_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
TrainState,
Transition,
)
from mava.types import MarlEnv, Observation
from mava.types import MarlEnv, MavaObservation
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
Expand Down Expand Up @@ -209,7 +209,7 @@ def make_update_fns(
# ---- Acting functions ----

def select_eps_greedy_action(
action_selection_state: ActionSelectionState, obs: Observation, term_or_trunc: Array
action_selection_state: ActionSelectionState, obs: MavaObservation, term_or_trunc: Array
) -> Tuple[ActionSelectionState, Array]:
"""Select action to take in epsilon-greedy way. Batch and agent dims are included."""
params, hidden_state, t, key = action_selection_state
Expand Down Expand Up @@ -279,7 +279,7 @@ def action_step(action_state: ActionState, _: Any) -> Tuple[ActionState, Dict]:

# ---- Training functions ----

def prep_inputs_to_scannedrnn(obs: Observation, term_or_trunc: chex.Array) -> chex.Array:
def prep_inputs_to_scannedrnn(obs: MavaObservation, term_or_trunc: chex.Array) -> chex.Array:
"""Prepares the inputs to the RNN network for either getting q values or the
eps-greedy distribution.

Expand Down
6 changes: 3 additions & 3 deletions mava/systems/q_learning/anakin/rec_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
TrainState,
Transition,
)
from mava.types import MarlEnv, Observation
from mava.types import MarlEnv, MavaObservation
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
Expand Down Expand Up @@ -241,7 +241,7 @@ def make_update_fns(
) -> Callable[[LearnerState[QMIXParams]], Tuple[LearnerState[QMIXParams], Tuple[Metrics, Metrics]]]:
def select_eps_greedy_action(
action_selection_state: ActionSelectionState,
obs: Observation,
obs: MavaObservation,
term_or_trunc: Array,
) -> Tuple[ActionSelectionState, Array]:
"""Select action to take in eps-greedy way. Batch and agent dims are included."""
Expand Down Expand Up @@ -310,7 +310,7 @@ def action_step(action_state: ActionState, _: Any) -> Tuple[ActionState, Dict]:

return new_act_state, next_timestep.extras["episode_metrics"]

def prep_inputs_to_scannedrnn(obs: Observation, term_or_trunc: chex.Array) -> chex.Array:
def prep_inputs_to_scannedrnn(obs: MavaObservation, term_or_trunc: chex.Array) -> chex.Array:
"""Prepares the inputs to the RNN network for either getting q values or the
eps-greedy distribution.

Expand Down
Loading