Skip to content

Commit

Permalink
WIP: class based api. types
Browse files Browse the repository at this point in the history
  • Loading branch information
skovbasa committed Oct 19, 2024
1 parent 2bfd8c3 commit 7aa8559
Show file tree
Hide file tree
Showing 8 changed files with 703 additions and 19 deletions.
Empty file added hiku/classes/__init__.py
Empty file.
160 changes: 160 additions & 0 deletions hiku/classes/node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import dataclasses as dc
import typing
from collections.abc import Hashable
from functools import partial

import hiku.types
from hiku.classes.strings import to_snake_case
from hiku.directives import SchemaDirective

"""
@node(...)
class Human:
id: int = field(...)
droid: ref[Droid] = field_link(...)
"""

_T = typing.TypeVar("_T", bound=Hashable)


class NodeProto(typing.Protocol[_T]):
__key__: _T


_TNode = typing.TypeVar("_TNode", bound=NodeProto)


@dc.dataclass
class HikuNode:
name: str
fields: "list[_HikuField] | list[_HikuFieldLink] | list[_HikuField | _HikuFieldLink]"
description: str | None
directives: list[SchemaDirective] | None
implements: list[str] | None


def node(
cls: type[_TNode] | None = None,
*,
name: str | None = None,
description: str | None = None,
directives: list[SchemaDirective] | None = None,
# TODO(s.kovbasa): handle interfaces from mro
implements: list[str] | None = None,
) -> typing.Callable[[type[_TNode]], type[_TNode]] | type[_TNode]:
# TODO(s.kovbasa): add validation and stuff

def _wrap_cls(
cls: type[_TNode],
name: str | None,
description: str | None,
directives: list[SchemaDirective] | None,
implements: list[str] | None,
) -> type[_TNode]:
setattr(
cls,
"__hiku_node__",
HikuNode(
name=name or cls.__name__,
fields=_get_fields(cls),
description=description,
directives=directives,
implements=implements,
),
)
return cls

_do_wrap = partial(
_wrap_cls,
name=name,
description=description,
directives=directives,
implements=implements,
)

if cls is None:
return _do_wrap

return _do_wrap(cls)


def _get_fields(
cls: type[_TNode],
) -> "list[_HikuField] | list[_HikuFieldLink] | list[_HikuField | _HikuFieldLink]":
# TODO(s.kovbasa): handle name and type from annotations
# TODO(s.kovbasa): first process fields, then links; resolve link requires
return []


@dc.dataclass
class _HikuField:
func: typing.Callable
name: str | None
typ: type
options: object | None
description: str | None
deprecated: str | None
directives: typing.Sequence[SchemaDirective] | None


def field(
func: typing.Callable | None = None,
*,
options: object | None = None,
name: str | None = None,
description: str | None = None,
deprecated: str | None = None,
directives: list | None = None,
) -> typing.Any:
return _HikuField(
func=func or resolve_getattr,
name=name,
typ=None, # type: ignore
options=options,
description=description,
deprecated=deprecated,
directives=directives,
)


@dc.dataclass
class _HikuFieldLink:
func: typing.Callable
name: str | None
typ: type
requires_func: typing.Callable[[], tuple] | None
options: object | None
description: str | None
deprecated: str | None
directives: typing.Sequence[SchemaDirective] | None


def field_link(
func: typing.Callable | None = None,
*,
options: object | None = None,
requires: typing.Callable[[], tuple[typing.Any, ...]] | None,
name: str | None = None,
description: str | None = None,
deprecated: str | None = None,
directives: list | None = None,
) -> typing.Any:
return _HikuFieldLink(
func=func or direct_link,
name=name,
typ=None, # type: ignore
requires_func=requires,
options=options,
description=description,
deprecated=deprecated,
directives=directives,
)


def resolve_getattr(fields, tuples) -> list[list]:
field_names = [to_snake_case(f.name) for f in fields]
return [[getattr(t, f_name) for f_name in field_names] for t in tuples]


def direct_link(ids):
return ids
10 changes: 10 additions & 0 deletions hiku/classes/strings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import re

UPPER_CAMEL_CASE_BOUNDS_RE = re.compile(r"(.)([A-Z][a-z]+)")
LOWER_CAMEL_CASE_BOUNDS_RE = re.compile(r"([a-z0-9])([A-Z])")


# http://stackoverflow.com/a/1176023/1072990
def to_snake_case(name: str) -> str:
s1 = UPPER_CAMEL_CASE_BOUNDS_RE.sub(r"\1_\2", name)
return LOWER_CAMEL_CASE_BOUNDS_RE.sub(r"\1_\2", s1).lower()
202 changes: 202 additions & 0 deletions hiku/classes/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import dataclasses as dc
import importlib
import inspect
import types
import typing
from collections.abc import Hashable

import hiku.graph
import hiku.types

_T = typing.TypeVar("_T", bound=Hashable)


class NodeProto(typing.Protocol[_T]):
__key__: _T


