Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic region to loops #2306

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,42 @@ def TTIRAttachMetalLayout: Pass<"ttir-attach-metal-layout", "::mlir::ModuleOp">
];
}

def TTIRGenericLinearizeMemref: Pass<"ttir-generic-linearize-memref", "::mlir::ModuleOp"> {
let summary = "Linearize memref operands for generic ops.";
let description = [{
This pass takes a nested loop structure over n-dimensional memrefs and linearizes
them into a single dimension. This is a useful because circular buffers in metal
are only one-dimensional.

Example, this pass will convert the following code:
```mlir
affine.for %arg5 = 0 to 2 {
affine.for %arg6 = 0 to 4 {
%0 = affine.load %arg2[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>
%1 = affine.load %arg3[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>
%2 = "ttir.tile_maximum"(%0, %1) : (!tt.tile<32x32, f32>, !tt.tile<32x32, f32>) -> !tt.tile<32x32, f32>
affine.store %2, %arg4[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>
}
}
```

Into:
```mlir
%collapse_shape = memref.collapse_shape %arg2 [[0, 1]] : memref<2x4x!tt.tile<32x32, f32>, #l1_> into memref<8x!tt.tile<32x32, f32>, #l1_>
%collapse_shape_0 = memref.collapse_shape %arg3 [[0, 1]] : memref<2x4x!tt.tile<32x32, f32>, #l1_> into memref<8x!tt.tile<32x32, f32>, #l1_>
%collapse_shape_1 = memref.collapse_shape %arg4 [[0, 1]] : memref<2x4x!tt.tile<32x32, f32>, #l1_> into memref<8x!tt.tile<32x32, f32>, #l1_>
affine.for %arg5 = 0 to 2 {
affine.for %arg6 = 0 to 4 {
%0 = affine.load %collapse_shape[%arg5 * 4 + %arg6] : memref<8x!tt.tile<32x32, f32>, #l1_>
%1 = affine.load %collapse_shape_0[%arg5 * 4 + %arg6] : memref<8x!tt.tile<32x32, f32>, #l1_>
%2 = "ttir.tile_maximum"(%0, %1) : (!tt.tile<32x32, f32>, !tt.tile<32x32, f32>) -> !tt.tile<32x32, f32>
affine.store %2, %collapse_shape_1[%arg5 * 4 + %arg6] : memref<8x!tt.tile<32x32, f32>, #l1_>
}
}
```
}];
}

