Skip to content

Commit

Permalink
test: add (failing) test coverage for field validators
Browse files Browse the repository at this point in the history
  • Loading branch information
lmmx committed Jan 24, 2024
1 parent 892f499 commit 9dab81b
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions projects/fal/tests/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import dill
import dill._dill as dill_serialization
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, field_validator, model_validator


def build_pydantic_model(
Expand Down Expand Up @@ -77,8 +77,9 @@ class Input(BaseModel):
validators (for the purpose of testing).
"""

prompt: str = Field()
prompt: str = ...
num_steps: int = Field(default=2, ge=1, le=10)
epochs: int = 10
validation_counter: int = 0

def steps_x2(self) -> int:
Expand All @@ -87,9 +88,15 @@ def steps_x2(self) -> int:
Computes double of the `num_steps` field value."""
return self.num_steps * 2

@field_validator("epochs")
@classmethod
def triple_epochs(cls, v: int) -> int:
"""A field validator that multiplies the validated field value by 10."""
return v * 3

@model_validator(mode="after")
def validate_num_steps(self) -> None:
"""A model post-validator."""
def increment(self) -> None:
"""A model post-validator that increments a counter."""
self.validation_counter += 100


Expand All @@ -108,9 +115,9 @@ def deserialise_pydantic_model():
"""
dill.settings["recurse"] = True
serialized_cls = dill.dumps(Input)
print("====== DESERIALIZING =====")
print("===== DESERIALIZING =====")
model_cls = dill.loads(serialized_cls)
print("======== RUNNING =====")
print("===== INSTANTIATING =====")
model = model_cls(prompt="a")
return model

Expand All @@ -122,7 +129,8 @@ def validate_deserialisation(model: Input) -> None:
assert prompt == "a", f"Prompt not retrieved: expected 'a' got {prompt!r}"
assert steps == 2, f"Steps not retrieved: expected 2 got {steps!r}"
assert steps_x2 == 4, f"Incorrect `steps_x2()`: expected 4 got {steps_x2}"
assert model.validation_counter == 100
assert model.epochs == 30, "The `validate_epochs` field validator didn't run"
assert model.validation_counter == 100, "The `increment` model validator didn't run"
return


Expand Down

0 comments on commit 9dab81b

Please sign in to comment.