diff --git a/thunder/tests/test_examine_memory.py b/thunder/tests/test_examine_memory.py index 80c0b5b83..4a7b40d19 100644 --- a/thunder/tests/test_examine_memory.py +++ b/thunder/tests/test_examine_memory.py @@ -115,5 +115,5 @@ def test_nanogpt_block(): # We are checking the estimated memory against a fixed value for consistency. assert max_mem_fw[0] == 262183936 assert sum(max_mem_fw[1].values()) == 135306240 - assert max_mem_bw[0] == 484833280 - assert sum(max_mem_bw[1].values()) == 169915392 + assert max_mem_bw[0] == 375516160 + assert sum(max_mem_bw[1].values()) == 40934400 diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 37c211a01..406af2994 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1943,7 +1943,8 @@ def test_backward_recomputation_decomposed_ops(device): def fn(a): return torch.nn.functional.gelu(a) - jfn = thunder.jit(fn, enable_saved_for_backward_recomputation=False) + # rematerialization will also trigger recomputation here. + jfn = thunder.jit(fn, executors=(), enable_saved_for_backward_recomputation=False) jfn2 = thunder.jit(fn, enable_saved_for_backward_recomputation=True) a = torch.randn(2, 2, device=device, requires_grad=True) res = jfn(a)