Skip to content

Commit

Permalink
Merge
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanNavillus committed Nov 26, 2024
1 parent 3c678d5 commit 4ea28e4
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 24 deletions.
4 changes: 2 additions & 2 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ train:
seed: 1
torch_deterministic: True
device: cuda
total_timesteps: 10_000_000
total_timesteps: 20_000_000
learning_rate: 1.5e-4
anneal_lr: True
gamma: 0.99
Expand All @@ -32,7 +32,7 @@ train:
max_grad_norm: 0.5
target_kl: ~

num_envs: 16
num_envs: 15
envs_per_worker: 1
envs_per_batch: 8
env_pool: True
Expand Down
21 changes: 10 additions & 11 deletions reinforcement_learning/clean_pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import pufferlib.vectorization
import pufferlib.frameworks.cleanrl
import pufferlib.policy_pool

from syllabus.curricula import CentralizedPrioritizedLevelReplay
SKIP_LOG_KEYS = ["curriculum/Task_", "env_id"]


Expand Down Expand Up @@ -123,15 +123,14 @@ def create(
resume_state = {}
path = os.path.join(config.data_dir, exp_name)
if False and os.path.exists(path):
pass
# trainer_path = os.path.join(path, "trainer_state.pt")
# resume_state = torch.load(trainer_path)
# model_path = os.path.join(path, resume_state["model_name"])
# agent = torch.load(model_path, map_location=device)
# print(
# f'Resumed from update {resume_state["update"]} '
# f'with policy {resume_state["model_name"]}'
# )
trainer_path = os.path.join(path, "trainer_state.pt")
resume_state = torch.load(trainer_path)
model_path = os.path.join(path, resume_state["model_name"])
agent = torch.load(model_path, map_location=device)
print(
f'Resumed from update {resume_state["update"]} '
f'with policy {resume_state["model_name"]}'
)
elif not eval_mode:
agent = pufferlib.emulation.make_object(
agent, agent_creator, [pool.driver_env], agent_kwargs
Expand Down Expand Up @@ -330,7 +329,7 @@ def evaluate(data):
value = value.flatten()

# Syllabus curriculum update
if data.curriculum is not None and data.prev_value is not None:
if data.curriculum is not None and isinstance(data.curriculum.curriculum.curriculum, CentralizedPrioritizedLevelReplay) and data.prev_value is not None:
tasks = [info["task_id"] for info in i["learner"]]
env_ids = [info["env_id"] for info in i["learner"]]

Expand Down
7 changes: 4 additions & 3 deletions syllabus_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from syllabus.task_space import DiscreteTaskSpace
from syllabus.core.evaluator import CleanRLDiscreteEvaluator, Evaluator
from syllabus.core.task_interface import PettingZooTaskWrapper
from syllabus.curricula import SequentialCurriculum, PrioritizedLevelReplay, CentralizedPrioritizedLevelReplay
from syllabus.curricula import SequentialCurriculum, PrioritizedLevelReplay, CentralizedPrioritizedLevelReplay, DomainRandomization
from syllabus.core import MultiagentSharedCurriculumWrapper, make_multiprocessing_curriculum
from nmmo.task.task_api import OngoingTask
from nmmo.task.base_predicates import StayAlive
Expand Down Expand Up @@ -125,7 +125,7 @@ def make_syllabus_env_creator(args, agent_module):
pad_obs = flat_observation * 0
task_space = SyllabusSeedWrapper.task_space
# curriculum = create_sequential_curriculum(task_space)
evaluator = PufferEvaluator(None, sample_env.possible_agents, pad_obs, device=args.train.device)
# evaluator = PufferEvaluator(None, sample_env.possible_agents, pad_obs, device=args.train.device)
curriculum = CentralizedPrioritizedLevelReplay(
task_space,
# sample_env.observation_space,
Expand All @@ -136,11 +136,12 @@ def make_syllabus_env_creator(args, agent_module):
# buffer_size=128,
gamma=args.train.gamma,
gae_lambda=args.train.gae_lambda,
task_sampler_kwargs_dict={"strategy": "value_l1"},
task_sampler_kwargs_dict={"strategy": "value_l1", "temperature": 0.3, "staleness_coef": 0.3, "alpha": 0.25},
# evaluator=evaluator,
# lstm_size=args.recurrent.input_size,
record_stats=True,
)
# curriculum = DomainRandomization(task_space)
curriculum = MultiagentSharedCurriculumWrapper(curriculum, sample_env.possible_agents, joint_policy=True)
curriculum = make_multiprocessing_curriculum(curriculum, start=False)

Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def update_args(args, mode=None):
args.exp_name = f"nmmo_{time.strftime('%Y%m%d_%H%M%S')}"

if args.mode == "train":
train(args, env_creator, agent_creator, syllabus)
train(args, env_creator, agent_creator, agent_module, syllabus=syllabus)
exit(0)
elif args.mode == "sweep":
sweep(args, env_creator, agent_creator)
Expand Down
19 changes: 12 additions & 7 deletions train_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from nmmo.task.task_spec import make_task_from_spec
from syllabus.curricula import PrioritizedLevelReplay
from reinforcement_learning import clean_pufferl, environment
from syllabus_wrapper import PufferEvaluator, SyllabusSeedWrapper, concatenate
from pufferlib.extensions import flatten

# Related to torch.use_deterministic_algorithms(True)
# See also https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
Expand All @@ -27,29 +29,31 @@ def init_wandb(args, resume=True):
assert args.wandb.project is not None, "Please set the wandb project in config.yaml"
assert args.wandb.entity is not None, "Please set the wandb entity in config.yaml"
wandb_kwargs = {
"id": args.exp_name or wandb.util.generate_id(),
"id": wandb.util.generate_id(),
"project": args.wandb.project,
"entity": args.wandb.entity,
"config": {
"cleanrl": args.train,
"env": args.env,
"cleanrl": vars(args.train),
"env": vars(args.env),
"agent_zoo": args.agent,
"policy": args.policy,
"recurrent": args.recurrent,
"reward_wrapper": args.reward_wrapper,
"policy": vars(args.policy),
"recurrent": vars(args.recurrent),
"reward_wrapper": vars(args.reward_wrapper),
"syllabus": args.syllabus,
"all": vars(args),
},
"name": args.exp_name,
"monitor_gym": True,
"save_code": True,
"resume": False,
"dir": "/fs/nexus-scratch/rsulli/nmmo-wandb",
}
if args.wandb.group is not None:
wandb_kwargs["group"] = args.wandb.group
return wandb.init(**wandb_kwargs)


def train(args, env_creator, agent_creator, syllabus=None):
def train(args, env_creator, agent_creator, agent_module, syllabus=None):
data = clean_pufferl.create(
config=args.train,
agent_creator=agent_creator,
Expand Down Expand Up @@ -201,6 +205,7 @@ def evaluate_agent(args, data, env_outputs, train_wandb, global_step):
# print("logging", global_step)
train_wandb.log({
"global_step": global_step,
"eval/return": np.mean(eval_returns),
**{f"{k}": v for k, v in data.stats.items()}
})

Expand Down

0 comments on commit 4ea28e4

Please sign in to comment.