Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EmitC] Support MNIST #1663

Merged
merged 4 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -903,11 +903,12 @@ def TTNN_EmptyOp : TTNN_Op<"empty"> {
Tensor empty operation
}];

let arguments = (ins Optional<TT_Device>:$device,
TTNN_ShapeAttr:$shape,
OptionalAttr<TT_DataTypeAttr>:$dtype,
OptionalAttr<TTNN_LayoutAttr>:$layout,
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);
let arguments = (ins TTNN_ShapeAttr:$shape,
TT_DataTypeAttr:$dtype,
TTNN_LayoutAttr:$layout,
TT_Device:$device,
TTNN_MemoryConfigAttr:$memory_config);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
Expand Down
34 changes: 14 additions & 20 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,33 +64,27 @@ class TensorEmptyConversionPattern
ttnn::LayoutAttr tensorLayoutAttr =
ttnn::LayoutAttr::get(op.getContext(), ttnnLayoutEnum);

// If the tensor is not going to device, we can create the op without
// device-specific attributes
// Device
//
ttnn::TensorMemoryLayoutAttr memLayout = layoutAttr.getMemLayout();
if (!memLayout) {
rewriter.replaceOpWithNewOp<ttnn::EmptyOp>(
op, this->getTypeConverter()->convertType(op.getType()), nullptr,
shapeAttr, dTypeAttr, tensorLayoutAttr, nullptr);

return success();
}

ttnn::BufferType bufferType = layoutAttr.getBufferType();
auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);

// Create MemoryConfigAttr
//
auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);
llvm::SmallVector<int64_t> shardShape = layoutAttr.getShardShape();
ttnn::BufferTypeAttr bufferTypeAttr =
ttnn::BufferTypeAttr::get(op.getContext(), layoutAttr.getBufferType());
ttnn::ShardSpecAttr shardSpecAttr = ttnn::ShardSpecAttr::get(
op.getContext(),
ttnn::ShapeAttr::get(op.getContext(), layoutAttr.getShardShape()));
ttnn::TensorMemoryLayoutAttr memLayout =
layoutAttr.getMemLayout() ? layoutAttr.getMemLayout() : nullptr;
ttnn::MemoryConfigAttr memoryConfigAttr = ttnn::MemoryConfigAttr::get(
op.getContext(), ttnn::BufferTypeAttr::get(op.getContext(), bufferType),
ttnn::ShardSpecAttr::get(
op.getContext(), ttnn::ShapeAttr::get(op.getContext(), shardShape)),
memLayout);
op.getContext(), bufferTypeAttr, shardSpecAttr, memLayout);

// Replace op
//
rewriter.replaceOpWithNewOp<ttnn::EmptyOp>(
op, this->getTypeConverter()->convertType(op.getType()), device,
shapeAttr, dTypeAttr, tensorLayoutAttr, memoryConfigAttr);
op, this->getTypeConverter()->convertType(op.getType()), shapeAttr,
dTypeAttr, tensorLayoutAttr, device, memoryConfigAttr);

return success();
}
Expand Down
174 changes: 126 additions & 48 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,43 @@ class DefaultOpConversionPattern
}
};

// Eltwise Unary op conversion pattern
//
// Currently, it has to insert nullopts for some parameters that are not
// modelled in the dialect (memcfg)
//
template <typename SourceOp, typename Adaptor = typename SourceOp::Adaptor>
class EltwiseUnaryOpConversionPattern
: public TTNNToEmitCBaseOpConversionPattern<SourceOp> {

public:
EltwiseUnaryOpConversionPattern(const TypeConverter &typeConverter,
MLIRContext *context,
PatternBenefit benefit = 1)
: TTNNToEmitCBaseOpConversionPattern<SourceOp>(typeConverter, context,
benefit) {}

LogicalResult
matchAndRewrite(SourceOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// emitc::CallOpaqueOp needs to know positions of operands vs attributes, so
// an ArrayAttr object holding IndexTypes is created to denote this
//
llvm::SmallVector<Attribute, 5> attrs;
attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 0));
attrs.push_back(ttnn_to_emitc::utils::createStdNullopt(rewriter));
attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 1));

