-
Notifications
You must be signed in to change notification settings - Fork 204
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add serialization, base model, typing, and some other goodness
- Loading branch information
Showing
4 changed files
with
107 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from pydantic import BaseModel | ||
from typing import Optional | ||
import json | ||
from pathlib import Path | ||
from typing import Type, TypeVar | ||
|
||
T = TypeVar("T", bound="KilnBaseModel") | ||
|
||
|
||
class KilnBaseModel(BaseModel): | ||
version: int = 1 | ||
path: Optional[Path] = None | ||
|
||
@classmethod | ||
def load_from_file(cls: Type[T], path: Path) -> T: | ||
with open(path, "r") as file: | ||
m = cls.model_validate_json(file.read(), strict=True) | ||
m.path = path | ||
return m | ||
|
||
def save_to_file(self) -> None: | ||
if self.path is None: | ||
raise ValueError( | ||
f"Cannot save to file because 'path' is not set. Class: {self.__class__.__name__}, " | ||
f"id: {getattr(self, 'id', None)}, path: {self.path}" | ||
) | ||
data = self.model_dump(exclude={"path"}) | ||
with open(self.path, "w") as file: | ||
json.dump(data, file, indent=4) | ||
|
||
print(f"Project saved to {self.path}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,5 @@ | ||
from pydantic import BaseModel | ||
from typing import Optional | ||
from .basemodel import KilnBaseModel | ||
|
||
|
||
class KilnProject(BaseModel): | ||
version: int = 1 | ||
class KilnProject(KilnBaseModel): | ||
name: str | ||
path: Optional[str] = None | ||
|
||
def __init__(self, name: str, path: Optional[str] = None): | ||
# TODO: learn about pydantic init | ||
super().__init__(name=name, path=path) | ||
if path is not None and name is not None: | ||
# path and name are mutually exclusive for constructor, name comes from path if passed | ||
raise ValueError("path and name are mutually exclusive") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import json | ||
import pytest | ||
from kiln_ai.datamodel.basemodel import KilnBaseModel | ||
|
||
|
||
@pytest.fixture | ||
def test_file(tmp_path): | ||
test_file_path = tmp_path / "test_model.json" | ||
data = {"version": 1} | ||
|
||
with open(test_file_path, "w") as file: | ||
json.dump(data, file, indent=4) | ||
|
||
return test_file_path | ||
|
||
|
||
def test_load_from_file(test_file): | ||
model = KilnBaseModel.load_from_file(test_file) | ||
assert model.version == 1 | ||
assert model.path == test_file | ||
|
||
|
||
def test_save_to_file(test_file): | ||
model = KilnBaseModel(path=test_file) | ||
model.save_to_file() | ||
|
||
with open(test_file, "r") as file: | ||
data = json.load(file) | ||
|
||
assert data["version"] == 1 | ||
|
||
|
||
def test_save_to_file_without_path(): | ||
model = KilnBaseModel() | ||
with pytest.raises(ValueError): | ||
model.save_to_file() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import json | ||
import pytest | ||
from kiln_ai.datamodel.project import KilnProject | ||
|
||
|
||
@pytest.fixture | ||
def test_file(tmp_path): | ||
test_file_path = tmp_path / "test_project.json" | ||
data = {"version": 1, "name": "Test Project"} | ||
|
||
with open(test_file_path, "w") as file: | ||
json.dump(data, file, indent=4) | ||
|
||
return test_file_path | ||
|
||
|
||
def test_load_from_file(test_file): | ||
project = KilnProject.load_from_file(test_file) | ||
assert project.version == 1 | ||
assert project.name == "Test Project" | ||
assert project.path == test_file | ||
|
||
|
||
def test_save_to_file(test_file): | ||
project = KilnProject(name="Test Project", path=test_file) | ||
project.save_to_file() | ||
|
||
with open(test_file, "r") as file: | ||
data = json.load(file) | ||
|
||
assert data["version"] == 1 | ||
assert data["name"] == "Test Project" | ||
|
||
|
||
def test_save_to_file_without_path(): | ||
project = KilnProject(name="Test Project") | ||
with pytest.raises(ValueError): | ||
project.save_to_file() |