def TTIRLayout: Pass<"ttir-layout", "::mlir::ModuleOp"> {
let summary = "Tensor tilize all generic ops.";
let description = [{
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTIR/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ add_mlir_dialect_library(MLIRTTIRTransforms
Allocate.cpp
Broadcast.cpp
Constant.cpp
Generic.cpp
GenericLinearizeMemref.cpp
HoistCPUOps.cpp
Layout.cpp
Transforms.cpp
Expand Down
16 changes: 0 additions & 16 deletions lib/Dialect/TTIR/Transforms/Generic.cpp

This file was deleted.

121 changes: 121 additions & 0 deletions lib/Dialect/TTIR/Transforms/GenericLinearizeMemref.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"
#include "ttmlir/Utils.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include <numeric>

namespace mlir::tt::ttir {
#define GEN_PASS_DEF_TTIRGENERICLINEARIZEMEMREF
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc"

namespace {
class TTIRGenericLinearizeMemrefRewriter : public OpRewritePattern<GenericOp> {
public:
using OpRewritePattern<GenericOp>::OpRewritePattern;

static bool isLinearizedMemref(BlockArgument arg) {
auto memref = mlir::cast<MemRefType>(arg.getType());
if (memref.getShape().size() == 1) {
return true;
}

return llvm::all_of(arg.getUsers(), [](Operation *user) {
return mlir::isa<memref::CollapseShapeOp>(user);
});
}

static mlir::AffineMap linearizeAffineMap(::mlir::MLIRContext *context,
mlir::AffineMap map,
ArrayRef<int64_t> shape) {
auto evaledShape = ttmlir::utils::evalShape(map, shape);
mlir::AffineExpr indexing = getAffineConstantExpr(0, context);
mlir::AffineExpr volumeExpr = getAffineConstantExpr(1, context);

assert(map.getNumResults() > 0);
for (int i = map.getNumResults() - 1; i >= 0; i--) {
mlir::AffineExpr linearIdx = getAffineDimExpr(i, context);
mlir::AffineExpr dim = getAffineConstantExpr(evaledShape[i], context);
indexing = linearIdx * volumeExpr + indexing;
volumeExpr = volumeExpr * dim;
}

mlir::AffineMap linearResult =
mlir::AffineMap::get(map.getNumResults(), 0, indexing, context);
return linearResult.compose(map);
}

LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const final {
Block *entry = &op.getRegion().front();
rewriter.setInsertionPointToStart(entry);
auto args = entry->getArguments();
if (llvm::all_of(args, isLinearizedMemref)) {
return failure();
}

rewriter.modifyOpInPlace(op, [&]() {
for (auto arg : args) {
if (isLinearizedMemref(arg)) {
continue;
}
auto memref = mlir::cast<MemRefType>(arg.getType());
auto shape = memref.getShape();
auto linearMap = linearizeAffineMap(
rewriter.getContext(), memref.getLayout().getAffineMap(), shape);
SmallVector<ReassociationIndices, 4> collapsedDims = {
llvm::to_vector(llvm::seq<int64_t>(0, shape.size()))};
auto linearizedArg = rewriter.create<memref::CollapseShapeOp>(
arg.getLoc(), arg, collapsedDims);
rewriter.replaceAllUsesExcept(arg, linearizedArg->getResult(0),
linearizedArg);
for (auto *user : linearizedArg->getUsers()) {
if (auto load = mlir::dyn_cast<affine::AffineLoadOp>(user)) {
load.setMap(linearMap.compose(load.getMap()));
} else if (auto store = mlir::dyn_cast<affine::AffineStoreOp>(user)) {
store.setMap(linearMap.compose(store.getMap()));
}
}
}
});

return success();
}
};
} // namespace

namespace {
class TTIRGenericLinearizeMemref
: public impl::TTIRGenericLinearizeMemrefBase<TTIRGenericLinearizeMemref> {
public:
using impl::TTIRGenericLinearizeMemrefBase<
TTIRGenericLinearizeMemref>::TTIRGenericLinearizeMemrefBase;

void runOnOperation() final {
RewritePatternSet patterns(&getContext());
patterns.add<TTIRGenericLinearizeMemrefRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsGreedily(getOperation(), patternSet))) {
signalPassFailure();
}
}
void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<mlir::tt::ttir::TTIRDialect>();
registry.insert<mlir::tt::TTDialect>();
registry.insert<mlir::arith::ArithDialect>();
}
};
} // namespace

} // namespace mlir::tt::ttir
5 changes: 5 additions & 0 deletions lib/Dialect/TTMetal/Pipelines/TTMetalPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

#include "ttmlir/Dialect/TTMetal/Pipelines/TTMetalPipelines.h"

#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Pass/PassManager.h"

#include "ttmlir/Conversion/Passes.h"
Expand Down Expand Up @@ -79,6 +81,9 @@ void createTTIRToTTMetalBackendPipeline(
// pm.addPass(mlir::tt::ttir::createTTIRGenericRegion());
if (options.version > 0) {
createTTIRBufferizationPipeline(pm);
pm.addPass(mlir::createConvertLinalgToAffineLoopsPass());
pm.addPass(mlir::tt::ttir::createTTIRGenericLinearizeMemref());
pm.addPass(mlir::createLowerAffinePass());
} else {
mlir::tt::ttir::TTIRLayoutOptions layoutOptions;
{
Expand Down
5 changes: 3 additions & 2 deletions lib/RegisterAll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ void mlir::tt::registerAllDialects(mlir::DialectRegistry &registry) {
mlir::tt::ttkernel::TTKernelDialect, mlir::func::FuncDialect,
mlir::arith::ArithDialect, mlir::ml_program::MLProgramDialect,
mlir::tensor::TensorDialect, mlir::linalg::LinalgDialect,
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
mlir::tosa::TosaDialect, mlir::vector::VectorDialect,
mlir::affine::AffineDialect, mlir::scf::SCFDialect,
mlir::cf::ControlFlowDialect, mlir::tosa::TosaDialect,
mlir::vector::VectorDialect, mlir::memref::MemRefDialect,
mlir::emitc::EmitCDialect, mlir::bufferization::BufferizationDialect,
mlir::LLVM::LLVMDialect>();

Expand Down
47 changes: 47 additions & 0 deletions test/ttmlir/Dialect/TTIR/loops/linearize_memref.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// RUN: ttmlir-opt --ttir-generic-linearize-memref %s | FileCheck %s

#l1_ = #tt.memory_space<l1>
#map = affine_map<(d0, d1) -> (d0, d1)>
#parallel = #tt.iterator_type<parallel>

func.func @add(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>, %arg1: memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>) -> memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>
"ttir.generic"(%arg0, %arg1, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array<i32: 2, 0, 1>, operand_cb_mapping = array<i64>}> ({
^bb0(%arg2: memref<2x4x!tt.tile<32x32, f32>, #l1_>, %arg3: memref<2x4x!tt.tile<32x32, f32>, #l1_>, %arg4: memref<2x4x!tt.tile<32x32, f32>, #l1_>):
// CHECK: = memref.collapse_shape %arg3
// CHECK: = memref.collapse_shape %arg4
affine.for %arg5 = 0 to 2 {
affine.for %arg6 = 0 to 4 {
// CHECK: = affine.load %collapse_shape[%arg5 * 4 + %arg6]
// CHECK: = affine.load %collapse_shape_0[%arg5 * 4 + %arg6]
%0 = affine.load %arg2[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>
%1 = affine.load %arg3[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>
%2 = "ttir.tile_add"(%0, %1) : (!tt.tile<32x32, f32>, !tt.tile<32x32, f32>) -> !tt.tile<32x32, f32>
// CHECK: affine.store %2, %collapse_shape_1[%{{.*}} * 4 + %{{.*}}]
affine.store %2, %arg4[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>
}
}
}) : (memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>, memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>, memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> ()
return %alloc : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>
}

func.func @addT(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>, %arg1T: memref<1x1x4x2x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>) -> memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>
"ttir.generic"(%arg0, %arg1T, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array<i32: 2, 0, 1>, operand_cb_mapping = array<i64>}> ({
^bb0(%arg2: memref<2x4x!tt.tile<32x32, f32>, #l1_>, %arg3: memref<4x2x!tt.tile<32x32, f32>, #l1_>, %arg4: memref<2x4x!tt.tile<32x32, f32>, #l1_>):
// CHECK: = memref.collapse_shape %arg3
// CHECK: = memref.collapse_shape %arg4
affine.for %arg5 = 0 to 2 {
affine.for %arg6 = 0 to 4 {
// CHECK: = affine.load %collapse_shape[%arg5 * 4 + %arg6]
// CHECK: = affine.load %collapse_shape_0[%arg6 * 2 + %arg5]
%0 = affine.load %arg2[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>
%1 = affine.load %arg3[%arg6, %arg5] : memref<4x2x!tt.tile<32x32, f32>, #l1_>
%2 = "ttir.tile_add"(%0, %1) : (!tt.tile<32x32, f32>, !tt.tile<32x32, f32>) -> !tt.tile<32x32, f32>
// CHECK: affine.store %2, %collapse_shape_1[%{{.*}} * 4 + %{{.*}}]
affine.store %2, %arg4[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>
}
}
}) : (memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>, memref<1x1x4x2x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>, memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> ()
return %alloc : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>
}
Loading