Skip to content

Commit

Permalink
op tests
Browse files Browse the repository at this point in the history
  • Loading branch information
arminaleTT committed Feb 24, 2025
1 parent 867dcb7 commit 34107c9
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "mlir/IR/AffineExpr.h"
#include "gtest/gtest.h"

#include <cstdint>
#include <llvm/ADT/SmallVector.h>
#include <optional>

namespace mlir::tt::ttnn {
Expand Down Expand Up @@ -215,4 +217,75 @@ TEST_F(OpModelBase, MatmulInterface) {
}
}

TEST_F(OpModelBase, MeanInterface) {
// create MeanOp
llvm::SmallVector<int64_t> tensorShapeA = {2048, 1024};
llvm::SmallVector<int64_t> tensorShapeO = {2048, 1024};

auto input = createEmptyTensor(tensorShapeA);
auto output = createEmptyTensor(tensorShapeO);

auto mean = builder.create<MeanOp>(builder.getUnknownLoc(), output.getType(),
::mlir::ValueRange{input});
mean->setAttr(DeviceAttr::name, getFakeDeviceAttr());
mean.setKeepDim(true);
mean.setDimArgAttr(builder.getArrayAttr(
llvm::SmallVector<mlir::Attribute>{builder.getI64IntegerAttr(1)}));

// test mean Op interface
auto constraintsExp = getOpConstraints(mean.getOperation());
if (constraintsExp) {
auto l1 = constraintsExp.get();
const auto &[cb_size, peak_size, output_size] = l1;
EXPECT_EQ(cb_size, 12288);
EXPECT_EQ(peak_size, 2048);
EXPECT_EQ(output_size, 2048);
} else {
FAIL() << "Missing L1 constraints; Error="
<< llvm::toString(constraintsExp.takeError()) << std::endl;
}

auto runtimeExp = getOpRuntime(mean.getOperation());
if (runtimeExp) {
EXPECT_TRUE(runtimeExp.get() > 0);
} else {
FAIL() << llvm::toString(runtimeExp.takeError());
}
}

TEST_F(OpModelBase, reshapeOp) {
// create ReshapeOp
llvm::SmallVector<int64_t> tensorShapeA = {64, 1024};
llvm::SmallVector<int64_t> tensorShapeO = {64 * 4, 1024 / 4};

auto input = createEmptyTensor(tensorShapeA);
auto output = createEmptyTensor(tensorShapeO);

auto reshape = builder.create<ReshapeOp>(
builder.getUnknownLoc(), output.getType(), ::mlir::ValueRange{input});
reshape->setAttr(DeviceAttr::name, getFakeDeviceAttr());
reshape.setShapeAttr(builder.getArrayAttr(llvm::SmallVector<mlir::Attribute>{
builder.getI64IntegerAttr(64 * 4), builder.getI64IntegerAttr(1024 / 4)}));

// test mean Op interface
auto constraintsExp = getOpConstraints(reshape.getOperation());
if (constraintsExp) {
auto l1 = constraintsExp.get();
const auto &[cb_size, peak_size, output_size] = l1;
EXPECT_EQ(cb_size, 262144);
EXPECT_EQ(peak_size, 4096);
EXPECT_EQ(output_size, 2048);
} else {
FAIL() << "Missing L1 constraints; Error="
<< llvm::toString(constraintsExp.takeError()) << std::endl;
}

auto runtimeExp = getOpRuntime(reshape.getOperation());
if (runtimeExp) {
EXPECT_TRUE(runtimeExp.get() > 0);
} else {
FAIL() << llvm::toString(runtimeExp.takeError());
}
}

} // namespace mlir::tt::ttnn

0 comments on commit 34107c9

Please sign in to comment.