diff --git a/reinforcement_learning/environment.py b/reinforcement_learning/environment.py index db59003f..7a270dc0 100644 --- a/reinforcement_learning/environment.py +++ b/reinforcement_learning/environment.py @@ -1,13 +1,17 @@ from argparse import Namespace +import nmmo +import nmmo.core.config as nc +import nmmo.core.game_api as ng import pufferlib import pufferlib.emulation - from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper +from syllabus.core import PettingZooMultiProcessingSyncWrapper +from syllabus_task_wrapper import NMMOTaskWrapper -import nmmo -import nmmo.core.config as nc -import nmmo.core.game_api as ng + +def alt_combat_damage_formula(offense, defense, multiplier, minimum_proportion): + return int(max(multiplier * offense - defense, offense * minimum_proportion)) class Config( @@ -23,7 +27,6 @@ class Config( nc.Exchange, ): """Configuration for Neural MMO.""" - def __init__(self, env_args: Namespace): super().__init__() @@ -48,11 +51,23 @@ def __init__(self, env_args: Namespace): self.set("CURRICULUM_FILE_PATH", env_args.curriculum_file_path) -def make_env_creator(reward_wrapper_cls: BaseParallelWrapper): +def make_env_creator(reward_wrapper_cls: BaseParallelWrapper, task_wrapper=False, curriculum=None): def env_creator(*args, **kwargs): """Create an environment.""" env = nmmo.Env(Config(kwargs["env"])) # args.env is provided as kwargs env = reward_wrapper_cls(env, **kwargs["reward_wrapper"]) + + # Add Syllabus task wrapper + if task_wrapper or curriculum is not None: + env = NMMOTaskWrapper(env) + + # Use curriculum if provided + if curriculum is not None: + # Add Syllabus Sync Wrapper + env = PettingZooMultiProcessingSyncWrapper( + env, curriculum.get_components(), update_on_step=False, task_space=env.task_space, + ) + env = pufferlib.emulation.PettingZooPufferEnv(env) return env diff --git a/reinforcement_learning/stat_wrapper.py b/reinforcement_learning/stat_wrapper.py index febcf865..c018e440 100644 --- a/reinforcement_learning/stat_wrapper.py +++ b/reinforcement_learning/stat_wrapper.py @@ -21,6 +21,9 @@ def __init__( self._reset_episode_stats() self._stat_prefix = stat_prefix + def seed(self, seed): + self.env.seed(seed) + def observation(self, agent_id, agent_obs): """Called before observations are returned from the environment Use this to define custom featurizers. Changing the space itself requires you to diff --git a/syllabus_task_wrapper.py b/syllabus_task_wrapper.py new file mode 100644 index 00000000..2aa18e34 --- /dev/null +++ b/syllabus_task_wrapper.py @@ -0,0 +1,435 @@ +""" Task wrapper for NMMO. """ +import gym +import numpy as np +from nmmo.core import realm +from nmmo.core.agent import Agent +from nmmo.core.tile import Tile +from nmmo.entity.entity import Entity +from nmmo.lib.material import Harvestable, Material +from nmmo.systems import skill as nmmo_skill +from nmmo.systems.item import Item +from nmmo.systems.skill import Skill +from nmmo.task import base_predicates as bp +from nmmo.systems import item as i +from nmmo.entity import entity as e +from nmmo.task import task_api, task_spec +from nmmo.task.base_predicates import AllDead, StayAlive +from nmmo.task.game_state import GameState +from nmmo.task.group import Group +from nmmo.task.task_api import OngoingTask, Task, make_predicate, make_same_task +from syllabus.core.task_interface import PettingZooTaskWrapper +from syllabus.task_space import TaskSpace + +import nmmo +from nmmo.lib import utils + + +class NMMOTaskWrapper(PettingZooTaskWrapper): + """ + Wrapper to handle tasks for the Neural MMO environment. + """ + # task_space = TaskSpace((18, 200), [tuple(np.arange(18)), tuple(np.arange(200))]) + task_space = TaskSpace(200) + + # task_space = TaskSpace((2719, 200), [tuple(np.arange(2719)), tuple(np.arange(200))]) + + def __init__(self, env: gym.Env): + super().__init__(env) + self.env = env + + self.task_list = self.sequential_task_list() + # self.task_list, task_names = self.create_manual_task_list() + # self.task_list = self._reformat_tasks(self.task_list) + self.task_space = NMMOTaskWrapper.task_space + self.task = None + self._task_index = None + self.task_fn = None + + def _parse_tasks(self): + """ + Parse LLM generated tasks from a python list of strings. + Strings are stored in tasks (imported from sample_tasks.py). + Returns a list of task functions. + """ + # Check if tasks are already parsed + # TODO: Find a way to efficiently compare current task list to parsed task list and overwrite if different + globs = globals() + if 'parsed_tasks' in globs: + parsed_tasks = [globs[task] for task in globs['parsed_tasks']] + return parsed_tasks + + parsed_predicates = [] + previous_globals = globs.copy() + for task in tasks: + # Check for common errors in LLM defined tasks + # NOTE: If the LLM decides to produce dangerous code, it will be executed + # DO NOT let users affect the input to the LLM. + try: + exec(task.strip(), globs) + except NameError as e: + print(f"\nFailed to parse task: {repr(e)}\n", task.strip(), "\n") + except SyntaxError as e: + print(f"\nFailed to parse task: {repr(e)}\n", task.strip(), "\n") + + for name, obj in globals().items(): + if name not in previous_globals: + parsed_predicates.append(obj) + + globals()['parsed_tasks'] = parsed_tasks + return parsed_tasks + + def reset(self, **kwargs): + seed = kwargs.pop("seed", None) + new_task = kwargs.pop("new_task", None) + if new_task is not None: + self.change_task(new_task) + task = new_task + self.task = new_task + new_task_specs = self.task_list[task] + self.task_fn = task_spec.make_task_from_spec( + self.env.possible_agents, [new_task_specs] * len(self.env.possible_agents) + ) + if seed is not None: + self.env.seed(int(seed)) + if seed is not None: + obs, info = self.env.reset(seed=int(seed), make_task_fn=(lambda: self.task_fn) if self.task_fn is not None else None, **kwargs) + else: + obs, info = self.env.reset(make_task_fn=(lambda: self.task_fn) if self.task_fn is not None else None, **kwargs) + + return self.observation(obs), info + + def change_task(self, new_task): + pass + + def step(self, action): + obs, rew, terms, truncs, info = self.env.step(action) + # obs[1]["Task"] = self._task_index + return self.observation(obs), rew, terms, truncs, info + + def action_space(self, agent): + """Implement Neural MMO's action_space method.""" + return self.env.action_space(agent) + + def create_original_task_list(self): + return [('agent', StayAlive, {'task_cls': OngoingTask})] + + def sequential_task_list(self): + # Stage 1 - Survival + stage1 = [] + stage1.append(task_spec.TaskSpec(bp.TickGE, {'num_tick': 500}, reward_to='agent')) + stage1.append(task_spec.TaskSpec(bp.CountEvent, {'event': "EAT_FOOD", 'N': 20}, reward_to='agent')) + stage1.append(task_spec.TaskSpec(bp.CountEvent, {'event': "DRINK_WATER", 'N': 20}, reward_to='agent')) + stage1.append(task_spec.TaskSpec(bp.CountEvent, {'event': "GO_FARTHEST", 'N': 20}, reward_to='agent')) + + # Stage 2 - Harvest Equiptment + stage2 = [] + stage2.append(task_spec.TaskSpec(bp.HarvestItem, {'item': i.Ammunition, 'level': 1, 'quantity': 20}, reward_to='agent')) + stage2.append(task_spec.TaskSpec(bp.HarvestItem, {'item': i.Weapon, 'level': 1, 'quantity': 20}, reward_to='agent')) + + # # Stage 3 - Equip Weapons + stage3 = [] + stage3.append(task_spec.TaskSpec(bp.EquipItem, {'item': i.Weapon, 'level': 1, 'num_agent': 1}, reward_to='agent')) + # stage3.append(task_spec.TaskSpec(bp.EquipItem, {'item': i.ammunition, 'level': 1, 'num_agent': 1}, reward_to='agent')) + stage3.append(task_spec.TaskSpec(bp.EquipItem, {'item': i.Weapon, 'level': 1, 'num_agent': 8}, reward_to='agent')) + # stage3.append(task_spec.TaskSpec(bp.EquipItem, {'item': i.ammunition, 'level': 1, 'num_agent': 8}, reward_to='agent')) + + # # Stage 4 - Fight + stage4 = [] + stage4.append(task_spec.TaskSpec(bp.CanSeeGroup, {'target': 'all_foes'}, reward_to='agent')) + stage4.append(task_spec.TaskSpec(bp.CountEvent, {'event': "SCORE_HIT", 'N': 20}, reward_to='agent')) + + # # Stage 5 - Kill + stage5 = [] + stage5.append(task_spec.TaskSpec(bp.DefeatEntity, {'agent_type': 'player', 'level': 1, 'num_agent': 1}, reward_to='agent')) + + return stage1 + stage2 + stage3 + stage4 + stage5 + + def create_manual_task_list(self): + STAY_ALIVE_GOAL = [50, 100, 150, 200, 300, 500] + # AGENT_NUM_GOAL = [1] # competition team size: 8 + task_specs = [] + task_names = [] + + # Find resource tiles + for resource in Harvestable: + for reward_to in ['agent']: + spec = task_spec.TaskSpec(bp.CanSeeTile, {'tile_type': resource}, reward_to=reward_to) + task_specs.append(spec) + # task_names.append("see_" + resource.name) + + # Stay alive + for reward_to in ['agent']: + for num_tick in STAY_ALIVE_GOAL: + spec = task_spec.TaskSpec(bp.TickGE, {'num_tick': num_tick}, reward_to=reward_to) + task_specs.append(spec) + # task_names.append("stay_alive_" + str(num_tick)) + + # Explore the map + for dist in [10, 20, 30, 50, 100]: # each agent + spec = task_spec.TaskSpec(bp.DistanceTraveled, {'dist': dist}, reward_to=reward_to) + task_specs.append(spec) + # task_names.append("explore_" + str(dist) + "m") + + return task_specs, task_names + + def _create_testing_task_list(self): + """ + Manually generate a list of tasks used for testing. + """ + EVENT_NUMBER_GOAL = [1, 2, 3, 4, 5, 7, 9, 12, 15, 20, 30, 50] + INFREQUENT_GOAL = list(range(1, 10)) + STAY_ALIVE_GOAL = [50, 100, 150, 200, 300, 500] + TEAM_NUMBER_GOAL = [10, 20, 30, 50, 70, 100] + LEVEL_GOAL = list(range(1, 10)) # TODO: get config + AGENT_NUM_GOAL = [1] # competition team size: 8 + ITEM_NUM_GOAL = AGENT_NUM_GOAL + TEAM_ITEM_GOAL = [1, 3, 5, 7, 10, 15, 20] + SKILLS = e.combat_skills + e.harvest_skills + COMBAT_STYLE = e.combat_skills + ALL_ITEM = i.armour + i.weapons + i.tools + i.ammunition + i.consumables + EQUIP_ITEM = i.armour + i.weapons + i.tools + i.ammunition + HARVEST_ITEM = i.weapons + i.ammunition + i.consumables + + """ task_specs is a list of tuple (reward_to, predicate class, kwargs) + + each tuple in the task_specswill create tasks for a team in teams + + reward_to: must be in ['team', 'agent'] + * 'team' create a single team task, in which all team members get rewarded + * 'agent' create a task for each agent, in which only the agent gets rewarded + + predicate class from the base predicates or custom predicates like above + + kwargs are the additional args that go into predicate. There are also special keys + * 'target' must be ['left_team', 'right_team', 'left_team_leader', 'right_team_leader'] + these str will be translated into the actual agent ids + * 'task_cls' is optional. If not provided, the standard Task is used. """ + task_specs = [] + + # explore, eat, drink, attack any agent, harvest any item, level up any skill + # which can happen frequently + essential_skills = ['GO_FARTHEST', 'EAT_FOOD', 'DRINK_WATER', + 'SCORE_HIT', 'HARVEST_ITEM', 'LEVEL_UP'] + for event_code in essential_skills: + task_specs += [('agent', bp.CountEvent, {'event': event_code, 'N': cnt}) + for cnt in EVENT_NUMBER_GOAL] + + # item/market skills, which happen less frequently or should not do too much + item_skills = ['CONSUME_ITEM', 'GIVE_ITEM', 'DESTROY_ITEM', 'EQUIP_ITEM', + 'GIVE_GOLD', 'LIST_ITEM', 'EARN_GOLD', 'BUY_ITEM'] + for event_code in item_skills: + task_specs += [('agent', bp.CountEvent, {'event': event_code, 'N': cnt}) + for cnt in INFREQUENT_GOAL] # less than 10 + + # find resource tiles + for resource in Harvestable: + for reward_to in ['agent', 'team']: + task_specs.append((reward_to, bp.CanSeeTile, {'tile_type': resource})) + + # stay alive ... like ... for 300 ticks + # i.e., getting incremental reward for each tick alive as an individual or a team + for reward_to in ['agent', 'team']: + for num_tick in STAY_ALIVE_GOAL: + task_specs.append((reward_to, bp.TickGE, {'num_tick': num_tick})) + + # protect the leader: get reward for each tick the leader is alive + task_specs.append(('team', bp.StayAlive, {'target': 'my_team_leader', 'task_cls': OngoingTask})) + + # want the other team or team leader to die + for target in ['left_team', 'left_team_leader', 'right_team', 'right_team_leader']: + task_specs.append(('team', bp.AllDead, {'target': target})) + + # occupy the center tile, assuming the Medium map size + # TODO: it'd be better to have some intermediate targets toward the center + for reward_to in ['agent', 'team']: + task_specs.append((reward_to, bp.OccupyTile, {'row': 80, 'col': 80})) # TODO: get config + + # form a tight formation, for a certain number of ticks + def PracticeFormation(gs, subject, dist, num_tick): + return bp.AllMembersWithinRange(gs, subject, dist) * bp.TickGE(gs, subject, num_tick) + + for dist in [1, 3, 5, 10]: + task_specs += [('team', PracticeFormation, {'dist': dist, 'num_tick': num_tick}) + for num_tick in STAY_ALIVE_GOAL] + + # find the other team leader + for reward_to in ['agent', 'team']: + for target in ['left_team_leader', 'right_team_leader']: + task_specs.append((reward_to, bp.CanSeeAgent, {'target': target})) + + # find the other team (any agent) + for reward_to in ['agent']: # , 'team']: + for target in ['left_team', 'right_team']: + task_specs.append((reward_to, bp.CanSeeGroup, {'target': target})) + + # explore the map -- sum the l-inf distance traveled by all subjects + for dist in [10, 20, 30, 50, 100]: # each agent + task_specs.append(('agent', bp.DistanceTraveled, {'dist': dist})) + for dist in [30, 50, 70, 100, 150, 200, 300, 500]: # summed over all team members + task_specs.append(('team', bp.DistanceTraveled, {'dist': dist})) + + # level up a skill + for skill in SKILLS: + for level in LEVEL_GOAL: + # since this is an agent task, num_agent must be 1 + task_specs.append(('agent', bp.AttainSkill, {'skill': skill, 'level': level, 'num_agent': 1})) + + # make attain skill a team task by varying the number of agents + for skill in SKILLS: + for level in LEVEL_GOAL: + for num_agent in AGENT_NUM_GOAL: + if level + num_agent <= 6 or num_agent == 1: # heuristic prune + task_specs.append(('team', bp.AttainSkill, + {'skill': skill, 'level': level, 'num_agent': num_agent})) + + # practice specific combat style + for style in COMBAT_STYLE: + for cnt in EVENT_NUMBER_GOAL: + task_specs.append(('agent', bp.ScoreHit, {'combat_style': style, 'N': cnt})) + for cnt in TEAM_NUMBER_GOAL: + task_specs.append(('team', bp.ScoreHit, {'combat_style': style, 'N': cnt})) + + # defeat agents of a certain level as a team + for agent_type in ['player', 'npc']: # c.AGENT_TYPE_CONSTRAINT + for level in LEVEL_GOAL: + for num_agent in AGENT_NUM_GOAL: + if level + num_agent <= 6 or num_agent == 1: # heuristic prune + task_specs.append(('team', bp.DefeatEntity, + {'agent_type': agent_type, 'level': level, 'num_agent': num_agent})) + + # hoarding gold -- evaluated on the current gold + for amount in EVENT_NUMBER_GOAL: + task_specs.append(('agent', bp.HoardGold, {'amount': amount})) + for amount in TEAM_NUMBER_GOAL: + task_specs.append(('team', bp.HoardGold, {'amount': amount})) + + # earning gold -- evaluated on the total gold earned by selling items + # does NOT include looted gold + for amount in EVENT_NUMBER_GOAL: + task_specs.append(('agent', bp.EarnGold, {'amount': amount})) + for amount in TEAM_NUMBER_GOAL: + task_specs.append(('team', bp.EarnGold, {'amount': amount})) + + # spending gold, by buying items + for amount in EVENT_NUMBER_GOAL: + task_specs.append(('agent', bp.SpendGold, {'amount': amount})) + for amount in TEAM_NUMBER_GOAL: + task_specs.append(('team', bp.SpendGold, {'amount': amount})) + + # making profits by trading -- only buying and selling are counted + for amount in EVENT_NUMBER_GOAL: + task_specs.append(('agent', bp.MakeProfit, {'amount': amount})) + for amount in TEAM_NUMBER_GOAL: + task_specs.append(('team', bp.MakeProfit, {'amount': amount})) + + # managing inventory space + def PracticeInventoryManagement(gs, subject, space, num_tick): + return bp.InventorySpaceGE(gs, subject, space) * bp.TickGE(gs, subject, num_tick) + for space in [2, 4, 8]: + task_specs += [('agent', PracticeInventoryManagement, {'space': space, 'num_tick': num_tick}) + for num_tick in STAY_ALIVE_GOAL] + + # own item, evaluated on the current inventory + for item in ALL_ITEM: + for level in LEVEL_GOAL: + # agent task + for quantity in ITEM_NUM_GOAL: + if level + quantity <= 6 or quantity == 1: # heuristic prune + task_specs.append(('agent', bp.OwnItem, + {'item': item, 'level': level, 'quantity': quantity})) + + # team task + for quantity in TEAM_ITEM_GOAL: + if level + quantity <= 10 or quantity == 1: # heuristic prune + task_specs.append(('team', bp.OwnItem, + {'item': item, 'level': level, 'quantity': quantity})) + + # equip item, evaluated on the current inventory and equipment status + for item in EQUIP_ITEM: + for level in LEVEL_GOAL: + # agent task + task_specs.append(('agent', bp.EquipItem, + {'item': item, 'level': level, 'num_agent': 1})) + + # team task + for num_agent in AGENT_NUM_GOAL: + if level + num_agent <= 6 or num_agent == 1: # heuristic prune + task_specs.append(('team', bp.EquipItem, + {'item': item, 'level': level, 'num_agent': num_agent})) + + # consume items (ration, potion), evaluated based on the event log + for item in i.consumables: + for level in LEVEL_GOAL: + # agent task + for quantity in ITEM_NUM_GOAL: + if level + quantity <= 6 or quantity == 1: # heuristic prune + task_specs.append(('agent', bp.ConsumeItem, + {'item': item, 'level': level, 'quantity': quantity})) + + # team task + for quantity in TEAM_ITEM_GOAL: + if level + quantity <= 10 or quantity == 1: # heuristic prune + task_specs.append(('team', bp.ConsumeItem, + {'item': item, 'level': level, 'quantity': quantity})) + + # harvest items, evaluated based on the event log + for item in HARVEST_ITEM: + for level in LEVEL_GOAL: + # agent task + for quantity in ITEM_NUM_GOAL: + if level + quantity <= 6 or quantity == 1: # heuristic prune + task_specs.append(('agent', bp.HarvestItem, + {'item': item, 'level': level, 'quantity': quantity})) + + # team task + for quantity in TEAM_ITEM_GOAL: + if level + quantity <= 10 or quantity == 1: # heuristic prune + task_specs.append(('team', bp.HarvestItem, + {'item': item, 'level': level, 'quantity': quantity})) + + # list items, evaluated based on the event log + for item in ALL_ITEM: + for level in LEVEL_GOAL: + # agent task + for quantity in ITEM_NUM_GOAL: + if level + quantity <= 6 or quantity == 1: # heuristic prune + task_specs.append(('agent', bp.ListItem, + {'item': item, 'level': level, 'quantity': quantity})) + + # team task + for quantity in TEAM_ITEM_GOAL: + if level + quantity <= 10 or quantity == 1: # heuristic prune + task_specs.append(('team', bp.ListItem, + {'item': item, 'level': level, 'quantity': quantity})) + + # buy items, evaluated based on the event log + for item in ALL_ITEM: + for level in LEVEL_GOAL: + # agent task + for quantity in ITEM_NUM_GOAL: + if level + quantity <= 6 or quantity == 1: # heuristic prune + task_specs.append(('agent', bp.BuyItem, + {'item': item, 'level': level, 'quantity': quantity})) + + # team task + for quantity in TEAM_ITEM_GOAL: + if level + quantity <= 10 or quantity == 1: # heuristic prune + task_specs.append(('team', bp.BuyItem, + {'item': item, 'level': level, 'quantity': quantity})) + + # fully armed, evaluated based on the current player/inventory status + for style in COMBAT_STYLE: + for level in LEVEL_GOAL: + for num_agent in AGENT_NUM_GOAL: + if level + num_agent <= 6 or num_agent == 1: # heuristic prune + task_specs.append(('team', bp.FullyArmed, + {'combat_style': style, 'level': level, 'num_agent': num_agent})) + + packaged_task_specs = [] + for spec in task_specs: + reward_to = spec[0] + eval_fn = spec[1] + eval_fn_kwargs = spec[2] + packaged_task_specs.append(task_spec.TaskSpec(eval_fn, eval_fn_kwargs, reward_to=reward_to)) + + return packaged_task_specs diff --git a/train.py b/train.py index 9d7c8048..9ee26c34 100644 --- a/train.py +++ b/train.py @@ -1,17 +1,19 @@ -# from pdb import set_trace as T -import importlib import argparse +import importlib import inspect import logging -import yaml -import time import sys +import time import pufferlib import pufferlib.utils +import yaml +from syllabus.core import MultiagentSharedCurriculumWrapper, make_multiprocessing_curriculum +from syllabus.curricula import SequentialCurriculum from reinforcement_learning import environment -from train_helper import init_wandb, train, sweep, generate_replay +from syllabus_task_wrapper import NMMOTaskWrapper +from train_helper import generate_replay, init_wandb, sweep, train DEBUG = False # See curriculum_generation/manual_curriculum.py for details @@ -57,15 +59,7 @@ def get_init_args(fn): return args -# Return env_creator, agent_creator -def setup_agent(module_name): - try: - agent_module = importlib.import_module(f"agent_zoo.{module_name}") - except ModuleNotFoundError: - raise ValueError(f"Agent module {module_name} not found under the agent_zoo directory.") - - env_creator = environment.make_env_creator(reward_wrapper_cls=agent_module.RewardWrapper) - +def setup_agent(agent_module): recurrent_policy = getattr(agent_module, "Recurrent", None) def agent_creator(env, args): @@ -76,16 +70,7 @@ def agent_creator(env, args): else: policy = pufferlib.frameworks.cleanrl.Policy(policy) return policy.to(args.train.device) - - init_args = { - "policy": get_init_args(agent_module.Policy.__init__), - "recurrent": get_init_args(agent_module.Recurrent.__init__) - if recurrent_policy is not None - else {}, - "reward_wrapper": get_init_args(agent_module.RewardWrapper.__init__), - } - - return agent_module, env_creator, agent_creator, init_args + return agent_creator def combine_config_args(parser, args, config): @@ -161,6 +146,33 @@ def update_args(args, mode=None): return args +def create_sequential_curriculum(task_space): + curricula = [] + stopping = [] + + # Stage 1 - Survival + stage1 = [0, 1, 2, 3] + stopping.append("episode_return>=0.9&episodes>=5000") + + # # Stage 2 - Harvest Equiptment + stage2 = [4, 5] + stopping.append("episode_return>=0.9&episodes>=5000") + + # # Stage 3 - Equip Weapons + stage3 = [6, 7] + stopping.append("episode_return>=0.9&episodes>=5000") + + # # Stage 4 - Fight + stage4 = [8, 9] + stopping.append("episode_return>=0.9&episodes>=5000") + + # # Stage 5 - Kill + stage5 = [10] + + curricula = [stage1, stage2, stage3, stage4, stage5] + return SequentialCurriculum(curricula, stopping, task_space, return_buffer_size=5000) + + if __name__ == "__main__": logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser(description="Parse environment argument", add_help=False) @@ -186,6 +198,12 @@ def update_args(args, mode=None): default=None, help="The index of the task to assign in the curriculum file", ) + parser.add_argument( + "--test-curriculum", type=str, default=BASELINE_CURRICULUM, help="Path to curriculum file" + ) + parser.add_argument( + "--syllabus", type=bool, default=False, help="Use Syllabus for curriculum" + ) # parser.add_argument('--baseline', action='store_true', help='Baseline run') parser.add_argument( "--vectorization", @@ -203,7 +221,17 @@ def update_args(args, mode=None): args = parser.parse_known_args()[0].__dict__ config = load_from_config(args["agent"], debug=args.get("debug", False)) - agent_module, env_creator, agent_creator, init_args = setup_agent(args["agent"]) + + try: + agent_module = importlib.import_module(f"agent_zoo.{args['agent']}") + except ModuleNotFoundError: + raise ValueError(f"Agent module {args['agent']} not found under the agent_zoo directory.") + + init_args = { + "policy": get_init_args(agent_module.Policy.__init__), + "recurrent": get_init_args(agent_module.Recurrent.__init__), + "reward_wrapper": get_init_args(agent_module.RewardWrapper.__init__), + } # Update config with environment defaults config.policy = {**init_args["policy"], **config.policy} @@ -216,6 +244,23 @@ def update_args(args, mode=None): # Perform mode-specific updates args = update_args(args, mode=args["mode"]) + sample_env_creator = environment.make_env_creator(reward_wrapper_cls=agent_module.RewardWrapper, task_wrapper=True) + + # Set up curriculum + curriculum = None + if args.syllabus: + sample_env = sample_env_creator(env=args.env, reward_wrapper=args.reward_wrapper) + task_space = NMMOTaskWrapper.task_space + curriculum = create_sequential_curriculum(task_space) + curriculum = MultiagentSharedCurriculumWrapper(curriculum, sample_env.possible_agents) + curriculum = make_multiprocessing_curriculum(curriculum) + else: + args.env.curriculum_file_path = args.curriculum + + env_creator = environment.make_env_creator(reward_wrapper_cls=agent_module.RewardWrapper, curriculum=curriculum) + + agent_creator = setup_agent(agent_module) + if args.train.env_pool is True: logging.warning( "Env_pool is enabled. This may increase training speed but break determinism."