Skip to content

Commit

Permalink
WIP consteval, only works for pure constants, not params, tt-torch do…
Browse files Browse the repository at this point in the history
…es not handle 0dim tensor inputs
  • Loading branch information
AleksKnezevic authored and LPanosTT committed Dec 13, 2024
1 parent 9bf2e89 commit 8c79856
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 84 deletions.
69 changes: 13 additions & 56 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,56 +37,6 @@
import sys


def run_shape_prop(gm, example_inputs):
shape_prop = torch.fx.passes.shape_prop.ShapeProp(gm)
if shape_prop.fake_mode is not None:
fake_args = [
shape_prop.fake_mode.from_tensor(act, static_shapes=True)
if isinstance(act, torch.Tensor)
else act
for act in example_inputs
]
else:
fake_args = example_inputs
shape_prop.run(*fake_args)


def reduce_graph(module_or_graph: Union[torch.fx.Graph, torch.fx.GraphModule]):
# Reduce the graph to only the nodes that are used

# Traverse up the graph from output nodes to populate consumed nodes set
graph = (
module_or_graph.graph
if isinstance(module_or_graph, torch.fx.GraphModule)
else module_or_graph
)
consumed = set()
working_nodes = []
for node in graph.nodes:
if node.op == "output":
working_nodes.append(node)
consumed.add(node)

while len(working_nodes) > 0:
node = working_nodes.pop(0)
if not isinstance(node, torch.fx.Node):
continue
for arg in node.all_input_nodes:
if arg not in consumed:
consumed.add(arg)
working_nodes.append(arg)

for node in reversed(graph.nodes):
if node not in consumed:
graph.erase_node(node)

if len(graph.nodes) == 1:
for node in graph.nodes:
if node.op == "output":
# Remove the output node if it's the only one
graph.erase_node(node)


def import_graph(graph: torch.fx.GraphModule):
context = Context()
torch_dialect.register_dialect(context)
Expand Down Expand Up @@ -133,9 +83,17 @@ def execute_process(receiver, sender, exec_event):


class Executor:
def __init__(self, gm, compiler_config=None, required_pcc=0.99, required_atol=1e-2):
def __init__(
self,
gm,
graph_constants,
compiler_config=None,
required_pcc=0.99,
required_atol=1e-2,
):
self.gm = gm
self.binary = None
self.graph_constants = tuple(graph_constants)
if compiler_config is None:
compiler_config = CompilerConfig()
self.compiler_config = compiler_config
Expand Down Expand Up @@ -442,7 +400,7 @@ def __call__(self, *inputs):
# No conversion required.
new_inputs = new_inputs + ((input),)

inputs = new_inputs
inputs = new_inputs + self.graph_constants

if self.compiler_config.compile_depth == CompileDepth.EXECUTE:
assert self.binary is not None, "Binary must be set for EXECUTE mode"
Expand All @@ -457,14 +415,13 @@ def __call__(self, *inputs):


def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config):
gm = pass_pipeline(gm, example_inputs)
reduce_graph(gm)
gm, graph_constants = pass_pipeline(gm, example_inputs, compiler_config)
gm.graph.print_tabular()
run_shape_prop(gm, example_inputs)
executor = Executor(gm, compiler_config)
executor = Executor(gm, graph_constants, compiler_config)
if compiler_config.compile_depth in (
CompileDepth.EXECUTE_OP_BY_OP,
CompileDepth.COMPILE_OP_BY_OP,
CompileDepth.TORCH_FX,
):
return executor

Expand Down
105 changes: 83 additions & 22 deletions tt_torch/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
# SPDX-License-Identifier: Apache-2.0
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental import const_fold
from torch._decomp import get_decompositions
from torch.func import functionalize
from typing import List, Optional
import traceback
from typing import List, Optional, Union

from .decompositions import (
DecompositionTable,
Expand All @@ -15,6 +15,56 @@
)


def run_shape_prop(gm, example_inputs):
shape_prop = torch.fx.passes.shape_prop.ShapeProp(gm)
if shape_prop.fake_mode is not None:
fake_args = [
shape_prop.fake_mode.from_tensor(act, static_shapes=True)
if isinstance(act, torch.Tensor)
else act
for act in example_inputs
]
else:
fake_args = example_inputs
shape_prop.run(*fake_args)


def reduce_graph(module_or_graph: Union[torch.fx.Graph, torch.fx.GraphModule]):
# Reduce the graph to only the nodes that are used

