diff --git a/predicators/args.py b/predicators/args.py index 2de1ec07ed..56a085a454 100644 --- a/predicators/args.py +++ b/predicators/args.py @@ -28,6 +28,7 @@ def create_arg_parser(env_required: bool = True, parser.add_argument("--make_failure_videos", action="store_true") parser.add_argument("--make_interaction_videos", action="store_true") parser.add_argument("--make_demo_videos", action="store_true") + parser.add_argument("--make_cogman_videos", action="store_true") parser.add_argument("--load_approach", action="store_true") # In the case of online learning approaches, load_approach by itself # will try to load an approach on *every* online learning cycle. diff --git a/predicators/cogman.py b/predicators/cogman.py index e73b48bfda..62a28f5193 100644 --- a/predicators/cogman.py +++ b/predicators/cogman.py @@ -11,13 +11,14 @@ import logging from typing import Callable, List, Optional, Sequence, Set +from predicators import utils from predicators.approaches import BaseApproach from predicators.execution_monitoring import BaseExecutionMonitor from predicators.perception import BasePerceiver from predicators.settings import CFG from predicators.structs import Action, Dataset, EnvironmentTask, GroundAtom, \ InteractionRequest, InteractionResult, LowLevelTrajectory, Metrics, \ - Observation, State, Task + Observation, State, Task, Video class CogMan: @@ -32,13 +33,18 @@ def __init__(self, approach: BaseApproach, perceiver: BasePerceiver, self._current_goal: Optional[Set[GroundAtom]] = None self._override_policy: Optional[Callable[[State], Action]] = None self._termination_fn: Optional[Callable[[State], bool]] = None + self._current_env_task: Optional[EnvironmentTask] = None self._episode_state_history: List[State] = [] self._episode_action_history: List[Action] = [] + self._episode_images: Video = [] + self._episode_num = -1 def reset(self, env_task: EnvironmentTask) -> None: """Start a new episode of environment interaction.""" logging.info("[CogMan] Reset called.") + self._episode_num += 1 task = self._perceiver.reset(env_task) + self._current_env_task = env_task self._current_goal = task.goal self._reset_policy(task) self._exec_monitor.reset(task) @@ -46,10 +52,19 @@ def reset(self, env_task: EnvironmentTask) -> None: self._approach.get_execution_monitoring_info()) self._episode_state_history = [task.init] self._episode_action_history = [] + self._episode_images = [] + if CFG.make_cogman_videos: + imgs = self._perceiver.render_mental_images(task.init, env_task) + self._episode_images.extend(imgs) def step(self, observation: Observation) -> Optional[Action]: """Receive an observation and produce an action, or None for done.""" state = self._perceiver.step(observation) + if CFG.make_cogman_videos: + assert self._current_env_task is not None + imgs = self._perceiver.render_mental_images( + state, self._current_env_task) + self._episode_images.extend(imgs) # Replace the first step because the state was already added in reset(). if not self._episode_action_history: self._episode_state_history[0] = state @@ -86,6 +101,10 @@ def finish_episode(self, observation: Observation) -> None: self._episode_action_history): state = self._perceiver.step(observation) self._episode_state_history.append(state) + if CFG.make_cogman_videos: + save_prefix = utils.get_config_path_str() + outfile = f"{save_prefix}__cogman__episode{self._episode_num}.mp4" + utils.save_video(outfile, self._episode_images) # The methods below provide an interface to the approach. In the future, # we may want to move some of these methods into cogman properly, e.g., diff --git a/predicators/perception/base_perceiver.py b/predicators/perception/base_perceiver.py index bca657e78b..3bf36e23d1 100644 --- a/predicators/perception/base_perceiver.py +++ b/predicators/perception/base_perceiver.py @@ -2,7 +2,8 @@ import abc -from predicators.structs import EnvironmentTask, Observation, State, Task +from predicators.structs import EnvironmentTask, Observation, State, Task, \ + Video class BasePerceiver(abc.ABC): @@ -20,3 +21,8 @@ def reset(self, env_task: EnvironmentTask) -> Task: @abc.abstractmethod def step(self, observation: Observation) -> State: """Produce a State given the current and past observations.""" + + @abc.abstractmethod + def render_mental_images(self, observation: Observation, + env_task: EnvironmentTask) -> Video: + """Create mental images for the given observation.""" diff --git a/predicators/perception/kitchen_perceiver.py b/predicators/perception/kitchen_perceiver.py index e8b6dfdb48..0aa179b90b 100644 --- a/predicators/perception/kitchen_perceiver.py +++ b/predicators/perception/kitchen_perceiver.py @@ -3,7 +3,7 @@ from predicators.envs.kitchen import KitchenEnv from predicators.perception.base_perceiver import BasePerceiver from predicators.structs import EnvironmentTask, GroundAtom, Observation, \ - State, Task + State, Task, Video class KitchenPerceiver(BasePerceiver): @@ -49,3 +49,7 @@ def step(self, observation: Observation) -> State: def _observation_to_state(self, obs: Observation) -> State: return KitchenEnv.state_info_to_state(obs["state_info"]) + + def render_mental_images(self, observation: Observation, + env_task: EnvironmentTask) -> Video: + raise NotImplementedError("Mental images not implemented for kitchen") diff --git a/predicators/perception/sokoban_perceiver.py b/predicators/perception/sokoban_perceiver.py index 428c9c21b6..e633d4c802 100644 --- a/predicators/perception/sokoban_perceiver.py +++ b/predicators/perception/sokoban_perceiver.py @@ -8,7 +8,7 @@ from predicators.envs.sokoban import SokobanEnv from predicators.perception.base_perceiver import BasePerceiver from predicators.structs import EnvironmentTask, GroundAtom, Object, \ - Observation, State, Task + Observation, State, Task, Video # Each observation is a tuple of four 2D boolean masks (numpy arrays). # The order is: free, goals, boxes, player. @@ -95,3 +95,7 @@ def _get_object_name(r: int, c: int, type_name: str) -> str: state = utils.create_state_from_dict(state_dict) return state + + def render_mental_images(self, observation: Observation, + env_task: EnvironmentTask) -> Video: + raise NotImplementedError("Mental images not implemented for sokoban") diff --git a/predicators/perception/trivial_perceiver.py b/predicators/perception/trivial_perceiver.py index 345d44f63a..bed17f1b0c 100644 --- a/predicators/perception/trivial_perceiver.py +++ b/predicators/perception/trivial_perceiver.py @@ -1,7 +1,10 @@ """A trivial perceiver that assumes observations are already states.""" +from predicators.envs import get_or_create_env from predicators.perception.base_perceiver import BasePerceiver -from predicators.structs import EnvironmentTask, Observation, State, Task +from predicators.settings import CFG +from predicators.structs import EnvironmentTask, Observation, State, Task, \ + Video class TrivialPerceiver(BasePerceiver): @@ -17,3 +20,10 @@ def reset(self, env_task: EnvironmentTask) -> Task: def step(self, observation: Observation) -> State: assert isinstance(observation, State) return observation + + def render_mental_images(self, observation: Observation, + env_task: EnvironmentTask) -> Video: + # Use the environment's render function by default. + assert isinstance(observation, State) + env = get_or_create_env(CFG.env) + return env.render_state(observation, env_task) diff --git a/tests/envs/test_sokoban.py b/tests/envs/test_sokoban.py index 0f0e120637..c0caf0008f 100644 --- a/tests/envs/test_sokoban.py +++ b/tests/envs/test_sokoban.py @@ -78,6 +78,8 @@ def test_sokoban(): imgs = env.render() assert len(imgs) == 1 task = perceiver.reset(env_task) + with pytest.raises(NotImplementedError): + perceiver.render_mental_images(env_task.init_obs, env_task) state = task.init atoms = utils.abstract(state, env.predicates) num_boxes = len({a for a in atoms if a.predicate == IsBox}) diff --git a/tests/test_main.py b/tests/test_main.py index 7f0a7a3a0a..6b6f85d051 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -117,9 +117,9 @@ def test_main(): eval_traj_dir = os.path.join(parent_dir, "_fake_trajs") sys.argv = [ "dummy", "--env", "cover", "--approach", "oracle", "--seed", "123", - "--make_test_videos", "--num_test_tasks", "1", "--video_dir", - video_dir, "--results_dir", results_dir, "--eval_trajectories_dir", - eval_traj_dir + "--make_test_videos", "--make_cogman_videos", "--num_test_tasks", "1", + "--video_dir", video_dir, "--results_dir", results_dir, + "--eval_trajectories_dir", eval_traj_dir ] main() # Test making videos of failures and local logging.