Skip to content

Commit

Permalink
Move run gm op by op to torch backend, do self.gm(*inputs) when torch…
Browse files Browse the repository at this point in the history
… is run op by op in stablehlo backend
  • Loading branch information
ddilbazTT committed Feb 25, 2025
1 parent 3721dc6 commit 8ba1e0f
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 101 deletions.
1 change: 1 addition & 0 deletions tests/models/resnet/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_resnet(record_property, mode, op_by_op):
cc.consteval_parameters = True
if op_by_op:
cc.compile_depth = CompileDepth.EXECUTE_OP_BY_OP
cc.op_by_op_backend = OpByOpBackend.STABLEHLO
tester = ThisTester(
model_name,
mode,
Expand Down
100 changes: 0 additions & 100 deletions tt_torch/dynamo/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,103 +351,3 @@ def run_op(self, binary, *inputs):
file_stderr.close()

return outputs, stderr_data

def run_gm_op_by_op(self, *inputs):
node_to_tensor = {}
input_index = 0
outputs = []
num_nodes = len(self.gm.graph.nodes)
out_degree = {}
for idx, node in enumerate(self.gm.graph.nodes):
print(f"Compiling {idx}/{num_nodes}: {node.target}")
out_degree[node] = len(node.users)
if node.op == "placeholder":
node_to_tensor[node] = inputs[input_index]
input_index += 1
elif node.op == "get_attr":
for buffer in self.gm.named_buffers():
if buffer[0] == node.target:
node_to_tensor[node] = buffer[1]
break
elif node.op == "call_function":
args = []
for arg in node.args:
if isinstance(arg, torch.fx.node.Node):
args.append(node_to_tensor[arg])
elif isinstance(arg, list):
args.append(
[
node_to_tensor[a]
if isinstance(a, torch.fx.node.Node)
else a
for a in arg
]
)
else:
args.append(arg)
binary = None
if self.compiler_config.op_by_op_backend == OpByOpBackend.TORCH:
try:
binary, op = self.compile_op(node, *args, **node.kwargs)
except Exception as e:
binary = None
print(
f"Failed to compile {idx}/{num_nodes}: {node.target}: {e}"
)

if (
self.compiler_config.compile_depth == CompileDepth.EXECUTE_OP_BY_OP
and binary is not None
):
try:
calculated, runtime_stack_dump = self.run_op(binary, *args)
self.compiler_config.unique_ops[
op.unique_key()
].runtime_stack_dump = runtime_stack_dump

print(f"Ran: {idx}/{num_nodes}: {node.target}")
if calculated is None:
raise ValueError("Failed to execute")
op.compilation_status = OpCompilationStatus.EXECUTED
tensor = node.target(*args, **node.kwargs)
if self.compiler_config.verify_op_by_op:
atol = calculate_atol(calculated, tensor)
op.atol = atol
if atol > self.required_atol:
print(f"atol too high for {idx}: {atol}")
pcc = calculate_pcc(calculated, tensor)
op.pcc = pcc
if pcc < self.required_pcc:
print(f"pcc too low for {idx}: {pcc}")
except Exception as e:
print(
f"Failed to execute {idx}/{num_nodes}: {node.target}: {e}"
)
tensor = node.target(*args, **node.kwargs)
else:
tensor = node.target(*args, **node.kwargs)
node_to_tensor[node] = tensor
elif node.op == "output":
args = node.args[0]
output_tensors = [node_to_tensor[arg] for arg in args]
outputs = output_tensors
args_set = set()
for arg in node.args:
if arg in args_set:
continue
args_set.add(arg)
if isinstance(arg, torch.fx.node.Node):
out_degree[arg] -= 1
if out_degree[arg] == 0 and arg.op != "output":
del node_to_tensor[arg]
out_degree.pop(arg)

self.compiler_config.save_unique_ops()
if self.execute_process is not None:
self.execute_process.terminate()
self.execute_process = None
if self.stderror_redirected:
os.unlink(self.file_stderr.name)
self.stderror_redirected = False

return outputs
2 changes: 1 addition & 1 deletion tt_torch/dynamo/shlo_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def __call__(self, *inputs):
if self.compiler_config.compile_depth == CompileDepth.COMPILE_OP_BY_OP:
self.compile_shlo_op_by_op()
if self.gm is not None:
return self.run_gm_op_by_op(*(inputs + self.graph_constants))
return self.gm(*inputs)
return # return nothing
else:
assert False, "Invalid compile depth"
96 changes: 96 additions & 0 deletions tt_torch/dynamo/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,102 @@ def get_node_name(self, node):
name = node.target.name() if hasattr(node.target, "name") else node.name
return name

def run_gm_op_by_op(self, *inputs):
node_to_tensor = {}
input_index = 0
outputs = []
num_nodes = len(self.gm.graph.nodes)
out_degree = {}
for idx, node in enumerate(self.gm.graph.nodes):
print(f"Compiling {idx}/{num_nodes}: {node.target}")
out_degree[node] = len(node.users)
if node.op == "placeholder":
node_to_tensor[node] = inputs[input_index]
input_index += 1
elif node.op == "get_attr":
for buffer in self.gm.named_buffers():
if buffer[0] == node.target:
node_to_tensor[node] = buffer[1]
break
elif node.op == "call_function":
args = []
for arg in node.args:
if isinstance(arg, torch.fx.node.Node):
args.append(node_to_tensor[arg])
elif isinstance(arg, list):
args.append(
[
node_to_tensor[a]
if isinstance(a, torch.fx.node.Node)
else a
for a in arg
]
)
else:
args.append(arg)
try:
binary, op = self.compile_op(node, *args, **node.kwargs)
except Exception as e:
binary = None
print(f"Failed to compile {idx}/{num_nodes}: {node.target}: {e}")

if (
self.compiler_config.compile_depth == CompileDepth.EXECUTE_OP_BY_OP
and binary is not None
):
try:
calculated, runtime_stack_dump = self.run_op(binary, *args)
self.compiler_config.unique_ops[
op.unique_key()
].runtime_stack_dump = runtime_stack_dump

print(f"Ran: {idx}/{num_nodes}: {node.target}")
if calculated is None:
raise ValueError("Failed to execute")
op.compilation_status = OpCompilationStatus.EXECUTED
tensor = node.target(*args, **node.kwargs)
if self.compiler_config.verify_op_by_op:
atol = calculate_atol(calculated, tensor)
op.atol = atol
if atol > self.required_atol:
print(f"atol too high for {idx}: {atol}")
pcc = calculate_pcc(calculated, tensor)
op.pcc = pcc
if pcc < self.required_pcc:
print(f"pcc too low for {idx}: {pcc}")
except Exception as e:
print(
f"Failed to execute {idx}/{num_nodes}: {node.target}: {e}"
)
tensor = node.target(*args, **node.kwargs)
else:
tensor = node.target(*args, **node.kwargs)
node_to_tensor[node] = tensor
elif node.op == "output":
args = node.args[0]
output_tensors = [node_to_tensor[arg] for arg in args]
outputs = output_tensors
args_set = set()
for arg in node.args:
if arg in args_set:
continue
args_set.add(arg)
if isinstance(arg, torch.fx.node.Node):
out_degree[arg] -= 1
if out_degree[arg] == 0 and arg.op != "output":
del node_to_tensor[arg]
out_degree.pop(arg)

self.compiler_config.save_unique_ops()
if self.execute_process is not None:
self.execute_process.terminate()
self.execute_process = None
if self.stderror_redirected:
os.unlink(self.file_stderr.name)
self.stderror_redirected = False

return outputs

def get_stable_hlo_graph(self, node, inputs, **kwargs):

input_shapes_and_constants = self.get_input_shapes_and_constants(inputs)
Expand Down

0 comments on commit 8ba1e0f

Please sign in to comment.