Skip to content

Commit

Permalink
make BoundSymbol.from_bsym call bind_postprocess (#1121)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Sep 7, 2024
1 parent 7b6b2ca commit 3269324
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 16 deletions.
8 changes: 0 additions & 8 deletions examples/ggml-quant/thunder_ggmlquant.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -474,10 +474,6 @@
" )\n",
"\n",
" new_computation_trace.bound_symbols.append(mm_bsym)\n",
" # we need the postprocess to set the internal state (call_ctx) because we do not bind / execute the new symbol to\n",
" # preserve the \"meta\"-info like source location, header, etc.\n",
" # TODO: switch to a better solution when it is there\n",
" ggmlquant_matmul._bind_postprocess(mm_bsym)\n",
" elif bsym.sym == thunder.torch.embedding and id(bsym.args[1]) in quantized_proxies:\n",
" assert len(bsym.args) == 7 # torch.linear(input, weight, bias)\n",
" assert bsym.args[2] is None and bsym.args[3] is None\n",
Expand All @@ -496,10 +492,6 @@
" )\n",
"\n",
" new_computation_trace.bound_symbols.append(emb_bsym)\n",
" # we need the postprocess to set the internal state (call_ctx) because we do not bind / execute the new symbol to\n",
" # preserve the \"meta\"-info like source location, header, etc.\n",
" # TODO: switch to a better solution when it is there\n",
" ggmlquant_embed._bind_postprocess(emb_bsym)\n",
" else:\n",
" new_computation_trace.bound_symbols.append(bsym.from_bsym())\n",
"\n",
Expand Down
6 changes: 4 additions & 2 deletions thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,10 @@ def from_bsym(self, **kwargs) -> BoundSymbol:
}

self_kwargs.update(kwargs)

return BoundSymbol(**self_kwargs)
bsym = BoundSymbol(**self_kwargs)
if bsym.sym._bind_postprocess:
bsym.sym._bind_postprocess(bsym)
return bsym

# NOTE coll must be a Container of "variableified" proxies
def has_input(self, coll) -> bool:
Expand Down
3 changes: 1 addition & 2 deletions thunder/extend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,10 @@ def _meta(*args):
return tuple(outputs)

def _bind_postprocess(bsym: BoundSymbol) -> None:
bsym.subsymbols = tuple(bsyms)
bsym._call_ctx = {name: fn}

sym = Symbol(name=name, meta=_meta, is_fusion=True, _bind_postprocess=_bind_postprocess, executor=self)
return sym.bind(*inputs, output=outputs)
return sym.bind(*inputs, output=outputs, subsymbols=tuple(bsyms))


class OperatorExecutor(Executor):
Expand Down
4 changes: 0 additions & 4 deletions thunder/transforms/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,6 @@ def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilo
)

new_computation_trace.bound_symbols.append(mm_bsym)
# we need the postprocess to set the internal state (call_ctx) because we do not bind / execute the new symbol to
# preserve the "meta"-info like source location, header, etc.
# TODO: switch to a better solution when it is there
bnb_matmul_nf4._bind_postprocess(mm_bsym)
else:
new_computation_trace.bound_symbols.append(bsym.from_bsym())

Expand Down

0 comments on commit 3269324

Please sign in to comment.