From 32693249987182f2d691543b5298539c8d901f1f Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Sat, 7 Sep 2024 20:42:41 +0200 Subject: [PATCH] make BoundSymbol.from_bsym call bind_postprocess (#1121) --- examples/ggml-quant/thunder_ggmlquant.ipynb | 8 -------- thunder/core/symbol.py | 6 ++++-- thunder/extend/__init__.py | 3 +-- thunder/transforms/quantization.py | 4 ---- 4 files changed, 5 insertions(+), 16 deletions(-) diff --git a/examples/ggml-quant/thunder_ggmlquant.ipynb b/examples/ggml-quant/thunder_ggmlquant.ipynb index bdb07c3816..3bcf53b884 100644 --- a/examples/ggml-quant/thunder_ggmlquant.ipynb +++ b/examples/ggml-quant/thunder_ggmlquant.ipynb @@ -474,10 +474,6 @@ " )\n", "\n", " new_computation_trace.bound_symbols.append(mm_bsym)\n", - " # we need the postprocess to set the internal state (call_ctx) because we do not bind / execute the new symbol to\n", - " # preserve the \"meta\"-info like source location, header, etc.\n", - " # TODO: switch to a better solution when it is there\n", - " ggmlquant_matmul._bind_postprocess(mm_bsym)\n", " elif bsym.sym == thunder.torch.embedding and id(bsym.args[1]) in quantized_proxies:\n", " assert len(bsym.args) == 7 # torch.linear(input, weight, bias)\n", " assert bsym.args[2] is None and bsym.args[3] is None\n", @@ -496,10 +492,6 @@ " )\n", "\n", " new_computation_trace.bound_symbols.append(emb_bsym)\n", - " # we need the postprocess to set the internal state (call_ctx) because we do not bind / execute the new symbol to\n", - " # preserve the \"meta\"-info like source location, header, etc.\n", - " # TODO: switch to a better solution when it is there\n", - " ggmlquant_embed._bind_postprocess(emb_bsym)\n", " else:\n", " new_computation_trace.bound_symbols.append(bsym.from_bsym())\n", "\n", diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index 0796953a28..da4fc98c1a 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -398,8 +398,10 @@ def from_bsym(self, **kwargs) -> BoundSymbol: } self_kwargs.update(kwargs) - - return BoundSymbol(**self_kwargs) + bsym = BoundSymbol(**self_kwargs) + if bsym.sym._bind_postprocess: + bsym.sym._bind_postprocess(bsym) + return bsym # NOTE coll must be a Container of "variableified" proxies def has_input(self, coll) -> bool: diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py index de7d60af6b..1551bfb880 100644 --- a/thunder/extend/__init__.py +++ b/thunder/extend/__init__.py @@ -247,11 +247,10 @@ def _meta(*args): return tuple(outputs) def _bind_postprocess(bsym: BoundSymbol) -> None: - bsym.subsymbols = tuple(bsyms) bsym._call_ctx = {name: fn} sym = Symbol(name=name, meta=_meta, is_fusion=True, _bind_postprocess=_bind_postprocess, executor=self) - return sym.bind(*inputs, output=outputs) + return sym.bind(*inputs, output=outputs, subsymbols=tuple(bsyms)) class OperatorExecutor(Executor): diff --git a/thunder/transforms/quantization.py b/thunder/transforms/quantization.py index 8d2534e19b..7c6c8dcde9 100644 --- a/thunder/transforms/quantization.py +++ b/thunder/transforms/quantization.py @@ -287,10 +287,6 @@ def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilo ) new_computation_trace.bound_symbols.append(mm_bsym) - # we need the postprocess to set the internal state (call_ctx) because we do not bind / execute the new symbol to - # preserve the "meta"-info like source location, header, etc. - # TODO: switch to a better solution when it is there - bnb_matmul_nf4._bind_postprocess(mm_bsym) else: new_computation_trace.bound_symbols.append(bsym.from_bsym())