Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add constraints and runtime API to transpose #2322

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,9 @@ def TTNN_SoftmaxOp : TTNN_Op<"softmax",
let hasVerifier = 1;
}

def TTNN_TransposeOp : TTNN_Op<"transpose"> {
def TTNN_TransposeOp : TTNN_Op<"transpose",
[DeclareOpInterfaceMethods<TTNN_OpModelInterface, ["getOpConstraints", "getOpRuntime"]>]
> {
let summary = "Transpose op.";
let description = [{
Transpose tensor along two given dimensions.
Expand Down
17 changes: 17 additions & 0 deletions include/ttmlir/OpModel/TTNN/TTNNOpModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,23 @@ getOpRuntime(llvm::ArrayRef<int64_t> inputShape,

}; // namespace ReshapeOpInterface

//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//

namespace TransposeOpInterface {
llvm::Expected<std::tuple<size_t, size_t, size_t>>
getOpConstraints(llvm::ArrayRef<int64_t> inputShape,
mlir::tt::ttnn::TTNNLayoutAttr inputLayout, const int dim0,
const int dim1, mlir::tt::ttnn::TTNNLayoutAttr outputLayout);

llvm::Expected<size_t>
getOpRuntime(llvm::ArrayRef<int64_t> inputShape,
mlir::tt::ttnn::TTNNLayoutAttr inputLayout, const int dim0,
const int dim1, mlir::tt::ttnn::TTNNLayoutAttr outputLayout);

}; // namespace TransposeOpInterface

//===----------------------------------------------------------------------===//
// MatmulOp
//===----------------------------------------------------------------------===//
Expand Down
33 changes: 33 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,39 @@ ReshapeOp::getOpRuntime(const std::vector<TTNNLayoutAttr> &inputs,
outputShape, output);
}

//===----------------------------------------------------------------------===//
// TransposeOp - TTNN Op Model Interface
//===----------------------------------------------------------------------===//

llvm::Expected<std::tuple<size_t, size_t, size_t>>
TransposeOp::getOpConstraints(const std::vector<TTNNLayoutAttr> &inputs,
const TTNNLayoutAttr &output) {
assert(inputs.size() == 1);

const auto inputShape =
mlir::cast<RankedTensorType>(getOperand().getType()).getShape();
Copy link
Contributor

@azecevicTT azecevicTT Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use getInput().getType() (you can check the name of the operands/attributes in TTNNOps.td, for every operand you have an appropriately generated getter) and you will directly get the RankedTensorType, getOperand() is reserved for more generic uses where you don't know exactly the operation you are working on.


llvm::Expected<bool> check = detail::checkDeviceWorkerGrid(getOperation());
if (!check) {
return check.takeError();
}

return op_model::ttnn::TransposeOpInterface::getOpConstraints(
inputShape, inputs[0], getDim0(), getDim1(), output);
}

llvm::Expected<size_t>
TransposeOp::getOpRuntime(const std::vector<TTNNLayoutAttr> &inputs,
const TTNNLayoutAttr &output) {
assert(inputs.size() == 1);

const auto inputShape =
mlir::cast<RankedTensorType>(getOperand().getType()).getShape();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.


return op_model::ttnn::TransposeOpInterface::getOpRuntime(
inputShape, inputs[0], getDim0(), getDim1(), output);
}

//===----------------------------------------------------------------------===//
// MatmulOp - TTNN Op Model Interface
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions lib/OpModel/TTNN/MetalHeaders.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
#include "ttnn/graph/graph_query_op_runtime.hpp"
#include "ttnn/graph/graph_trace_utils.hpp"
#include "ttnn/operations/data_movement/reshape_view/reshape.hpp"
#include "ttnn/operations/data_movement/transpose/transpose.hpp"
#include "ttnn/operations/eltwise/binary/binary.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/matmul/matmul.hpp"
Expand Down
65 changes: 65 additions & 0 deletions lib/OpModel/TTNN/TTNNOpModelLib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,71 @@ ReshapeOpInterface::getOpRuntime(llvm::ArrayRef<int64_t> inputShape,
#endif // TTMLIR_ENABLE_OPMODEL
}

//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
llvm::Expected<std::tuple<size_t, size_t, size_t>>
TransposeOpInterface::getOpConstraints(
llvm::ArrayRef<int64_t> inputShape,
mlir::tt::ttnn::TTNNLayoutAttr inputLayout, const int dim0, const int dim1,
mlir::tt::ttnn::TTNNLayoutAttr outputLayout) {
#ifdef TTMLIR_ENABLE_OPMODEL
auto transposeOpQuery = [](llvm::ArrayRef<int64_t> inputShape,
mlir::tt::ttnn::TTNNLayoutAttr inputLayout,
const int dim0, const int dim1,
mlir::tt::ttnn::TTNNLayoutAttr outputLayout) {
// open device device, will close it at the end of function
::tt::tt_metal::v0::IDevice *device =
SingletonDeviceContext::getInstance().getDevice();

// prepare io specs
const auto [inputSpec] = detail::convertToTensorSpec(
device, std::make_tuple(inputShape, inputLayout));

// run op constraint query
return ::ttnn::graph::query_op_constraints(
::ttnn::transpose, device, inputSpec, dim0, dim1,
conversion::getMemoryConfig(outputLayout));
};

return operation::getOpConstraints("TransposeOpInterface", transposeOpQuery,
inputShape, inputLayout, dim0, dim1,
outputLayout);
#else
return std::make_tuple(0, 0, 0);
#endif // TTMLIR_ENABLE_OPMODEL
}

llvm::Expected<size_t> TransposeOpInterface::getOpRuntime(
llvm::ArrayRef<int64_t> inputShape,
mlir::tt::ttnn::TTNNLayoutAttr inputLayout, const int dim0, const int dim1,
mlir::tt::ttnn::TTNNLayoutAttr outputLayout) {
#ifdef TTMLIR_ENABLE_OPMODEL
auto transposeOpQuery = [](llvm::ArrayRef<int64_t> inputShape,
mlir::tt::ttnn::TTNNLayoutAttr inputLayout,
const int dim0, const int dim1,
mlir::tt::ttnn::TTNNLayoutAttr outputLayout) {
// open device device, will close it at the end of function
::tt::tt_metal::v0::IDevice *device =
SingletonDeviceContext::getInstance().getDevice();

// prepare io specs
const auto [inputSpec] = detail::convertToTensorSpec(
device, std::make_tuple(inputShape, inputLayout));

return ::ttnn::graph::query_op_runtime(
::ttnn::transpose, device, inputSpec, dim0, dim1,
conversion::getMemoryConfig(outputLayout));
};

return operation::getOpRuntime("TransposeOpInterface", transposeOpQuery,
inputShape, inputLayout, dim0, dim1,
outputLayout);
#else
return llvm::createStringError("Not Implemented");
#endif // TTMLIR_ENABLE_OPMODEL
}

//===----------------------------------------------------------------------===//
// MatmulOp
//===----------------------------------------------------------------------===//
Expand Down
39 changes: 39 additions & 0 deletions test/unittests/OpModel/TTNN/Lib/TestOpModelLib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,45 @@ TEST_F(OpModelTest, Reshape) {
EXPECT_TRUE(runtimeExp.get() > 0);
}

TEST_F(OpModelTest, Transpose) {
const llvm::SmallVector<int64_t> tensorShape = {workerCoresN300, 1024};
const auto workerGrid = CreateWorkerGrid(gridShapeHwN300);
const mlir::tt::ttnn::TTNNLayoutAttr layoutDRAM =
CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::DRAM,
mlir::tt::ttnn::TensorMemoryLayout::Interleaved);
const mlir::tt::ttnn::TTNNLayoutAttr layoutL1 =
CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::L1,
mlir::tt::ttnn::TensorMemoryLayout::Interleaved);
auto legalExp = Device::getDeviceConstraints(workerGrid);
EXPECT_TRUE(static_cast<bool>(legalExp));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add one testcase where the provided layout is invalid?


auto constraintsExp = TransposeOpInterface::getOpConstraints(
tensorShape, layoutDRAM, 0, 1, layoutDRAM);
EXPECT_TRUE(static_cast<bool>(constraintsExp));
auto [cb_size, peak_size, output_size] = constraintsExp.get();
EXPECT_EQ(cb_size, 8192);
EXPECT_EQ(output_size, 0);
EXPECT_EQ(peak_size, 0);

auto runtimeExp = TransposeOpInterface::getOpRuntime(tensorShape, layoutDRAM,
0, 1, layoutDRAM);
EXPECT_TRUE(static_cast<bool>(runtimeExp));
EXPECT_TRUE(runtimeExp.get() > 0);

constraintsExp = TransposeOpInterface::getOpConstraints(
tensorShape, layoutDRAM, 0, 1, layoutL1);
EXPECT_TRUE(static_cast<bool>(constraintsExp));
std::tie(cb_size, peak_size, output_size) = constraintsExp.get();
EXPECT_EQ(cb_size, 8192);
EXPECT_EQ(output_size, 2048);
EXPECT_EQ(peak_size, 2048);

runtimeExp = TransposeOpInterface::getOpRuntime(tensorShape, layoutDRAM, 0, 1,
layoutL1);
EXPECT_TRUE(static_cast<bool>(runtimeExp));
EXPECT_TRUE(runtimeExp.get() > 0);
}

TEST_F(OpModelTest, SoftmaxSharded) {
const llvm::SmallVector<int64_t> tensorShape = {16 * workerCoresN300 * 32,
32};
Expand Down
35 changes: 34 additions & 1 deletion test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ TEST_F(OpModelBase, reshapeOp) {
reshape.setShapeAttr(builder.getArrayAttr(llvm::SmallVector<mlir::Attribute>{
builder.getI64IntegerAttr(64 * 4), builder.getI64IntegerAttr(1024 / 4)}));

// test mean Op interface
// test reshape Op interface
auto constraintsExp = getOpConstraints(reshape.getOperation());
if (constraintsExp) {
auto l1 = constraintsExp.get();
Expand All @@ -287,4 +287,37 @@ TEST_F(OpModelBase, reshapeOp) {
}
}

TEST_F(OpModelBase, transposeOp) {
// create TransposeOp
llvm::SmallVector<int64_t> tensorShapeA = {64, 1024};
llvm::SmallVector<int64_t> tensorShapeO = {1024, 64};

auto input = createEmptyTensor(tensorShapeA);
auto output = createEmptyTensor(tensorShapeO);

auto transpose = builder.create<TransposeOp>(builder.getUnknownLoc(),
output.getType(), input, 0, 1);
transpose->setAttr(DeviceAttr::name, getFakeDeviceAttr());

// test transpose Op interface
auto constraintsExp = getOpConstraints(transpose.getOperation());
if (constraintsExp) {
auto l1 = constraintsExp.get();
const auto &[cb_size, peak_size, output_size] = l1;
EXPECT_EQ(cb_size, 8192);
EXPECT_EQ(peak_size, 2048);
EXPECT_EQ(output_size, 2048);
} else {
FAIL() << "Missing L1 constraints; Error="
<< llvm::toString(constraintsExp.takeError()) << std::endl;
}

auto runtimeExp = getOpRuntime(transpose.getOperation());
if (runtimeExp) {
EXPECT_TRUE(runtimeExp.get() > 0);
} else {
FAIL() << llvm::toString(runtimeExp.takeError());
}
}

} // namespace mlir::tt::ttnn
Loading