Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release 2.2.1: Hotfix file closing #1754

Merged
merged 4 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
Changelog
==========

Release 2.2.0 (2023-11-16)
Release 2.2.1 (2023-11-17)
--------------------------
**Support for options at reset, bug fixes and better error messages**

.. note::

SB3 v2.2.0 was yanked after a breaking change was found in `GH#1751 <https://github.com/DLR-RM/stable-baselines3/issues/1751>`_.
Please use SB3 v2.2.1 and not v2.2.0.


Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
Expand All @@ -32,7 +38,9 @@ Bug Fixes:
- Fixed success reward dtype in ``SimpleMultiObsEnv`` (@NixGD)
- Fixed check_env for Sequence observation space (@corentinlger)
- Prevents instantiating BitFlippingEnv with conflicting observation spaces (@kylesayrs)
- Fixed ResourceWarning when loading and saving models (files were not closed)
- Fixed ResourceWarning when loading and saving models (files were not closed), please note that only path are closed automatically,
the behavior stay the same for tempfiles (they need to be closed manually),
the behavior is now consistent when loading/saving replay buffer

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -76,6 +84,7 @@ Others:
- Switched to PyTorch 2.1.0 in the CI (fixes type annotations)
- Fixed ``stable_baselines3/common/policies.py`` type hints
- Switched to ``mypy`` only for checking types
- Added tests to check consistency when saving/loading files

Documentation:
^^^^^^^^^^^^^^
Expand Down
69 changes: 39 additions & 30 deletions stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,28 +308,31 @@ def save_to_zip_file(
:param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
with open_path(save_path, "w", verbose=0, suffix="zip") as save_path:
# data/params can be None, so do not
# try to serialize them blindly
file = open_path(save_path, "w", verbose=0, suffix="zip")
# data/params can be None, so do not
# try to serialize them blindly
if data is not None:
serialized_data = data_to_json(data)

# Create a zip-archive and write our objects there.
with zipfile.ZipFile(file, mode="w") as archive:
# Do not try to save "None" elements
if data is not None:
serialized_data = data_to_json(data)

# Create a zip-archive and write our objects there.
with zipfile.ZipFile(save_path, mode="w") as archive:
# Do not try to save "None" elements
if data is not None:
archive.writestr("data", serialized_data)
if pytorch_variables is not None:
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
th.save(pytorch_variables, pytorch_variables_file)
if params is not None:
for file_name, dict_ in params.items():
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
th.save(dict_, param_file)
# Save metadata: library version when file was saved
archive.writestr("_stable_baselines3_version", sb3.__version__)
# Save system info about the current python env
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])
archive.writestr("data", serialized_data)
if pytorch_variables is not None:
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
th.save(pytorch_variables, pytorch_variables_file)
if params is not None:
for file_name, dict_ in params.items():
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
th.save(dict_, param_file)
# Save metadata: library version when file was saved
archive.writestr("_stable_baselines3_version", sb3.__version__)
# Save system info about the current python env
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])

if isinstance(save_path, (str, pathlib.Path)):
file.close()


def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, verbose: int = 0) -> None:
Expand All @@ -344,10 +347,12 @@ def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, ver
:param obj: The object to save.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
with open_path(path, "w", verbose=verbose, suffix="pkl") as file_handler:
# Use protocol>=4 to support saving replay buffers >= 4Gb
# See https://docs.python.org/3/library/pickle.html
pickle.dump(obj, file_handler, protocol=pickle.HIGHEST_PROTOCOL)
file = open_path(path, "w", verbose=verbose, suffix="pkl")
# Use protocol>=4 to support saving replay buffers >= 4Gb
# See https://docs.python.org/3/library/pickle.html
pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
if isinstance(path, (str, pathlib.Path)):
file.close()


def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: int = 0) -> Any:
Expand All @@ -360,8 +365,11 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: in
path actually exists. If path is a io.BufferedIOBase the path exists.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
with open_path(path, "r", verbose=verbose, suffix="pkl") as file_handler:
return pickle.load(file_handler)
file = open_path(path, "r", verbose=verbose, suffix="pkl")
obj = pickle.load(file)
if isinstance(path, (str, pathlib.Path)):
file.close()
return obj


def load_from_zip_file(
Expand Down Expand Up @@ -391,14 +399,14 @@ def load_from_zip_file(
:return: Class parameters, model state_dicts (aka "params", dict of state_dict)
and dict of pytorch variables
"""
load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")
file = open_path(load_path, "r", verbose=verbose, suffix="zip")

# set device to cpu if cuda is not available
device = get_device(device=device)

# Open the zip archive and load data
try:
with zipfile.ZipFile(load_path) as archive:
with zipfile.ZipFile(file) as archive:
namelist = archive.namelist()
# If data or parameters is not in the
# zip archive, assume they were stored
Expand Down Expand Up @@ -451,5 +459,6 @@ def load_from_zip_file(
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e
finally:
load_path.close()
if isinstance(load_path, (str, pathlib.Path)):
file.close()
return data, params, pytorch_variables
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.0
2.2.1
27 changes: 27 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import os
import pathlib
import tempfile
import warnings
import zipfile
from collections import OrderedDict
Expand Down Expand Up @@ -752,7 +753,33 @@ def test_dqn_target_update_interval(tmp_path):
# Turn warnings into errors
@pytest.mark.filterwarnings("error")
def test_no_resource_warning(tmp_path):
# Check behavior of save/load
# see https://github.com/DLR-RM/stable-baselines3/issues/1751

# check that files are properly closed
# Create a PPO agent and save it
PPO("MlpPolicy", "CartPole-v1").save(tmp_path / "dqn_cartpole")
PPO.load(tmp_path / "dqn_cartpole")

PPO("MlpPolicy", "CartPole-v1").save(str(tmp_path / "dqn_cartpole"))
PPO.load(str(tmp_path / "dqn_cartpole"))

# Do the same but in memory, should not close the file
with tempfile.TemporaryFile() as fp:
PPO("MlpPolicy", "CartPole-v1").save(fp)
PPO.load(fp)
assert not fp.closed

# Same but with replay buffer
model = SAC("MlpPolicy", "Pendulum-v1", buffer_size=200)
model.save_replay_buffer(tmp_path / "replay")
model.load_replay_buffer(tmp_path / "replay")

model.save_replay_buffer(str(tmp_path / "replay"))
model.load_replay_buffer(str(tmp_path / "replay"))

with tempfile.TemporaryFile() as fp:
model.save_replay_buffer(fp)
fp.seek(0)
model.load_replay_buffer(fp)
assert not fp.closed