Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix return hint for EnumerableSpace.get_elements() and simplify implementations #450

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/nocycle_grid_goal_mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import getopt
import sys
from collections.abc import Iterable
from collections.abc import Sequence
from enum import IntEnum
from math import sqrt
from typing import NamedTuple, Optional
Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(self, state=None, num_cols=0, num_rows=0):
self.num_cols = num_cols
self.num_rows = num_rows

def get_elements(self) -> Iterable[int]:
def get_elements(self) -> Sequence[int]:
if self.state is None:
return [a for a in MyActions]
else:
Expand Down
10 changes: 5 additions & 5 deletions skdecide/builders/domain/scheduling/scheduling_domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

import random
from collections.abc import Iterable
from collections.abc import Sequence
from enum import Enum
from itertools import product
from typing import Optional
Expand Down Expand Up @@ -1219,7 +1219,7 @@ def __init__(self, domain: SchedulingDomain, state: State):
self.state = state
self.elements = self._get_elements()

def _get_elements(self) -> Iterable[T]:
def _get_elements(self) -> Sequence[T]:
choices = [
SchedulingActionEnum.START,
SchedulingActionEnum.PAUSE,
Expand Down Expand Up @@ -1282,7 +1282,7 @@ def _get_elements(self) -> Iterable[T]:
)
return list_action

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
return self.elements

def sample(self) -> T:
Expand All @@ -1299,7 +1299,7 @@ def __init__(self, domain: SchedulingDomain, state: State):
self.state = state
self.elements = self._get_elements()

def _get_elements(self) -> Iterable[T]:
def _get_elements(self) -> Sequence[T]:
choices = [
SchedulingActionEnum.START,
SchedulingActionEnum.PAUSE,
Expand Down Expand Up @@ -1382,7 +1382,7 @@ def _get_elements(self) -> Iterable[T]:
)
return list_action

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
return self.elements

def sample(self) -> T:
Expand Down
4 changes: 2 additions & 2 deletions skdecide/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def contains(self, x: T) -> bool:
class EnumerableSpace(Space[T]):
"""A space which elements can be enumerated."""

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
"""Get the elements of this space.

# Returns
Expand All @@ -132,7 +132,7 @@ def contains(self, x: T) -> bool:
class EmptySpace(EnumerableSpace[T]):
"""An (enumerable) empty space."""

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
return ()


Expand Down
4 changes: 2 additions & 2 deletions skdecide/hub/domain/graph_domain/GraphDomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

import random
from collections.abc import Iterable
from collections.abc import Sequence
from typing import Optional

import networkx as nx
Expand Down Expand Up @@ -44,7 +44,7 @@ def contains(self, x: T) -> bool:
def __init__(self, l: list[object]):
self.l = l

def get_elements(self) -> Iterable[object]:
def get_elements(self) -> Sequence[object]:
return self.l


Expand Down
26 changes: 10 additions & 16 deletions skdecide/hub/space/gym/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ def __init__(self, n, element_class=int):
super().__init__(gym_space=gym_spaces.Discrete(n))
self._element_class = element_class

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
"""Get the elements of this space.

# Returns
The elements of this space.
"""
return np.array(list(range(self._gym_space.n)), dtype=np.int64)
return range(self._gym_space.n)

def to_unwrapped(self, sample_n: Iterable[T]) -> Iterable:
return (
Expand Down Expand Up @@ -118,16 +118,13 @@ def __init__(self, nvec, element_class=np.ndarray):
super().__init__(gym_space=gym_spaces.MultiDiscrete(nvec))
self._element_class = element_class

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
"""Get the elements of this space.

# Returns
The elements of this space.
"""
return np.array(
list(itertools.product(*[list(range(n)) for n in self._gym_space.nvec])),
dtype=np.int64,
)
return tuple(itertools.product(*(range(n) for n in self._gym_space.nvec)))

def to_unwrapped(self, sample_n: Iterable[T]) -> Iterable:
return (
Expand Down Expand Up @@ -155,16 +152,13 @@ def __init__(self, n, element_class=np.ndarray):
super().__init__(gym_space=gym_spaces.MultiBinary(n))
self._element_class = element_class

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
"""Get the elements of this space.

# Returns
The elements of this space.
"""
return np.array(
list(itertools.product(*[(1, 0) for _ in range(self._gym_space.n)])),
dtype=np.int8,
)
return tuple(itertools.product(*((1, 0) for _ in range(self._gym_space.n))))

def to_unwrapped(self, sample_n: Iterable[T]) -> Iterable:
return (
Expand Down Expand Up @@ -311,14 +305,14 @@ def __init__(self, enum_class: EnumMeta) -> None:
enum_class: The enumeration class for creating the Gym Discrete space (gym.spaces.Discrete) to wrap.
"""
self._enum_class = enum_class
self._list_enum = list(enum_class)
self._list_enum = tuple(enum_class)
gym_space = gym_spaces.Discrete(len(enum_class))
super().__init__(gym_space)

def contains(self, x: T) -> bool:
return isinstance(x, self._enum_class)

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
return self._list_enum

def sample(self) -> T:
Expand Down Expand Up @@ -367,7 +361,7 @@ def __init__(self, elements: Iterable[T]) -> None:
def contains(self, x: T) -> bool:
return x in self._elements

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
return self._elements

def sample(self) -> T:
Expand Down Expand Up @@ -417,7 +411,7 @@ def __init__(self, elements: Iterable[T]) -> None:
def contains(self, x: T) -> bool:
return x in self._elements

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
return self._elements

def sample(self) -> T:
Expand Down
Loading