diff --git a/tests/models/resnet/test_resnet.py b/tests/models/resnet/test_resnet.py index 781edf29..c51c5718 100644 --- a/tests/models/resnet/test_resnet.py +++ b/tests/models/resnet/test_resnet.py @@ -35,6 +35,7 @@ def test_resnet(record_property, mode, op_by_op): cc.consteval_parameters = True if op_by_op: cc.compile_depth = CompileDepth.EXECUTE_OP_BY_OP + cc.op_by_op_backend = OpByOpBackend.STABLEHLO tester = ThisTester( model_name, mode, diff --git a/tt_torch/dynamo/executor.py b/tt_torch/dynamo/executor.py index 3be3c0fb..e49b0038 100644 --- a/tt_torch/dynamo/executor.py +++ b/tt_torch/dynamo/executor.py @@ -351,103 +351,3 @@ def run_op(self, binary, *inputs): file_stderr.close() return outputs, stderr_data - - def run_gm_op_by_op(self, *inputs): - node_to_tensor = {} - input_index = 0 - outputs = [] - num_nodes = len(self.gm.graph.nodes) - out_degree = {} - for idx, node in enumerate(self.gm.graph.nodes): - print(f"Compiling {idx}/{num_nodes}: {node.target}") - out_degree[node] = len(node.users) - if node.op == "placeholder": - node_to_tensor[node] = inputs[input_index] - input_index += 1 - elif node.op == "get_attr": - for buffer in self.gm.named_buffers(): - if buffer[0] == node.target: - node_to_tensor[node] = buffer[1] - break - elif node.op == "call_function": - args = [] - for arg in node.args: - if isinstance(arg, torch.fx.node.Node): - args.append(node_to_tensor[arg]) - elif isinstance(arg, list): - args.append( - [ - node_to_tensor[a] - if isinstance(a, torch.fx.node.Node) - else a - for a in arg - ] - ) - else: - args.append(arg) - binary = None - if self.compiler_config.op_by_op_backend == OpByOpBackend.TORCH: - try: - binary, op = self.compile_op(node, *args, **node.kwargs) - except Exception as e: - binary = None - print( - f"Failed to compile {idx}/{num_nodes}: {node.target}: {e}" - ) - - if ( - self.compiler_config.compile_depth == CompileDepth.EXECUTE_OP_BY_OP - and binary is not None - ): - try: - calculated, runtime_stack_dump = self.run_op(binary, *args) - self.compiler_config.unique_ops[ - op.unique_key() - ].runtime_stack_dump = runtime_stack_dump - - print(f"Ran: {idx}/{num_nodes}: {node.target}") - if calculated is None: - raise ValueError("Failed to execute") - op.compilation_status = OpCompilationStatus.EXECUTED - tensor = node.target(*args, **node.kwargs) - if self.compiler_config.verify_op_by_op: - atol = calculate_atol(calculated, tensor) - op.atol = atol - if atol > self.required_atol: - print(f"atol too high for {idx}: {atol}") - pcc = calculate_pcc(calculated, tensor) - op.pcc = pcc - if pcc < self.required_pcc: - print(f"pcc too low for {idx}: {pcc}") - except Exception as e: - print( - f"Failed to execute {idx}/{num_nodes}: {node.target}: {e}" - ) - tensor = node.target(*args, **node.kwargs) - else: - tensor = node.target(*args, **node.kwargs) - node_to_tensor[node] = tensor - elif node.op == "output": - args = node.args[0] - output_tensors = [node_to_tensor[arg] for arg in args] - outputs = output_tensors - args_set = set() - for arg in node.args: - if arg in args_set: - continue - args_set.add(arg) - if isinstance(arg, torch.fx.node.Node): - out_degree[arg] -= 1 - if out_degree[arg] == 0 and arg.op != "output": - del node_to_tensor[arg] - out_degree.pop(arg) - - self.compiler_config.save_unique_ops() - if self.execute_process is not None: - self.execute_process.terminate() - self.execute_process = None - if self.stderror_redirected: - os.unlink(self.file_stderr.name) - self.stderror_redirected = False - - return outputs diff --git a/tt_torch/dynamo/shlo_backend.py b/tt_torch/dynamo/shlo_backend.py index 59315029..781f9190 100644 --- a/tt_torch/dynamo/shlo_backend.py +++ b/tt_torch/dynamo/shlo_backend.py @@ -261,7 +261,7 @@ def __call__(self, *inputs): if self.compiler_config.compile_depth == CompileDepth.COMPILE_OP_BY_OP: self.compile_shlo_op_by_op() if self.gm is not None: - return self.run_gm_op_by_op(*(inputs + self.graph_constants)) + return self.gm(*inputs) return # return nothing else: assert False, "Invalid compile depth" diff --git a/tt_torch/dynamo/torch_backend.py b/tt_torch/dynamo/torch_backend.py index 55243536..68d13cf1 100644 --- a/tt_torch/dynamo/torch_backend.py +++ b/tt_torch/dynamo/torch_backend.py @@ -84,6 +84,102 @@ def get_node_name(self, node): name = node.target.name() if hasattr(node.target, "name") else node.name return name + def run_gm_op_by_op(self, *inputs): + node_to_tensor = {} + input_index = 0 + outputs = [] + num_nodes = len(self.gm.graph.nodes) + out_degree = {} + for idx, node in enumerate(self.gm.graph.nodes): + print(f"Compiling {idx}/{num_nodes}: {node.target}") + out_degree[node] = len(node.users) + if node.op == "placeholder": + node_to_tensor[node] = inputs[input_index] + input_index += 1 + elif node.op == "get_attr": + for buffer in self.gm.named_buffers(): + if buffer[0] == node.target: + node_to_tensor[node] = buffer[1] + break + elif node.op == "call_function": + args = [] + for arg in node.args: + if isinstance(arg, torch.fx.node.Node): + args.append(node_to_tensor[arg]) + elif isinstance(arg, list): + args.append( + [ + node_to_tensor[a] + if isinstance(a, torch.fx.node.Node) + else a + for a in arg + ] + ) + else: + args.append(arg) + try: + binary, op = self.compile_op(node, *args, **node.kwargs) + except Exception as e: + binary = None + print(f"Failed to compile {idx}/{num_nodes}: {node.target}: {e}") + + if ( + self.compiler_config.compile_depth == CompileDepth.EXECUTE_OP_BY_OP + and binary is not None + ): + try: + calculated, runtime_stack_dump = self.run_op(binary, *args) + self.compiler_config.unique_ops[ + op.unique_key() + ].runtime_stack_dump = runtime_stack_dump + + print(f"Ran: {idx}/{num_nodes}: {node.target}") + if calculated is None: + raise ValueError("Failed to execute") + op.compilation_status = OpCompilationStatus.EXECUTED + tensor = node.target(*args, **node.kwargs) + if self.compiler_config.verify_op_by_op: + atol = calculate_atol(calculated, tensor) + op.atol = atol + if atol > self.required_atol: + print(f"atol too high for {idx}: {atol}") + pcc = calculate_pcc(calculated, tensor) + op.pcc = pcc + if pcc < self.required_pcc: + print(f"pcc too low for {idx}: {pcc}") + except Exception as e: + print( + f"Failed to execute {idx}/{num_nodes}: {node.target}: {e}" + ) + tensor = node.target(*args, **node.kwargs) + else: + tensor = node.target(*args, **node.kwargs) + node_to_tensor[node] = tensor + elif node.op == "output": + args = node.args[0] + output_tensors = [node_to_tensor[arg] for arg in args] + outputs = output_tensors + args_set = set() + for arg in node.args: + if arg in args_set: + continue + args_set.add(arg) + if isinstance(arg, torch.fx.node.Node): + out_degree[arg] -= 1 + if out_degree[arg] == 0 and arg.op != "output": + del node_to_tensor[arg] + out_degree.pop(arg) + + self.compiler_config.save_unique_ops() + if self.execute_process is not None: + self.execute_process.terminate() + self.execute_process = None + if self.stderror_redirected: + os.unlink(self.file_stderr.name) + self.stderror_redirected = False + + return outputs + def get_stable_hlo_graph(self, node, inputs, **kwargs): input_shapes_and_constants = self.get_input_shapes_and_constants(inputs)