Skip to content

Commit

Permalink
Set up central PLR
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanNavillus committed Nov 11, 2024
1 parent 2aa6ad6 commit faad8e0
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 30 deletions.
2 changes: 1 addition & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ train:
max_grad_norm: 0.5
target_kl: ~

num_envs: 15
num_envs: 16
envs_per_worker: 1
envs_per_batch: 6
env_pool: True
Expand Down
24 changes: 23 additions & 1 deletion reinforcement_learning/clean_pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def create(
eval_model_path: str = None,
# Policy Pool options
policy_selector: callable = None,
curriculum=None,
):
if config is None:
config = pufferlib.args.CleanPuffeRL()
Expand Down Expand Up @@ -247,6 +248,8 @@ def create(
device=device,
start_time=start_time,
eval_mode=eval_mode,
curriculum=curriculum,
prev_value=None,
)


Expand Down Expand Up @@ -313,7 +316,7 @@ def evaluate(data):
next_lstm_state[0][:, env_id],
next_lstm_state[1][:, env_id],
)
print("puffer shape", o.shape)

actions, logprob, value, next_lstm_state = data.policy_pool.forwards(
o.to(data.device), next_lstm_state
)
Expand All @@ -325,6 +328,25 @@ def evaluate(data):

value = value.flatten()

# Syllabus curriculum update
if data.curriculum is not None and data.prev_value is not None:
tasks = [info["task_id"] for info in i["learner"]]
env_ids = [info["env_id"] for info in i["learner"]]

update = {
"update_type": "on_demand",
"metrics": {
"value": data.prev_value,
"next_value": value,
"rew": r,
"dones": d,
"tasks": tasks,
"env_ids": env_ids
},
}
data.curriculum.update(update)
data.prev_value = value

with misc_profiler:
actions = actions.cpu().numpy()

Expand Down
40 changes: 15 additions & 25 deletions syllabus_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Syllabus task wrapper for NMMO."""

import time
import numpy as np
import torch
import copy
Expand All @@ -8,7 +9,7 @@
from syllabus.task_space import TaskSpace
from syllabus.core.evaluator import CleanRLDiscreteEvaluator, Evaluator
from syllabus.core.task_interface import PettingZooTaskWrapper
from syllabus.curricula import SequentialCurriculum, PrioritizedLevelReplay
from syllabus.curricula import SequentialCurriculum, PrioritizedLevelReplay, CentralizedPrioritizedLevelReplay
from syllabus.core import MultiagentSharedCurriculumWrapper, make_multiprocessing_curriculum
from nmmo.task.task_api import OngoingTask
from nmmo.task.base_predicates import StayAlive
Expand Down Expand Up @@ -48,7 +49,6 @@ def __init__(self, agent, possible_agents, pad_obs, *args, **kwargs):
self.set_agent(agent)

def set_agent(self, agent):
print("Setting agent", agent)
original_device = "cuda"
agent.to(self.device)
self.agent = copy.deepcopy(agent)
Expand Down Expand Up @@ -105,8 +105,6 @@ def _prepare_state(self, state):
new_state.append(np.stack(padded_obs.values()))

state = torch.Tensor(np.stack(new_state)).to(self.device)
print("syllabus shape", state.shape)

return state

def _set_eval_mode(self):
Expand All @@ -125,20 +123,22 @@ def make_syllabus_env_creator(args, agent_module):

flat_observation = concatenate(flatten(sample_obs[sample_env.possible_agents[0]]))
pad_obs = flat_observation * 0
task_space = SyllabusTaskWrapper.task_space
task_space = SyllabusSeedWrapper.task_space
# curriculum = create_sequential_curriculum(task_space)
evaluator = PufferEvaluator(None, sample_env.possible_agents, pad_obs, device=args.train.device)
curriculum = PrioritizedLevelReplay(
curriculum = CentralizedPrioritizedLevelReplay(
task_space,
sample_env.observation_space,
num_steps=args.train.batch_rows*4,
num_processes=args.train.num_envs,
num_minibatches=1,
# sample_env.observation_space,
num_steps=args.train.batch_rows,
# num_processes=args.train.num_envs,
num_processes=args.train.num_envs * args.env.num_agents,
# num_minibatches=2,
# buffer_size=128,
gamma=args.train.gamma,
gae_lambda=args.train.gae_lambda,
task_sampler_kwargs_dict={"strategy": "value_l1"},
evaluator=evaluator,
lstm_size=args.recurrent.input_size,
# evaluator=evaluator,
# lstm_size=args.recurrent.input_size,
record_stats=True,
)
curriculum = MultiagentSharedCurriculumWrapper(curriculum, sample_env.possible_agents, joint_policy=True)
Expand Down Expand Up @@ -204,26 +204,16 @@ def __init__(self, env: gym.Env):

def reset(self, **kwargs):
seed = kwargs.pop("seed", None)
new_task = kwargs.pop("new_task", None)
if new_task is not None:
self.change_task(new_task)
seed = new_task
elif seed is not None:
self.change_task(seed)

if seed is not None:
obs, info = self.env.reset(seed=int(seed), **kwargs)
else:
obs, info = self.env.reset(**kwargs)

new_task = kwargs.pop("new_task", seed)
obs, info = super().reset(new_task=new_task, **kwargs)
return self.observation(obs), info

def change_task(self, new_task):
self.env.seed(int(new_task))
self.task = new_task

def step(self, action):
obs, rew, terms, truncs, info = self.env.step(action)
obs, rew, terms, truncs, info = super().step(action)
return self.observation(obs), rew, terms, truncs, info

def action_space(self, agent):
Expand Down
9 changes: 6 additions & 3 deletions train_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pufferlib.policy_pool as pp
from nmmo.render.replay_helper import FileReplayHelper
from nmmo.task.task_spec import make_task_from_spec

from syllabus.curricula import PrioritizedLevelReplay
from reinforcement_learning import clean_pufferl

# Related to torch.use_deterministic_algorithms(True)
Expand Down Expand Up @@ -56,10 +56,13 @@ def train(args, env_creator, agent_creator, syllabus=None):
vectorization=args.vectorization,
exp_name=args.exp_name,
track=args.track,
curriculum=syllabus,
)

syllabus.curriculum.curriculum.evaluator.set_agent(data.agent)
syllabus.start()
if syllabus is not None:
if isinstance(syllabus.curriculum, PrioritizedLevelReplay):
syllabus.curriculum.curriculum.evaluator.set_agent(data.agent)
syllabus.start()

while not clean_pufferl.done_training(data):
clean_pufferl.evaluate(data)
Expand Down

0 comments on commit faad8e0

Please sign in to comment.