From 16bbe283c09b8b69b892ff6654648d31c09a58ee Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Thu, 19 Dec 2024 15:08:10 +0100 Subject: [PATCH] fix: test_splitter_autograd_function --- thunder/tests/test_dynamo.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 4bcd2333f4..1c90f5b0be 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -274,7 +274,7 @@ def forward(ctx, x): @staticmethod def backward(ctx, g): (x,) = ctx.saved_tensors - return g * torch.cos(x) + return g * torch.cos(x) * 100 def func(x): y = torch.cos(x) + Sin.apply(x) @@ -286,9 +286,16 @@ def func(x): actual = cfunc(x) backend = cfunc._backend - targets = (node.target for node in backend.subgraph_infos[0].split_graph_module.graph.nodes) - assert any(target.startswith("thunder_") for target in targets) - assert any(target.startswith("inductor_") for target in targets) + assert len(backend.subgraph_infos) == 1 # no graph break in dynamo + subgraph_info = backend.subgraph_infos[0] + assert len(subgraph_info.split_reasons) == 0 # no split + assert len(subgraph_info.thunder_compiled_fns) == 1 + jfunc = subgraph_info.thunder_compiled_fns[0] + trc = last_traces(jfunc)[0] + assert any( + isinstance(bsym.sym.id, str) and bsym.sym.id.startswith("higher_order_autograd_function_apply") + for bsym in trc.bound_symbols + ) # Verify forward pass torch.testing.assert_close(actual, expected)