Skip to content

Commit

Permalink
Fix tests and warnings when running locally with a GPU (#2069)
Browse files Browse the repository at this point in the history
* Fix test when GPU is available

* Sort file list for consistent results

* Ignore A2C warnings too
  • Loading branch information
araffin authored Jan 7, 2025
1 parent dba0baa commit 2b529e5
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ exclude = """(?x)(
env = ["PYTHONHASHSEED=0"]

filterwarnings = [
# A2C/PPO on GPU
"ignore:You are trying to run (PPO|A2C) on the GPU",
# Tensorboard warnings
"ignore::DeprecationWarning:tensorboard",
# Gymnasium warnings
Expand Down
12 changes: 6 additions & 6 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,16 +758,16 @@ def test_no_resource_warning(tmp_path):

# check that files are properly closed
# Create a PPO agent and save it
PPO("MlpPolicy", "CartPole-v1").save(tmp_path / "dqn_cartpole")
PPO.load(tmp_path / "dqn_cartpole")
PPO("MlpPolicy", "CartPole-v1", device="cpu").save(tmp_path / "dqn_cartpole")
PPO.load(tmp_path / "dqn_cartpole", device="cpu")

PPO("MlpPolicy", "CartPole-v1").save(str(tmp_path / "dqn_cartpole"))
PPO.load(str(tmp_path / "dqn_cartpole"))
PPO("MlpPolicy", "CartPole-v1", device="cpu").save(str(tmp_path / "dqn_cartpole"))
PPO.load(str(tmp_path / "dqn_cartpole"), device="cpu")

# Do the same but in memory, should not close the file
with tempfile.TemporaryFile() as fp:
PPO("MlpPolicy", "CartPole-v1").save(fp)
PPO.load(fp)
PPO("MlpPolicy", "CartPole-v1", device="cpu").save(fp)
PPO.load(fp, device="cpu")
assert not fp.closed

# Same but with replay buffer
Expand Down
1 change: 1 addition & 0 deletions tests/test_vec_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ def test_video_recorder(tmp_path):

# print all videos in video_folder, should be multiple step 0-100, step 1024-1124
video_files = list(map(str, tmp_path.glob("*.mp4")))
video_files.sort(reverse=True)

# Clean up
vec_env.close()
Expand Down

0 comments on commit 2b529e5

Please sign in to comment.