Skip to content

Commit

Permalink
updated readme
Browse files Browse the repository at this point in the history
  • Loading branch information
kywch committed May 31, 2024
1 parent 83438ab commit 9130265
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 13 deletions.
60 changes: 60 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,63 @@ Baselines for Neural MMO (neuralmmo.github.io) -- new users should treat this re
</a>

[Documentation](https://neuralmmo.github.io "Neural MMO Documentation") is hosted by github.io.

## Installation

```
pip install -e .[dev]
```

## Training

To test if the installation was successful (with the `--debug` mode), run the following command:

```
python train.py --debug --no-track
```

To log the training process, edit the wandb section in `config.yaml` and remove `--no-track` from the command line. The `config.yaml` file contains various configuration settings for the project.

### Agent zoo and your custom policy

This baseline comes with four different models under the `agent_zoo` directory: `neurips23_start_kit`, `yaofeng`, `takeru`, and `hybrid`. You can use any of these models by specifying the `-a` argument.

```
python train.py -a hybrid
```

You can also create your own policy by creating a new module under the `agent_zoo` directory, which should contain `Policy`, `Recurrent`, and `RewardWrapper` classes.

### Curriculum Learning using Syllabus

The training script supports automatic curriculum learning using the [Syllabus](https://github.com/RyanNavillus/Syllabus) library. To use it, add `--syllabus` to the command line.

```
python train.py --syllabus
```

## Replay generation

The `policies` directory contains a set of trained policies. For your models, create a directory and copy the checkpoint files to it. To generate a replay, run the following command:

```
python train.py -m replay -p policies
```

The replay file ends with `.replay.lzma`. You can view the replay using the [web viewer](https://kywch.github.io/nmmo-client/).

## Evaluation

The evaluation script supports the pvp and pve modes. The pve mode spawns all agents using only one policy. The pvp mode spawns groups of agents, each controlled by a different policy.

To evaluate models in the `policies` directory, run the following command:

```
python evaluate.py policies pvp -r 10
```

This generates 10 results json files in the same directory (by using `-r 10`), each of which contains the results from 200 episodes. Then the task completion metrics can be viewed using:

```
python analysis/proc_eval_result.py policies
```
12 changes: 7 additions & 5 deletions analysis/proc_eval_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def process_eval_files(policy_store_dir, eval_prefix):
"seed": random_seed,
"count": summary["length"]["count"],
"length": summary["length"]["mean"],
"score": summary["avg_progress"],
"task_progress": summary["avg_progress"],
"weighted_score": summary["weighted_score"],
}
)
Expand All @@ -121,13 +121,13 @@ def process_eval_files(policy_store_dir, eval_prefix):
"mode": mode,
"seed": random_seed,
"count": task_data["count"],
"score": task_data["mean"],
"task_progress": task_data["mean"],
}
)

summ_df = pl.DataFrame(summ_policy).sort(["policy_name", "mode", "seed"])
summ_grp = summ_df.group_by(["policy_name", "mode"]).agg(
pl.col("score").mean(),
pl.col("task_progress").mean(),
pl.col("weighted_score").mean(),
)
summ_grp = summ_grp.sort("weighted_score", descending=True)
Expand All @@ -139,13 +139,15 @@ def process_eval_files(policy_store_dir, eval_prefix):

