From 945176a80751887aeb8aa8a8c8e115cadb4abbe1 Mon Sep 17 00:00:00 2001 From: Luciferian Ink Date: Fri, 7 Jun 2024 07:52:10 -0500 Subject: [PATCH] fix validators, and pass tests locally in Docker --- docker-compose.yml | 10 ++++++++++ examples/albert/utils.py | 8 ++++---- hivemind/dht/schema.py | 24 +++++++++++++++--------- hivemind/optim/progress_tracker.py | 8 ++++---- tests/test_dht_schema.py | 16 ++++++++-------- tests/test_dht_validation.py | 8 ++++---- 6 files changed, 45 insertions(+), 29 deletions(-) create mode 100644 docker-compose.yml diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 000000000..cd832a628 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,10 @@ +services: + lab: + command: pytest -s -vv /home/hivemind/tests/test_optimizer.py + tty: true + stdin_open: true + build: . + volumes: + - ./:/home/hivemind + environment: + HIVEMIND_LOGLEVEL: DEBUG \ No newline at end of file diff --git a/examples/albert/utils.py b/examples/albert/utils.py index 3aa2cccf1..4d5d0a9b9 100644 --- a/examples/albert/utils.py +++ b/examples/albert/utils.py @@ -1,16 +1,16 @@ from typing import Dict, List, Tuple -from pydantic import StrictFloat, confloat, conint +from pydantic import BaseModel, StrictFloat, confloat, conint from hivemind.dht.crypto import RSASignatureValidator -from hivemind.dht.schema import BytesWithPublicKey, ExtendedBaseModel, SchemaValidator +from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator from hivemind.dht.validation import RecordValidatorBase from hivemind.utils.logging import get_logger logger = get_logger(__name__) -class LocalMetrics(ExtendedBaseModel): +class LocalMetrics(BaseModel): step: conint(ge=0, strict=True) samples_per_second: confloat(ge=0.0, strict=True) samples_accumulated: conint(ge=0, strict=True) @@ -18,7 +18,7 @@ class LocalMetrics(ExtendedBaseModel): mini_steps: conint(ge=0, strict=True) -class MetricSchema(ExtendedBaseModel): +class MetricSchema(BaseModel): metrics: Dict[BytesWithPublicKey, LocalMetrics] diff --git a/hivemind/dht/schema.py b/hivemind/dht/schema.py index 67849eea5..ac463f48c 100644 --- a/hivemind/dht/schema.py +++ b/hivemind/dht/schema.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Optional, Type import pydantic +from pydantic.fields import FieldInfo from pydantic_core import CoreSchema, core_schema from hivemind.dht.crypto import RSASignatureValidator @@ -13,18 +14,13 @@ logger = get_logger(__name__) -class ExtendedBaseModel(pydantic.BaseModel): - class Config: - arbitrary_types_allowed = True - - class SchemaValidator(RecordValidatorBase): """ Restricts specified DHT keys to match a Pydantic schema. This allows to enforce types, min/max values, require a subkey to contain a public key, etc. """ - def __init__(self, schema: Type[ExtendedBaseModel], allow_extra_keys: bool = True, prefix: Optional[str] = None): + def __init__(self, schema: Type[pydantic.BaseModel], allow_extra_keys: bool = True, prefix: Optional[str] = None): """ :param schema: The Pydantic model (a subclass of pydantic.BaseModel). @@ -54,7 +50,7 @@ def __init__(self, schema: Type[ExtendedBaseModel], allow_extra_keys: bool = Tru # the 'required' property was changed to 'is_required' in 2.0, and also made read-only # @staticmethod - # def _patch_schema(schema: ExtendedBaseModel): + # def _patch_schema(schema: pydantic.BaseModel): # # We set required=False because the validate() interface provides only one key at a time # for field in schema.__fields__.values(): # field.required = False @@ -162,7 +158,7 @@ def __setstate__(self, state): # self._patch_schema(schema) -def conbytes(*, regex: bytes = None, **kwargs) -> Type[ExtendedBaseModel]: +def conbytes(*, regex: bytes = None, **kwargs) -> Type[pydantic.BaseModel]: """ Extend pydantic.conbytes() to support ``regex`` constraints (like pydantic.constr() does). """ @@ -170,13 +166,23 @@ def conbytes(*, regex: bytes = None, **kwargs) -> Type[ExtendedBaseModel]: compiled_regex = re.compile(regex) if regex is not None else None class ConstrainedBytesWithRegex(pydantic.conbytes(**kwargs)): + value: bytes = pydantic.Field(**kwargs) + + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: pydantic.GetCoreSchemaHandler) -> CoreSchema: + schema = handler(bytes) + return core_schema.no_info_after_validator_function(cls.match_regex, schema) + @classmethod - @pydantic.validator("*") def match_regex(cls, value: bytes) -> bytes: if compiled_regex is not None and compiled_regex.match(value) is None: raise ValueError(f"Value `{value}` doesn't match regex `{regex}`") return value + @classmethod + def __get_pydantic_config__(cls) -> pydantic.config.ConfigDict: + return pydantic.config.ConfigDict(arbitrary_types_allowed=True) + return ConstrainedBytesWithRegex diff --git a/hivemind/optim/progress_tracker.py b/hivemind/optim/progress_tracker.py index 63987a983..9a6ff66e7 100644 --- a/hivemind/optim/progress_tracker.py +++ b/hivemind/optim/progress_tracker.py @@ -6,10 +6,10 @@ from typing import Dict, Optional import numpy as np -from pydantic import StrictBool, StrictFloat, confloat, conint +from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint from hivemind.dht import DHT -from hivemind.dht.schema import BytesWithPublicKey, ExtendedBaseModel, RSASignatureValidator, SchemaValidator +from hivemind.dht.schema import BytesWithPublicKey, RSASignatureValidator, SchemaValidator from hivemind.utils import DHTExpiration, ValueWithExpiration, enter_asynchronously, get_dht_time, get_logger from hivemind.utils.crypto import RSAPrivateKey from hivemind.utils.performance_ema import PerformanceEMA @@ -28,7 +28,7 @@ class GlobalTrainingProgress: next_fetch_time: float -class LocalTrainingProgress(ExtendedBaseModel): +class LocalTrainingProgress(BaseModel): peer_id: bytes epoch: conint(ge=0, strict=True) samples_accumulated: conint(ge=0, strict=True) @@ -37,7 +37,7 @@ class LocalTrainingProgress(ExtendedBaseModel): client_mode: StrictBool -class TrainingProgressSchema(ExtendedBaseModel): +class TrainingProgressSchema(BaseModel): progress: Dict[BytesWithPublicKey, Optional[LocalTrainingProgress]] diff --git a/tests/test_dht_schema.py b/tests/test_dht_schema.py index 74afd365d..8c9fd4aeb 100644 --- a/tests/test_dht_schema.py +++ b/tests/test_dht_schema.py @@ -2,16 +2,16 @@ from typing import Dict import pytest -from pydantic import StrictInt, conint +from pydantic import BaseModel, StrictInt, conint import hivemind from hivemind.dht.node import DHTNode -from hivemind.dht.schema import BytesWithPublicKey, ExtendedBaseModel, SchemaValidator +from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator from hivemind.dht.validation import DHTRecord, RecordValidatorBase from hivemind.utils.timed_storage import get_dht_time -class SampleSchema(ExtendedBaseModel): +class SampleSchema(BaseModel): experiment_name: bytes n_batches: Dict[bytes, conint(ge=0, strict=True)] signed_data: Dict[BytesWithPublicKey, bytes] @@ -94,10 +94,10 @@ async def test_expecting_public_keys(dht_nodes_with_schema): @pytest.mark.forked @pytest.mark.asyncio async def test_keys_outside_schema(dht_nodes_with_schema): - class Schema(ExtendedBaseModel): + class Schema(BaseModel): some_field: StrictInt - class MergedSchema(ExtendedBaseModel): + class MergedSchema(BaseModel): another_field: StrictInt for allow_extra_keys in [False, True]: @@ -121,7 +121,7 @@ class MergedSchema(ExtendedBaseModel): @pytest.mark.forked @pytest.mark.asyncio async def test_prefix(): - class Schema(ExtendedBaseModel): + class Schema(BaseModel): field: StrictInt validator = SchemaValidator(Schema, allow_extra_keys=False, prefix="prefix") @@ -153,11 +153,11 @@ def validate(self, record: DHTRecord) -> bool: # Can't merge with the validator of the different type assert not alice.protocol.record_validator.merge_with(second_validator) - class SecondSchema(ExtendedBaseModel): + class SecondSchema(BaseModel): some_field: StrictInt another_field: str - class ThirdSchema(ExtendedBaseModel): + class ThirdSchema(BaseModel): another_field: StrictInt # Allow it to be a StrictInt as well for schema in [SecondSchema, ThirdSchema]: diff --git a/tests/test_dht_validation.py b/tests/test_dht_validation.py index a0d77b5c4..56420d3d7 100644 --- a/tests/test_dht_validation.py +++ b/tests/test_dht_validation.py @@ -2,21 +2,21 @@ from typing import Dict import pytest -from pydantic import StrictInt +from pydantic import BaseModel, StrictInt import hivemind from hivemind.dht.crypto import RSASignatureValidator from hivemind.dht.protocol import DHTProtocol from hivemind.dht.routing import DHTID -from hivemind.dht.schema import BytesWithPublicKey, ExtendedBaseModel, SchemaValidator +from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator from hivemind.dht.validation import CompositeValidator, DHTRecord -class SchemaA(ExtendedBaseModel): +class SchemaA(BaseModel): field_a: bytes -class SchemaB(ExtendedBaseModel): +class SchemaB(BaseModel): field_b: Dict[BytesWithPublicKey, StrictInt]