Skip to content

Commit

Permalink
Multiple bug fixes in prioritized ma replay buffer and new mulit-agen…
Browse files Browse the repository at this point in the history
…t tuned example for DQN Cartpole.

Signed-off-by: Simon Zehnder <[email protected]>
  • Loading branch information
simonsays1980 committed May 17, 2024
1 parent e5f3d83 commit d9ab6bd
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 38 deletions.
76 changes: 76 additions & 0 deletions rllib/tuned_examples/dqn/multi_agent_cartpole_dqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
NUM_ENV_STEPS_SAMPLED_LIFETIME,
)
from ray.rllib.utils.test_utils import add_rllib_example_script_args
from ray.tune.registry import register_env

parser = add_rllib_example_script_args()
# Use `parser` to add your own custom command line options to this script
# and (if needed) use their values toset up `config` below.
args = parser.parse_args()

register_env(
"multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2})
)

config = (
DQNConfig()
.environment(env="multi_agent_cartpole")
.rl_module(
model_config_dict={
"fcnet_hiddens": [256],
"fcnet_activation": "relu",
"epsilon": [(0, 1.0), (10000, 0.02)],
"fcnet_bias_initializer": "zeros_",
"post_fcnet_bias_initializer": "zeros_",
"post_fcnet_hiddens": [256],
},
)
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.env_runners(
rollout_fragment_length=1,
num_env_runners=2,
num_envs_per_env_runner=1,
)
.training(
double_q=True,
num_atoms=1,
noisy=False,
dueling=True,
replay_buffer_config={
"type": "MultiAgentPrioritizedEpisodeReplayBuffer",
"capacity": 100000,
"alpha": 0.6,
"beta": 0.4,
},
)
.reporting(
metrics_num_episodes_for_smoothing=5,
min_sample_timesteps_per_iteration=1000,
)
.multi_agent(
policy_mapping_fn=lambda aid, *arg, **kw: f"p{aid}",
policies={"p0", "p1"},
)
.debugging(
log_level="DEBUG",
)
)

stop = {
NUM_ENV_STEPS_SAMPLED_LIFETIME: 500000,
# `episode_return_mean` is the sum of all agents/policies' returns.
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 450.0 * 2,
}

if __name__ == "__main__":
from ray.rllib.utils.test_utils import run_rllib_example_script_experiment

run_rllib_example_script_experiment(config, args, stop=stop)
31 changes: 9 additions & 22 deletions rllib/tuned_examples/sac/multi_agent_pendulum_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

register_env(
"multi_agent_pendulum",
lambda _: MultiAgentPendulum({"num_agents": 2})#args.num_agents or 2}),
lambda _: MultiAgentPendulum({"num_agents": 2}),
)

config = (
Expand Down Expand Up @@ -52,22 +52,22 @@
replay_buffer_config={
"type": "MultiAgentPrioritizedEpisodeReplayBuffer",
"capacity": 100000,
"alpha": 0.6,
"beta": 0.4,
"alpha": 1.0,
"beta": 0.0,
},
num_steps_sampled_before_learning_starts=256,
)
.reporting(
metrics_num_episodes_for_smoothing=5,
min_sample_timesteps_per_iteration=1000,
)
)

if True:#args.num_agents:
config.multi_agent(
.multi_agent(
policy_mapping_fn=lambda aid, *arg, **kw: f"p{aid}",
policies={"p0", "p1"},
)
)



stop = {
NUM_ENV_STEPS_SAMPLED_LIFETIME: 500000,
Expand All @@ -76,19 +76,6 @@
}

if __name__ == "__main__":
# from ray.rllib.utils.test_utils import run_rllib_example_script_experiment

# run_rllib_example_script_experiment(config, args, stop=stop)
from ray import tune, train
import ray
from ray.rllib.utils.test_utils import run_rllib_example_script_experiment