ArrayAttr arrayAttrs = ArrayAttr::get(srcOp->getContext(), attrs);

rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
srcOp, this->getTypeConverter()->convertType(srcOp.getType(0)),
this->convertOpName(srcOp), arrayAttrs, nullptr, adaptor.getOperands());

return success();
}
};

// Eltwise Binary op conversion pattern
//
// Currently, it has to insert nullopts for some parameters that are not
Expand All @@ -132,6 +169,7 @@ class EltwiseBinaryOpConversionPattern
LogicalResult
matchAndRewrite(SourceOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// emitc::CallOpaqueOp needs to know positions of operands vs attributes, so
// an ArrayAttr object holding IndexTypes is created to denote this
//
Expand All @@ -152,6 +190,50 @@ class EltwiseBinaryOpConversionPattern
}
};

// Matmul op conversion pattern
//
class MatmulOpConversionPattern
: public TTNNToEmitCBaseOpConversionPattern<ttnn::MatmulOp> {

public:
MatmulOpConversionPattern(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: TTNNToEmitCBaseOpConversionPattern<ttnn::MatmulOp>(typeConverter,
context, benefit) {}

LogicalResult
matchAndRewrite(ttnn::MatmulOp matmulOp, ttnn::MatmulOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// emitc::CallOpaqueOp needs to know positions of operands vs attributes, so
// an ArrayAttr object holding IndexTypes is created to denote this
//
ArrayAttr arrayAttrs = rewriter.getArrayAttr({
mlir::IntegerAttr::get(rewriter.getIndexType(), 0),
mlir::IntegerAttr::get(rewriter.getIndexType(), 1),
ttnn_to_emitc::utils::convertBoolAttr(
rewriter, BoolAttr::get(rewriter.getContext(), false)),
ttnn_to_emitc::utils::convertBoolAttr(
rewriter, BoolAttr::get(rewriter.getContext(), false)),
ttnn_to_emitc::utils::createStdNullopt(rewriter),
ttnn_to_emitc::utils::createStdNullopt(rewriter),
ttnn_to_emitc::utils::createStdNullopt(rewriter),
ttnn_to_emitc::utils::createStdNullopt(rewriter),
ttnn_to_emitc::utils::createStdNullopt(rewriter),
ttnn_to_emitc::utils::createStdNullopt(rewriter),
ttnn_to_emitc::utils::createStdNullopt(rewriter),
mlir::IntegerAttr::get(rewriter.getIndexType(), 2),
});

rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
matmulOp, this->getTypeConverter()->convertType(matmulOp.getType()),
this->convertOpName(matmulOp), arrayAttrs, nullptr,
adaptor.getOperands());

return success();
}
};

// GetDeviceOp conversion pattern
//
class GetDeviceOpConversionPattern
Expand Down Expand Up @@ -390,46 +472,42 @@ class EmptyOpConversionPattern
tt::DataTypeAttr dataTypeAttr = srcOp.getDtypeAttr();
ttnn::LayoutAttr layoutAttr = srcOp.getLayoutAttr();

// Find the GetDeviceOp
//
ttnn::GetDeviceOp getDeviceOp;
srcOp->getParentOp()->walk(
[&getDeviceOp](ttnn::GetDeviceOp currGetDeviceOp) {
getDeviceOp = currGetDeviceOp;
});

// Create ttnn::Shape() call
//
emitc::ExpressionOp shapeExpressionOp = ttnn_to_emitc::utils::createShapeOp(
rewriter, shapeAttr, srcOp->getBlock(), srcOp.getLoc());

llvm::SmallVector<Value, 3> operands{
shapeExpressionOp->getResult(0),
};

