From 5b19e37eb183aa23e4bf79611ac9a646242168aa Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 21 Dec 2024 08:24:25 +0100 Subject: [PATCH] Fix video recorder and add test (#2063) * Fix video recorder and add test * Update github CI * Install ffmpeg * Revert "Update github CI" This reverts commit 07791e97fccae4f003b2909428b23f59557d7034. * Skip VecVideoRecorder test on github --- docs/misc/changelog.rst | 8 ++++ .../common/vec_env/vec_video_recorder.py | 23 +++++----- stable_baselines3/version.txt | 2 +- tests/test_vec_envs.py | 44 ++++++++++++++++++- 4 files changed, 64 insertions(+), 13 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 43ba33e48..d8ce22df5 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,14 @@ Changelog ========== +Release 2.4.1 (2024-12-20) +-------------------------- + +Bug Fixes: +^^^^^^^^^^ +- Fixed a bug introduced in v2.4.0 where the ``VecVideoRecorder`` would override videos + + Release 2.4.0 (2024-11-18) -------------------------- diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index e586f94ab..10542226f 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -29,6 +29,9 @@ class VecVideoRecorder(VecEnvWrapper): :param name_prefix: Prefix to the video name """ + video_name: str + video_path: str + def __init__( self, venv: VecEnv, @@ -50,7 +53,7 @@ def __init__( if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv): metadata = temp_env.get_attr("metadata")[0] - else: + else: # pragma: no cover # assume gym interface metadata = temp_env.metadata self.env.metadata = metadata @@ -67,15 +70,12 @@ def __init__( self.step_id = 0 self.video_length = video_length - self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4" - self.video_path = os.path.join(self.video_folder, self.video_name) - self.recording = False self.recorded_frames: list[np.ndarray] = [] try: import moviepy # noqa: F401 - except ImportError as e: + except ImportError as e: # pragma: no cover raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install 'gymnasium[other]'`") from e def reset(self) -> VecEnvObs: @@ -85,6 +85,9 @@ def reset(self) -> VecEnvObs: return obs def _start_video_recorder(self) -> None: + # Update video name and path + self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4" + self.video_path = os.path.join(self.video_folder, self.video_name) self._start_recording() self._capture_frame() @@ -109,8 +112,6 @@ def _capture_frame(self) -> None: assert self.recording, "Cannot capture a frame, recording wasn't started." frame = self.env.render() - if isinstance(frame, List): - frame = frame[-1] if isinstance(frame, np.ndarray): self.recorded_frames.append(frame) @@ -123,12 +124,12 @@ def _capture_frame(self) -> None: def close(self) -> None: """Closes the wrapper then the video recorder.""" VecEnvWrapper.close(self) - if self.recording: + if self.recording: # pragma: no cover self._stop_recording() def _start_recording(self) -> None: """Start a new recording. If it is already recording, stops the current recording before starting the new one.""" - if self.recording: + if self.recording: # pragma: no cover self._stop_recording() self.recording = True @@ -137,7 +138,7 @@ def _stop_recording(self) -> None: """Stop current recording and saves the video.""" assert self.recording, "_stop_recording was called, but no recording was started" - if len(self.recorded_frames) == 0: + if len(self.recorded_frames) == 0: # pragma: no cover logger.warn("Ignored saving a video as there were zero frames to save.") else: from moviepy.video.io.ImageSequenceClip import ImageSequenceClip @@ -150,5 +151,5 @@ def _stop_recording(self) -> None: def __del__(self) -> None: """Warn the user in case last video wasn't saved.""" - if len(self.recorded_frames) > 0: + if len(self.recorded_frames) > 0: # pragma: no cover logger.warn("Unable to save last video! Did you call close()?") diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 197c4d5c2..005119baa 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0 +2.4.1 diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 3aa52762d..0dc3404c4 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -11,9 +11,17 @@ import pytest from gymnasium import spaces +from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.monitor import Monitor -from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize +from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize, VecVideoRecorder + +try: + import moviepy + + have_moviepy = True +except ImportError: + have_moviepy = False N_ENVS = 3 VEC_ENV_CLASSES = [DummyVecEnv, SubprocVecEnv] @@ -624,3 +632,37 @@ def test_render(vec_env_class): vec_env.render() vec_env.close() + + +@pytest.mark.skipif(not have_moviepy, reason="moviepy is not installed") +def test_video_recorder(tmp_path): + env_id = "CartPole-v1" + video_folder = str(tmp_path) + + vec_env = make_vec_env(env_id, n_envs=1) + + # Wrap to check unwrapping works + vec_env = VecNormalize(vec_env) + + # Record the video starting at the first step + vec_env = VecVideoRecorder( + vec_env, + video_folder, + record_video_trigger=lambda x: x % 65 == 0, + video_length=10, + name_prefix=f"agent-{env_id}", + ) + + model = PPO("MlpPolicy", vec_env, n_steps=64, n_epochs=1, verbose=0) + + model.learn(total_timesteps=128) + + # print all videos in video_folder, should be multiple step 0-100, step 1024-1124 + video_files = list(map(str, tmp_path.glob("*.mp4"))) + + # Clean up + vec_env.close() + + assert len(video_files) == 2 + assert "agent-CartPole-v1-step-65-to-step-75.mp4" in video_files[0] + assert "agent-CartPole-v1-step-0-to-step-10.mp4" in video_files[1]