Skip to content

Commit

Permalink
failing test case as starter
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Nov 5, 2024
1 parent 21c2af8 commit 4e05dae
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion thunder/tests/test_tensor_subclass.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import annotations
from typing import TYPE_CHECKING

import pytest
import torch
from torch.utils import _pytree as pytree

import thunder
from thunder.tests.framework import instantiate
from thunder.core.proxies import SubclassTensorProxy
from thunder.tests.framework import instantiate, nvFuserExecutor
from thunder.tests.make_tensor import make_tensor

if TYPE_CHECKING:
from typing import Any
from thunder.core.symbol import BoundSymbol


@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -163,6 +167,7 @@ def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass:

dtype = torch.float32
shape = (2, 2)

x = make_tensor(shape, device=device, dtype=dtype)
scale = make_tensor((), device=device, dtype=dtype)

Expand All @@ -180,3 +185,41 @@ def g(x: torch.Tensor) -> ScaleTensorSubclass:
expected = g(x)
actual = jitted(x)
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))


@instantiate(
dtypes=(thunder.core.dtypes.float32,),
)
def test_func_of_subclass_simple_math(executor, device, _):

def f(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass:

y = ScaleTensorSubclass(data, scale)
out = x + y
return out

jitted = executor.make_callable(f)

dtype = torch.float32
shape = (2, 2)
x = ScaleTensorSubclass(
make_tensor(shape, device=device, dtype=dtype),
make_tensor((), device=device, dtype=dtype),
)
data = make_tensor(shape, device=device, dtype=dtype)
scale = make_tensor((), device=device, dtype=dtype)

expected = f(x, data, scale)
actual = jitted(x, data, scale)
if executor == nvFuserExecutor:
with pytest.raises(Exception):
assert type(expected) is type(actual)
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))
else:
assert type(expected) is type(actual)
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))

return_bsym: BoundSymbol = thunder.last_traces(jitted)[-1].bound_symbols[-1]
return_proxy = return_bsym.flat_args[0]
# FIXME(crcrpar): Implement a trace transform that corrects the output type of bsyms involving tensor subclasses
assert not isinstance(return_proxy, SubclassTensorProxy)

0 comments on commit 4e05dae

Please sign in to comment.