// If there is a device operand, create tensor on device
// Create operands vector
//
ArrayAttr arrayAttr;
if (adaptor.getDevice()) {
operands.append(1, adaptor.getDevice());
llvm::SmallVector<Value, 3> operands{shapeExpressionOp->getResult(0),
adaptor.getDevice()};

// Create MemoryConfig object first, then pass it to the op
//
emitc::CallOpaqueOp memCfgOp = ttnn_to_emitc::utils::createMemoryConfigOp(
rewriter, srcOp.getMemoryConfig().value(), srcOp.getLoc());
// Create MemoryConfig object first, then pass it to the op
//
emitc::CallOpaqueOp memCfgOp = ttnn_to_emitc::utils::createMemoryConfigOp(
rewriter, srcOp.getMemoryConfig(), srcOp.getLoc());

// Concat operands and MemoryConfig object
//
operands.append(1, memCfgOp.getResult(0));
// Concat operands and MemoryConfig object
//
operands.append(1, memCfgOp.getResult(0));

// Create ArrayAttr object holding attributes and pointers to operands
//
arrayAttr = rewriter.getArrayAttr({
rewriter.getIndexAttr(0), // ttnn::Shape
ttnn_to_emitc::utils::convertDType(rewriter, dataTypeAttr),
ttnn_to_emitc::utils::convertLayoutAttr(rewriter, layoutAttr),
rewriter.getIndexAttr(1), // ttnn::Device
rewriter.getIndexAttr(2), // ttnn::MemoryConfig
});
} else {
arrayAttr = rewriter.getArrayAttr({
rewriter.getIndexAttr(0), // ttnn::Shape
ttnn_to_emitc::utils::convertDType(rewriter, dataTypeAttr),
ttnn_to_emitc::utils::convertLayoutAttr(rewriter, layoutAttr),
});
}
// Create ArrayAttr object holding attributes and pointers to operands
//
ArrayAttr arrayAttr = rewriter.getArrayAttr({
rewriter.getIndexAttr(0), // ttnn::Shape
ttnn_to_emitc::utils::convertDType(rewriter, dataTypeAttr),
ttnn_to_emitc::utils::convertLayoutAttr(rewriter, layoutAttr),
rewriter.getIndexAttr(1), // ttnn::Device
rewriter.getIndexAttr(2), // ttnn::MemoryConfig
});

// Finally, convert ttir::EmptyOp to ttnn::EmptyOp
//
Expand Down Expand Up @@ -469,14 +547,14 @@ class OnesOpConversionPattern
// Attrs (like shape) need to be instantiated into objects before being
// passed to the op. Therefore:
//
// We first create a ttnn::Shape object (SSA) by calling createShapeOp() and
// add it to the operands vector, but also add an IndexAttr in ArrayAttr to
// reference it (this is an EmitC mechanism that allows for combining Attrs
// and Values when calling an OpaqueOp).
// All the other input params are optional, so we create them on-the-fly
// into the ArrayAttr, whether they are an actual Attr, or a Value pointed
// to by IndexAttr. If they are present, we create the object and pass it to
// the op. If not, we pass std::nullopt.
// We first create a ttnn::Shape object (SSA) by calling createShapeOp()
// and add it to the operands vector, but also add an IndexAttr in
// ArrayAttr to reference it (this is an EmitC mechanism that allows for
// combining Attrs and Values when calling an OpaqueOp). All the other
// input params are optional, so we create them on-the-fly into the
// ArrayAttr, whether they are an actual Attr, or a Value pointed to by
// IndexAttr. If they are present, we create the object and pass it to the
// op. If not, we pass std::nullopt.

// Create ttnn::Shape() call
//
Expand All @@ -489,8 +567,8 @@ class OnesOpConversionPattern

