Skip to content

Commit

Permalink
revert to v1 api
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorrent committed Jun 10, 2024
1 parent bc5a52f commit 32a1c68
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 24 deletions.
2 changes: 1 addition & 1 deletion examples/albert/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, List, Tuple

from pydantic import BaseModel, StrictFloat, confloat, conint
from pydantic.v1 import BaseModel, StrictFloat, confloat, conint

from hivemind.dht.crypto import RSASignatureValidator
from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
Expand Down
31 changes: 11 additions & 20 deletions hivemind/dht/schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import re
from typing import Any, Dict, Optional, Type

import pydantic
from pydantic.fields import FieldInfo
from pydantic_core import CoreSchema, core_schema
import pydantic.v1 as pydantic

from hivemind.dht.crypto import RSASignatureValidator
from hivemind.dht.protocol import DHTProtocol
Expand Down Expand Up @@ -43,19 +41,18 @@ def __init__(self, schema: Type[pydantic.BaseModel], allow_extra_keys: bool = Tr
self._schemas = [schema]

self._key_id_to_field_name = {}
for field_name in schema.__fields__.keys():
raw_key = f"{prefix}_{field_name}" if prefix is not None else field_name
self._key_id_to_field_name[DHTID.generate(source=raw_key).to_bytes()] = field_name
for field in schema.__fields__.values():
raw_key = f"{prefix}_{field.name}" if prefix is not None else field.name
self._key_id_to_field_name[DHTID.generate(source=raw_key).to_bytes()] = field.name
self._allow_extra_keys = allow_extra_keys

@staticmethod
def _patch_schema(schema: pydantic.BaseModel):
# We set required=False because the validate() interface provides only one key at a time
for field_name, field_info in schema.__fields__.items():
field_info = pydantic.Field(default=None)
field_info = pydantic.Field(is_required=False)
for field in schema.__fields__.values():
field.required = False

schema.model_config.update({"extra": pydantic.Extra.forbid})
schema.Config.extra = pydantic.Extra.forbid

def validate(self, record: DHTRecord) -> bool:
"""
Expand Down Expand Up @@ -98,7 +95,7 @@ def validate(self, record: DHTRecord) -> bool:
validation_errors = []
for schema in self._schemas:
try:
parsed_record = schema.model_validate(record)
parsed_record = schema.parse_obj(record)
except pydantic.ValidationError as e:
if not self._is_failed_due_to_extra_field(e):
validation_errors.append(e)
Expand Down Expand Up @@ -166,23 +163,17 @@ def conbytes(*, regex: bytes = None, **kwargs) -> Type[pydantic.BaseModel]:
compiled_regex = re.compile(regex) if regex is not None else None

class ConstrainedBytesWithRegex(pydantic.conbytes(**kwargs)):
value: bytes = pydantic.Field(**kwargs)

@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: pydantic.GetCoreSchemaHandler) -> CoreSchema:
schema = handler.generate_schema(bytes)
return core_schema.no_info_after_validator_function(cls.match_regex, schema)
def __get_validators__(cls):
yield from super().__get_validators__()
yield cls.match_regex

@classmethod
def match_regex(cls, value: bytes) -> bytes:
if compiled_regex is not None and compiled_regex.match(value) is None:
raise ValueError(f"Value `{value}` doesn't match regex `{regex}`")
return value

# @classmethod
# def __get_pydantic_config__(cls) -> pydantic.config.ConfigDict:
# return pydantic.config.ConfigDict(arbitrary_types_allowed=True)

return ConstrainedBytesWithRegex


Expand Down
2 changes: 1 addition & 1 deletion hivemind/optim/progress_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Dict, Optional

import numpy as np
from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
from pydantic.v1 import BaseModel, StrictBool, StrictFloat, confloat, conint

from hivemind.dht import DHT
from hivemind.dht.schema import BytesWithPublicKey, RSASignatureValidator, SchemaValidator
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dht_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Dict

import pytest
from pydantic import BaseModel, StrictInt, conint
from pydantic.v1 import BaseModel, StrictInt, conint

import hivemind
from hivemind.dht.node import DHTNode
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dht_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Dict

import pytest
from pydantic import BaseModel, StrictInt
from pydantic.v1 import BaseModel, StrictInt

import hivemind
from hivemind.dht.crypto import RSASignatureValidator
Expand Down

0 comments on commit 32a1c68

Please sign in to comment.