diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index c79ee757e9..0bdce11c05 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -63,7 +63,7 @@ def TTNN_ToLayoutOp : TTNN_Op<"to_layout"> { Optional:$device); let results = (outs AnyRankedTensor:$result); - let hasCanonicalizeMethod = 1; + let hasCanonicalizer = 1; } def TTNN_TypecastOp : TTNN_Op<"typecast"> { diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 015b68b166..812254c349 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/Traits.h" -#include #include #include @@ -1052,60 +1051,92 @@ ::mlir::LogicalResult mlir::tt::ttnn::ToMemoryConfigOp::verify() { //===----------------------------------------------------------------------===// // ToLayoutOp canonicalization -// ToLayoutOp can be canonicalized if the previous op is also a ToLayoutOp. The -// previous op can be merged with the current ToLayoutOp op if the previous op -// has only one use. df - data format, l - layout, ms - memory space, tml - -// tensor memory layout -// -// | -// ----------------------- -// | ToLayoutOp | | -// | df1, l1, ms1, tml1 | ----------------------- -// ----------------------- | ToLayoutOp | -// | --> | df2, l1, ms2, tml1 | -// | ----------------------- -// ----------------------- | -// | ToLayoutOp | -// | df2, ms2 | -// ----------------------- -// | -// -::mlir::LogicalResult -mlir::tt::ttnn::ToLayoutOp::canonicalize(ToLayoutOp toLayoutOp, - PatternRewriter &rewriter) { - // Get the input operand and verify that the previous op is toLayoutOp - ToLayoutOp previousToLayoutOp = - toLayoutOp.getOperand(0).getDefiningOp(); - - // NOLINTNEXTLINE - if (!previousToLayoutOp) { - return mlir::failure(); - } +void mlir::tt::ttnn::ToLayoutOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + // ToLayoutOp can be folded if its input has the same layout as the output of + // toLayoutOp. + patterns.add(+[](mlir::tt::ttnn::ToLayoutOp toLayoutOp, + mlir::PatternRewriter &rewriter) { + RankedTensorType previousType = toLayoutOp.getInput().getType(); + TTNNLayoutAttr previousLayout = + mlir::dyn_cast(previousType.getEncoding()); + // Verify if input tensor has layout attribute. + if (!previousLayout) { + return mlir::failure(); + } - // Check if the previous op has only one use. We can only merge if the - // previous op has single use. - if (!previousToLayoutOp->hasOneUse()) { - return mlir::failure(); - } + RankedTensorType currentType = toLayoutOp.getType(); + TTNNLayoutAttr currentLayout = + mlir::dyn_cast(currentType.getEncoding()); + // Verify if the output tensor has layout attribute. + if (!currentLayout) { + return mlir::failure(); + } + + // Verify that the layouts are the same. + if (previousLayout != currentLayout) { + return mlir::failure(); + } - // Replace the previous op with the merged ToLayoutOp - Value mergedToLayout = rewriter.replaceOpWithNewOp( - previousToLayoutOp, toLayoutOp.getType(), previousToLayoutOp.getInput(), - toLayoutOp.getLayoutAttr(), - toLayoutOp.getDtypeAttr() ? toLayoutOp.getDtypeAttr() - : previousToLayoutOp.getDtypeAttr(), - toLayoutOp.getMemoryConfigAttr() - ? toLayoutOp.getMemoryConfigAttr() - : previousToLayoutOp.getMemoryConfigAttr(), - toLayoutOp.getDevice()); + rewriter.replaceAllUsesWith(toLayoutOp, toLayoutOp->getOperand(0)); + rewriter.eraseOp(toLayoutOp); + return mlir::success(); + }); + + // Two consecutive ToLayoutOps can be merged if the previous op has only one + // use. + // df - data format, l - layout, ms - memory + // space, tml - tensor memory layout + // + // | + // ----------------------- + // | ToLayoutOp | | + // | df1, l1, ms1, tml1 | ----------------------- + // ----------------------- | ToLayoutOp | + // | --> | df2, l1, ms2, tml1 | + // | ----------------------- + // ----------------------- | + // | ToLayoutOp | + // | df2, ms2 | + // ----------------------- + // | + // + patterns.add(+[](mlir::tt::ttnn::ToLayoutOp toLayoutOp, + mlir::PatternRewriter &rewriter) { + // Get the input operand and verify that the previous op is toLayoutOp. + ToLayoutOp previousToLayoutOp = + toLayoutOp.getOperand(0).getDefiningOp(); + + // NOLINTNEXTLINE + if (!previousToLayoutOp) { + return mlir::failure(); + } - // Replace all uses of the current op with the merged ToLayoutOp - rewriter.replaceAllUsesWith(toLayoutOp, mergedToLayout); + // Check if the previous op has only one use. We can only merge if the + // previous op has single use. + if (!previousToLayoutOp->hasOneUse()) { + return mlir::failure(); + } - // Erase the current op - rewriter.eraseOp(toLayoutOp); + // Replace the previous op with the merged ToLayoutOp. + Value mergedToLayout = rewriter.replaceOpWithNewOp( + previousToLayoutOp, toLayoutOp.getType(), previousToLayoutOp.getInput(), + toLayoutOp.getLayoutAttr(), + toLayoutOp.getDtypeAttr() ? toLayoutOp.getDtypeAttr() + : previousToLayoutOp.getDtypeAttr(), + toLayoutOp.getMemoryConfigAttr() + ? toLayoutOp.getMemoryConfigAttr() + : previousToLayoutOp.getMemoryConfigAttr(), + toLayoutOp.getDevice()); - return mlir::success(); + // Replace all uses of the current op with the merged ToLayoutOp. + rewriter.replaceAllUsesWith(toLayoutOp, mergedToLayout); + + // Erase the current op. + rewriter.eraseOp(toLayoutOp); + + return mlir::success(); + }); } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp index 480e8ff9ef..45a78350ba 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp @@ -374,10 +374,10 @@ class TTNNLayoutDPSOperandsRewriter } // TTNN mesh shard expects host input and output - // TODO(#2102): This can be removed once the workaround pass can correctly - // handle cannonicalization of toLayout ops. Currently the workaround pass - // cannot detect redundant toLayout ops as a result of forcing the output - // layout and removing them. + // TODO(#2291): This can be removed once the workaround pass can correctly + // handle canonicalization of toLayout ops (#2102). Currently the + // workaround pass cannot detect redundant toLayout ops as a result of + // forcing the output layout and removing them. if (mlir::isa(op.getOperation())) { modified = changeLayoutToHost(op, operand, rewriter, isDPSResult); continue; diff --git a/test/ttmlir/Dialect/TTNN/Canonicalizer/simple_to_layout_op_canonicalizer.mlir b/test/ttmlir/Dialect/TTNN/Canonicalizer/simple_to_layout_op_canonicalizer.mlir index 73f68a5f27..79ef044d3b 100644 --- a/test/ttmlir/Dialect/TTNN/Canonicalizer/simple_to_layout_op_canonicalizer.mlir +++ b/test/ttmlir/Dialect/TTNN/Canonicalizer/simple_to_layout_op_canonicalizer.mlir @@ -66,7 +66,7 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} { return %2 : tensor<32x32xf32, #ttnn_layout5> } - func.func @merge_to_layout_op_4x(%arg0: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xf32, #ttnn_layout5> { + func.func @merge_to_layout_op_4x(%arg0: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xf32, #ttnn_layout5> { // Verify that the to_layout op is canonicalized to a single to_layout op and the attributes are merged. // CHECK: "ttnn.to_layout"(%arg0, %0) // CHECK-SAME: dtype = #tt.supportedDataTypes @@ -81,4 +81,12 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} { %4 = "ttnn.to_layout"(%3, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<32x32>>>}> : (tensor<32x32xf32, #ttnn_layout3>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout5> return %4 : tensor<32x32xf32, #ttnn_layout5> } + + func.func @fold_to_layout_op(%arg0: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xbf16, #ttnn_layout> { + // Verify folding of to_layout_op. + %0 = "ttnn.to_layout"(%arg0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<32x32>>>}> : (tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xbf16, #ttnn_layout> + // CHECK-NOT: "ttnn.to_layout" + return %0 : tensor<32x32xbf16, #ttnn_layout> + // CHECK: return %arg0 : tensor<32x32xbf16 + } }