Skip to content

Commit

Permalink
fix validators, and pass tests locally in Docker
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorrent committed Jun 7, 2024
1 parent b254425 commit 945176a
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 29 deletions.
10 changes: 10 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
services:
lab:
command: pytest -s -vv /home/hivemind/tests/test_optimizer.py
tty: true
stdin_open: true
build: .
volumes:
- ./:/home/hivemind
environment:
HIVEMIND_LOGLEVEL: DEBUG
8 changes: 4 additions & 4 deletions examples/albert/utils.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
from typing import Dict, List, Tuple

from pydantic import StrictFloat, confloat, conint
from pydantic import BaseModel, StrictFloat, confloat, conint

from hivemind.dht.crypto import RSASignatureValidator
from hivemind.dht.schema import BytesWithPublicKey, ExtendedBaseModel, SchemaValidator
from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
from hivemind.dht.validation import RecordValidatorBase
from hivemind.utils.logging import get_logger

logger = get_logger(__name__)


class LocalMetrics(ExtendedBaseModel):
class LocalMetrics(BaseModel):
step: conint(ge=0, strict=True)
samples_per_second: confloat(ge=0.0, strict=True)
samples_accumulated: conint(ge=0, strict=True)
loss: StrictFloat
mini_steps: conint(ge=0, strict=True)


class MetricSchema(ExtendedBaseModel):
class MetricSchema(BaseModel):
metrics: Dict[BytesWithPublicKey, LocalMetrics]


Expand Down
24 changes: 15 additions & 9 deletions hivemind/dht/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, Optional, Type

import pydantic
from pydantic.fields import FieldInfo
from pydantic_core import CoreSchema, core_schema

from hivemind.dht.crypto import RSASignatureValidator
Expand All @@ -13,18 +14,13 @@
logger = get_logger(__name__)


class ExtendedBaseModel(pydantic.BaseModel):
class Config:
arbitrary_types_allowed = True


class SchemaValidator(RecordValidatorBase):
"""
Restricts specified DHT keys to match a Pydantic schema.
This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
"""

