-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
We follow the same guidelines as for the sb3 wrapper: - GNN based on pytorch-geometric - Feature extraction via GNN + reduction layer to a fixed number of feature - Observation = Graph or dict whose values contains at least one Graph - Action masks are taken into account if available - User must use GraphPPO instead of PPO as algorithm: GraphPPO overrides PPO to change the way obs is converted to pytorch format Worth noticing: - We use the old api stack as the RLlib wrapper is currently using it - For graph observations, the model is gnn extractor followed by a FullyConnectedNetwork - For dict of graphs (and other) observations, the model is - preprocess obs by using gnn features extractor for graph components - apply to the prepreocessed obs a ComplexInputNetwork - action masking is automatically activated according to domain class (not UnrestrictedActions) and algo class, as it was already coded in RayRLlib wrapper. The algo to be used is still GraphPPO as masking is managed by a custom model at RayRLlib wrapper level.
- Loading branch information
Showing
21 changed files
with
1,669 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .ppo.ppo import GraphPPO |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from typing import Optional | ||
|
||
from ray.rllib import Policy | ||
from ray.rllib.algorithms import PPO, AlgorithmConfig | ||
|
||
from skdecide.hub.solver.ray_rllib.gnn.algorithms.ppo.ppo_torch_policy import ( | ||
PPOTorchGraphPolicy, | ||
) | ||
|
||
|
||
class GraphPPO(PPO): | ||
@classmethod | ||
def get_default_policy_class( | ||
cls, config: AlgorithmConfig | ||
) -> Optional[type[Policy]]: | ||
if config["framework"] == "torch": | ||
return PPOTorchGraphPolicy | ||
elif config["framework"] == "tf": | ||
raise NotImplementedError("GraphPPO implemented for torch context") | ||
else: | ||
raise NotImplementedError("GraphPPO implemented for torch context") |
7 changes: 7 additions & 0 deletions
7
skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo_torch_policy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from ray.rllib.algorithms.ppo import PPOTorchPolicy | ||
|
||
from skdecide.hub.solver.ray_rllib.gnn.policy.torch_graph_policy import TorchGraphPolicy | ||
|
||
|
||
class PPOTorchGraphPolicy(TorchGraphPolicy, PPOTorchPolicy): | ||
... |
Empty file.
Empty file.
66 changes: 66 additions & 0 deletions
66
skdecide/hub/solver/ray_rllib/gnn/models/torch/complex_input_net.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import gymnasium as gym | ||
from ray.rllib import SampleBatch | ||
from ray.rllib.models.torch.complex_input_net import ComplexInputNetwork | ||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 | ||
from ray.rllib.utils.typing import TensorType | ||
from torch import nn | ||
|
||
from skdecide.hub.solver.ray_rllib.gnn.models.torch.gnn import GnnBasedModel | ||
from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import ( | ||
is_graph_dict_space, | ||
) | ||
|
||
|
||
class GraphComplexInputNetwork(TorchModelV2, nn.Module): | ||
def __init__(self, obs_space, action_space, num_outputs, model_config, name): | ||
if not model_config.get("_disable_preprocessor_api"): | ||
raise ValueError( | ||
"This model is intent to be used only when preprocessors are disabled." | ||
) | ||
if not isinstance(obs_space, gym.spaces.Dict): | ||
raise ValueError( | ||
"This model is intent to be used only on dict observation space." | ||
) | ||
|
||
nn.Module.__init__(self) | ||
super().__init__(obs_space, action_space, num_outputs, model_config, name) | ||
|
||
self.gnn = nn.ModuleDict() | ||
post_graph_obs_subspaces = dict(obs_space.spaces) | ||
for k, subspace in obs_space.spaces.items(): | ||
if is_graph_dict_space(subspace): | ||
submodel_name = f"gnn_{k}" | ||
gnn = GnnBasedModel( | ||
obs_space=subspace, | ||
action_space=action_space, | ||
num_outputs=None, | ||
model_config=model_config, | ||
framework="torch", | ||
name=submodel_name, | ||
) | ||
self.add_module(submodel_name, gnn) | ||
self.gnn[k] = gnn | ||
post_graph_obs_subspaces[k] = gnn.features_space | ||
|
||
post_graph_obs_space = gym.spaces.Dict(post_graph_obs_subspaces) | ||
self.post_graph_model = ComplexInputNetwork( | ||
obs_space=post_graph_obs_space, | ||
action_space=action_space, | ||
num_outputs=num_outputs, | ||
model_config=model_config, | ||
name="post_graph_model", | ||
) | ||
|
||
def forward(self, input_dict: SampleBatch, state, seq_lens): | ||
post_graph_input_dict = input_dict.copy(shallow=True) | ||
obs = input_dict["obs"] | ||
post_graph_obs = dict(obs) | ||
for k, gnn in self.gnn.items(): | ||
post_graph_obs[k] = gnn(SampleBatch({SampleBatch.OBS: obs[k]})) | ||
post_graph_input_dict["obs"] = post_graph_obs | ||
return self.post_graph_model( | ||
input_dict=post_graph_input_dict, state=state, seq_lens=seq_lens | ||
) | ||
|
||
def value_function(self) -> TensorType: | ||
return self.post_graph_model.value_function() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from collections import defaultdict | ||
from typing import Optional | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork | ||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 | ||
from ray.rllib.utils.typing import ModelConfigDict | ||
from torch import nn | ||
|
||
from skdecide.hub.solver.ray_rllib.gnn.torch_layers import GraphFeaturesExtractor | ||
from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import ( | ||
convert_dict_space_to_graph_space, | ||
is_graph_dict_space, | ||
) | ||
|
||
|
||
class GnnBasedModel(TorchModelV2, nn.Module): | ||
def __init__( | ||
self, | ||
obs_space: gym.spaces.Space, | ||
action_space: gym.spaces.Space, | ||
num_outputs: Optional[int], | ||
model_config: ModelConfigDict, | ||
name: str, | ||
**kw, | ||
): | ||
nn.Module.__init__(self) | ||
super().__init__(obs_space, action_space, num_outputs, model_config, name) | ||
|
||
# config for custom model | ||
custom_config = defaultdict( | ||
lambda: None, # will return None for missing keys | ||
model_config.get("custom_model_config", {}), | ||
) | ||
|
||
# gnn-based feature extractor | ||
features_extractor_kwargs = custom_config.get("features_extractor", {}) | ||
assert is_graph_dict_space( | ||
obs_space | ||
), f"{self.__class__.__name__} can only be applied to Graph observation spaces." | ||
graph_observation_space = convert_dict_space_to_graph_space(obs_space) | ||
self.features_extractor = GraphFeaturesExtractor( | ||
observation_space=graph_observation_space, **features_extractor_kwargs | ||
) | ||
self.features_space = gym.spaces.Box( | ||
low=-np.inf, high=np.inf, shape=(self.features_extractor.features_dim,) | ||
) | ||
|
||
if num_outputs is None: | ||
# only feature extraction (e.g. to be used by GraphComplexInputNetwork) | ||
self.num_outputs = self.features_extractor.features_dim | ||
self.pred_action_embed_model = None | ||
else: | ||
# fully connected network | ||
self.pred_action_embed_model = FullyConnectedNetwork( | ||
obs_space=self.features_space, | ||
action_space=action_space, | ||
num_outputs=num_outputs, | ||
model_config=model_config, | ||
name=name + "_pred_action_embed", | ||
) | ||
|
||
def forward(self, input_dict, state, seq_lens): | ||
obs = input_dict["obs"] | ||
features = self.features_extractor(obs) | ||
if self.pred_action_embed_model is None: | ||
return features, state | ||
else: | ||
return self.pred_action_embed_model( | ||
input_dict={"obs": features}, | ||
state=state, | ||
seq_lens=seq_lens, | ||
) | ||
|
||
def value_function(self): | ||
return self.pred_action_embed_model.value_function() |
Empty file.
116 changes: 116 additions & 0 deletions
116
skdecide/hub/solver/ray_rllib/gnn/policy/sample_batch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from numbers import Number | ||
from typing import Any, Optional, Union | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
import tree | ||
from ray.rllib import SampleBatch | ||
from ray.rllib.policy.sample_batch import attempt_count_timesteps, tf, torch | ||
from ray.rllib.utils.typing import ViewRequirementsDict | ||
|
||
|
||
def _pop_graph_items( | ||
full_dict: dict[Any, Any] | ||
) -> dict[Any, Union[gym.spaces.GraphInstance, list[gym.spaces.GraphInstance]]]: | ||
graph_dict = {} | ||
for k, v in full_dict.items(): | ||
if isinstance(v, gym.spaces.GraphInstance) or ( | ||
isinstance(v, list) and isinstance(v[0], gym.spaces.GraphInstance) | ||
): | ||
graph_dict[k] = v | ||
for k in graph_dict: | ||
full_dict.pop(k) | ||
return graph_dict | ||
|
||
|
||
def _split_graph_requirements( | ||
full_dict: ViewRequirementsDict, | ||
) -> tuple[ViewRequirementsDict, ViewRequirementsDict]: | ||
graph_dict = {} | ||
for k, v in full_dict.items(): | ||
if isinstance(v.space, gym.spaces.Graph): | ||
graph_dict[k] = v | ||
wo_graph_dict = {k: v for k, v in full_dict.items() if k not in graph_dict} | ||
return graph_dict, wo_graph_dict | ||
|
||
|
||
class GraphSampleBatch(SampleBatch): | ||
def __init__(self, *args, **kwargs): | ||
"""Constructs a sample batch with possibly graph obs. | ||
See `ray.rllib.SampleBatch` for more information. | ||
""" | ||
# split graph samples from others. | ||
dict_graphs = _pop_graph_items(kwargs) | ||
dict_from_args = dict(*args) | ||
dict_graphs.update(_pop_graph_items(dict_from_args)) | ||
|
||
super().__init__(dict_from_args, **kwargs) | ||
super().update(dict_graphs) | ||
|
||
def copy(self, shallow: bool = False) -> "SampleBatch": | ||
"""Creates a deep or shallow copy of this SampleBatch and returns it. | ||
Args: | ||
shallow: Whether the copying should be done shallowly. | ||
Returns: | ||
A deep or shallow copy of this SampleBatch object. | ||
""" | ||
copy_ = dict(self) | ||
data = tree.map_structure( | ||
lambda v: ( | ||
np.array(v, copy=not shallow) if isinstance(v, np.ndarray) else v | ||
), | ||
copy_, | ||
) | ||
copy_ = GraphSampleBatch( | ||
data, | ||
_time_major=self.time_major, | ||
_zero_padded=self.zero_padded, | ||
_max_seq_len=self.max_seq_len, | ||
_num_grad_updates=self.num_grad_updates, | ||
) | ||
copy_.set_get_interceptor(self.get_interceptor) | ||
copy_.added_keys = self.added_keys | ||
copy_.deleted_keys = self.deleted_keys | ||
copy_.accessed_keys = self.accessed_keys | ||
return copy_ | ||
|
||
def get_single_step_input_dict( | ||
self, | ||
view_requirements: ViewRequirementsDict, | ||
index: Union[str, int] = "last", | ||
) -> "SampleBatch": | ||
( | ||
view_requirements_graphs, | ||
view_requirements_wo_graphs, | ||
) = _split_graph_requirements(view_requirements) | ||
# w/o graphs | ||
sample = GraphSampleBatch( | ||
super().get_single_step_input_dict(view_requirements_wo_graphs, index) | ||
) | ||
# handle graphs | ||
last_mappings = { | ||
SampleBatch.OBS: SampleBatch.NEXT_OBS, | ||
SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS, | ||
SampleBatch.PREV_REWARDS: SampleBatch.REWARDS, | ||
} | ||
for view_col, view_req in view_requirements_graphs.items(): | ||
if view_req.used_for_compute_actions is False: | ||
continue | ||
|
||
# Create batches of size 1 (single-agent input-dict). | ||
data_col = view_req.data_col or view_col | ||
if index == "last": | ||
data_col = last_mappings.get(data_col, data_col) | ||
if view_req.shift_from is not None: | ||
raise NotImplementedError() | ||
else: | ||
sample[view_col] = self[data_col][-1:] | ||
else: | ||
sample[view_col] = self[data_col][ | ||
index : index + 1 if index != -1 else None | ||
] | ||
return sample |
16 changes: 16 additions & 0 deletions
16
skdecide/hub/solver/ray_rllib/gnn/policy/torch_graph_policy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import functools | ||
|
||
from ray.rllib import SampleBatch | ||
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 | ||
|
||
from skdecide.hub.solver.ray_rllib.gnn.utils.torch_utils import convert_to_torch_tensor | ||
|
||
|
||
class TorchGraphPolicy(TorchPolicyV2): | ||
def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None): | ||
if not isinstance(postprocessed_batch, SampleBatch): | ||
postprocessed_batch = SampleBatch(postprocessed_batch) | ||
postprocessed_batch.set_get_interceptor( | ||
functools.partial(convert_to_torch_tensor, device=device or self.device) | ||
) | ||
return postprocessed_batch |
Oops, something went wrong.