# Traverse up the graph from output nodes to populate consumed nodes set
graph = (
module_or_graph.graph
if isinstance(module_or_graph, torch.fx.GraphModule)
else module_or_graph
)
consumed = set()
working_nodes = []
for node in graph.nodes:
if node.op == "output":
working_nodes.append(node)
consumed.add(node)

while len(working_nodes) > 0:
node = working_nodes.pop(0)
if not isinstance(node, torch.fx.Node):
continue
for arg in node.all_input_nodes:
if arg not in consumed:
consumed.add(arg)
working_nodes.append(arg)

for node in reversed(graph.nodes):
if node not in consumed:
graph.erase_node(node)

if len(graph.nodes) == 1:
for node in graph.nodes:
if node.op == "output":
# Remove the output node if it's the only one
graph.erase_node(node)


def apply_decompositions(
gm: torch.fx.GraphModule,
example_inputs,
Expand All @@ -35,25 +85,36 @@ def apply_decompositions(
return gm


def pass_pipeline(gm: torch.fx.GraphModule, example_inputs):
decompose_ops = DEFAULT_DECOMPOSITION_TABLE
decompose_ops.update(CUSTOM_DECOMPOSITION_TABLE)
try:
# Convert SymInt to concrete int if possible
concrete_inputs = []
for inp in example_inputs:
if isinstance(inp, torch.Tensor):
# Convert any SymInt dimensions to concrete integers
concrete_shape = tuple(
int(dim) if hasattr(dim, "node") else dim for dim in inp.shape
)
concrete_inp = inp.view(concrete_shape)
concrete_inputs.append(concrete_inp)
def constant_fold(gm, example_inputs):
gm = const_fold.split_const_subgraphs(gm)

# run the module to generate the consteval constants
_ = gm(*example_inputs)
for node in gm.graph.nodes:
if node.op == "get_attr" and node.name == "_fx_const_folded_attrs":
gm.graph.inserting_before(node)
# loop through the get_item nodes
if isinstance(gm._FX_CONST_FOLDED_ATTRS, torch.Tensor):
placeholder = gm.graph.placeholder(f"_fx_const_folded_attrs")
node.replace_all_uses_with(placeholder)
graph_constants = [gm._FX_CONST_FOLDED_ATTRS.data]
else:
concrete_inputs.append(inp)
for idx, (key, value) in enumerate(node.users.items()):
placeholder = gm.graph.placeholder(f"_fx_const_folded_attrs_{idx}")
key.replace_all_uses_with(placeholder)
graph_constants = [param.data for param in gm._FX_CONST_FOLDED_ATTRS]

gm.graph.eliminate_dead_code()
return gm, graph_constants


return apply_decompositions(gm, concrete_inputs, decompose_ops)
except Exception as e:
print(f"Pass pipeline error: {e}")
print(traceback.format_exc())
raise
def pass_pipeline(gm: torch.fx.GraphModule, example_inputs, compiler_config):
decompose_ops = DEFAULT_DECOMPOSITIONS
gm = apply_decompositions(gm, example_inputs, decompose_ops) # type: ignore
if compiler_config.enable_costeval:
gm, graph_constants = constant_fold(gm, example_inputs)
else:
graph_constants = []
reduce_graph(gm)
run_shape_prop(gm, example_inputs + graph_constants)
return gm, graph_constants
17 changes: 11 additions & 6 deletions tt_torch/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@


class CompileDepth(Enum):
TORCH_MLIR = 1
STABLEHLO = 2
TTNN_IR = 3
COMPILE_OP_BY_OP = 4
EXECUTE_OP_BY_OP = 5
EXECUTE = 6
TORCH_FX = 1
TORCH_MLIR = 2
STABLEHLO = 3
TTNN_IR = 4
COMPILE_OP_BY_OP = 5
EXECUTE_OP_BY_OP = 6
EXECUTE = 7


class OpCompilationStatus(IntEnum):
Expand Down Expand Up @@ -125,6 +126,7 @@ def __init__(self):
self.results_path = "results/models/"
self.single_op_timeout = 5
self.enable_intermediate_verification = False
self.enable_costeval = False

self.apply_environment_overrides()

Expand All @@ -135,6 +137,9 @@ def apply_environment_overrides(self):
verify_intermediates = os.environ.get("TT_TORCH_VERIFY_INTERMEDIATES")
if verify_intermediates:
self.enable_intermediate_verification = True
enable_costeval = os.environ.get("TT_TORCH_CONSTEVAL")
if enable_costeval:
self.enable_costeval = True

def save_unique_ops(self):
unique_op_dict = {}
Expand Down

0 comments on commit 8c79856

Please sign in to comment.