Skip to content

Commit

Permalink
[lookaside][torch.autograd.Function lookaside] use shallow_copy iff…
Browse files Browse the repository at this point in the history
… forward is empty (#1485)
  • Loading branch information
crcrpar authored Dec 10, 2024
1 parent 087637f commit 9de5434
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,9 +636,12 @@ def _convert_pytorchfunc_to_thundertrace(
return wrapped_func_result, None

trace = TraceCtx()
trace.bound_symbols.extend(active_jit_ctx.computation_trace.pop_scope())
bsyms = active_jit_ctx.computation_trace.pop_scope()
trace.bound_symbols.extend(bsyms)
func_result = unwrap(wrapped_func_result)
if shallow_copy_output:
if shallow_copy_output and not bsyms:
from thunder.core.baseutils import sequencify

out_to_shallow_copy: dict[Variable, TensorProxy] = {}
for a in sequencify(func_result):
shallow_copy_of_a = prims.shallow_copy.meta(a)
Expand Down

0 comments on commit 9de5434

Please sign in to comment.