Skip to content
This repository has been archived by the owner on Oct 28, 2024. It is now read-only.

Commit

Permalink
Merge pull request #9 from allmonday/dev
Browse files Browse the repository at this point in the history
add parent.
  • Loading branch information
allmonday authored Apr 20, 2024
2 parents 89b3a6a + ef89e58 commit 5ad1454
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 16 deletions.
3 changes: 2 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Changelog

# pydantic2-resolve
## v.2.1.1 (2024.4.20)
- new feature: reading parent node in resolve, post and post_default_handler

## v.2.1.0 (2024.04.08)
- bugfix: https://github.com/allmonday/pydantic2-resolve/issues/7
Expand Down
12 changes: 12 additions & 0 deletions pydantic_resolve/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def _scan_resolve_method(method, field: str):
result = {
'trim_field': field.replace(const.RESOLVE_PREFIX, ''),
'context': False,
'parent': False,
'ancestor_context': False,
'dataloaders': [] # collect func or class, do not create instance
}
Expand All @@ -91,6 +92,9 @@ def _scan_resolve_method(method, field: str):
if signature.parameters.get('ancestor_context'):
result['ancestor_context'] = True

if signature.parameters.get('parent'):
result['parent'] = True

for name, param in signature.parameters.items():
if isinstance(param.default, Depends):
info = {
Expand All @@ -110,6 +114,7 @@ def _scan_post_method(method, field):
result = {
'trim_field': field.replace(const.POST_PREFIX, ''),
'context': False,
'parent': False,
'ancestor_context': False,
'collectors': []
}
Expand All @@ -120,6 +125,9 @@ def _scan_post_method(method, field):

if signature.parameters.get('ancestor_context'):
result['ancestor_context'] = True

if signature.parameters.get('parent'):
result['parent'] = True

for name, param in signature.parameters.items():
if isinstance(param.default, ICollector):
Expand All @@ -136,6 +144,7 @@ def _scan_post_method(method, field):
def _scan_post_default_handler(method):
result = {
'context': False,
'parent': False,
'ancestor_context': False,
}
signature = inspect.signature(method)
Expand All @@ -146,6 +155,9 @@ def _scan_post_default_handler(method):
if signature.parameters.get('ancestor_context'):
result['ancestor_context'] = True

if signature.parameters.get('parent'):
result['parent'] = True

return result


Expand Down
31 changes: 21 additions & 10 deletions pydantic_resolve/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(

self.ancestor_vars = {}
self.collector_contextvars = {}
self.parent_contextvars = {}

# for dataloader which has class attributes, you can assign the value at here
if loader_filters:
Expand Down Expand Up @@ -83,6 +84,11 @@ def _add_values_into_collectors(self, target, kls):
val = getattr(target, field)
collector.add(val)

def _add_parent(self, target):
if not self.parent_contextvars.get('parent'):
self.parent_contextvars['parent'] = contextvars.ContextVar('parent')
self.parent_contextvars['parent'].set(target)

def _add_expose_fields(self, target):
expose_dict: Optional[dict] = getattr(target, const.EXPOSE_TO_DESCENDANT, None)
if expose_dict:
Expand All @@ -107,6 +113,8 @@ def _execute_resolver_method(self, kls, field, method):
params['context'] = self.context
if resolve_param['ancestor_context']:
params['ancestor_context'] = self._prepare_ancestor_context()
if resolve_param['parent']:
params['parent'] = self.parent_contextvars['parent'].get()

for loader in resolve_param['dataloaders']:
cache_key = loader['path']
Expand All @@ -120,9 +128,10 @@ def _execute_post_method(self, target, kls, kls_path, post_field, method):
post_param = core.get_post_params(kls, post_field , self.metadata)
if post_param['context']:
params['context'] = self.context

if post_param['ancestor_context']:
params['ancestor_context'] = self._prepare_ancestor_context()
if post_param['parent']:
params['parent'] = self.parent_contextvars['parent'].get()

alias_map = self.object_collect_alias_map_store.get(id(target), {})
if alias_map:
Expand All @@ -135,13 +144,14 @@ def _execute_post_method(self, target, kls, kls_path, post_field, method):

def _execute_post_default_handler(self, kls, method):
params = {}
resolve_param = core.get_post_default_handler_params(kls, self.metadata)
post_default_param = core.get_post_default_handler_params(kls, self.metadata)

if resolve_param['context']:
if post_default_param['context']:
params['context'] = self.context

if resolve_param['ancestor_context']:
if post_default_param['ancestor_context']:
params['ancestor_context'] = self._prepare_ancestor_context()
if post_default_param['parent']:
params['parent'] = self.parent_contextvars['parent'].get()

ret_val = method(**params)
return ret_val
Expand All @@ -159,20 +169,21 @@ async def _resolve_obj_field(self, target, kls, field, trim_field, method):
val = util.try_parse_data_to_target_field_type(target, trim_field, val)

# continue dive deeper
val = await self._resolve(val)
val = await self._resolve(val, target)

setattr(target, trim_field, val)

async def _resolve(self, target: T) -> T:
async def _resolve(self, target: T, parent) -> T:
if isinstance(target, (list, tuple)):
await asyncio.gather(*[self._resolve(t) for t in target])
await asyncio.gather(*[self._resolve(t, parent) for t in target])

if core.is_acceptable_instance(target):
kls = target.__class__
kls_path = util.get_kls_full_path(kls)

self._prepare_collectors(target, kls)
self._add_expose_fields(target)
self._add_parent(parent)

tasks = []

Expand All @@ -181,7 +192,7 @@ async def _resolve(self, target: T) -> T:
for field, resolve_trim_field, method in resolve_list:
tasks.append(self._resolve_obj_field(target, kls, field, resolve_trim_field, method))
for field, attr_object in attribute_list:
tasks.append(self._resolve(attr_object))
tasks.append(self._resolve(attr_object, target))
await asyncio.gather(*tasks)

# reverse traversal and run post methods
Expand Down Expand Up @@ -217,5 +228,5 @@ async def resolve(self, target: T) -> T:
if has_context and self.context is None:
raise AttributeError('context is missing')

await self._resolve(target)
await self._resolve(target, None)
return target
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pydantic2-resolve"
version = "2.1.0"
version = "2.1.1"
description = "create nested data structure easily"
authors = ["tangkikodo <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 2 additions & 0 deletions tests/core/test_field_dataclass_anno.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test_resolve_params():
'trim_field': 'name',
'context': True,
'ancestor_context': True,
'parent': False,
'dataloaders': [
{
'param': 'loader',
Expand All @@ -99,6 +100,7 @@ def test_resolve_params():
'trim_field': 'zeta',
'context': False,
'ancestor_context': False,
'parent': False,
'dataloaders': [],
}
},
Expand Down
2 changes: 2 additions & 0 deletions tests/core/test_field_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def test_resolve_params():
'trim_field': 'name',
'context': True,
'ancestor_context': True,
'parent': False,
'dataloaders': [
{
'param': 'loader',
Expand All @@ -151,6 +152,7 @@ def test_resolve_params():
'trim_field': 'zeta',
'context': False,
'ancestor_context': False,
'parent': False,
'dataloaders': [],
}
},
Expand Down
7 changes: 5 additions & 2 deletions tests/core/test_scan_post_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ def post_a(self):
'trim_field': 'a',
'context': False,
'ancestor_context': False,
'parent': False,
'collectors': []
}


def test_scan_post_method_2():
class A(BaseModel):
a: str
def post_a(self, context, ancestor_context):
def post_a(self, context, ancestor_context, parent):
return 2 * self.a

result = _scan_post_method(A.post_a, 'post_a')
Expand All @@ -30,6 +31,7 @@ def post_a(self, context, ancestor_context):
'trim_field': 'a',
'context': True,
'ancestor_context': True,
'parent': True,
'collectors': []
}

Expand All @@ -55,11 +57,12 @@ class A(BaseModel):
def post_a(self, context, ancestor_context, collector=Collector(alias='c_name')):
return 2 * self.a

def post_default_handler(self, context, ancestor_context):
def post_default_handler(self, context, ancestor_context, parent):
return 1

result = _scan_post_default_handler(A.post_default_handler)

assert result['context'] == True
assert result['parent'] == True
assert result['ancestor_context'] == True

7 changes: 5 additions & 2 deletions tests/core/test_scan_resolve_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ def resolve_a(self):
'trim_field': 'a',
'context': False,
'ancestor_context': False,
'parent': False,
'dataloaders': []
}


