-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support VecEnv for gymnasium.vector.VectorEnv and Brax #1745
Comments
Where is the duplicate? I searched for it but couldn't find it. Would appreciate a pointer |
Partial duplicate of #1568 (comment) and #229 For short: a Related doc: https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#sb3-with-envpool-or-isaac-gym Related issues: #1712 and #772 (comment) |
I happened to step away from using Gymnasium APIs. I was focused on Brax. from typing import ClassVar, Optional
from brax.envs.base import PipelineEnv
from brax.io import image
# import gym
# from gym import spaces
import gymnasium
import gymnasium as gym
from gymnasium import spaces
from gymnasium.vector import utils
import jax
import jax.numpy as jp
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import VecEnvIndices
class SB3Wrapper(VecEnv):
def __init__(self,
env: PipelineEnv,
seed: int = 0,
info_keys: Optional[Sequence[str]] = None,
backend: Optional[str] = None):
self._env = env
self.info_keys = info_keys
self.metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 1 / self._env.dt
}
if not hasattr(self._env, 'batch_size'):
raise ValueError('underlying env must be batched')
if not hasattr(self._env, 'episode_length'):
raise ValueError('underlying env must be wrapped with an episode wrapper')
obs = np.inf * np.ones(self._env.observation_size, dtype='float32')
obs_space = spaces.Box(-obs, obs, dtype='float32')
action = jax.tree_map(np.array, self._env.sys.actuator.ctrl_range)
action_space = spaces.Box(action[:, 0], action[:, 1], dtype='float32')
self.num_envs = self._env.batch_size
self.observation_space = obs_space
self.action_space = action_space
# self.batch_observation_space = utils.batch_space(obs_space, self.num_envs)
# self.batch_action_space = utils.batch_space(action_space, self.num_envs)
self.seed(seed)
self.backend = backend
self._state = None
def reset(key):
key1, key2 = jax.random.split(key)
state = self._env.reset(key2)
return state, state.obs, key1
self._reset = jax.jit(reset, backend=self.backend)
def step(state, action):
state = self._env.step(state, action)
info = {**state.metrics, **state.info}
return state, state.obs, state.reward, state.done, state.info['truncation'], info
self._step = jax.jit(step, backend=self.backend)
def reset(self, **kwargs):
self._state, obs, self._key = self._reset(self._key)
return np.array(obs)
def step_async(self, action):
self.action = jp.array(action)
def step_wait(self):
self._state, obs, reward, done, truncation, info = self._step(self._state, self.action)
def batch_dict_to_list_dict(batched_dict, keys_to_process):
return [{} for i in range(self.num_envs)]
# if keys_to_process is None:
# return [{} for i in range(self.num_envs)]
# # Filter the dictionary to only include specified keys that are JAX arrays
# filtered_dict = {key: batched_dict[key] for key in keys_to_process if key in batched_dict and isinstance(batched_dict[key], jnp.ndarray)}
# # Find the batch size from the first item in the filtered dictionary
# batch_size = filtered_dict[next(iter(filtered_dict))].shape[0] if filtered_dict else 0
# # Create a list of dictionaries for each batch index
# return [{key: filtered_dict[key][i] for key in filtered_dict} for i in range(batch_size)]
info = batch_dict_to_list_dict(info, self.info_keys)
# print(reward)
return np.array(obs), np.array(reward), np.array(done), info
def seed(self, seed: int = 0):
self._key = jax.random.PRNGKey(seed)
def render(self, mode='human'):
if mode == 'rgb_array':
sys, state = self._env.sys, self._state
if state is None:
raise RuntimeError('must call reset or step before rendering')
return image.render_array(sys, state.pipeline_state.take(0), 256, 256)
else:
return super().render(mode=mode) # just raise an exception
def close(self):
pass
def env_is_wrapped(self, wrapper_class):
return [False] * self.num_envs
def step(self, actions):
self.step_async(actions)
return self.step_wait()
def get_attr(self, attr_name, indicies):
return getattr(self, attr_name)
def set_attr(self, attr_name, value, indicies):
return setattr(self, attr_name, value)
def env_method(self, method_name, *method_args, indicies, **method_kwargs):
return self.get_attr(method_name)(method_args, method_kwargs)
class AutoResetWrapper2(Wrapper):
"""Automatically resets Brax envs that are done."""
def reset(self, rng: jax.Array) -> State:
base_state = self.env.reset(rng)
info = base_state.info.copy()
info.update({
'initial_base_state': base_state,
'current_base_state': base_state
})
return State(
pipeline_state=base_state.pipeline_state,
obs=base_state.obs,
reward=base_state.reward,
done=base_state.done,
metrics=base_state.metrics,
info=info
)
def step(self, state: State, action: jax.Array) -> State:
initial_base_state = state.info['initial_base_state']
current_base_state = state.info['current_base_state']
next_base_state = self.env.step(current_base_state, action)
done = next_base_state.done
def where_done(x, y):
return jp.where(done, x, y)
info = jax.tree_map(where_done, initial_base_state.info, next_base_state.info).copy()
info.update ({
'initial_base_state': initial_base_state,
'current_base_state': jax.tree_map(where_done, initial_base_state, next_base_state),
})
return State(
pipeline_state=jax.tree_map(where_done, initial_base_state.pipeline_state, next_base_state.pipeline_state),
obs=jax.tree_map(where_done, initial_base_state.obs, next_base_state.obs),
reward=jax.tree_map(where_done, initial_base_state.reward, next_base_state.reward),
done=next_base_state.done,
metrics=jax.tree_map(where_done, initial_base_state.metrics, next_base_state.metrics),
info=info
)
from brax.envs.wrappers.training import VmapWrapper, EpisodeWrapper, AutoResetWrapper
from brax.envs.ant import Ant
from brax.envs.humanoid import Humanoid
episode_length = 1000
backend = 'spring'
batch_size = 1024
action_repeat = 1
env = Ant(backend='spring')
env = EpisodeWrapper(env, episode_length, action_repeat=action_repeat)
env = AutoResetWrapper2(env)
env = VmapWrapper(env, batch_size) |
^ This is really hacky stuff and there's tons that's terrible about it. This is a high level sketch of everything that would be needed to get this to work. |
Hello, |
@vyeevani did you finalise this into a working version? |
I have a working version here, but still need some polishing: https://gist.github.com/araffin/a7a576ec1453e74d9bb93120918ef7e7 |
🚀 Feature
It would be nice to have a wrapper that ingested gymnasium.vector.VectorEnv and gave back a VecEnv.
Motivation
I want to do highly parallelized hardware accelerated simulation. This pretty much leaves Isaac or Brax. Brax has a lighter weight setup plus runs on TPUs. Stable baselines has well documented and tested implementations of most of the algorithms that I'm interested in using, as well as deep integration with the imitate library. I'd like to use both of these libraries.
Pitch
Brax currently provides a wrapper for legacy OpenAI gym vectorized environments. I have a request up to support Gymnasium vectorized API (pretty much just change the imports to Gymnasium instead of Gym). Stable baselines requires vectorized environments to be implemented against it's specific VecEnv specification. As far as I can tell, it's pretty simple to migrate between gymnasium vectorized env API and sb3's representation.
I'd like a wrapper class to be provided that implements VecEnv with an underlying gymnasium vectorized env.
Alternatives
Given the public API allows users to extend the library to write this themselves, that would be the chief alternative.
Additional context
No response
Checklist
The text was updated successfully, but these errors were encountered: