Skip to content

Commit

Permalink
gymnasium 1.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico committed Jan 16, 2025
1 parent aed1289 commit 176616e
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 18 deletions.
4 changes: 2 additions & 2 deletions benchmarks/benchmark_sb3.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
# print(sb3.common.evaluation.evaluate_policy(model.policy, env))


# Stable Baselines3 SAC - LunarLanderContinuous-v2
# Stable Baselines3 SAC - LunarLanderContinuous-v3
# Decomment below to run SAC benchmarks

# if __name__ == "__main__":
# with timer("run_time", SumMetric, sync_on_compute=False):
# env = sb3.common.vec_env.DummyVecEnv(
# [lambda: gym.make("LunarLanderContinuous-v2", render_mode="rgb_array") for _ in range(4)]
# [lambda: gym.make("LunarLanderContinuous-v3", render_mode="rgb_array") for _ in range(4)]
# )
# model = SAC("MlpPolicy", env, verbose=0, device="cpu")
# model.learn(total_timesteps=1024 * 64, log_interval=None)
Expand Down
6 changes: 3 additions & 3 deletions howto/select_observations.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ The algorithms that work with only vector observations are reported here:
* SAC
* Droq

For any of them you **must select** only the environments that provide vector observations. For instance, you can train the *SAC* algorithm on the `LunarLanderContinuous-v2` environment, but you cannot train it on the `CarRacing-v2` environment.
For any of them you **must select** only the environments that provide vector observations. For instance, you can train the *SAC* algorithm on the `LunarLanderContinuous-v3` environment, but you cannot train it on the `CarRacing-v2` environment.

For these algorithms, you have to specify the *mlp* keys you want to encode. As usual, you have to specify them through the `mlp_keys.encoder` and `mlp_keys.decoder` arguments (in the command or the configs).

For instance, you can train a SAC agent on the `LunarLanderContinuous-v2` with the following command:
For instance, you can train a SAC agent on the `LunarLanderContinuous-v3` with the following command:
```bash
python sheeprl.py exp=sac env=gym env.id=LunarLanderContinuous-v2 algo.mlp_keys.encoder=[state]
python sheeprl.py exp=sac env=gym env.id=LunarLanderContinuous-v3 algo.mlp_keys.encoder=[state]
```


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ readme = { file = "docs/README.md", content-type = "text/markdown" }
requires-python = ">=3.8,<3.12"
classifiers = ["Programming Language :: Python", "Topic :: Scientific/Engineering :: Artificial Intelligence"]
dependencies = [
"gymnasium==0.29.*",
"gymnasium==1.0.0",
"pygame >=2.1.3",
"moviepy>=1.0.3",
"tensorboard>=2.10",
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/configs/exp/sac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ buffer:

# Environment
env:
id: LunarLanderContinuous-v2
id: LunarLanderContinuous-v3

metric:
aggregator:
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/configs/exp/sac_benchmarks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ run_benchmarks: True

# Environment
env:
id: LunarLanderContinuous-v2
id: LunarLanderContinuous-v3
capture_video: False
num_envs: 4

Expand Down
4 changes: 2 additions & 2 deletions sheeprl/envs/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(
if self._dict_obs_space:
self.observation_space = gym.spaces.Dict(
{
"rgb": gym.spaces.Box(0, 256, shape=image_size, dtype=np.uint8),
"rgb": gym.spaces.Box(0, 255, shape=image_size, dtype=np.uint8),
"state": gym.spaces.Box(-20, 20, shape=vector_shape, dtype=np.float32),
}
)
Expand All @@ -43,7 +43,7 @@ def get_obs(self) -> Dict[str, np.ndarray]:
if self._dict_obs_space:
return {
# da sostituire con np.random.rand
"rgb": np.full(self.observation_space["rgb"].shape, self._current_step % 256, dtype=np.uint8),
"rgb": np.full(self.observation_space["rgb"].shape, self._current_step % 255, dtype=np.uint8),
"state": np.full(self.observation_space["state"].shape, self._current_step, dtype=np.uint8),
}
else:
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/envs/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class MaskVelocityWrapper(gym.ObservationWrapper):
"MountainCarContinuous-v0": np.array([1]),
"Pendulum-v1": np.array([2]),
"LunarLander-v2": np.array([2, 3, 5]),
"LunarLanderContinuous-v2": np.array([2, 3, 5]),
"LunarLanderContinuous-v3": np.array([2, 3, 5]),
}

def __init__(self, env: gym.Env):
Expand Down
20 changes: 13 additions & 7 deletions sheeprl/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,14 @@ def thunk() -> gym.Env:
f"is allowed in {cfg.env.id}, "
f"only the first one is kept: {cfg.algo.cnn_keys.encoder[0]}"
)
obs_key = "state"
if encoder_mlp_keys_length > 0:
gym.wrappers.pixel_observation.STATE_KEY = cfg.algo.mlp_keys.encoder[0]
env = gym.wrappers.PixelObservationWrapper(
env, pixels_only=encoder_mlp_keys_length == 0, pixel_keys=(cfg.algo.cnn_keys.encoder[0],)
obs_key = cfg.algo.mlp_keys.encoder[0]
env = gym.wrappers.AddRenderObservation(
env,
render_only=encoder_mlp_keys_length == 0,
render_key=cfg.algo.cnn_keys.encoder[0],
obs_key=obs_key,
)
else:
if encoder_mlp_keys_length > 1:
Expand All @@ -120,7 +124,7 @@ def thunk() -> gym.Env:
f"only the first one is kept: {cfg.algo.mlp_keys.encoder[0]}"
)
mlp_key = cfg.algo.mlp_keys.encoder[0]
env = gym.wrappers.TransformObservation(env, lambda obs: {mlp_key: obs})
env = gym.wrappers.TransformObservation(env, lambda obs: {mlp_key: obs}, None)
env.observation_space = gym.spaces.Dict({mlp_key: env.observation_space})
elif isinstance(env.observation_space, gym.spaces.Box) and 2 <= len(env.observation_space.shape) <= 3:
# Pixel only observation
Expand All @@ -136,7 +140,9 @@ def thunk() -> gym.Env:
"Please set at least one cnn key in the config file: `algo.cnn_keys.encoder=[your_cnn_key]`"
)
cnn_key = cfg.algo.cnn_keys.encoder[0]
env = gym.wrappers.TransformObservation(env, lambda obs: {cnn_key: obs})
env = gym.wrappers.TransformObservation(
env, lambda obs: {cnn_key: obs}, gym.spaces.Dict({cnn_key: env.observation_space})
)
env.observation_space = gym.spaces.Dict({cnn_key: env.observation_space})

if (
Expand Down Expand Up @@ -195,7 +201,7 @@ def transform_obs(obs: Dict[str, Any]):

return obs

env = gym.wrappers.TransformObservation(env, transform_obs)
env = gym.wrappers.TransformObservation(env, transform_obs, None)
for k in cnn_keys:
env.observation_space[k] = gym.spaces.Box(
0, 255, (1 if cfg.env.grayscale else 3, cfg.env.screen_size, cfg.env.screen_size), np.uint8
Expand All @@ -222,7 +228,7 @@ def transform_obs(obs: Dict[str, Any]):
if cfg.env.capture_video and rank == 0 and vector_env_idx == 0 and run_name is not None:
if cfg.env.grayscale:
env = GrayscaleRenderWrapper(env)
env = gym.experimental.wrappers.RecordVideoV0(
env = gym.wrappers.RecordVideo(
env, os.path.join(run_name, prefix + "_videos" if prefix else "videos"), disable_logger=True
)
env.metadata["render_fps"] = env.frames_per_sec
Expand Down

0 comments on commit 176616e

Please sign in to comment.