Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support VecEnv for gymnasium.vector.VectorEnv and Brax #1745

Open
2 tasks done
vyeevani opened this issue Nov 10, 2023 · 7 comments
Open
2 tasks done

Support VecEnv for gymnasium.vector.VectorEnv and Brax #1745

vyeevani opened this issue Nov 10, 2023 · 7 comments
Labels
documentation Improvements or additions to documentation enhancement New feature or request

Comments

@vyeevani
Copy link

🚀 Feature

It would be nice to have a wrapper that ingested gymnasium.vector.VectorEnv and gave back a VecEnv.

Motivation

I want to do highly parallelized hardware accelerated simulation. This pretty much leaves Isaac or Brax. Brax has a lighter weight setup plus runs on TPUs. Stable baselines has well documented and tested implementations of most of the algorithms that I'm interested in using, as well as deep integration with the imitate library. I'd like to use both of these libraries.

Pitch

Brax currently provides a wrapper for legacy OpenAI gym vectorized environments. I have a request up to support Gymnasium vectorized API (pretty much just change the imports to Gymnasium instead of Gym). Stable baselines requires vectorized environments to be implemented against it's specific VecEnv specification. As far as I can tell, it's pretty simple to migrate between gymnasium vectorized env API and sb3's representation.

I'd like a wrapper class to be provided that implements VecEnv with an underlying gymnasium vectorized env.

Alternatives

Given the public API allows users to extend the library to write this themselves, that would be the chief alternative.

Additional context

No response

Checklist

  • I have checked that there is no similar issue in the repo
  • If I'm requesting a new feature, I have proposed alternatives
@vyeevani vyeevani added the enhancement New feature or request label Nov 10, 2023
@araffin araffin added the duplicate This issue or pull request already exists label Nov 10, 2023
@vyeevani
Copy link
Author

Where is the duplicate? I searched for it but couldn't find it. Would appreciate a pointer

@araffin
Copy link
Member

araffin commented Nov 10, 2023

Partial duplicate of #1568 (comment) and #229

For short: a VecEnvWrapper would be indeed a good idea but only after gymnasium 1.0 is released and fully tested. Would you be willing to contribute such wrapper?

Related doc: https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#sb3-with-envpool-or-isaac-gym

Related issues: #1712 and #772 (comment)

@vyeevani
Copy link
Author

vyeevani commented Nov 19, 2023

I happened to step away from using Gymnasium APIs. I was focused on Brax.

from typing import ClassVar, Optional

from brax.envs.base import PipelineEnv
from brax.io import image
# import gym
# from gym import spaces
import gymnasium
import gymnasium as gym
from gymnasium import spaces
from gymnasium.vector import utils
import jax
import jax.numpy as jp
import numpy as np

from stable_baselines3.common.vec_env.base_vec_env import VecEnvIndices

class SB3Wrapper(VecEnv):
  def __init__(self,
               env: PipelineEnv,
               seed: int = 0,
               info_keys: Optional[Sequence[str]] = None,
               backend: Optional[str] = None):
    self._env = env
    self.info_keys = info_keys
    self.metadata = {
        'render.modes': ['human', 'rgb_array'],
        'video.frames_per_second': 1 / self._env.dt
    }
    if not hasattr(self._env, 'batch_size'):
      raise ValueError('underlying env must be batched')

    if not hasattr(self._env, 'episode_length'):
      raise ValueError('underlying env must be wrapped with an episode wrapper')

    obs = np.inf * np.ones(self._env.observation_size, dtype='float32')
    obs_space = spaces.Box(-obs, obs, dtype='float32')

    action = jax.tree_map(np.array, self._env.sys.actuator.ctrl_range)
    action_space = spaces.Box(action[:, 0], action[:, 1], dtype='float32')

    self.num_envs = self._env.batch_size
    self.observation_space = obs_space
    self.action_space = action_space

    # self.batch_observation_space = utils.batch_space(obs_space, self.num_envs)
    # self.batch_action_space = utils.batch_space(action_space, self.num_envs)

    self.seed(seed)
    self.backend = backend
    self._state = None

    def reset(key):
      key1, key2 = jax.random.split(key)
      state = self._env.reset(key2)
      return state, state.obs, key1

    self._reset = jax.jit(reset, backend=self.backend)

    def step(state, action):
      state = self._env.step(state, action)
      info = {**state.metrics, **state.info}
      return state, state.obs, state.reward, state.done, state.info['truncation'], info

    self._step = jax.jit(step, backend=self.backend)

  def reset(self, **kwargs):
    self._state, obs, self._key = self._reset(self._key)
    return np.array(obs)

  def step_async(self, action):
    self.action = jp.array(action)

  def step_wait(self):
    self._state, obs, reward, done, truncation, info = self._step(self._state, self.action)
    def batch_dict_to_list_dict(batched_dict, keys_to_process):
      return [{} for i in range(self.num_envs)]
      # if keys_to_process is None:
        # return [{} for i in range(self.num_envs)]
      # # Filter the dictionary to only include specified keys that are JAX arrays
      # filtered_dict = {key: batched_dict[key] for key in keys_to_process if key in batched_dict and isinstance(batched_dict[key], jnp.ndarray)}

      # # Find the batch size from the first item in the filtered dictionary
      # batch_size = filtered_dict[next(iter(filtered_dict))].shape[0] if filtered_dict else 0

      # # Create a list of dictionaries for each batch index
      # return [{key: filtered_dict[key][i] for key in filtered_dict} for i in range(batch_size)]
    info = batch_dict_to_list_dict(info, self.info_keys)
    # print(reward)
    return np.array(obs), np.array(reward), np.array(done), info

  def seed(self, seed: int = 0):
    self._key = jax.random.PRNGKey(seed)

  def render(self, mode='human'):
    if mode == 'rgb_array':
      sys, state = self._env.sys, self._state
      if state is None:
        raise RuntimeError('must call reset or step before rendering')
      return image.render_array(sys, state.pipeline_state.take(0), 256, 256)
    else:
      return super().render(mode=mode)  # just raise an exception

  def close(self):
    pass

  def env_is_wrapped(self, wrapper_class):
    return [False] * self.num_envs

  def step(self, actions):
    self.step_async(actions)
    return self.step_wait()

  def get_attr(self, attr_name, indicies):
    return getattr(self, attr_name)

  def set_attr(self, attr_name, value, indicies):
    return setattr(self, attr_name, value)

  def env_method(self, method_name, *method_args, indicies, **method_kwargs):
    return self.get_attr(method_name)(method_args, method_kwargs)

