From afd52649330a4c0028399075876ab6712b611b74 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 8 Jan 2025 22:11:43 +0100 Subject: [PATCH] fix test_grad::test_forward_and_backward_from_trace --- thunder/tests/test_grad.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index dcea739d2..37c211a01 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1180,14 +1180,12 @@ def func(a, b, *, c): a = make_tensor((2, 3), device=device, dtype=torch.float64, requires_grad=True) b = make_tensor((2, 3), device=device, dtype=torch.float64, requires_grad=True) c = make_tensor((3,), device=device, dtype=torch.float64, requires_grad=True) - jfn = thunder.jit(func) - cd, inps, _ = thunder.compile_data(jfn).get_computation_and_inputs(a, b, c=c) - initial_trace = cd.computation_traces[0] + initial_trace = trace(inline_trace=False)(func, a, b, c=c) wrapped_trace = wrap_return_value_together_with_arguments(initial_trace) fw_trace, bw_trace = forward_and_backward_from_trace(wrapped_trace) fw = executor.make_callable(fw_trace) bw = executor.make_callable(bw_trace) - fw_out, saved_for_backward = fw(*inps) + fw_out, saved_for_backward = fw(a, b, c=c) initial_trace = trace()(value_and_grad(func), a, b, c=c) expected_vjp_func = executor.make_callable(initial_trace.python_callable(), disable_torch_autograd=True) @@ -1197,7 +1195,6 @@ def func(a, b, *, c): output_grads = tree_map(lambda x: torch.ones_like(x), fw_out["output"]) bw_out = bw(saved_for_backward, output_grads) - expected_grads = (*expected_grads[:-1], expected_grads[-1]["c"]) torch.testing.assert_close(bw_out, expected_grads)