diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index c79ee757e9..933526ea69 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -649,7 +649,9 @@ def TTNN_SumOp : TTNN_ReductionOp<"sum"> { }]; } -def TTNN_MeanOp : TTNN_ReductionOp<"mean"> { +def TTNN_MeanOp : TTNN_ReductionOp<"mean", + [DeclareOpInterfaceMethods] + > { let summary = "Mean reduction op."; let description = [{ Mean reduction op. @@ -952,7 +954,9 @@ def TTNN_ConcatOp : TTNN_NamedDPSOp<"concat", [HasMemoryConfigTrait]> { let hasVerifier = 1; } -def TTNN_ReshapeOp : TTNN_Op<"reshape"> { +def TTNN_ReshapeOp : TTNN_Op<"reshape", + [DeclareOpInterfaceMethods] + > { let summary = "Reshape op."; let description = [{ Reshape tensor. diff --git a/include/ttmlir/OpModel/TTNN/TTNNOpModel.h b/include/ttmlir/OpModel/TTNN/TTNNOpModel.h index 6dc409d634..27eff0ca5d 100644 --- a/include/ttmlir/OpModel/TTNN/TTNNOpModel.h +++ b/include/ttmlir/OpModel/TTNN/TTNNOpModel.h @@ -82,6 +82,44 @@ getOpRuntime(llvm::ArrayRef inputShape, }; // namespace SoftmaxOpInterface +//===----------------------------------------------------------------------===// +// MeanOp +//===----------------------------------------------------------------------===// + +namespace MeanOpInterface { +llvm::Expected> +getOpConstraints(llvm::ArrayRef inputShape, + mlir::tt::ttnn::TTNNLayoutAttr inputLayout, + std::optional> dimArg, bool keepDim, + mlir::tt::ttnn::TTNNLayoutAttr outputLayout); + +llvm::Expected +getOpRuntime(llvm::ArrayRef inputShape, + mlir::tt::ttnn::TTNNLayoutAttr inputLayout, + std::optional> dimArg, bool keepDim, + mlir::tt::ttnn::TTNNLayoutAttr outputLayout); + +}; // namespace MeanOpInterface + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +namespace ReshapeOpInterface { +llvm::Expected> +getOpConstraints(llvm::ArrayRef inputShape, + mlir::tt::ttnn::TTNNLayoutAttr inputLayout, + llvm::ArrayRef outputShape, + mlir::tt::ttnn::TTNNLayoutAttr outputLayout); + +llvm::Expected +getOpRuntime(llvm::ArrayRef inputShape, + mlir::tt::ttnn::TTNNLayoutAttr inputLayout, + llvm::ArrayRef outputShape, + mlir::tt::ttnn::TTNNLayoutAttr outputLayout); + +}; // namespace ReshapeOpInterface + //===----------------------------------------------------------------------===// // MatmulOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp b/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp index 2623db1cc5..d1874441f6 100644 --- a/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/Operation.h" #include +#include #include #include @@ -24,6 +25,22 @@ llvm::Expected checkDeviceWorkerGrid(mlir::Operation *op) { return op_model::ttnn::Device::getDeviceConstraints( deviceAttr.getWorkerGrid()); } + +std::optional> +convertReductionArg(std::optional arrayOpt) { + if (!arrayOpt.has_value()) { + return std::nullopt; + } + + llvm::SmallVector reduceDims; + + for (const mlir::Attribute &reduceDim : *arrayOpt) { + reduceDims.push_back(mlir::cast(reduceDim).getInt()); + } + + return reduceDims; +} + } // namespace detail //===----------------------------------------------------------------------===// @@ -41,7 +58,7 @@ ReluOp::getOpConstraints(const std::vector &inputs, const auto outputShape = mlir::cast(getResults().front().getType()).getShape(); - auto check = detail::checkDeviceWorkerGrid(getOperation()); + llvm::Expected check = detail::checkDeviceWorkerGrid(getOperation()); if (!check) { return check.takeError(); } @@ -83,7 +100,7 @@ AddOp::getOpConstraints(const std::vector &inputs, const auto outputShape = mlir::cast(getResult(0).getType()).getShape(); - auto check = detail::checkDeviceWorkerGrid(getOperation()); + llvm::Expected check = detail::checkDeviceWorkerGrid(getOperation()); if (!check) { return check.takeError(); } @@ -124,7 +141,7 @@ SoftmaxOp::getOpConstraints(const std::vector &inputs, const auto outputShape = mlir::cast(getResult().getType()).getShape(); - auto check = detail::checkDeviceWorkerGrid(getOperation()); + llvm::Expected check = detail::checkDeviceWorkerGrid(getOperation()); if (!check) { return check.takeError(); } @@ -148,6 +165,79 @@ SoftmaxOp::getOpRuntime(const std::vector &inputs, inputShape, inputs[0], getDimension(), outputShape, output); } +//===----------------------------------------------------------------------===// +// MeanOp - TTNN Op Model Interface +//===----------------------------------------------------------------------===// + +llvm::Expected> +MeanOp::getOpConstraints(const std::vector &inputs, + const TTNNLayoutAttr &output) { + assert(inputs.size() == 1); + + const auto inputShape = + mlir::cast(getOperand().getType()).getShape(); + + llvm::Expected check = detail::checkDeviceWorkerGrid(getOperation()); + if (!check) { + return check.takeError(); + } + + return op_model::ttnn::MeanOpInterface::getOpConstraints( + inputShape, inputs[0], detail::convertReductionArg(getDimArg()), + getKeepDim(), output); +} + +llvm::Expected +MeanOp::getOpRuntime(const std::vector &inputs, + const TTNNLayoutAttr &output) { + assert(inputs.size() == 1); + + const auto inputShape = + mlir::cast(getOperand().getType()).getShape(); + + return op_model::ttnn::MeanOpInterface::getOpRuntime( + inputShape, inputs[0], detail::convertReductionArg(getDimArg()), + getKeepDim(), output); +} + +//===----------------------------------------------------------------------===// +// ReshapeOp - TTNN Op Model Interface +//===----------------------------------------------------------------------===// + +llvm::Expected> +ReshapeOp::getOpConstraints(const std::vector &inputs, + const TTNNLayoutAttr &output) { + assert(inputs.size() == 1); + + const auto inputShape = + mlir::cast(getOperand().getType()).getShape(); + + const auto outputShape = + mlir::cast(getResult().getType()).getShape(); + + llvm::Expected check = detail::checkDeviceWorkerGrid(getOperation()); + if (!check) { + return check.takeError(); + } + + return op_model::ttnn::ReshapeOpInterface::getOpConstraints( + inputShape, inputs[0], outputShape, output); +} + +llvm::Expected +ReshapeOp::getOpRuntime(const std::vector &inputs, + const TTNNLayoutAttr &output) { + assert(inputs.size() == 1); + + const auto inputShape = + mlir::cast(getOperand().getType()).getShape(); + const auto outputShape = + mlir::cast(getResult().getType()).getShape(); + + return op_model::ttnn::ReshapeOpInterface::getOpRuntime(inputShape, inputs[0], + outputShape, output); +} + //===----------------------------------------------------------------------===// // MatmulOp - TTNN Op Model Interface //===----------------------------------------------------------------------===// @@ -165,7 +255,7 @@ MatmulOp::getOpConstraints(const std::vector &inputs, const auto outputShape = mlir::cast(getResult().getType()).getShape(); - auto check = detail::checkDeviceWorkerGrid(getOperation()); + llvm::Expected check = detail::checkDeviceWorkerGrid(getOperation()); if (!check) { return check.takeError(); } diff --git a/lib/OpModel/TTNN/Conversion.cpp b/lib/OpModel/TTNN/Conversion.cpp index 3b88c4ccab..fcb22e4da5 100644 --- a/lib/OpModel/TTNN/Conversion.cpp +++ b/lib/OpModel/TTNN/Conversion.cpp @@ -2,6 +2,9 @@ // // SPDX-License-Identifier: Apache-2.0 +#include +#include +#include #ifdef TTMLIR_ENABLE_OPMODEL #include "Conversion.hpp" @@ -150,6 +153,11 @@ ::ttnn::TensorSpec getTensorSpec(const ::llvm::ArrayRef shape, return ::ttnn::TensorSpec(getShape(shape), getTensorLayout(layout)); } +::ttnn::SmallVector +convertLLVMSmallVecToTTNNSmallVec(const ::llvm::ArrayRef vec) { + return ::ttnn::SmallVector(vec.begin(), vec.end()); +} + } // namespace conversion } // namespace mlir::tt::op_model::ttnn #endif // TTMLIR_ENABLE_OPMODEL diff --git a/lib/OpModel/TTNN/Conversion.hpp b/lib/OpModel/TTNN/Conversion.hpp index be2a070802..0dd37bfcb6 100644 --- a/lib/OpModel/TTNN/Conversion.hpp +++ b/lib/OpModel/TTNN/Conversion.hpp @@ -43,6 +43,9 @@ getTensorLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout); ::ttnn::TensorSpec getTensorSpec(const ::llvm::ArrayRef shape, const mlir::tt::ttnn::TTNNLayoutAttr &layout); +::ttnn::SmallVector +convertLLVMSmallVecToTTNNSmallVec(const ::llvm::ArrayRef vec); + } // namespace conversion } // namespace mlir::tt::op_model::ttnn diff --git a/lib/OpModel/TTNN/MetalHeaders.h b/lib/OpModel/TTNN/MetalHeaders.h index 98bfe78e37..cbf929dd3b 100644 --- a/lib/OpModel/TTNN/MetalHeaders.h +++ b/lib/OpModel/TTNN/MetalHeaders.h @@ -61,10 +61,12 @@ #include "ttnn/graph/graph_query_op_constraints.hpp" #include "ttnn/graph/graph_query_op_runtime.hpp" #include "ttnn/graph/graph_trace_utils.hpp" +#include "ttnn/operations/data_movement/reshape_view/reshape.hpp" #include "ttnn/operations/eltwise/binary/binary.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" #include "ttnn/operations/matmul/matmul.hpp" #include "ttnn/operations/normalization/softmax/softmax.hpp" +#include "ttnn/operations/reduction/generic/generic_reductions.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/tensor_spec.hpp" #include "ttnn/tensor/types.hpp" diff --git a/lib/OpModel/TTNN/TTNNOpModelLib.cpp b/lib/OpModel/TTNN/TTNNOpModelLib.cpp index 102eddbc99..f8b0de05ec 100644 --- a/lib/OpModel/TTNN/TTNNOpModelLib.cpp +++ b/lib/OpModel/TTNN/TTNNOpModelLib.cpp @@ -412,6 +412,160 @@ SoftmaxOpInterface::getOpRuntime(llvm::ArrayRef inputShape, #endif // TTMLIR_ENABLE_OPMODEL } +//===----------------------------------------------------------------------===// +// MeanOp +//===----------------------------------------------------------------------===// +llvm::Expected> +MeanOpInterface::getOpConstraints(llvm::ArrayRef inputShape, + mlir::tt::ttnn::TTNNLayoutAttr inputLayout, + std::optional> dimArg, + bool keepDim, + mlir::tt::ttnn::TTNNLayoutAttr outputLayout) { +#ifdef TTMLIR_ENABLE_OPMODEL + auto meanOpQuery = [](llvm::ArrayRef inputShape, + mlir::tt::ttnn::TTNNLayoutAttr inputLayout, + std::optional> dimArg, + bool keepDim, + mlir::tt::ttnn::TTNNLayoutAttr outputLayout) { + // open device device, will close it at the end of function + ::tt::tt_metal::v0::IDevice *device = + SingletonDeviceContext::getInstance().getDevice(); + + // prepare io specs + auto [inputSpec] = detail::convertToTensorSpec( + device, std::make_tuple(inputShape, inputLayout)); + auto memConfig = conversion::getMemoryConfig(outputLayout); + + std::optional<::ttnn::SmallVector> dimArgConverted; + if (dimArg) { + dimArgConverted = + conversion::convertLLVMSmallVecToTTNNSmallVec(dimArg.value()); + } else { + dimArgConverted = std::nullopt; + } + // run op constraint query + return ::ttnn::graph::query_op_constraints( + ::ttnn::mean, device, inputSpec, dimArgConverted, keepDim, memConfig); + }; + + return operation::getOpConstraints("MeanOpInterface", meanOpQuery, inputShape, + inputLayout, dimArg, keepDim, + outputLayout); +#else + return llvm::createStringError("Not Implemented"); +#endif // TTMLIR_ENABLE_OPMODEL +} + +llvm::Expected +MeanOpInterface::getOpRuntime(llvm::ArrayRef inputShape, + mlir::tt::ttnn::TTNNLayoutAttr inputLayout, + std::optional> dimArg, + bool keepDim, + mlir::tt::ttnn::TTNNLayoutAttr outputLayout) { +#ifdef TTMLIR_ENABLE_OPMODEL + auto meanOpQuery = [](llvm::ArrayRef inputShape, + mlir::tt::ttnn::TTNNLayoutAttr inputLayout, + std::optional> dimArg, + bool keepDim, + mlir::tt::ttnn::TTNNLayoutAttr outputLayout) { + // open device device, will close it at the end of function + ::tt::tt_metal::v0::IDevice *device = + SingletonDeviceContext::getInstance().getDevice(); + + // prepare io specs + auto [inputSpec] = detail::convertToTensorSpec( + device, std::make_tuple(inputShape, inputLayout)); + auto memConfig = conversion::getMemoryConfig(outputLayout); + + std::optional<::ttnn::SmallVector> dimArgConverted; + if (dimArg) { + dimArgConverted = + conversion::convertLLVMSmallVecToTTNNSmallVec(dimArg.value()); + } else { + dimArgConverted = std::nullopt; + } + + // run op runtime query + return ::ttnn::graph::query_op_runtime(::ttnn::mean, device, inputSpec, + dimArgConverted, keepDim, memConfig); + }; + + return operation::getOpRuntime("MeanOpInterface", meanOpQuery, inputShape, + inputLayout, dimArg, keepDim, outputLayout); +#else + return llvm::createStringError("Not Implemented"); +#endif // TTMLIR_ENABLE_OPMODEL +} + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// +llvm::Expected> +ReshapeOpInterface::getOpConstraints( + llvm::ArrayRef inputShape, + mlir::tt::ttnn::TTNNLayoutAttr inputLayout, + llvm::ArrayRef outputShape, + mlir::tt::ttnn::TTNNLayoutAttr outputLayout) { +#ifdef TTMLIR_ENABLE_OPMODEL + auto reshapeOpQuery = [](llvm::ArrayRef inputShape, + mlir::tt::ttnn::TTNNLayoutAttr inputLayout, + llvm::ArrayRef outputShape, + mlir::tt::ttnn::TTNNLayoutAttr outputLayout) { + // open device device, will close it at the end of function + ::tt::tt_metal::v0::IDevice *device = + SingletonDeviceContext::getInstance().getDevice(); + + // prepare io specs + const auto [inputSpec, outputSpec] = detail::convertToTensorSpec( + device, std::make_tuple(inputShape, inputLayout), + std::make_tuple(outputShape, outputLayout)); + + // run op constraint query + return ::ttnn::graph::query_op_constraints( + ::ttnn::reshape, device, inputSpec, conversion::getShape(outputShape), + outputSpec.tensor_layout().get_memory_config()); + }; + + return operation::getOpConstraints("ReshapeOpInterface", reshapeOpQuery, + inputShape, inputLayout, outputShape, + outputLayout); +#else + return std::make_tuple(0, 0, 0); +#endif // TTMLIR_ENABLE_OPMODEL +} + +llvm::Expected +ReshapeOpInterface::getOpRuntime(llvm::ArrayRef inputShape, + mlir::tt::ttnn::TTNNLayoutAttr inputLayout, + llvm::ArrayRef outputShape, + mlir::tt::ttnn::TTNNLayoutAttr outputLayout) { +#ifdef TTMLIR_ENABLE_OPMODEL + auto reshapeOpQuery = [](llvm::ArrayRef inputShape, + mlir::tt::ttnn::TTNNLayoutAttr inputLayout, + llvm::ArrayRef outputShape, + mlir::tt::ttnn::TTNNLayoutAttr outputLayout) { + // open device device, will close it at the end of function + ::tt::tt_metal::v0::IDevice *device = + SingletonDeviceContext::getInstance().getDevice(); + + // prepare io specs + const auto [inputSpec, outputSpec] = detail::convertToTensorSpec( + device, std::make_tuple(inputShape, inputLayout), + std::make_tuple(outputShape, outputLayout)); + + return ::ttnn::graph::query_op_runtime( + ::ttnn::reshape, device, inputSpec, conversion::getShape(outputShape), + outputSpec.tensor_layout().get_memory_config()); + }; + + return operation::getOpRuntime("ReshapeOpInterface", reshapeOpQuery, + inputShape, inputLayout, outputShape, + outputLayout); +#else + return llvm::createStringError("Not Implemented"); +#endif // TTMLIR_ENABLE_OPMODEL +} + //===----------------------------------------------------------------------===// // MatmulOp //===----------------------------------------------------------------------===// diff --git a/test/unittests/OpModel/TTNN/Lib/TestOpModelLib.cpp b/test/unittests/OpModel/TTNN/Lib/TestOpModelLib.cpp index dd3894041b..2dfd84c6e6 100644 --- a/test/unittests/OpModel/TTNN/Lib/TestOpModelLib.cpp +++ b/test/unittests/OpModel/TTNN/Lib/TestOpModelLib.cpp @@ -10,6 +10,7 @@ #include "llvm/ADT/SmallVector.h" #include "gtest/gtest.h" +#include #include namespace mlir::tt::op_model::ttnn { @@ -146,6 +147,77 @@ INSTANTIATE_TEST_SUITE_P( mlir::tt::ttnn::BufferType::L1}, detail::ExpectedResult{false}))); +class OpModelReductionParam + : public OpModelTest, + public testing::WithParamInterface< + std::tuple>, // dim arg + bool, // keep dim + detail::ExpectedResult>> {}; + +TEST_P(OpModelReductionParam, Reduction) { + auto params = GetParam(); + const auto [inputShape, inputTensorLayout, inputBufferType, + inputVirtualGrid] = std::get<0>(params); + + const auto [outputShape, outputTensorLayout, outputBufferType, + outputVirtualGrid] = std::get<1>(params); + const auto dimArg = std::get<2>(params); + const auto keepDim = std::get<3>(params); + const auto [expectedLegal, expectedCbSize, expectedPeakSize, + expectedOutputSize] = std::get<4>(params); + + const mlir::tt::ttnn::TTNNLayoutAttr inputLayout = CreateTiledLayout( + inputShape, inputBufferType, inputTensorLayout, inputVirtualGrid); + const mlir::tt::ttnn::TTNNLayoutAttr outputLayout = CreateTiledLayout( + outputShape, outputBufferType, outputTensorLayout, outputVirtualGrid); + + auto constraintsExp = MeanOpInterface::getOpConstraints( + inputShape, inputLayout, dimArg, keepDim, outputLayout); + // Manually cast to bool because EXPECT_TRUE requires a const bool operator + // which llvm::Expected does not have + EXPECT_EQ(static_cast(constraintsExp), expectedLegal); + if (expectedLegal) { + const auto [cbSize, peakSize, outputSize] = constraintsExp.get(); + EXPECT_EQ(cbSize, expectedCbSize); + EXPECT_EQ(peakSize, expectedPeakSize); + EXPECT_EQ(outputSize, expectedOutputSize); + } else { + // Must clean up the error + llvm::consumeError(constraintsExp.takeError()); + } + + auto runtimeExp = MeanOpInterface::getOpRuntime( + inputShape, inputLayout, dimArg, keepDim, outputLayout); + EXPECT_EQ(static_cast(runtimeExp), expectedLegal); + if (expectedLegal) { + EXPECT_TRUE(runtimeExp.get() > 0); + } else { + llvm::consumeError(runtimeExp.takeError()); + } +} + +INSTANTIATE_TEST_SUITE_P( + MeanTests, OpModelReductionParam, + ::testing::Values( + std::make_tuple(detail::interleavedN300X1024Dram, + detail::interleavedN300X1024Dram, + llvm::SmallVector{1}, true, + detail::ExpectedResult{true, 12288, 0, 0}), + std::make_tuple(detail::interleavedN300X1024Dram, + detail::interleavedN300X1024Dram, + llvm::SmallVector{1, 2}, false, + detail::ExpectedResult{false, 0, 0, 0}), + std::make_tuple(detail::interleavedN300X1024Dram, + detail::interleavedN300X1024Dram, + llvm::SmallVector{1, 0}, false, + detail::ExpectedResult{true, 12288, 0, 0}), + std::make_tuple(detail::interleavedN300X1024L1, + detail::interleavedN300X1024Dram, + llvm::SmallVector{1}, false, + detail::ExpectedResult{true, 12288, 0, 0}))); + TEST_F(OpModelTest, SoftmaxInterleaved) { const llvm::SmallVector tensorShape = {workerCoresN300, 1024}; const auto workerGrid = CreateWorkerGrid(gridShapeHwN300); @@ -213,6 +285,45 @@ TEST_F(OpModelTest, SoftmaxInterleaved) { } } +TEST_F(OpModelTest, Reshape) { + const llvm::SmallVector tensorShape = {workerCoresN300, 1024}; + const auto workerGrid = CreateWorkerGrid(gridShapeHwN300); + const mlir::tt::ttnn::TTNNLayoutAttr layoutDRAM = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::DRAM, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); + const mlir::tt::ttnn::TTNNLayoutAttr layoutL1 = + CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::L1, + mlir::tt::ttnn::TensorMemoryLayout::Interleaved); + auto legalExp = Device::getDeviceConstraints(workerGrid); + EXPECT_TRUE(static_cast(legalExp)); + + auto constraintsExp = ReshapeOpInterface::getOpConstraints( + tensorShape, layoutDRAM, {workerCoresN300 * 4, 256}, layoutDRAM); + EXPECT_TRUE(static_cast(constraintsExp)); + auto [cb_size, peak_size, output_size] = constraintsExp.get(); + EXPECT_EQ(cb_size, 262144); + EXPECT_EQ(output_size, 0); + EXPECT_EQ(peak_size, 0); + + auto runtimeExp = ReshapeOpInterface::getOpRuntime( + tensorShape, layoutDRAM, {workerCoresN300 * 4, 256}, layoutDRAM); + EXPECT_TRUE(static_cast(runtimeExp)); + EXPECT_TRUE(runtimeExp.get() > 0); + + constraintsExp = ReshapeOpInterface::getOpConstraints( + tensorShape, layoutDRAM, {workerCoresN300 * 4, 256}, layoutL1); + EXPECT_TRUE(static_cast(constraintsExp)); + std::tie(cb_size, peak_size, output_size) = constraintsExp.get(); + EXPECT_EQ(cb_size, 262144); + EXPECT_EQ(output_size, 2048); + EXPECT_EQ(peak_size, 4096); + + runtimeExp = ReshapeOpInterface::getOpRuntime( + tensorShape, layoutDRAM, {workerCoresN300 * 4, 256}, layoutL1); + EXPECT_TRUE(static_cast(runtimeExp)); + EXPECT_TRUE(runtimeExp.get() > 0); +} + TEST_F(OpModelTest, SoftmaxSharded) { const llvm::SmallVector tensorShape = {16 * workerCoresN300 * 32, 32}; diff --git a/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp b/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp index d4372ff0ec..27536e8bc6 100644 --- a/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp +++ b/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/AffineExpr.h" #include "gtest/gtest.h" +#include #include namespace mlir::tt::ttnn { @@ -215,4 +216,75 @@ TEST_F(OpModelBase, MatmulInterface) { } } +TEST_F(OpModelBase, MeanInterface) { + // create MeanOp + llvm::SmallVector tensorShapeA = {2048, 1024}; + llvm::SmallVector tensorShapeO = {2048, 1024}; + + auto input = createEmptyTensor(tensorShapeA); + auto output = createEmptyTensor(tensorShapeO); + + auto mean = builder.create(builder.getUnknownLoc(), output.getType(), + ::mlir::ValueRange{input}); + mean->setAttr(DeviceAttr::name, getFakeDeviceAttr()); + mean.setKeepDim(true); + mean.setDimArgAttr(builder.getArrayAttr( + llvm::SmallVector{builder.getI64IntegerAttr(1)})); + + // test mean Op interface + auto constraintsExp = getOpConstraints(mean.getOperation()); + if (constraintsExp) { + auto l1 = constraintsExp.get(); + const auto &[cb_size, peak_size, output_size] = l1; + EXPECT_EQ(cb_size, 12288); + EXPECT_EQ(peak_size, 2048); + EXPECT_EQ(output_size, 2048); + } else { + FAIL() << "Missing L1 constraints; Error=" + << llvm::toString(constraintsExp.takeError()) << std::endl; + } + + auto runtimeExp = getOpRuntime(mean.getOperation()); + if (runtimeExp) { + EXPECT_TRUE(runtimeExp.get() > 0); + } else { + FAIL() << llvm::toString(runtimeExp.takeError()); + } +} + +TEST_F(OpModelBase, reshapeOp) { + // create ReshapeOp + llvm::SmallVector tensorShapeA = {64, 1024}; + llvm::SmallVector tensorShapeO = {64 * 4, 1024 / 4}; + + auto input = createEmptyTensor(tensorShapeA); + auto output = createEmptyTensor(tensorShapeO); + + auto reshape = builder.create( + builder.getUnknownLoc(), output.getType(), ::mlir::ValueRange{input}); + reshape->setAttr(DeviceAttr::name, getFakeDeviceAttr()); + reshape.setShapeAttr(builder.getArrayAttr(llvm::SmallVector{ + builder.getI64IntegerAttr(64 * 4), builder.getI64IntegerAttr(1024 / 4)})); + + // test mean Op interface + auto constraintsExp = getOpConstraints(reshape.getOperation()); + if (constraintsExp) { + auto l1 = constraintsExp.get(); + const auto &[cb_size, peak_size, output_size] = l1; + EXPECT_EQ(cb_size, 262144); + EXPECT_EQ(peak_size, 4096); + EXPECT_EQ(output_size, 2048); + } else { + FAIL() << "Missing L1 constraints; Error=" + << llvm::toString(constraintsExp.takeError()) << std::endl; + } + + auto runtimeExp = getOpRuntime(reshape.getOperation()); + if (runtimeExp) { + EXPECT_TRUE(runtimeExp.get() > 0); + } else { + FAIL() << llvm::toString(runtimeExp.takeError()); + } +} + } // namespace mlir::tt::ttnn