class AutoResetWrapper2(Wrapper):
  """Automatically resets Brax envs that are done."""
  def reset(self, rng: jax.Array) -> State:
    base_state = self.env.reset(rng)
    info = base_state.info.copy()
    info.update({
        'initial_base_state': base_state,
        'current_base_state': base_state
    })

    return State(
        pipeline_state=base_state.pipeline_state,
        obs=base_state.obs,
        reward=base_state.reward,
        done=base_state.done,
        metrics=base_state.metrics,
        info=info
    )

  def step(self, state: State, action: jax.Array) -> State:
    initial_base_state = state.info['initial_base_state']
    current_base_state = state.info['current_base_state']
    next_base_state = self.env.step(current_base_state, action)

    done = next_base_state.done
    def where_done(x, y):
      return jp.where(done, x, y)

    info = jax.tree_map(where_done, initial_base_state.info, next_base_state.info).copy()
    info.update ({
        'initial_base_state': initial_base_state,
        'current_base_state': jax.tree_map(where_done, initial_base_state, next_base_state),
    })

    return State(
        pipeline_state=jax.tree_map(where_done, initial_base_state.pipeline_state, next_base_state.pipeline_state),
        obs=jax.tree_map(where_done, initial_base_state.obs, next_base_state.obs),
        reward=jax.tree_map(where_done, initial_base_state.reward, next_base_state.reward),
        done=next_base_state.done,
        metrics=jax.tree_map(where_done, initial_base_state.metrics, next_base_state.metrics),
        info=info
    )

from brax.envs.wrappers.training import VmapWrapper, EpisodeWrapper, AutoResetWrapper
from brax.envs.ant import Ant
from brax.envs.humanoid import Humanoid

episode_length = 1000
backend = 'spring'
batch_size = 1024
action_repeat = 1

env = Ant(backend='spring')
env = EpisodeWrapper(env, episode_length, action_repeat=action_repeat)
env = AutoResetWrapper2(env)
env = VmapWrapper(env, batch_size)

@vyeevani
Copy link
Author

^ This is really hacky stuff and there's tons that's terrible about it. This is a high level sketch of everything that would be needed to get this to work.

@araffin araffin changed the title Support VecEnv for gymnasium.vector.VectorEnv Support VecEnv for gymnasium.vector.VectorEnv and Brax Nov 20, 2023
@araffin
Copy link
Member

araffin commented Nov 20, 2023

Hello,
thanks for providing the code =)
Do you need any help to get it to work?
I would be happy to link it in our doc (and maybe integrate it in the zoo or sb3 contrib) as it should be similar to envpool/isaac gym: https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#sb3-with-envpool-or-isaac-gym

@araffin araffin added documentation Improvements or additions to documentation and removed duplicate This issue or pull request already exists labels Nov 20, 2023
@jamesheald
Copy link

@vyeevani did you finalise this into a working version?

@araffin
Copy link
Member

araffin commented Jan 8, 2025

I have a working version here, but still need some polishing: https://gist.github.com/araffin/a7a576ec1453e74d9bb93120918ef7e7

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants