diff --git a/lib/Conversion/TTNNToEmitC/CMakeLists.txt b/lib/Conversion/TTNNToEmitC/CMakeLists.txt index bed66c647a..87142614df 100644 --- a/lib/Conversion/TTNNToEmitC/CMakeLists.txt +++ b/lib/Conversion/TTNNToEmitC/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_conversion_library(TTMLIRTTNNToEmitC TTMLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRTTTransforms MLIRIR MLIRPass MLIRSCFToEmitC diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp index 9cddd3f219..a6c2cba39e 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp @@ -4,6 +4,7 @@ #include "ttmlir/Conversion/TTNNToEmitC/TTNNToEmitC.h" +#include "ttmlir/Dialect/TT/Transforms/Passes.h" #include "ttmlir/Dialect/TTNN/IR/TTNN.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" @@ -15,6 +16,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" @@ -50,6 +52,12 @@ class TTNNToEmitCTypeConverter : public TypeConverter { struct ConvertTTNNToEmitCPass : public ttnn::impl::ConvertTTNNToEmitCBase { void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + // Only run conversion on top-level moduleOp. + if (module->getParentOp() != nullptr) { + return; + } + mlir::ConversionTarget target(getContext()); // EmitC is legal, TTNN is illegal @@ -65,7 +73,6 @@ struct ConvertTTNNToEmitCPass // Add header imports to front of module // { - mlir::ModuleOp module = getOperation(); OpBuilder builder(module); if (module.getBodyRegion().empty()) { @@ -84,6 +91,17 @@ struct ConvertTTNNToEmitCPass /*isStandard=*/false); } + // Unwrap device_module into top-level ModuleOp (if present) + { + OpPassManager pm(ModuleOp::getOperationName()); + pm.addPass(tt::createTTUnwrapDeviceModulePass()); + + if (failed(runPipeline(pm, module))) { + signalPassFailure(); + return; + } + } + // TTNN -> EmitC // { @@ -111,8 +129,7 @@ struct ConvertTTNNToEmitCPass // Apply conversion // - if (failed(applyFullConversion(getOperation(), target, - std::move(patterns)))) { + if (failed(applyFullConversion(module, target, std::move(patterns)))) { signalPassFailure(); return; } diff --git a/lib/Dialect/TT/Transforms/TTModuleWrap.cpp b/lib/Dialect/TT/Transforms/TTModuleWrap.cpp index ef2ecdc84d..702384a62b 100644 --- a/lib/Dialect/TT/Transforms/TTModuleWrap.cpp +++ b/lib/Dialect/TT/Transforms/TTModuleWrap.cpp @@ -34,6 +34,11 @@ class TTWrapDeviceModulePass OpBuilder builder(&getContext()); auto innerModule = ModuleOp::create(rootModule.getLoc()); + // Transfer attributes from root module to inner module. + for (const auto &attr : rootModule->getAttrs()) { + innerModule->setAttr(attr.getName(), attr.getValue()); + } + innerModule.getBodyRegion().takeBody(rootModule.getBodyRegion()); rootModule.getRegion().emplaceBlock(); builder.setInsertionPointToStart(&rootModule.getBodyRegion().front()); @@ -86,6 +91,13 @@ class TTUnwrapDeviceModulePass topLevelBody.getOperations().splice(topLevelBody.end(), innerBody.getOperations()); + // Also transfer any attributes, e.g. system_desc, device + for (const auto &attr : innerModule->getAttrs()) { + if (!rootModule->hasAttr(attr.getName())) { + rootModule->setAttr(attr.getName(), attr.getValue()); + } + } + deviceOp->erase(); } }; diff --git a/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp b/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp index c87f5793e1..62ef49ea10 100644 --- a/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp +++ b/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp @@ -51,7 +51,9 @@ void createLinalgToLLVMPipeline(OpPassManager &manager, const LinalgToLLVMPipelineOptions &options) { // These are initial passes to ensure we start with well-form linalg dialect // operations. - manager.addPass(mlir::createCanonicalizerPass()); + // TODO (#2145): Explore ways to re-enable canonicalizer w/o return values for + // linalg funcs. + // manager.addPass(mlir::createCanonicalizerPass()); manager.addPass(mlir::createConvertElementwiseToLinalgPass()); manager.addPass(mlir::createConvertTensorToLinalgPass()); diff --git a/lib/Dialect/TTNN/Pipelines/CMakeLists.txt b/lib/Dialect/TTNN/Pipelines/CMakeLists.txt index 6fe28901ee..53681f4b9e 100644 --- a/lib/Dialect/TTNN/Pipelines/CMakeLists.txt +++ b/lib/Dialect/TTNN/Pipelines/CMakeLists.txt @@ -5,9 +5,15 @@ add_mlir_dialect_library(MLIRTTNNPipelines ${PROJECT_SOURCE_DIR}/include/ttmlir LINK_LIBS PUBLIC + MLIRLLVMTransforms + MLIRTTIRDialect MLIRTTUtils MLIRTTNNDialect + MLIRTTIRPipelines + MLIRTTIRTransforms MLIRTTNNTransforms + MLIRTTNNAnalysis + MLIRTTTransforms MLIRPass MLIRTransforms ) diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index 385db07ee2..40ececf9a1 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -6,7 +6,11 @@ #include "ttmlir/Conversion/Passes.h" #include "ttmlir/Conversion/TTNNToEmitC/TTNNToEmitC.h" +#include "ttmlir/Dialect/LLVM/Transforms/Passes.h" +#include "ttmlir/Dialect/TT/IR/TTOps.h" +#include "ttmlir/Dialect/TT/Transforms/Passes.h" #include "ttmlir/Dialect/TT/Utils/PopulateArgumentTypes.h" +#include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h" #include "ttmlir/Dialect/TTIR/Transforms/Passes.h" #include "ttmlir/Dialect/TTNN/Transforms/Passes.h" @@ -138,18 +142,34 @@ void createTTNNPipelineTTIRImplicitBroadcastFoldPassFromString( void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { - createTTNNPipelineTTIRPasses(pm, options); - createTTNNPipelineTTIRImplicitBroadcastFoldPass(pm, options); - createTTNNPipelineLoweringPasses(pm, options); - createTTNNPipelineWorkaroundPass(pm, options); - createTTNNPipelineAnalysisPasses(pm, options); - createTTNNPipelineLayoutDecompositionPass(pm, options); - createTTNNPipelineDeallocPass(pm, options); + // Create DeviceModule to wrap all ops. + pm.addPass(tt::createTTWrapDeviceModulePass()); + // Create CPUModuleOp to wrap hoisted ops (if any). + pm.addPass(ttir::createTTIRHoistTransform()); + + // Run regular TTIR to TTNN pipeline on DeviceModule. + OpPassManager &devicePm = + pm.nest().nest(); + createTTNNPipelineTTIRPasses(devicePm, options); + createTTNNPipelineTTIRImplicitBroadcastFoldPass(devicePm, options); + createTTNNPipelineLoweringPasses(devicePm, options); + createTTNNPipelineWorkaroundPass(devicePm, options); + createTTNNPipelineAnalysisPasses(devicePm, options); + createTTNNPipelineLayoutDecompositionPass(devicePm, options); + createTTNNPipelineDeallocPass(devicePm, options); + + // Run lowering to LLVM pass on hoisted funcs in CPUModule. + OpPassManager &cpuPm = pm.nest().nest(); + cpuPm.addPass(createConvertTTIRToLinalgPass()); + ttir::LinalgToLLVMPipelineOptions linalgToLLLVMOptions; + ttir::createLinalgToLLVMPipeline(cpuPm, linalgToLLLVMOptions); + cpuPm.addPass(llvm_util::createLLVMEmitCallingConventionWrapperFuncs()); } void createTTIRToEmitCPipeline(OpPassManager &pm, const TTIRToEmitCPipelineOptions &options) { createTTIRToTTNNBackendPipeline(pm, options); + pm.addPass(tt::createTTUnwrapDeviceModulePass()); pm.addPass(createTTNNCreateInputGenerators()); pm.addPass(createConvertTTNNToEmitCPass()); } diff --git a/lib/Dialect/TTNN/Transforms/Passes.cpp b/lib/Dialect/TTNN/Transforms/Passes.cpp index 832857f209..dbf7dbeb72 100644 --- a/lib/Dialect/TTNN/Transforms/Passes.cpp +++ b/lib/Dialect/TTNN/Transforms/Passes.cpp @@ -318,6 +318,23 @@ class TTNNModifySignaturesForDylib void runOnOperation() final { ModuleOp module = getOperation(); + + // If we have a nested module structure, we want to use nested module inside + // DeviceModule. + tt::DeviceModuleOp deviceModule; + for (auto &op : module.getBody()->getOperations()) { + deviceModule = llvm::dyn_cast(op); + if (deviceModule) { + break; + } + } + if (deviceModule) { + module = dyn_cast_if_present( + deviceModule.getBodyRegion().front().front()); + assert(module && + "Found tt::DeviceModuleOp but it didn't contain a single " + "mlir::ModuleOp!"); + } IRRewriter rewriter(&getContext()); // Ensure that the module has a single region and a single block within that diff --git a/lib/RegisterAll.cpp b/lib/RegisterAll.cpp index 9bf45b1676..0143c8c628 100644 --- a/lib/RegisterAll.cpp +++ b/lib/RegisterAll.cpp @@ -61,6 +61,7 @@ void mlir::tt::registerAllDialects(mlir::DialectRegistry ®istry) { registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); + LLVM::registerInlinerInterface(registry); } void mlir::tt::registerAllExtensions(mlir::DialectRegistry ®istry) { diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index bb4018d051..5bd20c1fba 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -1769,7 +1769,7 @@ std::shared_ptr ttnnToFlatbuffer( toFlatbuffer(cache, mlir::cast( module->getAttr(tt::SystemDescAttr::name))); // Always get debug info for top-level module. - auto mlir = toDebugInfo(fbb, "ttnn", module); + auto mlir = toDebugInfo(fbb, "ttnn", rootModule); std::string cpp; llvm::raw_string_ostream os(cpp); diff --git a/test/ttmlir/Dialect/TTNN/pipelines/ttir_to_emitc_add.mlir b/test/ttmlir/Dialect/TTNN/pipelines/ttir_to_emitc_add.mlir index ec2664b443..a9fc36d673 100644 --- a/test/ttmlir/Dialect/TTNN/pipelines/ttir_to_emitc_add.mlir +++ b/test/ttmlir/Dialect/TTNN/pipelines/ttir_to_emitc_add.mlir @@ -1,5 +1,5 @@ // RUN: ttmlir-opt --ttir-to-emitc-pipeline="system-desc-path=%system_desc_path%" %s > %direct.mlir -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" --ttnn-create-input-gens --convert-ttnn-to-emitc %s > %indirect.mlir +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" --tt-unwrap-device-module --ttnn-create-input-gens --convert-ttnn-to-emitc %s > %indirect.mlir // RUN: diff %direct.mlir %indirect.mlir // // This test checks that the (TTIR to EmitC pipeline) is equivalent to (TTIR to TTNN pipeline + dialect conversion from TTNN to EmitC). diff --git a/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_hoist_call.mlir b/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_hoist_call.mlir index ae704bc8bf..0e8b9dc354 100644 --- a/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_hoist_call.mlir +++ b/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_hoist_call.mlir @@ -1,7 +1,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -module attributes {} { +module { // CHECK-DAG: #{{.*}} = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #system_memory>> // CHECK-DAG: #{{.*}} = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, > + + // CHECK: tt.device_module { + // CHECK: builtin.module attributes {{.*}} { + // CHECK: func.func @forward func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %{{.*}} = "ttnn.multiply"(%{{.*}}, %{{.*}}) @@ -9,11 +13,12 @@ module attributes {} { // CHECK: %{{.*}} = "ttnn.ones" %2 = "ttir.ones"() <{shape = array}> : () -> tensor<64x128xf32> // CHECK: %{{.*}} = "ttnn.from_device"(%{{.*}}) : (tensor<[[DIMS:.*]], #{{.*}}>) -> tensor<[[DIMS]], #{{.*}}> - // CHECK: %{{.*}} = call @hoisted_func_decl(%{{.*}}, %{{.*}}, %{{.*}}) - %3 = call @hoisted_func_decl(%arg0, %1, %2) {ttir.hoisted_call} : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + // CHECK: %{{.*}} = call @hoisted_ttir_add_64x128_64x128_64x128_func_decl(%{{.*}}, %{{.*}}, %{{.*}}) + %3 = "ttir.add"(%arg0, %1, %2) <{operandSegmentSizes = array}> {should_hoist} : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %{{.*}} = "ttnn.zeros" %4 = "ttir.zeros"() <{shape = array}> : () -> tensor<64x128xf32> - %5 = call @hoisted_func_decl(%arg0, %3, %4) {ttir.hoisted_call} : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + // CHECK: %{{.*}} = call @hoisted_ttir_add_64x128_64x128_64x128_func_decl(%{{.*}}, %{{.*}}, %{{.*}}) + %5 = "ttir.add"(%arg0, %3, %4) <{operandSegmentSizes = array}> {should_hoist} : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> %6 = tensor.empty() : tensor<64x128xf32> // CHECK: %{{.*}} = "ttnn.to_layout"(%{{.*}}) <{layout = #ttnn.layout<{{.*}}>}> : (tensor<[[DIMS:.*]], #{{.*}}>) -> tensor<[[DIMS]], #{{.*}}> // CHECK: %{{.*}} = "ttnn.to_device"(%{{.*}}, %{{.*}}) <{memory_config = {{.*}}}> : (tensor<[[DIMS:.*]], #{{.*}}>, !tt.device<#{{.*}}>) -> tensor<[[DIMS]], #{{.*}}> @@ -23,5 +28,9 @@ module attributes {} { %7 = "ttir.multiply"(%3, %5, %6) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %7 : tensor<64x128xf32> } - func.func private @hoisted_func_decl(tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + // CHECK: func.func private @hoisted_ttir_add_64x128_64x128_64x128_func_decl + // CHECK: tt.cpu_module { + // CHECK: builtin.module { + // CHECK: llvm.func @hoisted_ttir_add_64x128_64x128_64x128_func + // CHECK: llvm.func @hoisted_ttir_add_64x128_64x128_64x128_func_helper(%arg0: !llvm.ptr) } diff --git a/tools/explorer/tt_adapter/src/tt_adapter/mlir.py b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py index 7050ab94a4..94881c147b 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/mlir.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py @@ -598,115 +598,16 @@ def build_graph(module, perf_trace=None): if loc: loc_to_perf[loc] = row["DEVICE FW DURATION [ns]"] - module_op = OpHandler(module.operation) - module_attrs = module_op.get_attributes() - module_attrs = dict((attr.key, attr.value) for attr in module_attrs) - # Add module attributes to the graph as "namespace attributes" - group_node_attrs = {} - group_node_attrs[module_op.get_namespace()] = module_attrs - - for op in module.body.operations: - append_later = [] - for region in op.regions: - for block in region.blocks: - for op in block.operations: - # Create all the nodes and constants in the first pass. - operation = OpHandler(op) - graph_node = operation.make_graph_node() - - if ( - operation.named_location in loc_to_perf - and operation.op.name not in EMPTY_OPS - ): - perf_node_data[operation.id] = node_data_builder.NodeDataResult( - loc_to_perf[operation.named_location] - ) - - if op.name not in FILTERED_OPS and op.name in EMPTY_OPS: - append_later.append(graph_node) - elif op.name not in FILTERED_OPS: - graph.nodes.append(graph_node) - - op_to_graph_node[op] = graph_node - - for operand in op.operands: - if isinstance(operand, ir.Value) and not isinstance( - operand.owner, ir.Operation - ): - # If the owner is not an op, then it is a constant provided from the toplevel FuncOp. - - if operand not in operands_in_graph: - # This is a constant and we need to create a node for it. - operand_node = operation.make_constant_node( - operand.get_name() - ) - graph.nodes.append(operand_node) - op_to_graph_node[operand] = operand_node - operands_in_graph.add(operand) - - # This puts the node at the far right when viewing which is a bit more consistant with it being the last operand. - for node in append_later: - graph.nodes.append(node) - - for op in block.operations: - # Create all the edges in the second pass. - for operand_index, operand in enumerate(op.operands): - if operand.owner == block: - source_node = op_to_graph_node[operand] - else: - source_node = op_to_graph_node[operand.owner] - - target_node = op_to_graph_node[op] - - target_node.incomingEdges.append( - graph_builder.IncomingEdge( - sourceNodeId=source_node.id, - sourceNodeOutputId=str( - output_connections[source_node.id] - ), - targetNodeInputId=str(operand_index), - ) - ) - - output_attrs = [] - if isinstance(operand.type, ir.RankedTensorType): - output_attrs = [ - graph_builder.KeyValue( - key="shape", value=str(operand.type.shape) - ), - graph_builder.KeyValue( - key="dtype", value=str(operand.type.element_type) - ), - graph_builder.KeyValue( - key="rank", value=str(operand.type.rank) - ), - ] - if hasattr(operand.type, "encoding") and operand.type.encoding: - if "ttnn_layout" in str(operand.type.encoding): - output_attrs.extend( - AttrHandler.parse_attr( - operand.type.encoding.get_named("ttnn_layout") - ) - ) - else: - # Parse as a standard layout - output_attrs.extend( - AttrHandler.parse_attr( - operand.type.encoding.get_named("tt.layout") - ) - ) - source_node.outputsMetadata.append( - graph_builder.MetadataItem( - id=str(output_connections[source_node.id]), - attrs=[ - graph_builder.KeyValue( - key="__tensor_tag", value=str(target_node.label) - ), - ] - + output_attrs, - ) - ) - output_connections[source_node.id] += 1 + # Process the module hierarchy recursively + process_module( + module, + graph, + op_to_graph_node, + operands_in_graph, + output_connections, + loc_to_perf, + perf_node_data, + ) # Add performance data to the graph color overlay, if it exists overlay_data = None @@ -722,6 +623,237 @@ def build_graph(module, perf_trace=None): graphsData={"tt-graph": graph_node_data} ) - graph.groupNodeAttributes = group_node_attrs OpHandler.schedule = 0 return graph, overlay_data + + +def process_module( + module, + graph, + op_to_graph_node, + operands_in_graph, + output_connections, + loc_to_perf, + perf_node_data, +): + """ + Process a module's operations. Only works on top-level module, any nested modules won't have a body so they need to directly call process_operations instead. + + Args: + module: The module to process + graph: The graph being built + op_to_graph_node: Mapping from operations to graph nodes + operands_in_graph: Set of operands already added to graph + output_connections: Tracking of output connections + loc_to_perf: Mapping from locations to performance data + perf_node_data: Performance data for nodes + """ + module_op = OpHandler(module.operation) + module_attrs = module_op.get_attributes() + module_attrs = dict((attr.key, attr.value) for attr in module_attrs) + + # Add module attributes to the graph as "namespace attributes" + if not graph.groupNodeAttributes: + graph.groupNodeAttributes = {} + + # Add this module's namespace attributes + namespace = module_op.get_namespace() + if namespace not in graph.groupNodeAttributes: + graph.groupNodeAttributes[namespace] = module_attrs + else: + # Merge with existing attributes if namespace already exists + graph.groupNodeAttributes[namespace].update(module_attrs) + + # Process operations in this module + process_operations( + module.body.operations, + graph, + op_to_graph_node, + operands_in_graph, + output_connections, + loc_to_perf, + perf_node_data, + ) + + +def process_operations( + operations, + graph, + op_to_graph_node, + operands_in_graph, + output_connections, + loc_to_perf, + perf_node_data, +): + """ + Recursively process a list of operations, including handling nested modules. + + Args: + operations: List of operations to process + graph: The graph being built + op_to_graph_node: Mapping from operations to graph nodes + operands_in_graph: Set of operands already added to graph + output_connections: Tracking of output connections + loc_to_perf: Mapping from locations to performance data + perf_node_data: Performance data for nodes + """ + append_later = [] + + # First pass: create all nodes and constants + for op in operations: + # Check if this operation is a nested module + if is_module_op(op): + # Process the nested module's ops recursively + process_operations( + op.regions[0].blocks[0], + graph, + op_to_graph_node, + operands_in_graph, + output_connections, + loc_to_perf, + perf_node_data, + ) + continue + + # Process regions in the operation + for region in op.regions: + for block in region.blocks: + # Recursively process operations in this block + process_operations( + block.operations, + graph, + op_to_graph_node, + operands_in_graph, + output_connections, + loc_to_perf, + perf_node_data, + ) + + # Create graph node for this operation + operation = OpHandler(op) + graph_node = operation.make_graph_node() + + if ( + operation.named_location in loc_to_perf + and operation.op.name not in EMPTY_OPS + ): + perf_node_data[operation.id] = node_data_builder.NodeDataResult( + loc_to_perf[operation.named_location] + ) + + if op.name not in FILTERED_OPS and op.name in EMPTY_OPS: + append_later.append(graph_node) + elif op.name not in FILTERED_OPS: + graph.nodes.append(graph_node) + + op_to_graph_node[op] = graph_node + + # Process operands + for operand in op.operands: + if isinstance(operand, ir.Value) and not isinstance( + operand.owner, ir.Operation + ): + # If the owner is not an op, then it is a constant provided from the toplevel FuncOp. + if operand not in operands_in_graph: + # This is a constant and we need to create a node for it. + operand_node = operation.make_constant_node(operand.get_name()) + graph.nodes.append(operand_node) + op_to_graph_node[operand] = operand_node + operands_in_graph.add(operand) + + # Add the nodes that should be appended later + for node in append_later: + graph.nodes.append(node) + + # Second pass: create all edges + for op in operations: + # Skip module operations as they've been processed recursively + if is_module_op(op): + continue + + # Process regions in the operation + for region in op.regions: + for block in region.blocks: + create_edges_for_block(block, op_to_graph_node, output_connections) + + +def create_edges_for_block(block, op_to_graph_node, output_connections): + """ + Create edges between nodes for operations in a block. + + Args: + block: The block containing operations + op_to_graph_node: Mapping from operations to graph nodes + output_connections: Tracking of output connections + """ + for op in block.operations: + # Skip module operations as they've been processed recursively + if is_module_op(op): + continue + + # Create edges for this operation + for operand_index, operand in enumerate(op.operands): + if operand.owner == block: + source_node = op_to_graph_node[operand] + else: + source_node = op_to_graph_node[operand.owner] + + target_node = op_to_graph_node[op] + + target_node.incomingEdges.append( + graph_builder.IncomingEdge( + sourceNodeId=source_node.id, + sourceNodeOutputId=str(output_connections[source_node.id]), + targetNodeInputId=str(operand_index), + ) + ) + + output_attrs = [] + if isinstance(operand.type, ir.RankedTensorType): + output_attrs = [ + graph_builder.KeyValue(key="shape", value=str(operand.type.shape)), + graph_builder.KeyValue( + key="dtype", value=str(operand.type.element_type) + ), + graph_builder.KeyValue(key="rank", value=str(operand.type.rank)), + ] + if hasattr(operand.type, "encoding") and operand.type.encoding: + if "ttnn_layout" in str(operand.type.encoding): + output_attrs.extend( + AttrHandler.parse_attr( + operand.type.encoding.get_named("ttnn_layout") + ) + ) + else: + # Parse as a standard layout + output_attrs.extend( + AttrHandler.parse_attr( + operand.type.encoding.get_named("tt.layout") + ) + ) + source_node.outputsMetadata.append( + graph_builder.MetadataItem( + id=str(output_connections[source_node.id]), + attrs=[ + graph_builder.KeyValue( + key="__tensor_tag", value=str(target_node.label) + ), + ] + + output_attrs, + ) + ) + output_connections[source_node.id] += 1 + + +def is_module_op(op): + """ + Check if an operation represents a module. + + Args: + op: The operation to check + + Returns: + bool: True if the operation is a module, False otherwise + """ + # Check for tt.device_module or builtin.module operations + return op.name == "tt.device_module" or op.name == "builtin.module"