diff --git a/scripts/train_happo.py b/scripts/train_happo.py new file mode 100644 index 000000000..b27a7548f --- /dev/null +++ b/scripts/train_happo.py @@ -0,0 +1,15 @@ +import jax +from mava.algorithms.happo import HAPPO +from mava.configs.happo_config import HAPPOConfig +from mava.environments import create_environment +from mava.trainers import Trainer + +def main(): + config = HAPPOConfig() + environment = create_environment(config) + algorithm = HAPPO(config) + trainer = Trainer(config, environment, algorithm) + trainer.train() + +if __name__ == "__main__": + main()