From dba0baa4910fdf548564d445495f6fe4fb8f45cb Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 7 Jan 2025 11:57:54 +0100 Subject: [PATCH] Fix mypy error (#2067) * Fix mypy error * Ignore new errors --- stable_baselines3/common/save_util.py | 2 +- stable_baselines3/common/vec_env/vec_normalize.py | 4 ++-- tests/test_vec_envs.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 3da23451b..d26d1e5d1 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -181,7 +181,7 @@ def json_to_data(json_string: str, custom_objects: Optional[dict[str, Any]] = No @functools.singledispatch def open_path( path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose: int = 0, suffix: Optional[str] = None -) -> Union[io.BufferedWriter, io.BufferedReader, io.BytesIO]: +) -> Union[io.BufferedWriter, io.BufferedReader, io.BytesIO, io.BufferedRandom]: """ Opens a path for reading or writing with a preferred suffix and raises debug information. If the provided path is a derivative of io.BufferedIOBase it ensures that the file diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index 5f0ee1c25..439243f9f 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -241,7 +241,7 @@ def normalize_obs(self, obs: Union[np.ndarray, dict[str, np.ndarray]]) -> Union[ assert self.norm_obs_keys is not None # Only normalize the specified keys for key in self.norm_obs_keys: - obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32) + obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32) # type: ignore[call-overload] else: assert isinstance(self.obs_rms, RunningMeanStd) obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32) @@ -265,7 +265,7 @@ def unnormalize_obs(self, obs: Union[np.ndarray, dict[str, np.ndarray]]) -> Unio if isinstance(obs, dict) and isinstance(self.obs_rms, dict): assert self.norm_obs_keys is not None for key in self.norm_obs_keys: - obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key]) + obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key]) # type: ignore[call-overload] else: assert isinstance(self.obs_rms, RunningMeanStd) obs_ = self._unnormalize_obs(obs, self.obs_rms) diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 15c6b32c8..5bc11f7f3 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -17,7 +17,7 @@ from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize, VecVideoRecorder try: - import moviepy + import moviepy # noqa: F401 have_moviepy = True except ImportError: