Skip to content

Commit

Permalink
update more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Jan 9, 2025
1 parent 5ce5f60 commit ed1af68
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions thunder/tests/test_examine_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,5 @@ def test_nanogpt_block():
# We are checking the estimated memory against a fixed value for consistency.
assert max_mem_fw[0] == 262183936
assert sum(max_mem_fw[1].values()) == 135306240
assert max_mem_bw[0] == 484833280
assert sum(max_mem_bw[1].values()) == 169915392
assert max_mem_bw[0] == 375516160
assert sum(max_mem_bw[1].values()) == 40934400
3 changes: 2 additions & 1 deletion thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1943,7 +1943,8 @@ def test_backward_recomputation_decomposed_ops(device):
def fn(a):
return torch.nn.functional.gelu(a)

jfn = thunder.jit(fn, enable_saved_for_backward_recomputation=False)
# rematerialization will also trigger recomputation here.
jfn = thunder.jit(fn, executors=(), enable_saved_for_backward_recomputation=False)
jfn2 = thunder.jit(fn, enable_saved_for_backward_recomputation=True)
a = torch.randn(2, 2, device=device, requires_grad=True)
res = jfn(a)
Expand Down

0 comments on commit ed1af68

Please sign in to comment.