Skip to content

Commit

Permalink
Adding ToLayout op folding
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjordjevicTT committed Feb 26, 2025
1 parent 8f932ca commit 95d69d0
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 56 deletions.
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def TTNN_ToLayoutOp : TTNN_Op<"to_layout"> {
Optional<TT_Device>:$device);
let results = (outs AnyRankedTensor:$result);

let hasCanonicalizeMethod = 1;
let hasCanonicalizer = 1;
}

def TTNN_TypecastOp : TTNN_Op<"typecast"> {
Expand Down
131 changes: 81 additions & 50 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

#include "mlir/Dialect/Traits.h"

#include <llvm/ADT/ArrayRef.h>
#include <numeric>
#include <optional>

Expand Down Expand Up @@ -1052,60 +1051,92 @@ ::mlir::LogicalResult mlir::tt::ttnn::ToMemoryConfigOp::verify() {
//===----------------------------------------------------------------------===//

// ToLayoutOp canonicalization
// ToLayoutOp can be canonicalized if the previous op is also a ToLayoutOp. The
// previous op can be merged with the current ToLayoutOp op if the previous op
// has only one use. df - data format, l - layout, ms - memory space, tml -
// tensor memory layout
//
// |
// -----------------------
// | ToLayoutOp | |
// | df1, l1, ms1, tml1 | -----------------------
// ----------------------- | ToLayoutOp |
// | --> | df2, l1, ms2, tml1 |
// | -----------------------
// ----------------------- |
// | ToLayoutOp |
// | df2, ms2 |
// -----------------------
// |
//
::mlir::LogicalResult
mlir::tt::ttnn::ToLayoutOp::canonicalize(ToLayoutOp toLayoutOp,
PatternRewriter &rewriter) {
// Get the input operand and verify that the previous op is toLayoutOp
ToLayoutOp previousToLayoutOp =
toLayoutOp.getOperand(0).getDefiningOp<ToLayoutOp>();

// NOLINTNEXTLINE
if (!previousToLayoutOp) {
return mlir::failure();
}
void mlir::tt::ttnn::ToLayoutOp::getCanonicalizationPatterns(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
// ToLayoutOp can be folded if its input has the same layout as the output of
// toLayoutOp.
patterns.add(+[](mlir::tt::ttnn::ToLayoutOp toLayoutOp,
mlir::PatternRewriter &rewriter) {
RankedTensorType previousType = toLayoutOp.getInput().getType();
TTNNLayoutAttr previousLayout =
mlir::dyn_cast<TTNNLayoutAttr>(previousType.getEncoding());
// Verify if input tensor has layout attribute.
if (!previousLayout) {
return mlir::failure();
}

// Check if the previous op has only one use. We can only merge if the
// previous op has single use.
if (!previousToLayoutOp->hasOneUse()) {
return mlir::failure();
}
RankedTensorType currentType = toLayoutOp.getType();
TTNNLayoutAttr currentLayout =
mlir::dyn_cast<TTNNLayoutAttr>(currentType.getEncoding());
// Verify if the output tensor has layout attribute.
if (!currentLayout) {
return mlir::failure();
}

// Verify that the layouts are the same.
if (previousLayout != currentLayout) {
return mlir::failure();
}

// Replace the previous op with the merged ToLayoutOp
Value mergedToLayout = rewriter.replaceOpWithNewOp<ToLayoutOp>(
previousToLayoutOp, toLayoutOp.getType(), previousToLayoutOp.getInput(),
toLayoutOp.getLayoutAttr(),
toLayoutOp.getDtypeAttr() ? toLayoutOp.getDtypeAttr()
: previousToLayoutOp.getDtypeAttr(),
toLayoutOp.getMemoryConfigAttr()
? toLayoutOp.getMemoryConfigAttr()
: previousToLayoutOp.getMemoryConfigAttr(),
toLayoutOp.getDevice());
rewriter.replaceAllUsesWith(toLayoutOp, toLayoutOp->getOperand(0));
rewriter.eraseOp(toLayoutOp);
return mlir::success();
});

// Two consecutive ToLayoutOps can be merged if the previous op has only one
// use.
// df - data format, l - layout, ms - memory
// space, tml - tensor memory layout
//
// |
// -----------------------
// | ToLayoutOp | |
// | df1, l1, ms1, tml1 | -----------------------
// ----------------------- | ToLayoutOp |
// | --> | df2, l1, ms2, tml1 |
// | -----------------------
// ----------------------- |
// | ToLayoutOp |
// | df2, ms2 |
// -----------------------
// |
//
patterns.add(+[](mlir::tt::ttnn::ToLayoutOp toLayoutOp,
mlir::PatternRewriter &rewriter) {
// Get the input operand and verify that the previous op is toLayoutOp.
ToLayoutOp previousToLayoutOp =
toLayoutOp.getOperand(0).getDefiningOp<ToLayoutOp>();

// NOLINTNEXTLINE
if (!previousToLayoutOp) {
return mlir::failure();
}

// Replace all uses of the current op with the merged ToLayoutOp
rewriter.replaceAllUsesWith(toLayoutOp, mergedToLayout);
// Check if the previous op has only one use. We can only merge if the
// previous op has single use.
if (!previousToLayoutOp->hasOneUse()) {
return mlir::failure();
}

// Erase the current op
rewriter.eraseOp(toLayoutOp);
// Replace the previous op with the merged ToLayoutOp.
Value mergedToLayout = rewriter.replaceOpWithNewOp<ToLayoutOp>(
previousToLayoutOp, toLayoutOp.getType(), previousToLayoutOp.getInput(),
toLayoutOp.getLayoutAttr(),
toLayoutOp.getDtypeAttr() ? toLayoutOp.getDtypeAttr()
: previousToLayoutOp.getDtypeAttr(),
toLayoutOp.getMemoryConfigAttr()
? toLayoutOp.getMemoryConfigAttr()
: previousToLayoutOp.getMemoryConfigAttr(),
toLayoutOp.getDevice());

return mlir::success();
// Replace all uses of the current op with the merged ToLayoutOp.
rewriter.replaceAllUsesWith(toLayoutOp, mergedToLayout);

// Erase the current op.
rewriter.eraseOp(toLayoutOp);

return mlir::success();
});
}

//===----------------------------------------------------------------------===//
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/TTNN/Transforms/TTNNLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,10 @@ class TTNNLayoutDPSOperandsRewriter
}

// TTNN mesh shard expects host input and output
// TODO(#2102): This can be removed once the workaround pass can correctly
// handle cannonicalization of toLayout ops. Currently the workaround pass
// cannot detect redundant toLayout ops as a result of forcing the output
// layout and removing them.
// TODO(#2291): This can be removed once the workaround pass can correctly
// handle canonicalization of toLayout ops (#2102). Currently the
// workaround pass cannot detect redundant toLayout ops as a result of
// forcing the output layout and removing them.
if (mlir::isa<ttir::MeshShardOp>(op.getOperation())) {
modified = changeLayoutToHost(op, operand, rewriter, isDPSResult);
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} {
return %2 : tensor<32x32xf32, #ttnn_layout5>
}

func.func @merge_to_layout_op_4x(%arg0: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xf32, #ttnn_layout5> {
func.func @merge_to_layout_op_4x(%arg0: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xf32, #ttnn_layout5> {
// Verify that the to_layout op is canonicalized to a single to_layout op and the attributes are merged.
// CHECK: "ttnn.to_layout"(%arg0, %0)
// CHECK-SAME: dtype = #tt.supportedDataTypes<f32>
Expand All @@ -81,4 +81,12 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} {
%4 = "ttnn.to_layout"(%3, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#system_memory, <<32x32>>>}> : (tensor<32x32xf32, #ttnn_layout3>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout5>
return %4 : tensor<32x32xf32, #ttnn_layout5>
}

func.func @fold_to_layout_op(%arg0: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xbf16, #ttnn_layout> {
// Verify folding of to_layout_op.
%0 = "ttnn.to_layout"(%arg0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#system_memory, <<32x32>>>}> : (tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xbf16, #ttnn_layout>
// CHECK-NOT: "ttnn.to_layout"
return %0 : tensor<32x32xbf16, #ttnn_layout>
// CHECK: return %arg0 : tensor<32x32xbf16
}
}

0 comments on commit 95d69d0

Please sign in to comment.