Skip to content

Commit

Permalink
Add maskable GraphPPO based on sb3_contrib.MaskablePPO
Browse files Browse the repository at this point in the history
  • Loading branch information
nhuet authored and fteicht committed Jan 10, 2025
1 parent 248f6d0 commit dcf2120
Show file tree
Hide file tree
Showing 11 changed files with 481 additions and 77 deletions.
36 changes: 27 additions & 9 deletions examples/gnn_sb3_jsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,25 @@
from gymnasium.spaces import Box, Graph, GraphInstance

from skdecide.core import Space, TransitionOutcome, Value
from skdecide.domains import Domain
from skdecide.hub.domain.gym import GymDomain
from skdecide.hub.solver.stable_baselines import StableBaseline
from skdecide.hub.solver.stable_baselines.gnn import GraphPPO
from skdecide.hub.space.gym import GymSpace, ListSpace
from skdecide.hub.solver.stable_baselines.gnn.ppo_mask import MaskableGraphPPO
from skdecide.hub.space.gym import DiscreteSpace, GymSpace, ListSpace
from skdecide.utils import rollout

# JSP graph env


class D(Domain):
class D(GymDomain):
T_state = GraphInstance # Type of states
T_observation = T_state # Type of observations
T_event = int # Type of events
T_value = float # Type of transition values (rewards or costs)
T_info = None # Type of additional information in environment outcome


class GraphJspDomain(GymDomain, D):
class GraphJspDomain(D):
_gym_env: DisjunctiveGraphJspEnv

def __init__(self, gym_env):
Expand All @@ -45,14 +45,13 @@ def _get_applicable_actions_from(
) -> D.T_agent[Space[D.T_event]]:
return ListSpace(np.nonzero(self._gym_env.valid_action_mask())[0])

def _is_applicable_action_from(
self, action: D.T_agent[D.T_event], memory: D.T_memory[D.T_state]
) -> bool:
return self._gym_env.valid_action_mask()[action]

def _state_reset(self) -> D.T_state:
return self._np_state2graph_state(super()._state_reset())

def _get_action_space_(self) -> D.T_agent[Space[D.T_event]]:
# overriden to get an enumerable space
return DiscreteSpace(n=self._gym_env.action_space.n)

