diff --git a/chia/_tests/conftest.py b/chia/_tests/conftest.py index 1037356b6d17..ffd20c017edc 100644 --- a/chia/_tests/conftest.py +++ b/chia/_tests/conftest.py @@ -14,7 +14,6 @@ import sysconfig import tempfile from contextlib import AsyncExitStack -from enum import IntEnum from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Tuple, Union import aiohttp @@ -29,7 +28,15 @@ from chia._tests.core.data_layer.util import ChiaRoot from chia._tests.core.node_height import node_height_at_least from chia._tests.simulation.test_simulation import test_constants_modified -from chia._tests.util.misc import BenchmarkRunner, GcMode, RecordingWebServer, TestId, _AssertRuntime, measure_overhead +from chia._tests.util.misc import ( + BenchmarkRunner, + ComparableEnum, + GcMode, + RecordingWebServer, + TestId, + _AssertRuntime, + measure_overhead, +) from chia._tests.util.setup_nodes import ( OldSimulatorsAndWallets, SimulatorsAndWallets, @@ -187,7 +194,7 @@ def get_keychain(): KeyringWrapper.cleanup_shared_instance() -class ConsensusMode(IntEnum): +class ConsensusMode(ComparableEnum): PLAIN = 0 SOFT_FORK_4 = 1 HARD_FORK_2_0 = 2 diff --git a/chia/_tests/util/misc.py b/chia/_tests/util/misc.py index 4b58fc6d6eb6..91d858f8faec 100644 --- a/chia/_tests/util/misc.py +++ b/chia/_tests/util/misc.py @@ -14,6 +14,7 @@ import sys from concurrent.futures import Future from dataclasses import dataclass, field +from enum import Enum from pathlib import Path from statistics import mean from textwrap import dedent @@ -637,3 +638,44 @@ class DataTypeProtocol(Protocol): def unmarshal(cls: Type[T], marshalled: Dict[str, Any]) -> T: ... def marshal(self) -> Dict[str, Any]: ... + + +T_ComparableEnum = TypeVar("T_ComparableEnum", bound="ComparableEnum") + + +class ComparableEnum(Enum): + def __lt__(self: T_ComparableEnum, other: T_ComparableEnum) -> object: + if self.__class__ is not other.__class__: + return NotImplemented + + return self.value.__lt__(other.value) + + def __le__(self: T_ComparableEnum, other: T_ComparableEnum) -> object: + if self.__class__ is not other.__class__: + return NotImplemented + + return self.value.__le__(other.value) + + def __eq__(self: T_ComparableEnum, other: object) -> bool: + if self.__class__ is not other.__class__: + return False + + return cast(bool, self.value.__eq__(cast(T_ComparableEnum, other).value)) + + def __ne__(self: T_ComparableEnum, other: object) -> bool: + if self.__class__ is not other.__class__: + return True + + return cast(bool, self.value.__ne__(cast(T_ComparableEnum, other).value)) + + def __gt__(self: T_ComparableEnum, other: T_ComparableEnum) -> object: + if self.__class__ is not other.__class__: + return NotImplemented + + return self.value.__gt__(other.value) + + def __ge__(self: T_ComparableEnum, other: T_ComparableEnum) -> object: + if self.__class__ is not other.__class__: + return NotImplemented + + return self.value.__ge__(other.value)