From 34107c9e942f642637381e1b4710f273a880be90 Mon Sep 17 00:00:00 2001 From: Armin Ale Date: Mon, 24 Feb 2025 23:15:56 +0000 Subject: [PATCH] op tests --- .../OpModel/TTNN/Op/TestOpModelInterface.cpp | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp b/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp index d4372ff0ec..eaf7dec71b 100644 --- a/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp +++ b/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp @@ -11,6 +11,8 @@ #include "mlir/IR/AffineExpr.h" #include "gtest/gtest.h" +#include +#include #include namespace mlir::tt::ttnn { @@ -215,4 +217,75 @@ TEST_F(OpModelBase, MatmulInterface) { } } +TEST_F(OpModelBase, MeanInterface) { + // create MeanOp + llvm::SmallVector tensorShapeA = {2048, 1024}; + llvm::SmallVector tensorShapeO = {2048, 1024}; + + auto input = createEmptyTensor(tensorShapeA); + auto output = createEmptyTensor(tensorShapeO); + + auto mean = builder.create(builder.getUnknownLoc(), output.getType(), + ::mlir::ValueRange{input}); + mean->setAttr(DeviceAttr::name, getFakeDeviceAttr()); + mean.setKeepDim(true); + mean.setDimArgAttr(builder.getArrayAttr( + llvm::SmallVector{builder.getI64IntegerAttr(1)})); + + // test mean Op interface + auto constraintsExp = getOpConstraints(mean.getOperation()); + if (constraintsExp) { + auto l1 = constraintsExp.get(); + const auto &[cb_size, peak_size, output_size] = l1; + EXPECT_EQ(cb_size, 12288); + EXPECT_EQ(peak_size, 2048); + EXPECT_EQ(output_size, 2048); + } else { + FAIL() << "Missing L1 constraints; Error=" + << llvm::toString(constraintsExp.takeError()) << std::endl; + } + + auto runtimeExp = getOpRuntime(mean.getOperation()); + if (runtimeExp) { + EXPECT_TRUE(runtimeExp.get() > 0); + } else { + FAIL() << llvm::toString(runtimeExp.takeError()); + } +} + +TEST_F(OpModelBase, reshapeOp) { + // create ReshapeOp + llvm::SmallVector tensorShapeA = {64, 1024}; + llvm::SmallVector tensorShapeO = {64 * 4, 1024 / 4}; + + auto input = createEmptyTensor(tensorShapeA); + auto output = createEmptyTensor(tensorShapeO); + + auto reshape = builder.create( + builder.getUnknownLoc(), output.getType(), ::mlir::ValueRange{input}); + reshape->setAttr(DeviceAttr::name, getFakeDeviceAttr()); + reshape.setShapeAttr(builder.getArrayAttr(llvm::SmallVector{ + builder.getI64IntegerAttr(64 * 4), builder.getI64IntegerAttr(1024 / 4)})); + + // test mean Op interface + auto constraintsExp = getOpConstraints(reshape.getOperation()); + if (constraintsExp) { + auto l1 = constraintsExp.get(); + const auto &[cb_size, peak_size, output_size] = l1; + EXPECT_EQ(cb_size, 262144); + EXPECT_EQ(peak_size, 4096); + EXPECT_EQ(output_size, 2048); + } else { + FAIL() << "Missing L1 constraints; Error=" + << llvm::toString(constraintsExp.takeError()) << std::endl; + } + + auto runtimeExp = getOpRuntime(reshape.getOperation()); + if (runtimeExp) { + EXPECT_TRUE(runtimeExp.get() > 0); + } else { + FAIL() << llvm::toString(runtimeExp.takeError()); + } +} + } // namespace mlir::tt::ttnn