diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e986b134b..758615c39 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a1 (WIP) +Release 2.4.0a2 (WIP) -------------------------- Breaking Changes: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 48adc0106..e828d3c3d 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a1 +2.4.0a2 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 1dc178bd8..c7df7b26f 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -804,9 +804,8 @@ def test_save_load_net_arch_none(tmp_path): Test that the model is loaded correctly when net_arch is manually set to None. See GH#1928 """ - policy_kwargs = dict(net_arch=None) - model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs) - model.learn(100) - model.save(tmp_path / "ppo.zip") + PPO("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=None)).save(tmp_path / "ppo.zip") model = PPO.load(tmp_path / "ppo.zip") + # None has been replaced by the default net arch + assert model.policy.net_arch is not None os.remove(tmp_path / "ppo.zip")