Skip to content

Commit

Permalink
feat: generalizable abs, typeable supports_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
supersergiy committed May 15, 2024
1 parent 3d05654 commit 4a564ee
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 49 deletions.
74 changes: 33 additions & 41 deletions mypy_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,54 +25,46 @@
pass


def supports_dict_callback(ctx):
original_function_type = ctx.arg_types[0][0]
# def supports_dict_callback(ctx):
# original_function_type = ctx.arg_types[0][0]

if not isinstance(original_function_type, CallableType):
ctx.api.fail("Argument to 'supports_dict' must be callable", ctx.context)
return original_function_type
# if not isinstance(original_function_type, CallableType):
# ctx.api.fail("Argument to 'supports_dict' must be callable", ctx.context)
# return original_function_type

if not original_function_type.arg_types:
ctx.api.fail("Function must have at least one argument", ctx.context)
return original_function_type
# if not original_function_type.arg_types:
# ctx.api.fail("Function must have at least one argument", ctx.context)
# return original_function_type

original_arg_type = original_function_type.arg_types[0]
# original_arg_type = original_function_type.arg_types[0]

if isinstance(original_arg_type, AnyType):
ctx.api.fail("The first argument must be annotated", ctx.context)
return original_function_type
# if isinstance(original_arg_type, AnyType):
# ctx.api.fail("The first argument must be annotated", ctx.context)
# return original_function_type

# if (
# original_arg_type.type.fullname == "builtins.dict"
# or original_arg_type.type.fullname == "typing.Mapping"
# ):
# ctx.api.fail(
# "The first argument must not be of type 'dict' or 'Mapping'", ctx.context
# )
# return True

original_ret_type = original_function_type.ret_type
# original_ret_type = original_function_type.ret_type

str_type = ctx.api.named_type("builtins.str")
dict_type = ctx.api.named_type("builtins.dict")
mapping_type = ctx.api.named_type("typing.Mapping")
# str_type = ctx.api.named_type("builtins.str")
# dict_type = ctx.api.named_type("builtins.dict")
# mapping_type = ctx.api.named_type("typing.Mapping")

dict_instance = Instance(
dict_type.type,
args=[str_type, original_arg_type],
)
mapping_instance = Instance(mapping_type.type, args=[str_type, original_ret_type])
# dict_instance = Instance(
# dict_type.type,
# args=[str_type, original_arg_type],
# )
# mapping_instance = Instance(mapping_type.type, args=[str_type, original_ret_type])

overload_2 = original_function_type.copy_modified(
arg_types=[mapping_instance] + original_function_type.arg_types[1:],
arg_kinds=[ARG_POS] + original_function_type.arg_kinds[1:],
arg_names=["data"] + original_function_type.arg_names[1:],
ret_type=dict_instance,
)
# overload_2 = original_function_type.copy_modified(
# arg_types=[mapping_instance] + original_function_type.arg_types[1:],
# arg_kinds=[ARG_POS] + original_function_type.arg_kinds[1:],
# arg_names=["data"] + original_function_type.arg_names[1:],
# ret_type=dict_instance,
# )

overloaded_type = Overloaded([original_function_type, overload_2])
# overloaded_type = Overloaded([original_function_type, overload_2])

return overloaded_type
# return overloaded_type


def task_maker_cls_callback(ctx): # pragma: no cover # type: ignore
Expand Down Expand Up @@ -161,10 +153,10 @@ def get_class_decorator_hook_2(
return flow_schema_cls_callback
return None

def get_function_hook(self, fullname: str):
if fullname == "zetta_utils.tensor_ops.common.supports_dict":
return supports_dict_callback
return None
# def get_function_hook(self, fullname: str):
# if fullname == "zetta_utils.tensor_ops.common.supports_dict":
# return supports_dict_callback
# return None


def plugin(version): # pragma: no cover
Expand Down
51 changes: 43 additions & 8 deletions zetta_utils/tensor_ops/common.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
# pylint: disable=missing-docstring
from typing import (
Any,
Callable,
Generic,
Literal,
Mapping,
Optional,
Sequence,
SupportsIndex,
TypeVar,
Union,
overload,
)

import attrs
import einops
import numpy as np
import tinybrain
Expand All @@ -24,20 +28,39 @@
T = TypeVar("T")


def supports_dict(func: Callable[Concatenate[T, P], T]):
def wrapper(data, *args: P.args, **kwargs: P.kwargs):
if isinstance(data, Mapping): # pylint: disable=all # p311
return {k: func(v, *args, **kwargs) for k, v in data.items()}
@attrs.frozen
class DictSupportingTensorOp(Generic[P, TensorTypeVar]):
fn: Callable[Concatenate[TensorTypeVar, P], TensorTypeVar]

@overload
def __call__(self, data: TensorTypeVar, *args: P.args, **kwargs: P.kwargs) -> TensorTypeVar:
...

@overload
def __call__(
self, data: Mapping[Any, TensorTypeVar], *args: P.args, **kwargs: P.kwargs
) -> dict[Any, TensorTypeVar]:
...

def __call__(self, data, *args: P.args, **kwargs: P.kwargs):
if isinstance(data, Mapping):
return {k: self.fn(v, *args, **kwargs) for k, v in data.items()}
else:
return func(data, *args, **kwargs)
return self.fn(data, *args, **kwargs)


return wrapper
def supports_dict(
fn: Callable[Concatenate[TensorTypeVar, P], TensorTypeVar]
) -> DictSupportingTensorOp[P, TensorTypeVar]:
return DictSupportingTensorOp[P, TensorTypeVar](fn)


@builder.register("rearrange")
@supports_dict
def rearrange(data: TensorTypeVar, **kwargs) -> TensorTypeVar: # pragma: no cover
return einops.rearrange(tensor=data, **kwargs) # type: ignore # bad typing by einops
def rearrange(data: TensorTypeVar, pattern: str, **kwargs) -> TensorTypeVar: # pragma: no cover
return einops.rearrange( # type: ignore # bad typing by einops
tensor=data, pattern=pattern, **kwargs
)


@builder.register("reduce")
Expand Down Expand Up @@ -611,3 +634,15 @@ def tensor_op_chain(
for step in steps:
result = step(result)
return result


@builder.register("abs")
@typechecked
@supports_dict
def abs( # pragma: no cover # pylint: disable=redefined-builtin
data: TensorTypeVar,
) -> TensorTypeVar:
if isinstance(data, torch.Tensor):
return data.abs()
else:
return np.abs(data)

0 comments on commit 4a564ee

Please sign in to comment.