From 2b529e57a0b7604e7a9be5a7f3f2346d8818296c Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 7 Jan 2025 14:19:05 +0100 Subject: [PATCH] Fix tests and warnings when running locally with a GPU (#2069) * Fix test when GPU is available * Sort file list for consistent results * Ignore A2C warnings too --- pyproject.toml | 2 ++ tests/test_save_load.py | 12 ++++++------ tests/test_vec_envs.py | 1 + 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 89af5a67f..a12e014be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,8 @@ exclude = """(?x)( env = ["PYTHONHASHSEED=0"] filterwarnings = [ + # A2C/PPO on GPU + "ignore:You are trying to run (PPO|A2C) on the GPU", # Tensorboard warnings "ignore::DeprecationWarning:tensorboard", # Gymnasium warnings diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 962088246..61c9feaf5 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -758,16 +758,16 @@ def test_no_resource_warning(tmp_path): # check that files are properly closed # Create a PPO agent and save it - PPO("MlpPolicy", "CartPole-v1").save(tmp_path / "dqn_cartpole") - PPO.load(tmp_path / "dqn_cartpole") + PPO("MlpPolicy", "CartPole-v1", device="cpu").save(tmp_path / "dqn_cartpole") + PPO.load(tmp_path / "dqn_cartpole", device="cpu") - PPO("MlpPolicy", "CartPole-v1").save(str(tmp_path / "dqn_cartpole")) - PPO.load(str(tmp_path / "dqn_cartpole")) + PPO("MlpPolicy", "CartPole-v1", device="cpu").save(str(tmp_path / "dqn_cartpole")) + PPO.load(str(tmp_path / "dqn_cartpole"), device="cpu") # Do the same but in memory, should not close the file with tempfile.TemporaryFile() as fp: - PPO("MlpPolicy", "CartPole-v1").save(fp) - PPO.load(fp) + PPO("MlpPolicy", "CartPole-v1", device="cpu").save(fp) + PPO.load(fp, device="cpu") assert not fp.closed # Same but with replay buffer diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 5bc11f7f3..7e4e5ec0b 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -659,6 +659,7 @@ def test_video_recorder(tmp_path): # print all videos in video_folder, should be multiple step 0-100, step 1024-1124 video_files = list(map(str, tmp_path.glob("*.mp4"))) + video_files.sort(reverse=True) # Clean up vec_env.close()