tuner = tune.Tuner(
"SAC",
param_space=config,
run_config=train.RunConfig(
stop=stop,
verbose=2,
),
)
tuner.fit()
run_rllib_example_script_experiment(config, args, stop=stop)
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def add(
)
new_eps_to_evict = episodes.pop(idx)
self._num_timesteps -= new_eps_to_evict.env_steps()
self._num_timesteps_added -= new_eps_to_evict.env_steps()
self._num_timesteps_added -= new_eps_to_evict.env_steps()
# Remove the timesteps of the evicted episode from the counter.
self._num_timesteps -= evicted_episode.env_steps()
self._num_agent_timesteps -= evicted_episode.agent_steps()
Expand Down Expand Up @@ -180,12 +180,16 @@ def add(
if self._module_to_max_idx[module_id] == idx_quadlet[3]
else 0
)
sample_idx = self._module_to_tree_idx_to_sample_idx[module_id][idx_quadlet[3]]
sample_idx = self._module_to_tree_idx_to_sample_idx[module_id][
idx_quadlet[3]
]
self._module_to_sum_segment[module_id][idx_quadlet[3]] = 0.0
self._module_to_min_segment[module_id][idx_quadlet[3]] = float(
"inf"
)
self._module_to_tree_idx_to_sample_idx[module_id].pop(sample_idx)
self._module_to_tree_idx_to_sample_idx[module_id].pop(
sample_idx
)
else:
new_module_indices.append(idx_quadlet)
self._module_to_tree_idx_to_sample_idx[module_id][
Expand Down Expand Up @@ -257,11 +261,9 @@ def _add_new_module_indices(
module_eps = ma_episode.agent_episodes[agent_id]
# Check if the module episode is already in the buffer.
if exists:
old_ma_episode = self.episodes[
ma_episode_idx
]
old_ma_episode = self.episodes[ma_episode_idx]
# Is the agent episode already in the buffer?
# Note, at this point we have not yet concatenated the new
# Note, at this point we have not yet concatenated the new
# multi-agent episode.
existing_eps_len = len(old_ma_episode.agent_episodes[agent_id])
else:
Expand Down Expand Up @@ -445,9 +447,13 @@ def _sample_independent(
# Sample proportionally from the replay buffer's module segments using the
# respective weights.
module_total_segment_sum = self._module_to_sum_segment[module_id].sum()
module_p_min = self._module_to_min_segment[module_id].min() / module_total_segment_sum
module_p_min = (
self._module_to_min_segment[module_id].min() / module_total_segment_sum
)
# TODO (simon): Allow individual betas per module.
module_max_weight = (module_p_min * self.get_num_timesteps(module_id)) ** (-beta)
module_max_weight = (module_p_min * self.get_num_timesteps(module_id)) ** (
-beta
)
B = 0
while B < batch_size_B:
# First, draw a random sample from Uniform(0, sum over all weights).
Expand All @@ -466,7 +472,8 @@ def _sample_independent(
)
# Get the theoretical probability mass for drawing this sample.
module_p_sample = (
self._module_to_sum_segment[module_id][module_idx] / module_total_segment_sum
self._module_to_sum_segment[module_id][module_idx]
/ module_total_segment_sum
)
# Compute the importance sampling weight.
module_weight = (
Expand Down Expand Up @@ -536,14 +543,12 @@ def _sample_independent(
# last time step, check, if the single-agent episode is terminated
# or truncated.
terminated=(
False
if sa_episode_ts + actual_n_step < len(sa_episode)
else sa_episode.is_terminated
sa_episode_ts + actual_n_step >= len(sa_episode)
and sa_episode.is_terminated
),
truncated=(
False
if sa_episode_ts + actual_n_step < len(sa_episode)
else sa_episode.is_truncated
sa_episode_ts + actual_n_step >= len(sa_episode)
and sa_episode.is_truncated
),
extra_model_outputs={
"weights": [
Expand Down Expand Up @@ -598,6 +603,7 @@ def update_priorities(
"""

assert len(priorities) == len(self._module_to_last_sampled_indices[module_id])

for idx, priority in zip(
self._module_to_last_sampled_indices[module_id], priorities
):
Expand Down

0 comments on commit d9ab6bd

Please sign in to comment.