Skip to content

Commit

Permalink
Add explicit __eq__ operators for classes that have ndarray attribute…
Browse files Browse the repository at this point in the history
…s and could be used as Static jax_dataclasses attributes
  • Loading branch information
traversaro committed Mar 11, 2024
1 parent 43cb39c commit 90f861c
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/jaxsim/parsers/descriptions/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def change_link(
enabled=self.enabled,
)

def __eq__(self, other):
retval = (
self.parent_link == other.parent_link
and (self.position == other.position).all()
and self.enabled == other.enabled
)
return retval

def __str__(self):
return (
f"{self.__class__.__name__}("
Expand Down Expand Up @@ -93,6 +101,9 @@ class BoxCollision(CollisionShape):

center: npt.NDArray

def __eq__(self, other):
return (self.center == other.center).all() and super().__eq__(other)


@dataclasses.dataclass
class SphereCollision(CollisionShape):
Expand All @@ -105,3 +116,6 @@ class SphereCollision(CollisionShape):
"""

center: npt.NDArray

def __eq__(self, other):
return (self.center == other.center).all() and super().__eq__(other)
11 changes: 11 additions & 0 deletions src/jaxsim/parsers/descriptions/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ class LinkDescription(JaxsimDataclass):
def __hash__(self) -> int:
return hash(self.__repr__())

def __eq__(self, other) -> bool:
return (
self.name == other.name
and self.mass == other.mass
and (self.inertia == other.inertia).all()
and self.index == other.index
and self.parent == other.parent
and (self.pose == other.pose).all()
and self.children == other.children
)

@property
def name_and_index(self) -> str:
"""
Expand Down
5 changes: 5 additions & 0 deletions src/jaxsim/parsers/kinematic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class RootPose(NamedTuple):
root_position: npt.NDArray = np.zeros(3)
root_quaternion: npt.NDArray = np.array([1.0, 0, 0, 0])

def __eq__(self, other):
return (self.root_position == other.root_position).all() and (
self.root_quaternion == other.root_quaternion
).all()


@dataclasses.dataclass(frozen=True)
class KinematicGraph:
Expand Down

0 comments on commit 90f861c

Please sign in to comment.