diff --git a/changelog.md b/changelog.md index 6897b15..401b1d4 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,7 @@ # Changelog -# pydantic2-resolve +## v.2.1.1 (2024.4.20) +- new feature: reading parent node in resolve, post and post_default_handler ## v.2.1.0 (2024.04.08) - bugfix: https://github.com/allmonday/pydantic2-resolve/issues/7 diff --git a/pydantic_resolve/core.py b/pydantic_resolve/core.py index f06005a..bc2c8ac 100644 --- a/pydantic_resolve/core.py +++ b/pydantic_resolve/core.py @@ -80,6 +80,7 @@ def _scan_resolve_method(method, field: str): result = { 'trim_field': field.replace(const.RESOLVE_PREFIX, ''), 'context': False, + 'parent': False, 'ancestor_context': False, 'dataloaders': [] # collect func or class, do not create instance } @@ -91,6 +92,9 @@ def _scan_resolve_method(method, field: str): if signature.parameters.get('ancestor_context'): result['ancestor_context'] = True + if signature.parameters.get('parent'): + result['parent'] = True + for name, param in signature.parameters.items(): if isinstance(param.default, Depends): info = { @@ -110,6 +114,7 @@ def _scan_post_method(method, field): result = { 'trim_field': field.replace(const.POST_PREFIX, ''), 'context': False, + 'parent': False, 'ancestor_context': False, 'collectors': [] } @@ -120,6 +125,9 @@ def _scan_post_method(method, field): if signature.parameters.get('ancestor_context'): result['ancestor_context'] = True + + if signature.parameters.get('parent'): + result['parent'] = True for name, param in signature.parameters.items(): if isinstance(param.default, ICollector): @@ -136,6 +144,7 @@ def _scan_post_method(method, field): def _scan_post_default_handler(method): result = { 'context': False, + 'parent': False, 'ancestor_context': False, } signature = inspect.signature(method) @@ -146,6 +155,9 @@ def _scan_post_default_handler(method): if signature.parameters.get('ancestor_context'): result['ancestor_context'] = True + if signature.parameters.get('parent'): + result['parent'] = True + return result diff --git a/pydantic_resolve/resolver.py b/pydantic_resolve/resolver.py index 3197651..c4d16c2 100644 --- a/pydantic_resolve/resolver.py +++ b/pydantic_resolve/resolver.py @@ -28,6 +28,7 @@ def __init__( self.ancestor_vars = {} self.collector_contextvars = {} + self.parent_contextvars = {} # for dataloader which has class attributes, you can assign the value at here if loader_filters: @@ -83,6 +84,11 @@ def _add_values_into_collectors(self, target, kls): val = getattr(target, field) collector.add(val) + def _add_parent(self, target): + if not self.parent_contextvars.get('parent'): + self.parent_contextvars['parent'] = contextvars.ContextVar('parent') + self.parent_contextvars['parent'].set(target) + def _add_expose_fields(self, target): expose_dict: Optional[dict] = getattr(target, const.EXPOSE_TO_DESCENDANT, None) if expose_dict: @@ -107,6 +113,8 @@ def _execute_resolver_method(self, kls, field, method): params['context'] = self.context if resolve_param['ancestor_context']: params['ancestor_context'] = self._prepare_ancestor_context() + if resolve_param['parent']: + params['parent'] = self.parent_contextvars['parent'].get() for loader in resolve_param['dataloaders']: cache_key = loader['path'] @@ -120,9 +128,10 @@ def _execute_post_method(self, target, kls, kls_path, post_field, method): post_param = core.get_post_params(kls, post_field , self.metadata) if post_param['context']: params['context'] = self.context - if post_param['ancestor_context']: params['ancestor_context'] = self._prepare_ancestor_context() + if post_param['parent']: + params['parent'] = self.parent_contextvars['parent'].get() alias_map = self.object_collect_alias_map_store.get(id(target), {}) if alias_map: @@ -135,13 +144,14 @@ def _execute_post_method(self, target, kls, kls_path, post_field, method): def _execute_post_default_handler(self, kls, method): params = {} - resolve_param = core.get_post_default_handler_params(kls, self.metadata) + post_default_param = core.get_post_default_handler_params(kls, self.metadata) - if resolve_param['context']: + if post_default_param['context']: params['context'] = self.context - - if resolve_param['ancestor_context']: + if post_default_param['ancestor_context']: params['ancestor_context'] = self._prepare_ancestor_context() + if post_default_param['parent']: + params['parent'] = self.parent_contextvars['parent'].get() ret_val = method(**params) return ret_val @@ -159,13 +169,13 @@ async def _resolve_obj_field(self, target, kls, field, trim_field, method): val = util.try_parse_data_to_target_field_type(target, trim_field, val) # continue dive deeper - val = await self._resolve(val) + val = await self._resolve(val, target) setattr(target, trim_field, val) - async def _resolve(self, target: T) -> T: + async def _resolve(self, target: T, parent) -> T: if isinstance(target, (list, tuple)): - await asyncio.gather(*[self._resolve(t) for t in target]) + await asyncio.gather(*[self._resolve(t, parent) for t in target]) if core.is_acceptable_instance(target): kls = target.__class__ @@ -173,6 +183,7 @@ async def _resolve(self, target: T) -> T: self._prepare_collectors(target, kls) self._add_expose_fields(target) + self._add_parent(parent) tasks = [] @@ -181,7 +192,7 @@ async def _resolve(self, target: T) -> T: for field, resolve_trim_field, method in resolve_list: tasks.append(self._resolve_obj_field(target, kls, field, resolve_trim_field, method)) for field, attr_object in attribute_list: - tasks.append(self._resolve(attr_object)) + tasks.append(self._resolve(attr_object, target)) await asyncio.gather(*tasks) # reverse traversal and run post methods @@ -217,5 +228,5 @@ async def resolve(self, target: T) -> T: if has_context and self.context is None: raise AttributeError('context is missing') - await self._resolve(target) + await self._resolve(target, None) return target \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index c9c8f84..6583c13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pydantic2-resolve" -version = "2.1.0" +version = "2.1.1" description = "create nested data structure easily" authors = ["tangkikodo "] readme = "README.md" diff --git a/tests/core/test_field_dataclass_anno.py b/tests/core/test_field_dataclass_anno.py index e6f87b8..422efbc 100644 --- a/tests/core/test_field_dataclass_anno.py +++ b/tests/core/test_field_dataclass_anno.py @@ -87,6 +87,7 @@ def test_resolve_params(): 'trim_field': 'name', 'context': True, 'ancestor_context': True, + 'parent': False, 'dataloaders': [ { 'param': 'loader', @@ -99,6 +100,7 @@ def test_resolve_params(): 'trim_field': 'zeta', 'context': False, 'ancestor_context': False, + 'parent': False, 'dataloaders': [], } }, diff --git a/tests/core/test_field_pydantic.py b/tests/core/test_field_pydantic.py index 0fca686..67c532b 100644 --- a/tests/core/test_field_pydantic.py +++ b/tests/core/test_field_pydantic.py @@ -139,6 +139,7 @@ def test_resolve_params(): 'trim_field': 'name', 'context': True, 'ancestor_context': True, + 'parent': False, 'dataloaders': [ { 'param': 'loader', @@ -151,6 +152,7 @@ def test_resolve_params(): 'trim_field': 'zeta', 'context': False, 'ancestor_context': False, + 'parent': False, 'dataloaders': [], } }, diff --git a/tests/core/test_scan_post_method.py b/tests/core/test_scan_post_method.py index 4b4b785..9c8912f 100644 --- a/tests/core/test_scan_post_method.py +++ b/tests/core/test_scan_post_method.py @@ -14,6 +14,7 @@ def post_a(self): 'trim_field': 'a', 'context': False, 'ancestor_context': False, + 'parent': False, 'collectors': [] } @@ -21,7 +22,7 @@ def post_a(self): def test_scan_post_method_2(): class A(BaseModel): a: str - def post_a(self, context, ancestor_context): + def post_a(self, context, ancestor_context, parent): return 2 * self.a result = _scan_post_method(A.post_a, 'post_a') @@ -30,6 +31,7 @@ def post_a(self, context, ancestor_context): 'trim_field': 'a', 'context': True, 'ancestor_context': True, + 'parent': True, 'collectors': [] } @@ -55,11 +57,12 @@ class A(BaseModel): def post_a(self, context, ancestor_context, collector=Collector(alias='c_name')): return 2 * self.a - def post_default_handler(self, context, ancestor_context): + def post_default_handler(self, context, ancestor_context, parent): return 1 result = _scan_post_default_handler(A.post_default_handler) assert result['context'] == True + assert result['parent'] == True assert result['ancestor_context'] == True diff --git a/tests/core/test_scan_resolve_method.py b/tests/core/test_scan_resolve_method.py index 22b6ab5..c63ef43 100644 --- a/tests/core/test_scan_resolve_method.py +++ b/tests/core/test_scan_resolve_method.py @@ -15,6 +15,7 @@ def resolve_a(self): 'trim_field': 'a', 'context': False, 'ancestor_context': False, + 'parent': False, 'dataloaders': [] } @@ -22,7 +23,7 @@ def resolve_a(self): def test_scan_resolve_method_2(): class A(BaseModel): a: str - def resolve_a(self, context, ancestor_context): + def resolve_a(self, context, ancestor_context, parent): return 2 * self.a result = _scan_resolve_method(A.resolve_a, 'resolve_a') @@ -31,6 +32,7 @@ def resolve_a(self, context, ancestor_context): 'trim_field': 'a', 'context': True, 'ancestor_context': True, + 'parent': True, 'dataloaders': [] } @@ -42,7 +44,7 @@ async def batch_loader_fn(self, keys): class A(BaseModel): a: str - def resolve_a(self, context, ancestor_context, loader=LoaderDepend(Loader)): + def resolve_a(self, context, ancestor_context, parent, loader=LoaderDepend(Loader)): return 2 * self.a result = _scan_resolve_method(A.resolve_a, 'resolve_a') @@ -51,6 +53,7 @@ def resolve_a(self, context, ancestor_context, loader=LoaderDepend(Loader)): 'trim_field': 'a', 'context': True, 'ancestor_context': True, + 'parent': True, 'dataloaders': [ { 'param': 'loader', diff --git a/tests/resolver/test_39_parent.py b/tests/resolver/test_39_parent.py new file mode 100644 index 0000000..a135cf0 --- /dev/null +++ b/tests/resolver/test_39_parent.py @@ -0,0 +1,76 @@ +from __future__ import annotations +import pytest +from typing import Optional, List +from pydantic import BaseModel +from pydantic_resolve import Resolver + + +class Base(BaseModel): + name: str + + child: Optional[Child] = None + def resolve_child(self): + return Child() + + children: List[Child] = [] + def resolve_children(self): + return [Child()] + + parent: Optional[str] = '123' + def resolve_parent(self, parent): + return parent + +class Child(BaseModel): + pname: str = '' + def resolve_pname(self, parent: Base): + return parent.name + + pname2: str = '' + def resolve_pname2(self, parent: Base): + return parent.name + + +@pytest.mark.asyncio +async def test_parent(): + b = Base(name='kikodo') + b = await Resolver().resolve(b) + assert b.parent is None # parent of root is none + + assert b.name == 'kikodo' + assert b.child.pname == 'kikodo' + assert b.child.pname2 == 'kikodo' # work with obj + + assert b.children[0].pname == 'kikodo' + assert b.children[0].pname2 == 'kikodo' # work with list + + +class Tree(BaseModel): + name: str + + path: str = '' + def resolve_path(self, parent): + if parent is not None: + return f'{parent.path}/{self.name}' + return self.name + children: List[Tree] = [] + +@pytest.mark.asyncio +async def test_tree(): + data = dict(name="a", children=[ + dict(name="b", children=[ + dict(name="c") + ]), + dict(name="d", children=[ + dict(name="c") + ]) + ]) + data = await Resolver().resolve(Tree(**data)) + assert data.model_dump() == dict(name="a", path="a", children=[ + dict(name="b", path="a/b", children=[ + dict(name="c", path="a/b/c", children=[]) + ]), + dict(name="d", path="a/d", children=[ + dict(name="c", path="a/d/c", children=[]) + ]) + ]) + print(data.model_dump_json(indent=2))