Skip to content

Commit

Permalink
support dynamic getitem in dictionary for symbolic values (#1450)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jjsjann123 and pre-commit-ci[bot] authored Dec 2, 2024
1 parent 8eddb9f commit 20216b8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
4 changes: 3 additions & 1 deletion thunder/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2513,7 +2513,9 @@ def __getitem__(self, key):
except Exception as e:
return do_raise(e)

populate_single_dict_item_wrapper(uv, self, key.value)
from thunder.core.proxies import Proxy

populate_single_dict_item_wrapper(uv, self, key if isinstance(key.value, Proxy) else key.value)
v = self.item_wrappers[key.value]
assert uv is v.value or uv is v.original_value, f"value for {key.value} out of sync {uv} {v.value}"
return v
Expand Down
31 changes: 27 additions & 4 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,6 @@ def forward(self, x):
ids=("remove_duplicate=False", "remove_duplicate=True"),
)
def test_named_params_and_named_buffers(prefix, recurse, remove_duplicate):

buffer_tensor = torch.tensor([1.0])

class SubMod(torch.nn.Module):
Expand Down Expand Up @@ -1143,7 +1142,6 @@ def test_custom_autograd_function():
from torch.testing._internal.common_utils import gradcheck

class MyFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
return x * 2.0
Expand Down Expand Up @@ -1206,7 +1204,6 @@ def forward(self, x):


def test_autograd_function_apply():

def forward(ctx, x):
saved_for_backward = (x,)
return x.sin(), saved_for_backward
Expand Down Expand Up @@ -1275,7 +1272,6 @@ def my_sin_with_wrong_backward(x):


def test_autograd_function_empty_forward():

class Fn(torch.autograd.Function):
@staticmethod
def forward(self, x):
Expand Down Expand Up @@ -1464,3 +1460,30 @@ def foo(a):
expected = foo(a)

assert_close(actual, expected)


def test_cache_symbolic_values_dict():
def foo(a, v):
return a[v].relu()

jfoo = thunder.jit(foo, cache="symbolic values")

a = {
2: torch.randn(2, 3, 8, requires_grad=True, device="cpu"),
5: torch.randn(4, 8, requires_grad=True, device="cpu"),
}

actual = jfoo(a, 2)
expected = foo(a, 2)

assert_close(actual, expected)

b = {
"a": torch.randn(2, 8, requires_grad=True, device="cpu"),
"b": torch.randn(7, requires_grad=True, device="cpu"),
}

actual = jfoo(b, "b")
expected = foo(b, "b")

assert_close(actual, expected)

0 comments on commit 20216b8

Please sign in to comment.