def _get_observation_space_(self) -> Space[D.T_observation]:
if self._gym_env.normalize_observation_space:
original_graph_space = Graph(
Expand Down Expand Up @@ -135,3 +134,22 @@ def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any:

solver.solve()
rollout(domain=domain_factory(), solver=solver, max_steps=100, num_episodes=1)

# solver with sb3-MaskableGraphPPO
domain_factory = lambda: GraphJspDomain(gym_env=jsp_env)
with StableBaseline(
domain_factory=domain_factory,
algo_class=MaskableGraphPPO,
baselines_policy="GraphInputPolicy",
learn_config={"total_timesteps": 100},
use_action_masking=True,
) as solver:

solver.solve()
rollout(
domain=domain_factory(),
solver=solver,
max_steps=100,
num_episodes=1,
use_applicable_actions=True,
)
42 changes: 40 additions & 2 deletions examples/gnn_sb3_maze.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import Any, Optional

import numpy as np
import numpy.typing as npt
from gymnasium.spaces import Box, Discrete, Graph, GraphInstance

from skdecide.builders.domain import Renderable, UnrestrictedActions
from skdecide.core import Space, Value
from skdecide.core import Mask, Space, Value
from skdecide.domains import DeterministicPlanningDomain
from skdecide.hub.domain.maze import Maze
from skdecide.hub.domain.maze.maze import DEFAULT_MAZE, Action, State
from skdecide.hub.solver.stable_baselines import StableBaseline
from skdecide.hub.solver.stable_baselines.gnn import GraphPPO
from skdecide.hub.solver.stable_baselines.gnn.ppo_mask import MaskableGraphPPO
from skdecide.hub.space.gym import GymSpace, ListSpace
from skdecide.utils import rollout

Expand Down Expand Up @@ -158,6 +158,25 @@ def _render_from(self, memory: D.T_state, **kwargs: Any) -> Any:
maze_memory = self._graph2mazestate(memory)
self.maze_domain._render_from(memory=maze_memory, **kwargs)

def _get_action_mask(
self, memory: Optional[D.T_memory[D.T_state]] = None
) -> D.T_agent[Mask]:
# overriden since by default it is only 1's (inheriting from UnrestrictedAction)
# we could also override only _get_applicable_action() but it will be more computationally efficient to
# implement directly get_action_mask()
if memory is None:
memory = self._memory
mazestate_memory = self._graph2mazestate(memory)
return np.array(
[
self._graph2mazestate(
self._get_next_state(action=action, memory=memory)
)
!= mazestate_memory
for action in self._get_action_space().get_elements()
]
)


MAZE = """
+-+-+-+-+o+-+-+--+-+-+
Expand Down Expand Up @@ -201,3 +220,22 @@ def _render_from(self, memory: D.T_state, **kwargs: Any) -> Any:

solver.solve()
rollout(domain=domain_factory(), solver=solver, max_steps=max_steps, num_episodes=1)

# solver with sb3-MaskableGraphPPO
domain_factory = lambda: GraphMaze(maze_str=MAZE)
with StableBaseline(
domain_factory=domain_factory,
algo_class=MaskableGraphPPO,
baselines_policy="GraphInputPolicy",
learn_config={"total_timesteps": 100},
use_action_masking=True,
) as solver:

solver.solve()
rollout(
domain=domain_factory(),
solver=solver,
max_steps=100,
num_episodes=1,
use_applicable_actions=True,
)
20 changes: 17 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ numpy = { version = "^1.20.1", optional = true }
matplotlib = { version = ">=3.3.4", optional = true }
joblib = { version = ">=1.0.1", optional = true }
stable-baselines3 = { version = ">=2.0.0", optional = true }
sb3_contrib = { version = ">=2.3", optional = true }
ray = { extras = ["rllib"], version = ">=2.9.0, <2.38", optional = true }
discrete-optimization = { version = ">=0.5.0" }
openap = { version = ">=1.5", python = ">=3.9", optional = true }
Expand Down Expand Up @@ -105,6 +106,7 @@ solvers = [
"joblib",
"ray",
"stable-baselines3",
"sb3_contrib",
"unified-planning",
"up-tamer",
"up-fast-downward",
Expand All @@ -122,6 +124,7 @@ all = [
"joblib",
"ray",
"stable-baselines3",
"sb3_contrib",
"openap",
"pygeodesy",
"unified-planning",
Expand Down
49 changes: 49 additions & 0 deletions skdecide/hub/solver/stable_baselines/gnn/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
import torch as th
import torch_geometric as thg
from gymnasium import spaces
from sb3_contrib.common.maskable.buffers import (
MaskableDictRolloutBuffer,
MaskableRolloutBuffer,
MaskableRolloutBufferSamples,
)
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.type_aliases import RolloutBufferSamples
Expand Down Expand Up @@ -246,6 +251,50 @@ def _get_observations_samples(
}


class _BaseMaskableRolloutBuffer:

tensor_names = [
"actions",
"values",
"log_probs",
"advantages",
"returns",
"action_masks",
]

def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None:
"""
:param action_masks: Masks applied to constrain the choice of possible actions.
"""
if action_masks is not None:
self.action_masks[self.pos] = action_masks.reshape(
(self.n_envs, self.mask_dims)
)

super().add(*args, **kwargs)

def _get_samples(
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
) -> MaskableRolloutBufferSamples:
samples_wo_action_masks = super()._get_samples(batch_inds=batch_inds, env=env)
return MaskableRolloutBufferSamples(
*samples_wo_action_masks,
action_masks=self.action_masks[batch_inds].reshape(-1, self.mask_dims),
)


class MaskableGraphRolloutBuffer(
_BaseMaskableRolloutBuffer, GraphRolloutBuffer, MaskableRolloutBuffer
):
...


class MaskableDictGraphRolloutBuffer(
_BaseMaskableRolloutBuffer, DictGraphRolloutBuffer, MaskableDictRolloutBuffer
):
...


T = TypeVar("T")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
import numpy as np
import torch as th
from gymnasium import spaces
from sb3_contrib.common.maskable.buffers import (
MaskableDictRolloutBuffer,
MaskableRolloutBuffer,
)
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
Expand All @@ -18,6 +23,13 @@
class GraphOnPolicyAlgorithm(OnPolicyAlgorithm):
"""Base class for On-Policy algorithms (ex: A2C/PPO) with graph observations."""

support_action_masking = False
"""Whether this algorithm supports action masking.
Useful to share the code between algorithms.
"""

def __init__(
self,
policy: Union[str, type[ActorCriticPolicy]],
Expand Down Expand Up @@ -54,19 +66,21 @@ def collect_rollouts(
callback: BaseCallback,
rollout_buffer: RolloutBuffer,
n_rollout_steps: int,
use_masking: bool = False,
) -> bool:
"""
Collect experiences using the current policy and fill a ``RolloutBuffer``.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
This method is largely identical to the implementation found in the parent class.
This method is largely identical to the implementation found in the parent class and MaskablePPO.
:param env: The training environment
:param callback: Callback that will be called at each step
(and at the beginning and end of the rollout)
:param rollout_buffer: Buffer to fill with rollouts
:param n_rollout_steps: Number of experiences to collect per environment
:param use_masking: Whether to use invalid action masks during training
:return: True if function returned with at least `n_rollout_steps`
collected, False if callback terminated rollout prematurely.
"""
Expand All @@ -75,8 +89,23 @@ def collect_rollouts(
self.policy.set_training_mode(False)

n_steps = 0
action_masks = None
rollout_buffer.reset()

if (
use_masking
and self.support_action_masking
and not is_masking_supported(env)
):
raise ValueError(
"Environment does not support action masking. Consider using ActionMasker wrapper"
)

if use_masking and not self.support_action_masking:
raise ValueError(
f"The algorithm {self.__class__.__name__} does not support action masking."
)

# Sample new weights for the state dependent exploration
if self.use_sde:
self.policy.reset_noise(env.num_envs)
Expand All @@ -96,7 +125,15 @@ def collect_rollouts(
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device)

actions, values, log_probs = self.policy(obs_tensor)
if use_masking and self.support_action_masking:
action_masks = get_action_masks(env)

if self.support_action_masking:
actions, values, log_probs = self.policy(
obs_tensor, action_masks=action_masks
)
else:
actions, values, log_probs = self.policy(obs_tensor)
actions = actions.cpu().numpy()

# Rescale and perform action
Expand Down Expand Up @@ -145,14 +182,27 @@ def collect_rollouts(
terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type]
rewards[idx] += self.gamma * terminal_value

rollout_buffer.add(
self._last_obs, # type: ignore[arg-type]
actions,
rewards,
self._last_episode_starts, # type: ignore[arg-type]
values,
log_probs,
)
if isinstance(
rollout_buffer, (MaskableRolloutBuffer, MaskableDictRolloutBuffer)
):
rollout_buffer.add(
self._last_obs, # type: ignore[arg-type]
actions,
rewards,
self._last_episode_starts, # type: ignore[arg-type]
values,
log_probs,
action_masks=action_masks,
)
else:
rollout_buffer.add(
self._last_obs, # type: ignore[arg-type]
actions,
rewards,
self._last_episode_starts, # type: ignore[arg-type]
values,
log_probs,
)
self._last_obs = new_obs # type: ignore[assignment]
self._last_episode_starts = dones

Expand Down
Loading

0 comments on commit dcf2120

Please sign in to comment.