Skip to content

Commit

Permalink
avoid flattening non-tensor args of subclass ctor
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Dec 21, 2024
1 parent 409798d commit e9027ba
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1896,12 +1896,19 @@ def __init__(self, *args, **kwargs):
kwarg_non_tensors = kwargs.pop("non_tensors", [])
subclass_type = kwargs.pop("subclass_type", None)

has_name_before_init = hasattr(self, "_name")
# If tensors (and non_tensors) are not empty, then it should be the path of `_make_wrapper_subclass`
# where `self` should already have gotten its name.
flat_args, spec = tree_flatten((args, kwargs))
tensors = list(filter(lambda t: isinstance(t, TensorProxy), flat_args))
non_tensors = list(filter(lambda t: not isinstance(t, TensorProxy), flat_args))
has_name_before_init = hasattr(self, "_name")
tensors: list[TensorProxy] = []
non_tensors: list[Any] = []
for t in args + tuple(kwargs.values()):
if type(t) is SubclassTensorProxy:
continue
if type(t) is TensorProxy:
tensors.append(t)
else:
non_tensors.append(t)

is_dunder_init_following_make_wrapper_subclass: bool = False
if tensors:
Expand Down

0 comments on commit e9027ba

Please sign in to comment.