// Create ArrayAttr object holding attributes and pointers to operands
//
// Params that are Values are added to the operands vector on-the-fly, and a
// corresponding IndexAttr is added to the ArrayAttr to reference them.
// Params that are Values are added to the operands vector on-the-fly, and
// a corresponding IndexAttr is added to the ArrayAttr to reference them.
//
size_t operandIndex = 0;
ArrayAttr arrayAttr = rewriter.getArrayAttr({
Expand Down Expand Up @@ -594,8 +672,8 @@ class GetTupleElementOpConversionPattern
getTupleElementOp->getLoc(), rewriter.getIndexType(),
std::to_string(adaptor.getIndex()));

// SubscriptOp also returns an emitc::LValueType, so we wrap the OpaqueType
// with LValueType
// SubscriptOp also returns an emitc::LValueType, so we wrap the
// OpaqueType with LValueType
//
emitc::LValueType lvalueReturnType = emitc::LValueType::get(
emitc::OpaqueType::get(rewriter.getContext(), "ttnn::Tensor"));
Expand All @@ -621,9 +699,9 @@ class TupleOpConversionPattern : public OpConversionPattern<tt::TupleOp> {
LogicalResult
matchAndRewrite(tt::TupleOp tupleOp, tt::TupleOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// EmitC doesn't offer a way to create a vector from a list of values, so we
// need to create a utility function that does this. This is achieved by
// using EmitC's VerbatimOp.
// EmitC doesn't offer a way to create a vector from a list of values, so
// we need to create a utility function that does this. This is achieved
// by using EmitC's VerbatimOp.

// Try to find if utility vec creation function is already defined in the
// module. If not, insert it.
Expand Down Expand Up @@ -708,7 +786,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
DefaultOpConversionPattern<ttnn::LogicalNotOp>,
DefaultOpConversionPattern<ttnn::BitwiseNotOp>,
DefaultOpConversionPattern<ttnn::NegOp>,
DefaultOpConversionPattern<ttnn::ReluOp>,
EltwiseUnaryOpConversionPattern<ttnn::ReluOp>,
DefaultOpConversionPattern<ttnn::LeakyReluOp>,
DefaultOpConversionPattern<ttnn::GeluOp>,
DefaultOpConversionPattern<ttnn::SqrtOp>,
Expand Down Expand Up @@ -761,7 +839,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Matmul ops
//
patterns.add<DefaultOpConversionPattern<ttnn::LinearOp>,
DefaultOpConversionPattern<ttnn::MatmulOp>>(typeConverter, ctx);
MatmulOpConversionPattern>(typeConverter, ctx);

// Reduction ops
//
Expand Down
52 changes: 8 additions & 44 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,6 @@ ::mlir::LogicalResult mlir::tt::ttnn::ArangeOp::verify() {

// EmptyOp verification
::mlir::LogicalResult mlir::tt::ttnn::EmptyOp::verify() {
// ==============================
// === CHECK ATTRIBUTES START ===
// ==============================
// Check that the attributes of the op match the attributes of the output
// tensor type.
//
Expand All @@ -192,50 +189,17 @@ ::mlir::LogicalResult mlir::tt::ttnn::EmptyOp::verify() {

// DataType and Layout
//
if (getLayout().has_value()) {
ttnn::Layout ttnnLayoutEnum = layoutAttr.getLayout();
assert(ttnnLayoutEnum == getLayoutAttr().getValue());
}
if (getDtype().has_value()) {
tt::DataType dtype = layoutAttr.getDataType();
assert(dtype == getDtype());
}
assert(getLayout() == layoutAttr.getLayout());
assert(getDtype() == layoutAttr.getDataType());

// MemoryConfig
// Check that op has MemoryConfigAttr set on itself, then compare internal
// attrs with output tensor attrs.
//
if (getMemoryConfig().has_value()) {
ttnn::BufferType bufferType = layoutAttr.getBufferType();
ttnn::TensorMemoryLayoutAttr tensorMemoryLayoutAttr =
layoutAttr.getMemLayout();
assert(bufferType == getMemoryConfig()->getBufferType().getValue());
assert(tensorMemoryLayoutAttr ==
getMemoryConfig()->getTensorMemoryLayout());
}
// Compare internal attrs with output tensor attrs.
//
// ==============================
// ==== CHECK ATTRIBUTES END ====
// ==============================

// ==============================
// === CHECK SIGNATURES START ===
// ==============================
// Check that call-site uses the correct signature. We only allow 2 for now:
// 1. none, Shape, DataType, Layout, none
// 2. Device, Shape, DataType, Layout, MemoryConfig
//
assert(
// 1.
(!getDevice() && getDtype().has_value() && getLayout().has_value() &&
!getMemoryConfig().has_value()) ||
// 2.
(getDevice() && getDtype().has_value() && getLayout().has_value() &&
getMemoryConfig().has_value()));
//
// ==============================
// ==== CHECK SIGNATURES END ====
// ==============================
assert(getMemoryConfig().getBufferType().getValue() ==
layoutAttr.getBufferType());
assert(getMemoryConfig().getTensorMemoryLayout() ==
layoutAttr.getMemLayout());

return success();
}

Expand Down
Loading
Loading