diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 9e0a459..937b36d 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -155,7 +155,7 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals if it % self.save_interval == 0: self.save(os.path.join(self.log_dir, f"model_{it}.pt")) ep_infos.clear() - if it == start_iter: + if it == start_iter and self.cfg.get("store_code_state", True): # obtain all the diff files git_file_paths = store_code_state(self.log_dir, self.git_status_repos) # if possible store them to wandb @@ -210,6 +210,10 @@ def log(self, locs: dict, width: int = 80, pad: int = 35): "Train/mean_episode_length/time", statistics.mean(locs["lenbuffer"]), self.tot_time ) + # Video recording for wandb + if self.logger_type == "wandb": + self.writer.update_video_files(log_name="Video", fps=30) + str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m " if len(locs["rewbuffer"]) > 0: diff --git a/rsl_rl/utils/wandb_utils.py b/rsl_rl/utils/wandb_utils.py index 2868ce9..8552d9c 100644 --- a/rsl_rl/utils/wandb_utils.py +++ b/rsl_rl/utils/wandb_utils.py @@ -4,6 +4,8 @@ from __future__ import annotations import os +import json +import pathlib from dataclasses import asdict from torch.utils.tensorboard import SummaryWriter @@ -17,7 +19,7 @@ class WandbSummaryWriter(SummaryWriter): """Summary writer for Weights and Biases.""" def __init__(self, log_dir: str, flush_secs: int, cfg): - super().__init__(log_dir, flush_secs) + super().__init__(log_dir=log_dir, flush_secs=flush_secs) try: project = cfg["wandb_project"] @@ -27,14 +29,15 @@ def __init__(self, log_dir: str, flush_secs: int, cfg): try: entity = os.environ["WANDB_USERNAME"] except KeyError: - raise KeyError( - "Wandb username not found. Please run or add to ~/.bashrc: export WANDB_USERNAME=YOUR_USERNAME" - ) + entity = None + print("`WANDB_USERNAME` is not found! WandB will request your username in the interactive mode.") wandb.init(project=project, entity=entity) # Change generated name to project-number format wandb.run.name = project + wandb.run.name.split("-")[-1] + with open(os.path.join(log_dir, "wandb_info.json"), "w") as f: + json.dump({"wandb_run_id": wandb.run.id, "wandb_run_name": wandb.run.name}, f) self.name_map = { "Train/mean_reward/time": "Train/mean_reward_time", @@ -45,12 +48,18 @@ def __init__(self, log_dir: str, flush_secs: int, cfg): wandb.log({"log_dir": run_name}) + # Video logging bookkeeper + self.saved_video_files = {} + def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): wandb.config.update({"runner_cfg": runner_cfg}) wandb.config.update({"policy_cfg": policy_cfg}) wandb.config.update({"alg_cfg": alg_cfg}) wandb.config.update({"env_cfg": asdict(env_cfg)}) + def get_config(self): + return wandb.config + def _map_path(self, path): if path in self.name_map: return self.name_map[path] @@ -67,6 +76,29 @@ def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_sty ) wandb.log({self._map_path(tag): scalar_value}, step=global_step) + def update_video_files(self, log_name: str, fps: int): + # Check if there are new video files + log_dir = pathlib.Path(self.log_dir) + video_files = list(log_dir.rglob("*.mp4")) + for video_file in video_files: + file_size_kb = os.stat(str(video_file)).st_size / 1024 + # If it is new file + if str(video_file) not in self.saved_video_files: + self.saved_video_files[str(video_file)] = {"size": file_size_kb, "added": False, "count": 0} + else: + # Only upload if the file size is not changing anymore to avoid uploading non-ready video. + video_info = self.saved_video_files[str(video_file)] + if video_info["added"] is False and video_info["size"] == file_size_kb and file_size_kb > 100: + if video_info["count"] > 10: + print(f"[Wandb] Uploading {os.path.basename(str(video_file))}.") + wandb.log({log_name: wandb.Video(str(video_file), fps=fps)}) + self.saved_video_files[str(video_file)]["added"] = True + else: + video_info["count"] += 1 + else: + self.saved_video_files[str(video_file)]["size"] = file_size_kb + video_info["count"] = 0 + def stop(self): wandb.finish()