_TNode = typing.TypeVar("_TNode", bound=NodeProto)


@dc.dataclass
class raw_type:
"""
Helps to update hiku types gradually. E.g.
id: typing.Annotated[str, hiku.raw_type(hiku.types.ID)]
some_field: typing.Annotated[None, hiku.raw_type(TypeRef["Product"])]
"""

typ: hiku.types.GenericMeta

def apply(
self,
container: hiku.types.OptionalMeta | hiku.types.SequenceMeta,
) -> typing.Self:
return dc.replace(self, typ=container[self.typ])

def __hash__(self) -> int:
return hash(self.typ)


class lazy:
"""
Allows for a lazy type resolve when circular imports are encountered.
Lazy resolvers are processed during Graph.__init__
"""

module: str
package: str | None

def __init__(self, module: str):
self.module = module
self.package = None

if module.startswith("."):
current_frame = inspect.currentframe()
assert current_frame is not None
assert current_frame.f_back is not None

self.package = current_frame.f_back.f_globals["__package__"]


class ref(typing.Generic[_TNode]):
"""Represents a reference to another object type.
Is needed in case we someday plan to implement proper mypy checks - this way
we can make use of ref object as a thin wrapper around type's __key__
"""


_BUILTINS_TO_HIKU = {
int: hiku.types.Integer,
float: hiku.types.Float,
str: hiku.types.String,
bool: hiku.types.Boolean,
}


@dc.dataclass
class _LazyTypeRef:
"""strawberry-like impl for lazy type refs"""

classname: str
module: str
package: str | None
containers: (
list[hiku.types.OptionalMeta | hiku.types.SequenceMeta] | None
) = None

@property
def typ(self) -> hiku.types.GenericMeta:
module = importlib.import_module(self.module, self.package)
cls = module.__dict__[self.classname]

type_ref = hiku.types.TypeRef[cls.__hiku_node__.name]

containers = reversed(self.containers or [])
for c in containers:
type_ref = c[type_ref]

return type_ref

def apply(
self,
container: hiku.types.OptionalMeta | hiku.types.SequenceMeta,
) -> typing.Self:
return dc.replace(
self,
containers=[container] + (self.containers or []),
)


class _HikuTypeWrapperProto(typing.Protocol):

@property
def typ(self) -> hiku.types.GenericMeta: ...

def apply(
self, container: hiku.types.OptionalMeta | hiku.types.SequenceMeta
) -> typing.Self: ...


def to_hiku_type(typ: type, lazy_: lazy | None = None) -> _HikuTypeWrapperProto:
if typ in _BUILTINS_TO_HIKU:
return raw_type(_BUILTINS_TO_HIKU[typ])

origin = typing.get_origin(typ)
args = typing.get_args(typ)

if origin is typing.Annotated:
metadata = typ.__metadata__

raw_types = []
lazy_refs = []
for val in metadata:
if isinstance(val, raw_type):
raw_types.append(val)
elif isinstance(val, lazy):
lazy_refs.append(val)

if lazy_refs and raw_types:
raise ValueError("lazy and raw_type are not composable")

if len(raw_types) > 1:
raise ValueError("more than 1 raw_type")

if len(raw_types) == 1:
return raw_types[0]

if len(lazy_refs) > 1:
raise ValueError("more than 1 lazy reference")

if len(lazy_refs) == 1:
lazy_typeref = to_hiku_type(typ.__origin__, lazy_refs[0])
if not isinstance(lazy_typeref, _LazyTypeRef):
raise ValueError("lazy can only be used with ref types")

return lazy_typeref

return to_hiku_type(args[0])

# new optionals
if origin in (typing.Union, types.UnionType):
if len(args) != 2 or types.NoneType not in args:
raise ValueError("unions are allowed only as optional types")

next_type = [a for a in args if a is not types.NoneType][0]
arg = to_hiku_type(next_type, lazy_)
return arg.apply(hiku.types.Optional)

# old optionals
if origin is typing.Optional:
arg = to_hiku_type(args[0], lazy_)
return arg.apply(hiku.types.Optional)

# lists
if origin in (list, typing.List):
if len(args) == 0:
raise ValueError("naked lists not allowed")

next_type = args[0]
arg = to_hiku_type(next_type, lazy_)
return arg.apply(hiku.types.Sequence)

if origin is ref:
ref_ = args[0]
if isinstance(ref_, typing.ForwardRef):
if lazy_ is None:
raise ValueError("need to use hiku.lazy for lazy imports")

return _LazyTypeRef(
classname=ref_.__forward_arg__,
module=lazy_.module,
package=lazy_.package,
)

if not hasattr(ref_, "__hiku_node__"):
raise ValueError("expected ref arg to be a @node")

return raw_type(hiku.types.TypeRef[ref_.__hiku_node__.name])

raise ValueError("invalid hiku type")
Loading

0 comments on commit 7aa8559

Please sign in to comment.