diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 3067c3a8..cbbea960 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -27,7 +27,7 @@ repos:
hooks:
- id: flake8
args:
- - '--per-file-ignores=*/__init__.py:F401'
+ - "--per-file-ignores=*/__init__.py:F401"
- --ignore=E203,W503,E741
- --max-complexity=30
- --max-line-length=456
@@ -64,6 +64,6 @@ repos:
language: node
pass_filenames: false
types: [python]
- additional_dependencies: ["pyright"]
+ additional_dependencies: ["pyright@1.1.347"]
args:
- --project=pyproject.toml
diff --git a/docs/_static/videos/breakable-bottles.gif b/docs/_static/videos/breakable-bottles.gif
index 4f901d79..ddca154e 100644
Binary files a/docs/_static/videos/breakable-bottles.gif and b/docs/_static/videos/breakable-bottles.gif differ
diff --git a/docs/_static/videos/deep-sea-treasure-concave.gif b/docs/_static/videos/deep-sea-treasure-concave.gif
index 4bad0688..a0a76216 100644
Binary files a/docs/_static/videos/deep-sea-treasure-concave.gif and b/docs/_static/videos/deep-sea-treasure-concave.gif differ
diff --git a/docs/_static/videos/deep-sea-treasure-mirrored.gif b/docs/_static/videos/deep-sea-treasure-mirrored.gif
new file mode 100644
index 00000000..dc5c4994
Binary files /dev/null and b/docs/_static/videos/deep-sea-treasure-mirrored.gif differ
diff --git a/docs/_static/videos/deep-sea-treasure.gif b/docs/_static/videos/deep-sea-treasure.gif
index 6f5b5d81..b06f7471 100644
Binary files a/docs/_static/videos/deep-sea-treasure.gif and b/docs/_static/videos/deep-sea-treasure.gif differ
diff --git a/docs/_static/videos/four-room.gif b/docs/_static/videos/four-room.gif
index 24282928..8641a8f3 100644
Binary files a/docs/_static/videos/four-room.gif and b/docs/_static/videos/four-room.gif differ
diff --git a/docs/_static/videos/fruit-tree.gif b/docs/_static/videos/fruit-tree.gif
new file mode 100644
index 00000000..483af1e2
Binary files /dev/null and b/docs/_static/videos/fruit-tree.gif differ
diff --git a/docs/_static/videos/minecart-deterministic.gif b/docs/_static/videos/minecart-deterministic.gif
index 3559f6ba..39ad4172 100644
Binary files a/docs/_static/videos/minecart-deterministic.gif and b/docs/_static/videos/minecart-deterministic.gif differ
diff --git a/docs/_static/videos/minecart.gif b/docs/_static/videos/minecart.gif
index 6242074e..0c3e99dc 100644
Binary files a/docs/_static/videos/minecart.gif and b/docs/_static/videos/minecart.gif differ
diff --git a/docs/_static/videos/mo-halfcheetah.gif b/docs/_static/videos/mo-halfcheetah.gif
index bf4cf6be..3fe0efc6 100644
Binary files a/docs/_static/videos/mo-halfcheetah.gif and b/docs/_static/videos/mo-halfcheetah.gif differ
diff --git a/docs/_static/videos/mo-hopper.gif b/docs/_static/videos/mo-hopper.gif
index 402fc20a..8677eecf 100644
Binary files a/docs/_static/videos/mo-hopper.gif and b/docs/_static/videos/mo-hopper.gif differ
diff --git a/docs/_static/videos/mo-lunar-lander.gif b/docs/_static/videos/mo-lunar-lander.gif
index abb23939..2051d754 100644
Binary files a/docs/_static/videos/mo-lunar-lander.gif and b/docs/_static/videos/mo-lunar-lander.gif differ
diff --git a/docs/_static/videos/mo-mountaincar.gif b/docs/_static/videos/mo-mountaincar.gif
index d0b7db84..c9a8d4ba 100644
Binary files a/docs/_static/videos/mo-mountaincar.gif and b/docs/_static/videos/mo-mountaincar.gif differ
diff --git a/docs/_static/videos/mo-mountaincarcontinuous.gif b/docs/_static/videos/mo-mountaincarcontinuous.gif
index 9af3f1e4..81d96abb 100644
Binary files a/docs/_static/videos/mo-mountaincarcontinuous.gif and b/docs/_static/videos/mo-mountaincarcontinuous.gif differ
diff --git a/docs/_static/videos/mo-reacher.gif b/docs/_static/videos/mo-reacher.gif
index 4f2d878d..cd2d6086 100644
Binary files a/docs/_static/videos/mo-reacher.gif and b/docs/_static/videos/mo-reacher.gif differ
diff --git a/docs/_static/videos/mo-supermario.gif b/docs/_static/videos/mo-supermario.gif
index b60f379d..f87efc03 100644
Binary files a/docs/_static/videos/mo-supermario.gif and b/docs/_static/videos/mo-supermario.gif differ
diff --git a/docs/_static/videos/resource-gathering.gif b/docs/_static/videos/resource-gathering.gif
index 2187043b..916630bc 100644
Binary files a/docs/_static/videos/resource-gathering.gif and b/docs/_static/videos/resource-gathering.gif differ
diff --git a/docs/_static/videos/water-reservoir.gif b/docs/_static/videos/water-reservoir.gif
index c0709c40..5ebbbcdf 100644
Binary files a/docs/_static/videos/water-reservoir.gif and b/docs/_static/videos/water-reservoir.gif differ
diff --git a/docs/environments/all-environments.md b/docs/environments/all-environments.md
index 4d2f011a..0975216a 100644
--- a/docs/environments/all-environments.md
+++ b/docs/environments/all-environments.md
@@ -7,25 +7,26 @@ title: "Environments"
MO-Gymnasium includes environments taken from the MORL literature, as well as multi-objective version of classical environments, such as Mujoco.
-| Env | Obs/Action spaces | Objectives | Description |
-|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------|---------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| [`deep-sea-treasure-v0`](https://mo-gymnasium.farama.org/environments/deep-sea-treasure/)
| Discrete / Discrete | `[treasure, time_penalty]` | Agent is a submarine that must collect a treasure while taking into account a time penalty. Treasures values taken from [Yang et al. 2019](https://arxiv.org/pdf/1908.08342.pdf). |
-| [`deep-sea-treasure-concave-v0`](https://mo-gymnasium.farama.org/environments/deep-sea-treasure-concave/)
| Discrete / Discrete | `[treasure, time_penalty]` | Agent is a submarine that must collect a treasure while taking into account a time penalty. Treasures values taken from [Vamplew et al. 2010](https://link.springer.com/article/10.1007/s10994-010-5232-5). |
-| [`resource-gathering-v0`](https://mo-gymnasium.farama.org/environments/resource-gathering/)
| Discrete / Discrete | `[enemy, gold, gem]` | Agent must collect gold or gem. Enemies have a 10% chance of killing the agent. From [Barret & Narayanan 2008](https://dl.acm.org/doi/10.1145/1390156.1390162). |
-| [`fishwood-v0`](https://mo-gymnasium.farama.org/environments/fishwood/)
| Discrete / Discrete | `[fish_amount, wood_amount]` | ESR environment, the agent must collect fish and wood to light a fire and eat. From [Roijers et al. 2018](https://www.researchgate.net/publication/328718263_Multi-objective_Reinforcement_Learning_for_the_Expected_Utility_of_the_Return). |
-| [`breakable-bottles-v0`](https://mo-gymnasium.farama.org/environments/breakable-bottles/)
| Discrete (Dictionary) / Discrete | `[time_penalty, bottles_delivered, potential]` | Gridworld with 5 cells. The agents must collect bottles from the source location and deliver to the destination. From [Vamplew et al. 2021](https://www.sciencedirect.com/science/article/pii/S0952197621000336). |
-| [`fruit-tree-v0`](https://mo-gymnasium.farama.org/environments/fruit-tree/)
| Discrete / Discrete | `[nutri1, ..., nutri6]` | Full binary tree of depth d=5,6 or 7. Every leaf contains a fruit with a value for the nutrients Protein, Carbs, Fats, Vitamins, Minerals and Water. From [Yang et al. 2019](https://arxiv.org/pdf/1908.08342.pdf). |
-| [`water-reservoir-v0`](https://mo-gymnasium.farama.org/environments/water-reservoir/)
| Continuous / Continuous | `[cost_flooding, deficit_water]` | A Water reservoir environment. The agent executes a continuous action, corresponding to the amount of water released by the dam. From [Pianosi et al. 2013](https://iwaponline.com/jh/article/15/2/258/3425/Tree-based-fitted-Q-iteration-for-multi-objective). |
-| [`four-room-v0`](https://mo-gymnasium.farama.org/environments/four-room/)
| Discrete / Discrete | `[item1, item2, item3]` | Agent must collect three different types of items in the map and reach the goal. From [Alegre et al. 2022](https://proceedings.mlr.press/v162/alegre22a.html). |
-| [`mo-mountaincar-v0`](https://mo-gymnasium.farama.org/environments/mo-mountaincar/)
| Continuous / Discrete | `[time_penalty, reverse_penalty, forward_penalty]` | Classic Mountain Car env, but with extra penalties for the forward and reverse actions. From [Vamplew et al. 2011](https://www.researchgate.net/publication/220343783_Empirical_evaluation_methods_for_multiobjective_reinforcement_learning_algorithms). |
-| [`mo-mountaincarcontinuous-v0`](https://mo-gymnasium.farama.org/environments/mo-mountaincarcontinuous/)
| Continuous / Continuous | `[time_penalty, fuel_consumption_penalty]` | Continuous Mountain Car env, but with penalties for fuel consumption. |
-| [`mo-lunar-lander-v2`](https://mo-gymnasium.farama.org/environments/mo-lunar-lander/)
| Continuous / Discrete or Continuous | `[landed, shaped_reward, main_engine_fuel, side_engine_fuel]` | MO version of the `LunarLander-v2` [environment](https://gymnasium.farama.org/environments/box2d/lunar_lander/). Objectives defined similarly as in [Hung et al. 2022](https://openreview.net/forum?id=AwWaBXLIJE). |
-| [`minecart-v0`](https://mo-gymnasium.farama.org/environments/minecart/)
| Continuous or Image / Discrete | `[ore1, ore2, fuel]` | Agent must collect two types of ores and minimize fuel consumption. From [Abels et al. 2019](https://arxiv.org/abs/1809.07803v2). |
-| [`mo-highway-v0`](https://mo-gymnasium.farama.org/environments/mo-highway/) and `mo-highway-fast-v0`
| Continuous / Discrete | `[speed, right_lane, collision]` | The agent's objective is to reach a high speed while avoiding collisions with neighbouring vehicles and staying on the rightest lane. From [highway-env](https://github.com/eleurent/highway-env). |
-| [`mo-supermario-v0`](https://mo-gymnasium.farama.org/environments/mo-supermario/)
| Image / Discrete | `[x_pos, time, death, coin, enemy]` | [:warning: SuperMarioBrosEnv support is limited.] Multi-objective version of [SuperMarioBrosEnv](https://github.com/Kautenja/gym-super-mario-bros). Objectives are defined similarly as in [Yang et al. 2019](https://arxiv.org/pdf/1908.08342.pdf). |
-| [`mo-reacher-v4`](https://mo-gymnasium.farama.org/environments/mo-reacher/)
| Continuous / Discrete | `[target_1, target_2, target_3, target_4]` | Mujoco version of `mo-reacher-v0`, based on `Reacher-v4` [environment](https://gymnasium.farama.org/environments/mujoco/reacher/). |
-| [`mo-hopper-v4`](https://mo-gymnasium.farama.org/environments/mo-hopper/)
| Continuous / Continuous | `[velocity, height, energy]` | Multi-objective version of [Hopper-v4](https://gymnasium.farama.org/environments/mujoco/hopper/) env. |
-| [`mo-halfcheetah-v4`](https://mo-gymnasium.farama.org/environments/mo-halfcheetah/)
| Continuous / Continuous | `[velocity, energy]` | Multi-objective version of [HalfCheetah-v4](https://gymnasium.farama.org/environments/mujoco/half_cheetah/) env. Similar to [Xu et al. 2020](https://github.com/mit-gfx/PGMORL). |
+| Env | Obs/Action spaces | Objectives | Description |
+|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------|---------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [`deep-sea-treasure-v0`](https://mo-gymnasium.farama.org/environments/deep-sea-treasure/)
| Discrete / Discrete | `[treasure, time_penalty]` | Agent is a submarine that must collect a treasure while taking into account a time penalty. Treasures values taken from [Yang et al. 2019](https://arxiv.org/pdf/1908.08342.pdf). |
+| [`deep-sea-treasure-concave-v0`](https://mo-gymnasium.farama.org/environments/deep-sea-treasure-concave/)
| Discrete / Discrete | `[treasure, time_penalty]` | Agent is a submarine that must collect a treasure while taking into account a time penalty. Treasures values taken from [Vamplew et al. 2010](https://link.springer.com/article/10.1007/s10994-010-5232-5). |
+| [`deep-sea-treasure-mirrored-v0`](https://mo-gymnasium.farama.org/environments/deep-sea-treasure-mirrored/)
| Discrete / Discrete | `[treasure, time_penalty]` | Harder version of the concave DST [Felten et al. 2022](https://www.scitepress.org/Papers/2022/109891/109891.pdf). |
+| [`resource-gathering-v0`](https://mo-gymnasium.farama.org/environments/resource-gathering/)
| Discrete / Discrete | `[enemy, gold, gem]` | Agent must collect gold or gem. Enemies have a 10% chance of killing the agent. From [Barret & Narayanan 2008](https://dl.acm.org/doi/10.1145/1390156.1390162). |
+| [`fishwood-v0`](https://mo-gymnasium.farama.org/environments/fishwood/)
| Discrete / Discrete | `[fish_amount, wood_amount]` | ESR environment, the agent must collect fish and wood to light a fire and eat. From [Roijers et al. 2018](https://www.researchgate.net/publication/328718263_Multi-objective_Reinforcement_Learning_for_the_Expected_Utility_of_the_Return). |
+| [`breakable-bottles-v0`](https://mo-gymnasium.farama.org/environments/breakable-bottles/)
| Discrete (Dictionary) / Discrete | `[time_penalty, bottles_delivered, potential]` | Gridworld with 5 cells. The agents must collect bottles from the source location and deliver to the destination. From [Vamplew et al. 2021](https://www.sciencedirect.com/science/article/pii/S0952197621000336). |
+| [`fruit-tree-v0`](https://mo-gymnasium.farama.org/environments/fruit-tree/)
| Discrete / Discrete | `[nutri1, ..., nutri6]` | Full binary tree of depth d=5,6 or 7. Every leaf contains a fruit with a value for the nutrients Protein, Carbs, Fats, Vitamins, Minerals and Water. From [Yang et al. 2019](https://arxiv.org/pdf/1908.08342.pdf). |
+| [`water-reservoir-v0`](https://mo-gymnasium.farama.org/environments/water-reservoir/)
| Continuous / Continuous | `[cost_flooding, deficit_water]` | A Water reservoir environment. The agent executes a continuous action, corresponding to the amount of water released by the dam. From [Pianosi et al. 2013](https://iwaponline.com/jh/article/15/2/258/3425/Tree-based-fitted-Q-iteration-for-multi-objective). |
+| [`four-room-v0`](https://mo-gymnasium.farama.org/environments/four-room/)
| Discrete / Discrete | `[item1, item2, item3]` | Agent must collect three different types of items in the map and reach the goal. From [Alegre et al. 2022](https://proceedings.mlr.press/v162/alegre22a.html). |
+| [`mo-mountaincar-v0`](https://mo-gymnasium.farama.org/environments/mo-mountaincar/)
| Continuous / Discrete | `[time_penalty, reverse_penalty, forward_penalty]` | Classic Mountain Car env, but with extra penalties for the forward and reverse actions. From [Vamplew et al. 2011](https://www.researchgate.net/publication/220343783_Empirical_evaluation_methods_for_multiobjective_reinforcement_learning_algorithms). |
+| [`mo-mountaincarcontinuous-v0`](https://mo-gymnasium.farama.org/environments/mo-mountaincarcontinuous/)
| Continuous / Continuous | `[time_penalty, fuel_consumption_penalty]` | Continuous Mountain Car env, but with penalties for fuel consumption. |
+| [`mo-lunar-lander-v2`](https://mo-gymnasium.farama.org/environments/mo-lunar-lander/)
| Continuous / Discrete or Continuous | `[landed, shaped_reward, main_engine_fuel, side_engine_fuel]` | MO version of the `LunarLander-v2` [environment](https://gymnasium.farama.org/environments/box2d/lunar_lander/). Objectives defined similarly as in [Hung et al. 2022](https://openreview.net/forum?id=AwWaBXLIJE). |
+| [`minecart-v0`](https://mo-gymnasium.farama.org/environments/minecart/)
| Continuous or Image / Discrete | `[ore1, ore2, fuel]` | Agent must collect two types of ores and minimize fuel consumption. From [Abels et al. 2019](https://arxiv.org/abs/1809.07803v2). |
+| [`mo-highway-v0`](https://mo-gymnasium.farama.org/environments/mo-highway/) and `mo-highway-fast-v0`
| Continuous / Discrete | `[speed, right_lane, collision]` | The agent's objective is to reach a high speed while avoiding collisions with neighbouring vehicles and staying on the rightest lane. From [highway-env](https://github.com/eleurent/highway-env). |
+| [`mo-supermario-v0`](https://mo-gymnasium.farama.org/environments/mo-supermario/)
| Image / Discrete | `[x_pos, time, death, coin, enemy]` | [:warning: SuperMarioBrosEnv support is limited.] Multi-objective version of [SuperMarioBrosEnv](https://github.com/Kautenja/gym-super-mario-bros). Objectives are defined similarly as in [Yang et al. 2019](https://arxiv.org/pdf/1908.08342.pdf). |
+| [`mo-reacher-v4`](https://mo-gymnasium.farama.org/environments/mo-reacher/)
| Continuous / Discrete | `[target_1, target_2, target_3, target_4]` | Mujoco version of `mo-reacher-v0`, based on `Reacher-v4` [environment](https://gymnasium.farama.org/environments/mujoco/reacher/). |
+| [`mo-hopper-v4`](https://mo-gymnasium.farama.org/environments/mo-hopper/)
| Continuous / Continuous | `[velocity, height, energy]` | Multi-objective version of [Hopper-v4](https://gymnasium.farama.org/environments/mujoco/hopper/) env. |
+| [`mo-halfcheetah-v4`](https://mo-gymnasium.farama.org/environments/mo-halfcheetah/)
| Continuous / Continuous | `[velocity, energy]` | Multi-objective version of [HalfCheetah-v4](https://gymnasium.farama.org/environments/mujoco/half_cheetah/) env. Similar to [Xu et al. 2020](https://github.com/mit-gfx/PGMORL). |
```{toctree}
diff --git a/mo_gymnasium/envs/deep_sea_treasure/__init__.py b/mo_gymnasium/envs/deep_sea_treasure/__init__.py
index 65799cdc..152d4f54 100644
--- a/mo_gymnasium/envs/deep_sea_treasure/__init__.py
+++ b/mo_gymnasium/envs/deep_sea_treasure/__init__.py
@@ -1,6 +1,9 @@
from gymnasium.envs.registration import register
-from mo_gymnasium.envs.deep_sea_treasure.deep_sea_treasure import CONCAVE_MAP
+from mo_gymnasium.envs.deep_sea_treasure.deep_sea_treasure import (
+ CONCAVE_MAP,
+ MIRRORED_MAP,
+)
register(
@@ -15,3 +18,10 @@
max_episode_steps=100,
kwargs={"dst_map": CONCAVE_MAP},
)
+
+register(
+ id="deep-sea-treasure-mirrored-v0",
+ entry_point="mo_gymnasium.envs.deep_sea_treasure.deep_sea_treasure:DeepSeaTreasure",
+ max_episode_steps=100,
+ kwargs={"dst_map": MIRRORED_MAP},
+)
diff --git a/mo_gymnasium/envs/deep_sea_treasure/deep_sea_treasure.py b/mo_gymnasium/envs/deep_sea_treasure/deep_sea_treasure.py
index efd95750..374c15a0 100644
--- a/mo_gymnasium/envs/deep_sea_treasure/deep_sea_treasure.py
+++ b/mo_gymnasium/envs/deep_sea_treasure/deep_sea_treasure.py
@@ -68,6 +68,23 @@
np.array([124.0, -19]),
]
+# As in Felten et al. 2022, same PF as concave, just harder map
+MIRRORED_MAP = np.array(
+ [
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, -10, -10, 2.0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, -10, -10, -10, -10, 3.0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, -10, -10, -10, -10, -10, -10, 5.0, 8.0, 16.0, 0, 0, 0, 0],
+ [0, 0, 0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 0, 0, 0, 0],
+ [0, 0, 0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 0, 0, 0, 0],
+ [0, 0, 0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 24.0, 50.0, 0, 0],
+ [0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 0, 0],
+ [0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 74.0, 0],
+ [0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 124.0],
+ ]
+)
+
class DeepSeaTreasure(gym.Env, EzPickle):
"""
@@ -96,7 +113,7 @@ class DeepSeaTreasure(gym.Env, EzPickle):
The episode terminates when the agent reaches a treasure.
## Arguments
- - dst_map: the map of the deep sea treasure. Default is the convex map from Yang et al. (2019). To change, use `mo_gymnasium.make("DeepSeaTreasure-v0", dst_map=CONCAVE_MAP).`
+ - dst_map: the map of the deep sea treasure. Default is the convex map from Yang et al. (2019). To change, use `mo_gymnasium.make("DeepSeaTreasure-v0", dst_map=CONCAVE_MAP | MIRRORED_MAP).`
- float_state: if True, the state is a 2D continuous box with values in [0.0, 1.0] for the x and y coordinates of the submarine.
## Credits
@@ -115,8 +132,18 @@ def __init__(self, render_mode: Optional[str] = None, dst_map=DEFAULT_MAP, float
# The map of the deep sea treasure (convex version)
self.sea_map = dst_map
- self._pareto_front = CONVEX_FRONT if np.all(dst_map == DEFAULT_MAP) else CONCAVE_FRONT
- assert self.sea_map.shape == DEFAULT_MAP.shape, "The map's shape must be 11x11"
+ if dst_map.shape[0] == DEFAULT_MAP.shape[0] and dst_map.shape[1] == DEFAULT_MAP.shape[1]:
+ if np.all(dst_map == DEFAULT_MAP):
+ self.map_name = "convex"
+ elif np.all(dst_map == CONCAVE_MAP):
+ self.map_name = "concave"
+ else:
+ raise ValueError("Invalid map")
+ elif np.all(dst_map == MIRRORED_MAP):
+ self.map_name = "mirrored"
+ else:
+ raise ValueError("Invalid map")
+ self._pareto_front = CONVEX_FRONT if self.map_name == "convex" else CONCAVE_FRONT
self.dir = {
0: np.array([-1, 0], dtype=np.int32), # up
@@ -130,7 +157,7 @@ def __init__(self, render_mode: Optional[str] = None, dst_map=DEFAULT_MAP, float
if self.float_state:
self.observation_space = Box(low=0.0, high=1.0, shape=(2,), dtype=obs_type)
else:
- self.observation_space = Box(low=0, high=10, shape=(2,), dtype=obs_type)
+ self.observation_space = Box(low=0, high=len(self.sea_map[0]), shape=(2,), dtype=obs_type)
# action space specification: 1 dimension, 0 up, 1 down, 2 left, 3 right
self.action_space = Discrete(4)
@@ -144,11 +171,15 @@ def __init__(self, render_mode: Optional[str] = None, dst_map=DEFAULT_MAP, float
self.current_state = np.array([0, 0], dtype=np.int32)
# pygame
- self.window_size = (min(64 * self.sea_map.shape[1], 512), min(64 * self.sea_map.shape[0], 512))
+ ratio = self.sea_map.shape[1] / self.sea_map.shape[0]
+ padding = 10
+ self.pix_inside = (min(64 * self.sea_map.shape[1], 512) * ratio, min(64 * self.sea_map.shape[0], 512))
+ # adding some padding on the sides
+ self.window_size = (self.pix_inside[0] + 2 * padding, self.pix_inside[1])
# The size of a single grid square in pixels
self.pix_square_size = (
- self.window_size[1] // self.sea_map.shape[1] + 1,
- self.window_size[0] // self.sea_map.shape[0] + 1,
+ self.pix_inside[0] // self.sea_map.shape[1] + 1,
+ self.pix_inside[1] // self.sea_map.shape[0] + 1, # watch out for axis inversions here
)
self.window = None
self.clock = None
@@ -257,7 +288,12 @@ def _get_state(self):
def reset(self, seed=None, **kwargs):
super().reset(seed=seed)
- self.current_state = np.array([0, 0], dtype=np.int32)
+ if self.map_name == "convex" or self.map_name == "concave":
+ self.current_state = np.array([0, 0], dtype=np.int32)
+ elif self.map_name == "mirrored":
+ self.current_state = np.array([0, 10], dtype=np.int32)
+ else:
+ raise ValueError("Invalid map")
self.step_count = 0.0
state = self._get_state()
if self.render_mode == "human":
diff --git a/mo_gymnasium/envs/fruit_tree/assets/agent.png b/mo_gymnasium/envs/fruit_tree/assets/agent.png
new file mode 100644
index 00000000..8027fcb1
Binary files /dev/null and b/mo_gymnasium/envs/fruit_tree/assets/agent.png differ
diff --git a/mo_gymnasium/envs/fruit_tree/assets/node_blue.png b/mo_gymnasium/envs/fruit_tree/assets/node_blue.png
new file mode 100644
index 00000000..17645780
Binary files /dev/null and b/mo_gymnasium/envs/fruit_tree/assets/node_blue.png differ
diff --git a/mo_gymnasium/envs/fruit_tree/fruit_tree.py b/mo_gymnasium/envs/fruit_tree/fruit_tree.py
index fac03c06..7d978c38 100644
--- a/mo_gymnasium/envs/fruit_tree/fruit_tree.py
+++ b/mo_gymnasium/envs/fruit_tree/fruit_tree.py
@@ -1,8 +1,10 @@
# Environment from https://github.com/RunzheYang/MORL/blob/master/synthetic/envs/fruit_tree.py
-from typing import List
+from os import path
+from typing import List, Optional
import gymnasium as gym
import numpy as np
+import pygame
from gymnasium import spaces
from gymnasium.utils import EzPickle
@@ -264,16 +266,16 @@ class FruitTreeEnv(gym.Env, EzPickle):
The episode terminates when the agent reaches a leaf node.
"""
- def __init__(self, depth=6):
+ metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
+
+ def __init__(self, depth=6, render_mode: Optional[str] = None):
assert depth in [5, 6, 7], "Depth must be 5, 6 or 7."
EzPickle.__init__(self, depth)
+ self.render_mode = render_mode
self.reward_dim = 6
self.tree_depth = depth # zero based depth
branches = np.zeros((int(2**self.tree_depth - 1), self.reward_dim))
- # fruits = np.random.randn(2**self.tree_depth, self.reward_dim)
- # fruits = np.abs(fruits) / np.linalg.norm(fruits, 2, 1, True)
- # print(fruits*10)
fruits = np.array(FRUITS[str(depth)])
self.tree = np.concatenate([branches, fruits])
@@ -288,9 +290,35 @@ def __init__(self, depth=6):
self.current_state = np.array([0, 0], dtype=np.int32)
self.terminal = False
+ # pygame
+ self.row_height = 20
+ self.top_margin = 15
+
+ # Add margin at the bottom to account for the node rewards
+ self.window_size = (1200, self.row_height * self.tree_depth + 150)
+ self.window_padding = 15 # padding on the left and right of the window
+ self.node_square_size = np.array([10, 10], dtype=np.int32)
+ self.font_size = 12
+ pygame.font.init()
+ self.font = pygame.font.SysFont(None, self.font_size)
+
+ self.window = None
+ self.clock = None
+ self.node_img = None
+ self.agent_img = None
+
def get_ind(self, pos):
+ """Given the pos = current_state = [row_ind, pos_in_row]
+ return the index of the node in the tree array"""
return int(2 ** pos[0] - 1) + pos[1]
+ def ind_to_state(self, ind):
+ """Given the index of the node in the tree array return the
+ current_state = [row_ind, pos_in_row]"""
+ x = int(np.log2(ind + 1))
+ y = ind - 2**x + 1
+ return np.array([x, y], dtype=np.int32)
+
def get_tree_value(self, pos):
return np.array(self.tree[self.get_ind(pos)], dtype=np.float32)
@@ -325,5 +353,120 @@ def step(self, action):
reward = self.get_tree_value(self.current_state)
if self.current_state[0] == self.tree_depth:
self.terminal = True
-
return self.current_state.copy(), reward, self.terminal, False, {}
+
+ def get_pos_in_window(self, row, index_in_row):
+ """Given the row and index_in_row of the node
+ calculate its position in the window in pixels"""
+ window_width = self.window_size[0] - 2 * self.window_padding
+ distance_between_nodes = window_width / (2 ** (row))
+ pos_x = self.window_padding + (index_in_row + 0.5) * distance_between_nodes
+ pos_y = row * self.row_height
+ return np.array([pos_x, pos_y])
+
+ def render(self):
+ if self.render_mode is None:
+ assert self.spec is not None
+ gym.logger.warn(
+ "You are calling render method without specifying render mode."
+ "You can specify the render_mode at initialization, "
+ f'e.g. mo_gym.make("{self.spec.id}", render_mode="rgb_array")'
+ )
+ return
+
+ if self.clock is None and self.render_mode == "human":
+ self.clock = pygame.time.Clock()
+
+ if self.window is None:
+ pygame.init()
+
+ if self.render_mode == "human":
+ pygame.display.init()
+ pygame.display.set_caption("Fruit Tree")
+ self.window = pygame.display.set_mode(self.window_size)
+ self.clock.tick(self.metadata["render_fps"])
+ else:
+ self.window = pygame.Surface(self.window_size)
+
+ if self.node_img is None:
+ filename = path.join(path.dirname(__file__), "assets", "node_blue.png")
+ self.node_img = pygame.transform.scale(pygame.image.load(filename), self.node_square_size)
+ self.node_img = pygame.transform.flip(self.node_img, flip_x=True, flip_y=False)
+
+ if self.agent_img is None:
+ filename = path.join(path.dirname(__file__), "assets", "agent.png")
+ self.agent_img = pygame.transform.scale(pygame.image.load(filename), self.node_square_size)
+
+ canvas = pygame.Surface(self.window_size)
+ canvas.fill((255, 255, 255)) # White
+
+ # draw branches
+ for ind, node in enumerate(self.tree):
+ row, index_in_row = self.ind_to_state(ind)
+ node_pos = self.get_pos_in_window(row, index_in_row)
+ if row < self.tree_depth:
+ # Get childerns' positions and draw branches
+ child1_pos = self.get_pos_in_window(row + 1, 2 * index_in_row)
+ child2_pos = self.get_pos_in_window(row + 1, 2 * index_in_row + 1)
+ half_square = self.node_square_size / 2
+ pygame.draw.line(canvas, (90, 82, 85), node_pos + half_square, child1_pos + half_square, 1)
+ pygame.draw.line(canvas, (90, 82, 85), node_pos + half_square, child2_pos + half_square, 1)
+
+ for ind, node in enumerate(self.tree):
+ row, index_in_row = self.ind_to_state(ind)
+ if (row, index_in_row) == tuple(self.current_state):
+ img = self.agent_img
+ font_color = (164, 0, 0) # Red digits for agent node
+ else:
+ img = self.node_img
+ if ind % 2:
+ font_color = (250, 128, 114) # Green
+ else:
+ font_color = (45, 72, 101) # Dark Blue
+
+ node_pos = self.get_pos_in_window(row, index_in_row)
+
+ canvas.blit(img, np.array(node_pos))
+
+ # Print node values at the bottom of the tree
+ if row == self.tree_depth:
+ odd_nodes_values_offset = 0.5 * (ind % 2)
+ values_imgs = [self.font.render(f"{val:.2f}", True, font_color) for val in node]
+ for i, val_img in enumerate(values_imgs):
+ canvas.blit(val_img, node_pos + np.array([-5, (i + 1 + odd_nodes_values_offset) * 1.5 * self.font_size]))
+
+ background = pygame.Surface(self.window_size)
+ background.fill((255, 255, 255)) # White
+ background.blit(canvas, (0, self.top_margin))
+
+ self.window.blit(background, (0, 0))
+
+ if self.render_mode == "human":
+ pygame.event.pump()
+ pygame.display.update()
+ self.clock.tick(self.metadata["render_fps"])
+ elif self.render_mode == "rgb_array":
+ return np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2))
+
+ background = pygame.Surface(self.window_size)
+ background.fill((255, 255, 255)) # White
+
+ background.blit(canvas, (0, self.top_margin))
+
+ self.window.blit(background, (0, 0))
+
+
+if __name__ == "__main__":
+ import time
+
+ import mo_gymnasium as mo_gym
+
+ env = mo_gym.make("fruit-tree", depth=6, render_mode="human")
+ env.reset()
+ while True:
+ env.render()
+ obs, r, terminal, truncated, info = env.step(env.action_space.sample())
+ if terminal or truncated:
+ env.render()
+ time.sleep(2)
+ env.reset()
diff --git a/tests/test_envs.py b/tests/test_envs.py
index 4443789c..7e338be4 100644
--- a/tests/test_envs.py
+++ b/tests/test_envs.py
@@ -64,8 +64,8 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
if env_spec.nondeterministic is True:
return
- env_1 = env_spec.make(disable_env_checker=True)
- env_2 = env_spec.make(disable_env_checker=True)
+ env_1 = mo_gym.make(env_spec.id)
+ env_2 = mo_gym.make(env_spec.id)
env_1 = mo_gym.LinearReward(env_1)
env_2 = mo_gym.LinearReward(env_2)