Skip to content

Commit

Permalink
Implement a GNN PPO for ray-rllib
Browse files Browse the repository at this point in the history
We follow the same guidelines as for the sb3 wrapper:
- GNN based on pytorch-geometric
- Feature extraction via GNN + reduction layer to a fixed number of
  feature
- Observation = Graph or dict whose values contains at least one Graph
- Action masks are taken into account if available
- User must use GraphPPO instead of PPO as algorithm: GraphPPO overrides
  PPO to change the way obs is converted to pytorch format

Worth noticing:
- We use the old api stack as the RLlib wrapper is currently using it
- For graph observations, the model is gnn extractor followed by a FullyConnectedNetwork
- For dict of graphs (and other) observations, the model is
  - preprocess obs by using gnn features extractor for graph components
  - apply to the prepreocessed obs a ComplexInputNetwork
- action masking is automatically activated according to domain class
  (not UnrestrictedActions) and algo class, as it was already coded in
  RayRLlib wrapper. The algo to be used is still GraphPPO as masking is
  managed by a custom model at RayRLlib wrapper level.
  • Loading branch information
nhuet authored and fteicht committed Jan 17, 2025
1 parent 9d825f7 commit 324ebf4
Show file tree
Hide file tree
Showing 21 changed files with 1,669 additions and 94 deletions.
45 changes: 36 additions & 9 deletions skdecide/hub/solver/ray_rllib/custom_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from gymnasium.spaces import flatten_space
from ray.rllib import SampleBatch
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork as TFFullyConnectedNetwork
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.fcnet import (
Expand All @@ -9,6 +10,15 @@
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray, unbatch
from ray.rllib.utils.torch_utils import FLOAT_MAX, FLOAT_MIN

from skdecide.hub.solver.ray_rllib.gnn.models.torch.complex_input_net import (
GraphComplexInputNetwork,
)
from skdecide.hub.solver.ray_rllib.gnn.models.torch.gnn import GnnBasedModel
from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import (
is_graph_dict_multiinput_space,
is_graph_dict_space,
)

tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()

Expand Down Expand Up @@ -98,8 +108,20 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, name, **k
self.action_ids_shifted = torch.arange(1, num_outputs + 1, dtype=torch.int64)
self.true_obs_space = model_config["custom_model_config"]["true_obs_space"]

self.pred_action_embed_model = TorchFullyConnectedNetwork(
flatten_space(self.true_obs_space),
if is_graph_dict_space(self.true_obs_space):
pred_action_embed_model_cls = GnnBasedModel
self.obs_with_graph = True
embed_model_obs_space = self.true_obs_space
elif is_graph_dict_multiinput_space(self.true_obs_space):
pred_action_embed_model_cls = GraphComplexInputNetwork
self.obs_with_graph = True
embed_model_obs_space = self.true_obs_space
else:
pred_action_embed_model_cls = TorchFullyConnectedNetwork
self.obs_with_graph = False
embed_model_obs_space = flatten_space(self.true_obs_space)
self.pred_action_embed_model = pred_action_embed_model_cls(
embed_model_obs_space,
action_space,
model_config["custom_model_config"]["action_embed_size"],
model_config,
Expand All @@ -115,16 +137,21 @@ def forward(self, input_dict, state, seq_lens):
# Extract the available actions mask tensor from the observation.
valid_avail_actions_mask = input_dict["obs"]["valid_avail_actions_mask"]

# Unbatch true observations before flattening them
unbatched_true_obs = unbatch(input_dict["obs"]["true_obs"])
if self.obs_with_graph:
# use directly the obs (already converted at proper format by custom `convert_to_torch_tensor`)
embed_model_obs = input_dict["obs"]["true_obs"]
else:
# Unbatch true observations before flattening them
embed_model_obs = torch.stack(
[
flatten_to_single_ndarray(o)
for o in unbatch(input_dict["obs"]["true_obs"])
]
)

# Compute the predicted action embedding
pred_action_embed, _ = self.pred_action_embed_model(
{
"obs": torch.stack(
[flatten_to_single_ndarray(o) for o in unbatched_true_obs]
)
}
SampleBatch({SampleBatch.OBS: embed_model_obs})
)

# Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
Expand Down
Empty file.
1 change: 1 addition & 0 deletions skdecide/hub/solver/ray_rllib/gnn/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .ppo.ppo import GraphPPO
Empty file.
21 changes: 21 additions & 0 deletions skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Optional

from ray.rllib import Policy
from ray.rllib.algorithms import PPO, AlgorithmConfig

from skdecide.hub.solver.ray_rllib.gnn.algorithms.ppo.ppo_torch_policy import (
PPOTorchGraphPolicy,
)


class GraphPPO(PPO):
@classmethod
def get_default_policy_class(
cls, config: AlgorithmConfig
) -> Optional[type[Policy]]:
if config["framework"] == "torch":
return PPOTorchGraphPolicy
elif config["framework"] == "tf":
raise NotImplementedError("GraphPPO implemented for torch context")
else:
raise NotImplementedError("GraphPPO implemented for torch context")
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ray.rllib.algorithms.ppo import PPOTorchPolicy

from skdecide.hub.solver.ray_rllib.gnn.policy.torch_graph_policy import TorchGraphPolicy


class PPOTorchGraphPolicy(TorchGraphPolicy, PPOTorchPolicy):
...
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import gymnasium as gym
from ray.rllib import SampleBatch
from ray.rllib.models.torch.complex_input_net import ComplexInputNetwork
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.typing import TensorType
from torch import nn

from skdecide.hub.solver.ray_rllib.gnn.models.torch.gnn import GnnBasedModel
from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import (
is_graph_dict_space,
)


class GraphComplexInputNetwork(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
if not model_config.get("_disable_preprocessor_api"):
raise ValueError(
"This model is intent to be used only when preprocessors are disabled."
)
if not isinstance(obs_space, gym.spaces.Dict):
raise ValueError(
"This model is intent to be used only on dict observation space."
)

nn.Module.__init__(self)
super().__init__(obs_space, action_space, num_outputs, model_config, name)

self.gnn = nn.ModuleDict()
post_graph_obs_subspaces = dict(obs_space.spaces)
for k, subspace in obs_space.spaces.items():
if is_graph_dict_space(subspace):
submodel_name = f"gnn_{k}"
gnn = GnnBasedModel(
obs_space=subspace,
action_space=action_space,
num_outputs=None,
model_config=model_config,
framework="torch",
name=submodel_name,
)
self.add_module(submodel_name, gnn)
self.gnn[k] = gnn
post_graph_obs_subspaces[k] = gnn.features_space

post_graph_obs_space = gym.spaces.Dict(post_graph_obs_subspaces)
self.post_graph_model = ComplexInputNetwork(
obs_space=post_graph_obs_space,
action_space=action_space,
num_outputs=num_outputs,
model_config=model_config,
name="post_graph_model",
)

def forward(self, input_dict: SampleBatch, state, seq_lens):
post_graph_input_dict = input_dict.copy(shallow=True)
obs = input_dict["obs"]
post_graph_obs = dict(obs)
for k, gnn in self.gnn.items():
post_graph_obs[k] = gnn(SampleBatch({SampleBatch.OBS: obs[k]}))
post_graph_input_dict["obs"] = post_graph_obs
return self.post_graph_model(
input_dict=post_graph_input_dict, state=state, seq_lens=seq_lens
)

def value_function(self) -> TensorType:
return self.post_graph_model.value_function()
77 changes: 77 additions & 0 deletions skdecide/hub/solver/ray_rllib/gnn/models/torch/gnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from collections import defaultdict
from typing import Optional

import gymnasium as gym
import numpy as np
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.typing import ModelConfigDict
from torch import nn

from skdecide.hub.solver.ray_rllib.gnn.torch_layers import GraphFeaturesExtractor
from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import (
convert_dict_space_to_graph_space,
is_graph_dict_space,
)


class GnnBasedModel(TorchModelV2, nn.Module):
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: Optional[int],
model_config: ModelConfigDict,
name: str,
**kw,
):
nn.Module.__init__(self)
super().__init__(obs_space, action_space, num_outputs, model_config, name)

# config for custom model
custom_config = defaultdict(
lambda: None, # will return None for missing keys
model_config.get("custom_model_config", {}),
)

# gnn-based feature extractor
features_extractor_kwargs = custom_config.get("features_extractor", {})
assert is_graph_dict_space(
obs_space
), f"{self.__class__.__name__} can only be applied to Graph observation spaces."
graph_observation_space = convert_dict_space_to_graph_space(obs_space)
self.features_extractor = GraphFeaturesExtractor(
observation_space=graph_observation_space, **features_extractor_kwargs
)
self.features_space = gym.spaces.Box(
low=-np.inf, high=np.inf, shape=(self.features_extractor.features_dim,)
)

if num_outputs is None:
# only feature extraction (e.g. to be used by GraphComplexInputNetwork)
self.num_outputs = self.features_extractor.features_dim
self.pred_action_embed_model = None
else:
# fully connected network
self.pred_action_embed_model = FullyConnectedNetwork(
obs_space=self.features_space,
action_space=action_space,
num_outputs=num_outputs,
model_config=model_config,
name=name + "_pred_action_embed",
)

def forward(self, input_dict, state, seq_lens):
obs = input_dict["obs"]
features = self.features_extractor(obs)
if self.pred_action_embed_model is None:
return features, state
else:
return self.pred_action_embed_model(
input_dict={"obs": features},
state=state,
seq_lens=seq_lens,
)

def value_function(self):
return self.pred_action_embed_model.value_function()
Empty file.
116 changes: 116 additions & 0 deletions skdecide/hub/solver/ray_rllib/gnn/policy/sample_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from numbers import Number
from typing import Any, Optional, Union

import gymnasium as gym
import numpy as np
import tree
from ray.rllib import SampleBatch
from ray.rllib.policy.sample_batch import attempt_count_timesteps, tf, torch
from ray.rllib.utils.typing import ViewRequirementsDict


def _pop_graph_items(
full_dict: dict[Any, Any]
) -> dict[Any, Union[gym.spaces.GraphInstance, list[gym.spaces.GraphInstance]]]:
graph_dict = {}
for k, v in full_dict.items():
if isinstance(v, gym.spaces.GraphInstance) or (
isinstance(v, list) and isinstance(v[0], gym.spaces.GraphInstance)
):
graph_dict[k] = v
for k in graph_dict:
full_dict.pop(k)
return graph_dict


def _split_graph_requirements(
full_dict: ViewRequirementsDict,
) -> tuple[ViewRequirementsDict, ViewRequirementsDict]:
graph_dict = {}
for k, v in full_dict.items():
if isinstance(v.space, gym.spaces.Graph):
graph_dict[k] = v
wo_graph_dict = {k: v for k, v in full_dict.items() if k not in graph_dict}
return graph_dict, wo_graph_dict


class GraphSampleBatch(SampleBatch):
def __init__(self, *args, **kwargs):
"""Constructs a sample batch with possibly graph obs.
See `ray.rllib.SampleBatch` for more information.
"""
# split graph samples from others.
dict_graphs = _pop_graph_items(kwargs)
dict_from_args = dict(*args)
dict_graphs.update(_pop_graph_items(dict_from_args))

super().__init__(dict_from_args, **kwargs)
super().update(dict_graphs)

def copy(self, shallow: bool = False) -> "SampleBatch":
"""Creates a deep or shallow copy of this SampleBatch and returns it.
Args:
shallow: Whether the copying should be done shallowly.
Returns:
A deep or shallow copy of this SampleBatch object.
"""
copy_ = dict(self)
data = tree.map_structure(
lambda v: (
np.array(v, copy=not shallow) if isinstance(v, np.ndarray) else v
),
copy_,
)
copy_ = GraphSampleBatch(
data,
_time_major=self.time_major,
_zero_padded=self.zero_padded,
_max_seq_len=self.max_seq_len,
_num_grad_updates=self.num_grad_updates,
)
copy_.set_get_interceptor(self.get_interceptor)
copy_.added_keys = self.added_keys
copy_.deleted_keys = self.deleted_keys
copy_.accessed_keys = self.accessed_keys
return copy_

def get_single_step_input_dict(
self,
view_requirements: ViewRequirementsDict,
index: Union[str, int] = "last",
) -> "SampleBatch":
(
view_requirements_graphs,
view_requirements_wo_graphs,
) = _split_graph_requirements(view_requirements)
# w/o graphs
sample = GraphSampleBatch(
super().get_single_step_input_dict(view_requirements_wo_graphs, index)
)
# handle graphs
last_mappings = {
SampleBatch.OBS: SampleBatch.NEXT_OBS,
SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS,
SampleBatch.PREV_REWARDS: SampleBatch.REWARDS,
}
for view_col, view_req in view_requirements_graphs.items():
if view_req.used_for_compute_actions is False:
continue

# Create batches of size 1 (single-agent input-dict).
data_col = view_req.data_col or view_col
if index == "last":
data_col = last_mappings.get(data_col, data_col)
if view_req.shift_from is not None:
raise NotImplementedError()
else:
sample[view_col] = self[data_col][-1:]
else:
sample[view_col] = self[data_col][
index : index + 1 if index != -1 else None
]
return sample
16 changes: 16 additions & 0 deletions skdecide/hub/solver/ray_rllib/gnn/policy/torch_graph_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import functools

from ray.rllib import SampleBatch
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2

from skdecide.hub.solver.ray_rllib.gnn.utils.torch_utils import convert_to_torch_tensor


class TorchGraphPolicy(TorchPolicyV2):
def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None):
if not isinstance(postprocessed_batch, SampleBatch):
postprocessed_batch = SampleBatch(postprocessed_batch)
postprocessed_batch.set_get_interceptor(
functools.partial(convert_to_torch_tensor, device=device or self.device)
)
return postprocessed_batch
Loading

0 comments on commit 324ebf4

Please sign in to comment.