Currently everything is done through pip. TODO for conda env.
-
Clone this repo
-
pip install -r requirements.txt
-
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-
pip install -U numba
Change the environment setup and hyper-parameter settings in brax_rodent_run_ppo.py
, currently the config is:
config = {
"env_name": env_name,
"algo_name": "ppo",
"task_name": "run",
"num_envs": 2048,
"num_timesteps": 10_000_000,
"eval_every": 10_000,
"episode_length": 1000,
"num_evals": 1000,
"batch_size": 512,
"learning_rate": 3e-4,
"terminate_when_unhealthy": False
}
Caveat: On run.ai
cluster with Nvidia A40, we can only use the num_envs = 512
.
Use the followings script to run the training.
python brax_rodent_run.ppo