Skip to content

Commit

Permalink
Enabled consteval on params, unsqueeze 0dim tensor, TODO: Add flag wh…
Browse files Browse the repository at this point in the history
…ether to enable on pure constants or params too
  • Loading branch information
AleksKnezevic authored and LPanosTT committed Dec 13, 2024
1 parent 8c79856 commit cf9922a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

@pytest.fixture(autouse=True)
def run_around_tests():
torch._dynamo.config.inline_inbuilt_nn_modules = False
torch.manual_seed(0)
yield
torch._dynamo.reset()
39 changes: 31 additions & 8 deletions tt_torch/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,34 +77,56 @@ def apply_decompositions(
if decompositions is None:
return gm

gm = make_fx(
functionalize(gm),
decomposition_table=decompositions,
)(*example_inputs)
with torch.no_grad():
decompositions = get_decompositions(decompose_ops)
fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(example_inputs)
fake_tensor_mode.allow_non_fake_inputs = True
gm = make_fx(
gm,
tracing_mode="symbolic",
_allow_non_fake_inputs=True,
decomposition_table=decompositions,
)(*example_inputs)

return gm


def bypass_redundant_getitem(gm):
for node in gm.graph.nodes:
if node.op == "call_function" and "getitem" in node.name:
if isinstance(node.args[0], tuple):
idx = node.args[1]
if isinstance(idx, int):
node.replace_all_uses_with(node.args[0][idx])
return gm


def constant_fold(gm, example_inputs):
gm = const_fold.split_const_subgraphs(gm)
gm.run_folding()
graph_constants = []

# run the module to generate the consteval constants
_ = gm(*example_inputs)
for node in gm.graph.nodes:
if node.op == "get_attr" and node.name == "_fx_const_folded_attrs":
gm.graph.inserting_before(node)
# loop through the get_item nodes
if isinstance(gm._FX_CONST_FOLDED_ATTRS, torch.Tensor):
placeholder = gm.graph.placeholder(f"_fx_const_folded_attrs")
node.replace_all_uses_with(placeholder)
if len(gm._FX_CONST_FOLDED_ATTRS.data.shape) == 0:
graph_constants = (
gm._FX_CONST_FOLDED_ATTRS.data
) = gm._FX_CONST_FOLDED_ATTRS.data.unsqueeze(0)
graph_constants = [gm._FX_CONST_FOLDED_ATTRS.data]
else:
for idx, (key, value) in enumerate(node.users.items()):
for idx, key in enumerate(node.users.keys()):
placeholder = gm.graph.placeholder(f"_fx_const_folded_attrs_{idx}")
key.replace_all_uses_with(placeholder)
for param in gm._FX_CONST_FOLDED_ATTRS:
if len(param.data.shape) == 0:
param.data = param.data.unsqueeze(0)
graph_constants = [param.data for param in gm._FX_CONST_FOLDED_ATTRS]

gm.graph.eliminate_dead_code()
return gm, graph_constants


Expand All @@ -115,6 +137,7 @@ def pass_pipeline(gm: torch.fx.GraphModule, example_inputs, compiler_config):
gm, graph_constants = constant_fold(gm, example_inputs)
else:
graph_constants = []
gm = bypass_redundant_getitem(gm)
reduce_graph(gm)
run_shape_prop(gm, example_inputs + graph_constants)
return gm, graph_constants

0 comments on commit cf9922a

Please sign in to comment.