def __init__(self, schema: Type[ExtendedBaseModel], allow_extra_keys: bool = True, prefix: Optional[str] = None):
def __init__(self, schema: Type[pydantic.BaseModel], allow_extra_keys: bool = True, prefix: Optional[str] = None):
"""
:param schema: The Pydantic model (a subclass of pydantic.BaseModel).
Expand Down Expand Up @@ -54,7 +50,7 @@ def __init__(self, schema: Type[ExtendedBaseModel], allow_extra_keys: bool = Tru

# the 'required' property was changed to 'is_required' in 2.0, and also made read-only
# @staticmethod
# def _patch_schema(schema: ExtendedBaseModel):
# def _patch_schema(schema: pydantic.BaseModel):
# # We set required=False because the validate() interface provides only one key at a time
# for field in schema.__fields__.values():
# field.required = False
Expand Down Expand Up @@ -162,21 +158,31 @@ def __setstate__(self, state):
# self._patch_schema(schema)


def conbytes(*, regex: bytes = None, **kwargs) -> Type[ExtendedBaseModel]:
def conbytes(*, regex: bytes = None, **kwargs) -> Type[pydantic.BaseModel]:
"""
Extend pydantic.conbytes() to support ``regex`` constraints (like pydantic.constr() does).
"""

compiled_regex = re.compile(regex) if regex is not None else None

class ConstrainedBytesWithRegex(pydantic.conbytes(**kwargs)):
value: bytes = pydantic.Field(**kwargs)

@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: pydantic.GetCoreSchemaHandler) -> CoreSchema:
schema = handler(bytes)
return core_schema.no_info_after_validator_function(cls.match_regex, schema)

@classmethod
@pydantic.validator("*")
def match_regex(cls, value: bytes) -> bytes:
if compiled_regex is not None and compiled_regex.match(value) is None:
raise ValueError(f"Value `{value}` doesn't match regex `{regex}`")
return value

@classmethod
def __get_pydantic_config__(cls) -> pydantic.config.ConfigDict:
return pydantic.config.ConfigDict(arbitrary_types_allowed=True)

return ConstrainedBytesWithRegex


Expand Down
8 changes: 4 additions & 4 deletions hivemind/optim/progress_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from typing import Dict, Optional

import numpy as np
from pydantic import StrictBool, StrictFloat, confloat, conint
from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint

from hivemind.dht import DHT
from hivemind.dht.schema import BytesWithPublicKey, ExtendedBaseModel, RSASignatureValidator, SchemaValidator
from hivemind.dht.schema import BytesWithPublicKey, RSASignatureValidator, SchemaValidator
from hivemind.utils import DHTExpiration, ValueWithExpiration, enter_asynchronously, get_dht_time, get_logger
from hivemind.utils.crypto import RSAPrivateKey
from hivemind.utils.performance_ema import PerformanceEMA
Expand All @@ -28,7 +28,7 @@ class GlobalTrainingProgress:
next_fetch_time: float


class LocalTrainingProgress(ExtendedBaseModel):
class LocalTrainingProgress(BaseModel):
peer_id: bytes
epoch: conint(ge=0, strict=True)
samples_accumulated: conint(ge=0, strict=True)
Expand All @@ -37,7 +37,7 @@ class LocalTrainingProgress(ExtendedBaseModel):
client_mode: StrictBool


class TrainingProgressSchema(ExtendedBaseModel):
class TrainingProgressSchema(BaseModel):
progress: Dict[BytesWithPublicKey, Optional[LocalTrainingProgress]]


Expand Down
16 changes: 8 additions & 8 deletions tests/test_dht_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
from typing import Dict

import pytest
from pydantic import StrictInt, conint
from pydantic import BaseModel, StrictInt, conint

import hivemind
from hivemind.dht.node import DHTNode
from hivemind.dht.schema import BytesWithPublicKey, ExtendedBaseModel, SchemaValidator
from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
from hivemind.dht.validation import DHTRecord, RecordValidatorBase
from hivemind.utils.timed_storage import get_dht_time


class SampleSchema(ExtendedBaseModel):
class SampleSchema(BaseModel):
experiment_name: bytes
n_batches: Dict[bytes, conint(ge=0, strict=True)]
signed_data: Dict[BytesWithPublicKey, bytes]
Expand Down Expand Up @@ -94,10 +94,10 @@ async def test_expecting_public_keys(dht_nodes_with_schema):
@pytest.mark.forked
@pytest.mark.asyncio
async def test_keys_outside_schema(dht_nodes_with_schema):
class Schema(ExtendedBaseModel):
class Schema(BaseModel):
some_field: StrictInt

class MergedSchema(ExtendedBaseModel):
class MergedSchema(BaseModel):
another_field: StrictInt

for allow_extra_keys in [False, True]:
Expand All @@ -121,7 +121,7 @@ class MergedSchema(ExtendedBaseModel):
@pytest.mark.forked
@pytest.mark.asyncio
async def test_prefix():
class Schema(ExtendedBaseModel):
class Schema(BaseModel):
field: StrictInt

validator = SchemaValidator(Schema, allow_extra_keys=False, prefix="prefix")
Expand Down Expand Up @@ -153,11 +153,11 @@ def validate(self, record: DHTRecord) -> bool:
# Can't merge with the validator of the different type
assert not alice.protocol.record_validator.merge_with(second_validator)

class SecondSchema(ExtendedBaseModel):
class SecondSchema(BaseModel):
some_field: StrictInt
another_field: str

class ThirdSchema(ExtendedBaseModel):
class ThirdSchema(BaseModel):
another_field: StrictInt # Allow it to be a StrictInt as well

for schema in [SecondSchema, ThirdSchema]:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_dht_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
from typing import Dict

import pytest
from pydantic import StrictInt
from pydantic import BaseModel, StrictInt

import hivemind
from hivemind.dht.crypto import RSASignatureValidator
from hivemind.dht.protocol import DHTProtocol
from hivemind.dht.routing import DHTID
from hivemind.dht.schema import BytesWithPublicKey, ExtendedBaseModel, SchemaValidator
from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
from hivemind.dht.validation import CompositeValidator, DHTRecord


class SchemaA(ExtendedBaseModel):
class SchemaA(BaseModel):
field_a: bytes


class SchemaB(ExtendedBaseModel):
class SchemaB(BaseModel):
field_b: Dict[BytesWithPublicKey, StrictInt]


Expand Down

0 comments on commit 945176a

Please sign in to comment.