task_df = pl.DataFrame(summ_task).sort(["mode", "category", "task_name", "policy_name", "seed"])
task_grp = task_df.group_by(["mode", "category", "task_name", "policy_name"]).agg(
pl.col("score").mean()
pl.col("task_progress").mean()
)
task_grp = task_grp.sort(["mode", "category", "task_name", "policy_name"])
task_grp.write_csv(
os.path.join(policy_store_dir, "score_task_summary.tsv"), separator="\t", float_precision=6
)
cate_grp = task_df.group_by(["mode", "category", "policy_name"]).agg(pl.col("score").mean())
cate_grp = task_df.group_by(["mode", "category", "policy_name"]).agg(
pl.col("task_progress").mean()
)
cate_grp = cate_grp.sort(["mode", "category", "policy_name"])
cate_grp.write_csv(
os.path.join(policy_store_dir, "score_category_summary.tsv"),
Expand Down
2 changes: 1 addition & 1 deletion policies/score_by_seed.tsv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
policy_name mode seed count length score weighted_score
policy_name mode seed count length task_progress weighted_score
baseline_10M pvp 1 2875 113.036870 0.057293 7.711149
baseline_10M pvp 12196392 2856 113.172619 0.053504 6.621138
baseline_10M pvp 19525770 2856 111.803221 0.051267 7.079278
Expand Down
2 changes: 1 addition & 1 deletion policies/score_by_task_seed.tsv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
category task_name weight policy_name mode seed count score
category task_name weight policy_name mode seed count task_progress
combat curriculum/Task_CountEvent_(event:PLAYER_KILL_N:20)_reward_to:agent 5.555556 baseline_10M pvp 1 44 0.096591
combat curriculum/Task_CountEvent_(event:PLAYER_KILL_N:20)_reward_to:agent 5.555556 baseline_10M pvp 12196392 51 0.093137
combat curriculum/Task_CountEvent_(event:PLAYER_KILL_N:20)_reward_to:agent 5.555556 baseline_10M pvp 19525770 41 0.078049
Expand Down
2 changes: 1 addition & 1 deletion policies/score_category_summary.tsv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mode category policy_name score
mode category policy_name task_progress
pvp combat baseline_10M 0.060323
pvp combat learner 0.022414
pvp combat takeru_100M 0.141534
Expand Down
2 changes: 1 addition & 1 deletion policies/score_summary.tsv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
policy_name mode score weighted_score
policy_name mode task_progress weighted_score
yaofeng_200M pvp 0.260633 34.119108
yaofeng_100M pvp 0.213844 28.731941
yaofeng_50M pvp 0.167206 22.670012
Expand Down
2 changes: 1 addition & 1 deletion policies/score_task_summary.tsv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mode category task_name policy_name score
mode category task_name policy_name task_progress
pvp combat curriculum/Task_CountEvent_(event:PLAYER_KILL_N:20)_reward_to:agent baseline_10M 0.090335
pvp combat curriculum/Task_CountEvent_(event:PLAYER_KILL_N:20)_reward_to:agent learner 0.039319
pvp combat curriculum/Task_CountEvent_(event:PLAYER_KILL_N:20)_reward_to:agent takeru_100M 0.198185
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ requires = ["pip>=23.0", "setuptools>=61.0", "wheel"]
[project]
name = "nmmo2-baselines"
version = "0.1.0"
description = "Neural MMO 2.1 baselines"
description = "Neural MMO 2023 competition baselines"
keywords = []
classifiers = [
"Natural Language :: English",
Expand All @@ -16,7 +16,7 @@ classifiers = [
]
dependencies = [
"accelerate==0.27.2",
"nmmo@git+https://github.com/kywch/nmmo-environment", # WIP nmmo 2.1
"nmmo@git+https://github.com/NeuralMMO/environment@2.1",
"polars==0.20.21",
"pufferlib[nmmo]>=0.7.3",
"psutil==5.9.8",
Expand Down
4 changes: 3 additions & 1 deletion train_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def generate_replay(args, env_creator, agent_creator, stop_when_all_complete_tas
# Add the policy names to agent names
if len(policies) > 1:
for policy_id, samp in data.policy_pool.sample_idxs.items():
policy_name = data.policy_pool.current_policies[policy_id]["name"]
policy_name = "learner"
if policy_id in data.policy_pool.current_policies:
policy_name = data.policy_pool.current_policies[policy_id]["name"]
for idx in samp:
agent_id = idx + 1 # agents are 0-indexed in policy_pool, but 1-indexed in nmmo
nmmo_env.realm.players[agent_id].name = f"{policy_name}_{agent_id}"
Expand Down

0 comments on commit 9130265

Please sign in to comment.