def test_scan_resolve_method_2():
class A(BaseModel):
a: str
def resolve_a(self, context, ancestor_context):
def resolve_a(self, context, ancestor_context, parent):
return 2 * self.a

result = _scan_resolve_method(A.resolve_a, 'resolve_a')
Expand All @@ -31,6 +32,7 @@ def resolve_a(self, context, ancestor_context):
'trim_field': 'a',
'context': True,
'ancestor_context': True,
'parent': True,
'dataloaders': []
}

Expand All @@ -42,7 +44,7 @@ async def batch_loader_fn(self, keys):

class A(BaseModel):
a: str
def resolve_a(self, context, ancestor_context, loader=LoaderDepend(Loader)):
def resolve_a(self, context, ancestor_context, parent, loader=LoaderDepend(Loader)):
return 2 * self.a

result = _scan_resolve_method(A.resolve_a, 'resolve_a')
Expand All @@ -51,6 +53,7 @@ def resolve_a(self, context, ancestor_context, loader=LoaderDepend(Loader)):
'trim_field': 'a',
'context': True,
'ancestor_context': True,
'parent': True,
'dataloaders': [
{
'param': 'loader',
Expand Down
76 changes: 76 additions & 0 deletions tests/resolver/test_39_parent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from __future__ import annotations
import pytest
from typing import Optional, List
from pydantic import BaseModel
from pydantic_resolve import Resolver


class Base(BaseModel):
name: str

child: Optional[Child] = None
def resolve_child(self):
return Child()

children: List[Child] = []
def resolve_children(self):
return [Child()]

parent: Optional[str] = '123'
def resolve_parent(self, parent):
return parent

class Child(BaseModel):
pname: str = ''
def resolve_pname(self, parent: Base):
return parent.name

pname2: str = ''
def resolve_pname2(self, parent: Base):
return parent.name


@pytest.mark.asyncio
async def test_parent():
b = Base(name='kikodo')
b = await Resolver().resolve(b)
assert b.parent is None # parent of root is none

assert b.name == 'kikodo'
assert b.child.pname == 'kikodo'
assert b.child.pname2 == 'kikodo' # work with obj

assert b.children[0].pname == 'kikodo'
assert b.children[0].pname2 == 'kikodo' # work with list


class Tree(BaseModel):
name: str

path: str = ''
def resolve_path(self, parent):
if parent is not None:
return f'{parent.path}/{self.name}'
return self.name
children: List[Tree] = []

@pytest.mark.asyncio
async def test_tree():
data = dict(name="a", children=[
dict(name="b", children=[
dict(name="c")
]),
dict(name="d", children=[
dict(name="c")
])
])
data = await Resolver().resolve(Tree(**data))
assert data.model_dump() == dict(name="a", path="a", children=[
dict(name="b", path="a/b", children=[
dict(name="c", path="a/b/c", children=[])
]),
dict(name="d", path="a/d", children=[
dict(name="c", path="a/d/c", children=[])
])
])
print(data.model_dump_json(indent=2))

0 comments on commit 5ad1454

Please sign in to comment.