diff --git a/src/jaxsim/parsers/descriptions/collision.py b/src/jaxsim/parsers/descriptions/collision.py index 7b4f8d184..289ab70be 100644 --- a/src/jaxsim/parsers/descriptions/collision.py +++ b/src/jaxsim/parsers/descriptions/collision.py @@ -50,6 +50,14 @@ def change_link( enabled=self.enabled, ) + def __eq__(self, other): + retval = ( + self.parent_link == other.parent_link + and (self.position == other.position).all() + and self.enabled == other.enabled + ) + return retval + def __str__(self): return ( f"{self.__class__.__name__}(" @@ -93,6 +101,9 @@ class BoxCollision(CollisionShape): center: npt.NDArray + def __eq__(self, other): + return (self.center == other.center).all() and super().__eq__(other) + @dataclasses.dataclass class SphereCollision(CollisionShape): @@ -105,3 +116,6 @@ class SphereCollision(CollisionShape): """ center: npt.NDArray + + def __eq__(self, other): + return (self.center == other.center).all() and super().__eq__(other) diff --git a/src/jaxsim/parsers/descriptions/link.py b/src/jaxsim/parsers/descriptions/link.py index f5be265b2..d2912d7f0 100644 --- a/src/jaxsim/parsers/descriptions/link.py +++ b/src/jaxsim/parsers/descriptions/link.py @@ -38,6 +38,17 @@ class LinkDescription(JaxsimDataclass): def __hash__(self) -> int: return hash(self.__repr__()) + def __eq__(self, other) -> bool: + return ( + self.name == other.name + and self.mass == other.mass + and (self.inertia == other.inertia).all() + and self.index == other.index + and self.parent == other.parent + and (self.pose == other.pose).all() + and self.children == other.children + ) + @property def name_and_index(self) -> str: """ diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 9a4f8a43e..7b3021f47 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -34,6 +34,11 @@ class RootPose(NamedTuple): root_position: npt.NDArray = np.zeros(3) root_quaternion: npt.NDArray = np.array([1.0, 0, 0, 0]) + def __eq__(self, other): + return (self.root_position == other.root_position).all() and ( + self.root_quaternion == other.root_quaternion + ).all() + @dataclasses.dataclass(frozen=True) class KinematicGraph: