Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix video recorder and add test #2063

Merged
merged 5 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.5.0a0 (WIP)
Release 2.5.0a1 (WIP)
--------------------------

Breaking Changes:
Expand Down Expand Up @@ -42,6 +42,14 @@ Documentation:
- Add FootstepNet Envs to the project page (@cgaspard3333)
- Added FRASA to the project page (@MarcDcls)

Release 2.4.1 (2024-12-20)
--------------------------

Bug Fixes:
^^^^^^^^^^
- Fixed a bug introduced in v2.4.0 where the ``VecVideoRecorder`` would override videos


Release 2.4.0 (2024-11-18)
--------------------------

Expand Down
23 changes: 12 additions & 11 deletions stable_baselines3/common/vec_env/vec_video_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class VecVideoRecorder(VecEnvWrapper):
:param name_prefix: Prefix to the video name
"""

video_name: str
video_path: str

def __init__(
self,
venv: VecEnv,
Expand All @@ -50,7 +53,7 @@ def __init__(

if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv):
metadata = temp_env.get_attr("metadata")[0]
else:
else: # pragma: no cover # assume gym interface
metadata = temp_env.metadata

self.env.metadata = metadata
Expand All @@ -67,15 +70,12 @@ def __init__(
self.step_id = 0
self.video_length = video_length

self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4"
self.video_path = os.path.join(self.video_folder, self.video_name)

self.recording = False
self.recorded_frames: list[np.ndarray] = []

try:
import moviepy # noqa: F401
except ImportError as e:
except ImportError as e: # pragma: no cover
raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install 'gymnasium[other]'`") from e

def reset(self) -> VecEnvObs:
Expand All @@ -85,6 +85,9 @@ def reset(self) -> VecEnvObs:
return obs

def _start_video_recorder(self) -> None:
# Update video name and path
self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4"
self.video_path = os.path.join(self.video_folder, self.video_name)
self._start_recording()
self._capture_frame()

Expand All @@ -109,8 +112,6 @@ def _capture_frame(self) -> None:
assert self.recording, "Cannot capture a frame, recording wasn't started."

frame = self.env.render()
if isinstance(frame, list):
frame = frame[-1]

if isinstance(frame, np.ndarray):
self.recorded_frames.append(frame)
Expand All @@ -123,12 +124,12 @@ def _capture_frame(self) -> None:
def close(self) -> None:
"""Closes the wrapper then the video recorder."""
VecEnvWrapper.close(self)
if self.recording:
if self.recording: # pragma: no cover
self._stop_recording()

def _start_recording(self) -> None:
"""Start a new recording. If it is already recording, stops the current recording before starting the new one."""
if self.recording:
if self.recording: # pragma: no cover
self._stop_recording()

self.recording = True
Expand All @@ -137,7 +138,7 @@ def _stop_recording(self) -> None:
"""Stop current recording and saves the video."""
assert self.recording, "_stop_recording was called, but no recording was started"

if len(self.recorded_frames) == 0:
if len(self.recorded_frames) == 0: # pragma: no cover
logger.warn("Ignored saving a video as there were zero frames to save.")
else:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
Expand All @@ -150,5 +151,5 @@ def _stop_recording(self) -> None:

def __del__(self) -> None:
"""Warn the user in case last video wasn't saved."""
if len(self.recorded_frames) > 0:
if len(self.recorded_frames) > 0: # pragma: no cover
logger.warn("Unable to save last video! Did you call close()?")
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.5.0a0
2.5.0a1
44 changes: 43 additions & 1 deletion tests/test_vec_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@
import pytest
from gymnasium import spaces

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize, VecVideoRecorder

try:
import moviepy

have_moviepy = True
except ImportError:
have_moviepy = False

N_ENVS = 3
VEC_ENV_CLASSES = [DummyVecEnv, SubprocVecEnv]
Expand Down Expand Up @@ -624,3 +632,37 @@ def test_render(vec_env_class):
vec_env.render()

vec_env.close()


@pytest.mark.skipif(not have_moviepy, reason="moviepy is not installed")
def test_video_recorder(tmp_path):
env_id = "CartPole-v1"
video_folder = str(tmp_path)

vec_env = make_vec_env(env_id, n_envs=1)

# Wrap to check unwrapping works
vec_env = VecNormalize(vec_env)

# Record the video starting at the first step
vec_env = VecVideoRecorder(
vec_env,
video_folder,
record_video_trigger=lambda x: x % 65 == 0,
video_length=10,
name_prefix=f"agent-{env_id}",
)

model = PPO("MlpPolicy", vec_env, n_steps=64, n_epochs=1, verbose=0)

model.learn(total_timesteps=128)

# print all videos in video_folder, should be multiple step 0-100, step 1024-1124
video_files = list(map(str, tmp_path.glob("*.mp4")))

# Clean up
vec_env.close()

assert len(video_files) == 2
assert "agent-CartPole-v1-step-65-to-step-75.mp4" in video_files[0]
assert "agent-CartPole-v1-step-0-to-step-10.mp4" in video_files[1]
Loading