Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable constraints and runtime measurement for Reshape and Mean Ops #2273

Merged
merged 15 commits into from
Feb 26, 2025
8 changes: 6 additions & 2 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,9 @@ def TTNN_SumOp : TTNN_ReductionOp<"sum"> {
}];
}

def TTNN_MeanOp : TTNN_ReductionOp<"mean"> {
def TTNN_MeanOp : TTNN_ReductionOp<"mean",
[DeclareOpInterfaceMethods<TTNN_OpModelInterface, ["getOpConstraints", "getOpRuntime"]>]
> {
let summary = "Mean reduction op.";
let description = [{
Mean reduction op.
Expand Down Expand Up @@ -952,7 +954,9 @@ def TTNN_ConcatOp : TTNN_NamedDPSOp<"concat", [HasMemoryConfigTrait]> {
let hasVerifier = 1;
}

def TTNN_ReshapeOp : TTNN_Op<"reshape"> {
def TTNN_ReshapeOp : TTNN_Op<"reshape",
[DeclareOpInterfaceMethods<TTNN_OpModelInterface, ["getOpConstraints", "getOpRuntime"]>]
> {
let summary = "Reshape op.";
let description = [{
Reshape tensor.
Expand Down
38 changes: 38 additions & 0 deletions include/ttmlir/OpModel/TTNN/TTNNOpModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,44 @@ getOpRuntime(llvm::ArrayRef<int64_t> inputShape,

}; // namespace SoftmaxOpInterface

//===----------------------------------------------------------------------===//
// MeanOp
//===----------------------------------------------------------------------===//

namespace MeanOpInterface {
llvm::Expected<std::tuple<size_t, size_t, size_t>>
getOpConstraints(llvm::ArrayRef<int64_t> inputShape,
mlir::tt::ttnn::TTNNLayoutAttr inputLayout,
std::optional<llvm::ArrayRef<int64_t>> dimArg, bool keepDim,
mlir::tt::ttnn::TTNNLayoutAttr outputLayout);

llvm::Expected<size_t>
getOpRuntime(llvm::ArrayRef<int64_t> inputShape,
mlir::tt::ttnn::TTNNLayoutAttr inputLayout,
std::optional<llvm::ArrayRef<int64_t>> dimArg, bool keepDim,
mlir::tt::ttnn::TTNNLayoutAttr outputLayout);

}; // namespace MeanOpInterface

//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//

namespace ReshapeOpInterface {
llvm::Expected<std::tuple<size_t, size_t, size_t>>
getOpConstraints(llvm::ArrayRef<int64_t> inputShape,
mlir::tt::ttnn::TTNNLayoutAttr inputLayout,
llvm::ArrayRef<int64_t> outputShape,
mlir::tt::ttnn::TTNNLayoutAttr outputLayout);

llvm::Expected<size_t>
getOpRuntime(llvm::ArrayRef<int64_t> inputShape,
mlir::tt::ttnn::TTNNLayoutAttr inputLayout,
llvm::ArrayRef<int64_t> outputShape,
mlir::tt::ttnn::TTNNLayoutAttr outputLayout);

}; // namespace ReshapeOpInterface

//===----------------------------------------------------------------------===//
// MatmulOp
//===----------------------------------------------------------------------===//
Expand Down
98 changes: 94 additions & 4 deletions lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/IR/Operation.h"

#include <cassert>
#include <cstdint>
#include <optional>
#include <tuple>

Expand All @@ -24,6 +25,22 @@ llvm::Expected<bool> checkDeviceWorkerGrid(mlir::Operation *op) {
return op_model::ttnn::Device::getDeviceConstraints(
deviceAttr.getWorkerGrid());
}

std::optional<llvm::SmallVector<int64_t>>
convertReductionArg(std::optional<mlir::ArrayAttr> arrayOpt) {
if (!arrayOpt.has_value()) {
return std::nullopt;
}

llvm::SmallVector<int64_t> reduceDims;

for (const mlir::Attribute &reduceDim : *arrayOpt) {
reduceDims.push_back(mlir::cast<mlir::IntegerAttr>(reduceDim).getInt());
}

return reduceDims;
}

} // namespace detail

//===----------------------------------------------------------------------===//
Expand All @@ -41,7 +58,7 @@ ReluOp::getOpConstraints(const std::vector<TTNNLayoutAttr> &inputs,
const auto outputShape =
mlir::cast<RankedTensorType>(getResults().front().getType()).getShape();

auto check = detail::checkDeviceWorkerGrid(getOperation());
llvm::Expected<bool> check = detail::checkDeviceWorkerGrid(getOperation());
if (!check) {
return check.takeError();
}
Expand Down Expand Up @@ -83,7 +100,7 @@ AddOp::getOpConstraints(const std::vector<TTNNLayoutAttr> &inputs,
const auto outputShape =
mlir::cast<RankedTensorType>(getResult(0).getType()).getShape();

auto check = detail::checkDeviceWorkerGrid(getOperation());
llvm::Expected<bool> check = detail::checkDeviceWorkerGrid(getOperation());
if (!check) {
return check.takeError();
}
Expand Down Expand Up @@ -124,7 +141,7 @@ SoftmaxOp::getOpConstraints(const std::vector<TTNNLayoutAttr> &inputs,
const auto outputShape =
mlir::cast<RankedTensorType>(getResult().getType()).getShape();

auto check = detail::checkDeviceWorkerGrid(getOperation());
llvm::Expected<bool> check = detail::checkDeviceWorkerGrid(getOperation());
if (!check) {
return check.takeError();
}
Expand All @@ -148,6 +165,79 @@ SoftmaxOp::getOpRuntime(const std::vector<TTNNLayoutAttr> &inputs,
inputShape, inputs[0], getDimension(), outputShape, output);
}

//===----------------------------------------------------------------------===//
// MeanOp - TTNN Op Model Interface
//===----------------------------------------------------------------------===//

llvm::Expected<std::tuple<size_t, size_t, size_t>>
MeanOp::getOpConstraints(const std::vector<TTNNLayoutAttr> &inputs,
const TTNNLayoutAttr &output) {
assert(inputs.size() == 1);

const auto inputShape =
mlir::cast<RankedTensorType>(getOperand().getType()).getShape();

llvm::Expected<bool> check = detail::checkDeviceWorkerGrid(getOperation());
if (!check) {
return check.takeError();
}

return op_model::ttnn::MeanOpInterface::getOpConstraints(
inputShape, inputs[0], detail::convertReductionArg(getDimArg()),
getKeepDim(), output);
}

llvm::Expected<size_t>
MeanOp::getOpRuntime(const std::vector<TTNNLayoutAttr> &inputs,
const TTNNLayoutAttr &output) {
assert(inputs.size() == 1);

const auto inputShape =
mlir::cast<RankedTensorType>(getOperand().getType()).getShape();

return op_model::ttnn::MeanOpInterface::getOpRuntime(
inputShape, inputs[0], detail::convertReductionArg(getDimArg()),
getKeepDim(), output);
}

//===----------------------------------------------------------------------===//
// ReshapeOp - TTNN Op Model Interface
//===----------------------------------------------------------------------===//

llvm::Expected<std::tuple<size_t, size_t, size_t>>
ReshapeOp::getOpConstraints(const std::vector<TTNNLayoutAttr> &inputs,
const TTNNLayoutAttr &output) {
assert(inputs.size() == 1);

const auto inputShape =
mlir::cast<RankedTensorType>(getOperand().getType()).getShape();

const auto outputShape =
mlir::cast<RankedTensorType>(getResult().getType()).getShape();

llvm::Expected<bool> check = detail::checkDeviceWorkerGrid(getOperation());
if (!check) {
return check.takeError();
}

return op_model::ttnn::ReshapeOpInterface::getOpConstraints(
inputShape, inputs[0], outputShape, output);
}

llvm::Expected<size_t>
ReshapeOp::getOpRuntime(const std::vector<TTNNLayoutAttr> &inputs,
const TTNNLayoutAttr &output) {
assert(inputs.size() == 1);

const auto inputShape =
mlir::cast<RankedTensorType>(getOperand().getType()).getShape();
const auto outputShape =
mlir::cast<RankedTensorType>(getResult().getType()).getShape();

return op_model::ttnn::ReshapeOpInterface::getOpRuntime(inputShape, inputs[0],
outputShape, output);
}

//===----------------------------------------------------------------------===//
// MatmulOp - TTNN Op Model Interface
//===----------------------------------------------------------------------===//
Expand All @@ -165,7 +255,7 @@ MatmulOp::getOpConstraints(const std::vector<TTNNLayoutAttr> &inputs,
const auto outputShape =
mlir::cast<RankedTensorType>(getResult().getType()).getShape();

auto check = detail::checkDeviceWorkerGrid(getOperation());
llvm::Expected<bool> check = detail::checkDeviceWorkerGrid(getOperation());
if (!check) {
return check.takeError();
}
Expand Down
8 changes: 8 additions & 0 deletions lib/OpModel/TTNN/Conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>
#include <optional>
#include <stdexcept>
#ifdef TTMLIR_ENABLE_OPMODEL
#include "Conversion.hpp"

Expand Down Expand Up @@ -150,6 +153,11 @@ ::ttnn::TensorSpec getTensorSpec(const ::llvm::ArrayRef<int64_t> shape,
return ::ttnn::TensorSpec(getShape(shape), getTensorLayout(layout));
}

::ttnn::SmallVector<int>
convertLLVMSmallVecToTTNNSmallVec(const ::llvm::ArrayRef<int64_t> vec) {
return ::ttnn::SmallVector<int>(vec.begin(), vec.end());
}

} // namespace conversion
} // namespace mlir::tt::op_model::ttnn
#endif // TTMLIR_ENABLE_OPMODEL
3 changes: 3 additions & 0 deletions lib/OpModel/TTNN/Conversion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ getTensorLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout);
::ttnn::TensorSpec getTensorSpec(const ::llvm::ArrayRef<int64_t> shape,
const mlir::tt::ttnn::TTNNLayoutAttr &layout);

::ttnn::SmallVector<int>
convertLLVMSmallVecToTTNNSmallVec(const ::llvm::ArrayRef<int64_t> vec);

} // namespace conversion
} // namespace mlir::tt::op_model::ttnn

Expand Down
2 changes: 2 additions & 0 deletions lib/OpModel/TTNN/MetalHeaders.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,12 @@
#include "ttnn/graph/graph_query_op_constraints.hpp"
#include "ttnn/graph/graph_query_op_runtime.hpp"
#include "ttnn/graph/graph_trace_utils.hpp"
#include "ttnn/operations/data_movement/reshape_view/reshape.hpp"
#include "ttnn/operations/eltwise/binary/binary.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/matmul/matmul.hpp"
#include "ttnn/operations/normalization/softmax/softmax.hpp"
#include "ttnn/operations/reduction/generic/generic_reductions.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/tensor_spec.hpp"
#include "ttnn/tensor/types.hpp"
Expand Down
Loading
Loading