Skip to content

Commit

Permalink
Fix mypy error (#2067)
Browse files Browse the repository at this point in the history
* Fix mypy error

* Ignore new errors
  • Loading branch information
araffin authored Jan 7, 2025
1 parent 57e8b97 commit dba0baa
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def json_to_data(json_string: str, custom_objects: Optional[dict[str, Any]] = No
@functools.singledispatch
def open_path(
path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose: int = 0, suffix: Optional[str] = None
) -> Union[io.BufferedWriter, io.BufferedReader, io.BytesIO]:
) -> Union[io.BufferedWriter, io.BufferedReader, io.BytesIO, io.BufferedRandom]:
"""
Opens a path for reading or writing with a preferred suffix and raises debug information.
If the provided path is a derivative of io.BufferedIOBase it ensures that the file
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/vec_env/vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def normalize_obs(self, obs: Union[np.ndarray, dict[str, np.ndarray]]) -> Union[
assert self.norm_obs_keys is not None
# Only normalize the specified keys
for key in self.norm_obs_keys:
obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32)
obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32) # type: ignore[call-overload]
else:
assert isinstance(self.obs_rms, RunningMeanStd)
obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32)
Expand All @@ -265,7 +265,7 @@ def unnormalize_obs(self, obs: Union[np.ndarray, dict[str, np.ndarray]]) -> Unio
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
assert self.norm_obs_keys is not None
for key in self.norm_obs_keys:
obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key])
obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key]) # type: ignore[call-overload]
else:
assert isinstance(self.obs_rms, RunningMeanStd)
obs_ = self._unnormalize_obs(obs, self.obs_rms)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_vec_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize, VecVideoRecorder

try:
import moviepy
import moviepy # noqa: F401

have_moviepy = True
except ImportError:
Expand Down

0 comments on commit dba0baa

Please sign in to comment.