Skip to content

Commit

Permalink
fix: test_splitter_autograd_function
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Dec 19, 2024
1 parent 0a6a39b commit 16bbe28
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def forward(ctx, x):
@staticmethod
def backward(ctx, g):
(x,) = ctx.saved_tensors
return g * torch.cos(x)
return g * torch.cos(x) * 100

def func(x):
y = torch.cos(x) + Sin.apply(x)
Expand All @@ -286,9 +286,16 @@ def func(x):
actual = cfunc(x)

backend = cfunc._backend
targets = (node.target for node in backend.subgraph_infos[0].split_graph_module.graph.nodes)
assert any(target.startswith("thunder_") for target in targets)
assert any(target.startswith("inductor_") for target in targets)
assert len(backend.subgraph_infos) == 1 # no graph break in dynamo
subgraph_info = backend.subgraph_infos[0]
assert len(subgraph_info.split_reasons) == 0 # no split
assert len(subgraph_info.thunder_compiled_fns) == 1
jfunc = subgraph_info.thunder_compiled_fns[0]
trc = last_traces(jfunc)[0]
assert any(
isinstance(bsym.sym.id, str) and bsym.sym.id.startswith("higher_order_autograd_function_apply")
for bsym in trc.bound_symbols
)

# Verify forward pass
torch.testing.assert_close(actual, expected)
Expand Down

0 comments on commit 16bbe28

Please sign in to comment.