diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index 992d1e27ec..6557df1802 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -2513,7 +2513,9 @@ def __getitem__(self, key): except Exception as e: return do_raise(e) - populate_single_dict_item_wrapper(uv, self, key.value) + from thunder.core.proxies import Proxy + + populate_single_dict_item_wrapper(uv, self, key if isinstance(key.value, Proxy) else key.value) v = self.item_wrappers[key.value] assert uv is v.value or uv is v.original_value, f"value for {key.value} out of sync {uv} {v.value}" return v diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index c7e30a0c58..b9e304e681 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -1029,7 +1029,6 @@ def forward(self, x): ids=("remove_duplicate=False", "remove_duplicate=True"), ) def test_named_params_and_named_buffers(prefix, recurse, remove_duplicate): - buffer_tensor = torch.tensor([1.0]) class SubMod(torch.nn.Module): @@ -1143,7 +1142,6 @@ def test_custom_autograd_function(): from torch.testing._internal.common_utils import gradcheck class MyFunction(torch.autograd.Function): - @staticmethod def forward(ctx, x: torch.Tensor) -> torch.Tensor: return x * 2.0 @@ -1206,7 +1204,6 @@ def forward(self, x): def test_autograd_function_apply(): - def forward(ctx, x): saved_for_backward = (x,) return x.sin(), saved_for_backward @@ -1275,7 +1272,6 @@ def my_sin_with_wrong_backward(x): def test_autograd_function_empty_forward(): - class Fn(torch.autograd.Function): @staticmethod def forward(self, x): @@ -1464,3 +1460,30 @@ def foo(a): expected = foo(a) assert_close(actual, expected) + + +def test_cache_symbolic_values_dict(): + def foo(a, v): + return a[v].relu() + + jfoo = thunder.jit(foo, cache="symbolic values") + + a = { + 2: torch.randn(2, 3, 8, requires_grad=True, device="cpu"), + 5: torch.randn(4, 8, requires_grad=True, device="cpu"), + } + + actual = jfoo(a, 2) + expected = foo(a, 2) + + assert_close(actual, expected) + + b = { + "a": torch.randn(2, 8, requires_grad=True, device="cpu"), + "b": torch.randn(7, requires_grad=True, device="cpu"), + } + + actual = jfoo(b, "b") + expected = foo(b, "b") + + assert_close(actual, expected)