Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding memory config to a reshape op #2275

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -953,16 +953,17 @@ def TTNN_ConcatOp : TTNN_Op<"concat", [HasMemoryConfigTrait]> {
let hasVerifier = 1;
}

def TTNN_ReshapeOp : TTNN_Op<"reshape",
[DeclareOpInterfaceMethods<TTNN_OpModelInterface, ["getOpConstraints", "getOpRuntime"]>]
def TTNN_ReshapeOp : TTNN_Op<"reshape", [HasMemoryConfigTrait,
DeclareOpInterfaceMethods<TTNN_OpModelInterface, ["getOpConstraints", "getOpRuntime"]>]
> {
let summary = "Reshape op.";
let description = [{
Reshape tensor.
}];

let arguments = (ins AnyRankedTensor:$input,
I32ArrayAttr:$shape);
I32ArrayAttr:$shape,
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the ttnntoemitc conversion as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done this for concat op as well. :) Please take one more look at the emitC conversion changes.


let results = (outs AnyRankedTensor:$result);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class ReduceOpsKeepDimRewritePattern : public OpRewritePattern<ReduceOp> {
llvm::SmallVector<int32_t>(outputType.getShape()));

rewriter.replaceOpWithNewOp<mlir::tt::ttnn::ReshapeOp>(
srcOp, outputType, newReduceOp, shapeAttr);
srcOp, outputType, newReduceOp, shapeAttr, /* memory_config */ nullptr);
}

// Determine if the workaround is required.
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ table ReshapeOp {
in: tt.target.ttnn.TensorRef;
out: tt.target.ttnn.TensorRef;
shape: [int32];
memory_config: tt.target.ttnn.MemoryConfig;
}

table RepeatOp {
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ class ReshapeOpConversionPattern : public OpConversionPattern<ttir::ReshapeOp> {
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::ReshapeOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getShape());
adaptor.getInput(), adaptor.getShape(), /* memory_config */ nullptr);
return success();
}
};
Expand Down Expand Up @@ -731,7 +731,7 @@ class SqueezeOpConversionPattern : public OpConversionPattern<ttir::SqueezeOp> {
// Replace the SqueezeOp with a ReshapeOp
rewriter.replaceOpWithNewOp<ttnn::ReshapeOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), shapeAttr);
adaptor.getInput(), shapeAttr, /* memory_config */ nullptr);

return success();
}
Expand Down Expand Up @@ -854,7 +854,7 @@ class UnsqueezeOpConversionPattern
// Replace the UnsqueezeOp with a ReshapeOp
rewriter.replaceOpWithNewOp<ttnn::ReshapeOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), shapeAttr);
adaptor.getInput(), shapeAttr, /* memory_config */ nullptr);

return success();
}
Expand Down
5 changes: 3 additions & 2 deletions lib/Conversion/TTIRToTTNN/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ ttnn::ReshapeOp generateReshape(mlir::TypedValue<mlir::RankedTensorType> input,
newShape, inputType.getElementType(), outputLayoutAttr);

llvm::SmallVector<int32_t> newShapeI32(newShape.begin(), newShape.end());
return rewriter.create<ttnn::ReshapeOp>(
input.getLoc(), outputType, input, rewriter.getI32ArrayAttr(newShapeI32));
return rewriter.create<ttnn::ReshapeOp>(input.getLoc(), outputType, input,
rewriter.getI32ArrayAttr(newShapeI32),
/* memory_config */ nullptr);
}

ttnn::ReshapeOp
Expand Down
60 changes: 45 additions & 15 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,19 +511,34 @@ class ReshapeOpConversionPattern
ttnn::ReshapeOp>::TTNNToEmitCBaseOpConversionPattern;

LogicalResult
matchAndRewrite(ttnn::ReshapeOp srcOp, OpAdaptor adaptor,
matchAndRewrite(ttnn::ReshapeOp reshapeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Create operands vector
//
llvm::SmallVector<Value, 2> operands{
adaptor.getOperands()[0], // Input tensor
};

// emitc::CallOpaqueOp needs to know positions of operands vs attributes, so
// an ArrayAttr object holding IndexTypes is created to denote this.
//
ArrayAttr arrayAttrs = rewriter.getArrayAttr(
{rewriter.getIndexAttr(0), ttnn_to_emitc::utils::convertArrayAttrToSpan(
rewriter, srcOp.getShapeAttr())});
ArrayAttr arrayAttrs = rewriter.getArrayAttr({
rewriter.getIndexAttr(0), // Input tensor
ttnn_to_emitc::utils::convertArrayAttrToSpan(
rewriter, reshapeOp.getShapeAttr()), // Shape span
reshapeOp.getMemoryConfig()
? (operands.append(1, ttnn_to_emitc::utils::createMemoryConfigOp(
rewriter, reshapeOp.getMemoryConfigAttr(),
reshapeOp.getLoc())
->getResult(0)),
mlir::cast<Attribute>(rewriter.getIndexAttr(1)))
: ttnn_to_emitc::utils::createStdNullopt(
rewriter) // ttnn::MemoryConfig
});

rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
srcOp, this->getTypeConverter()->convertType(srcOp.getType()),
this->convertOpName(srcOp), arrayAttrs, nullptr, adaptor.getOperands());
reshapeOp, this->getTypeConverter()->convertType(reshapeOp.getType()),
this->convertOpName(reshapeOp), arrayAttrs, nullptr, operands);

return success();
}
Expand Down Expand Up @@ -566,7 +581,7 @@ class ConcatOpConversionPattern
ttnn::ConcatOp>::TTNNToEmitCBaseOpConversionPattern;

LogicalResult
matchAndRewrite(ttnn::ConcatOp srcOp, OpAdaptor adaptor,
matchAndRewrite(ttnn::ConcatOp concatOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// ttnn::concat op requires a `std::vector<>` of `Tensor` objects, but we
Expand All @@ -575,23 +590,38 @@ class ConcatOpConversionPattern
// by creating a utility function within the IR that converts a list of
// `Tensor` objects into a `std::vector<ttnn::Tensor>`.

ttnn_to_emitc::utils::insertVecCreateFnIfNotExists(rewriter, srcOp);
ttnn_to_emitc::utils::insertVecCreateFnIfNotExists(rewriter, concatOp);

mlir::emitc::CallOpaqueOp vectorOp = rewriter.create<emitc::CallOpaqueOp>(
srcOp.getLoc(),
concatOp.getLoc(),
emitc::OpaqueType::get(rewriter.getContext(),
"std::vector<ttnn::Tensor>"),
ttnn_to_emitc::utils::kCreateVectorFunctionName, nullptr, nullptr,
adaptor.getInputs());

ArrayAttr arrayAttrs = rewriter.getArrayAttr(
{mlir::IntegerAttr::get(rewriter.getIndexType(), 0),
srcOp.getDimAttr()});
// Create operands vector
//
llvm::SmallVector<Value, 2> operands{
vectorOp->getResult(0), // Input vector of tensors
};

ArrayAttr arrayAttrs = rewriter.getArrayAttr({
mlir::IntegerAttr::get(rewriter.getIndexType(),
0), // Input vector of tensors
concatOp.getDimAttr(), // Concat dimension
concatOp.getMemoryConfig()
? (operands.append(1, ttnn_to_emitc::utils::createMemoryConfigOp(
rewriter, concatOp.getMemoryConfigAttr(),
concatOp.getLoc())
->getResult(0)),
mlir::cast<Attribute>(rewriter.getIndexAttr(1)))
: ttnn_to_emitc::utils::createStdNullopt(
rewriter) // ttnn::MemoryConfig
});

rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
srcOp, this->getTypeConverter()->convertType(srcOp.getType()),
this->convertOpName(srcOp), arrayAttrs, nullptr,
ValueRange(vectorOp->getResults()));
concatOp, this->getTypeConverter()->convertType(concatOp.getType()),
this->convertOpName(concatOp), arrayAttrs, nullptr, operands);

return success();
}
Expand Down
63 changes: 30 additions & 33 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -691,61 +691,58 @@ ::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getOutput().getType();
auto shape = getShape();
int64_t shape_size = static_cast<int64_t>(shape.size());
int64_t shapeSize = static_cast<int64_t>(shape.size());

// Check that the shape size matches the rank of the output tensor
if (shape_size != static_cast<int64_t>(outputType.getRank())) {
return emitOpError("Shape attribute size must match output tensor rank");
// Check that the shape attribute is non-empty.
if (shapeSize == 0) {
return emitOpError("Shape attribute must be non-empty");
}

// Check that the shape attribute is non-empty
if (shape_size == 0) {
return emitOpError("Shape attribute must be non-empty");
// Check that the shape size matches the rank of the output tensor.
if (shapeSize != static_cast<int64_t>(outputType.getRank())) {
return emitOpError() << "Shape attribute size " << shapeSize
<< " must match output tensor rank "
<< outputType.getRank();
}

// Cardinality of the input and output tensors must be the same
// Cardinality of the input and output tensors must be the same.
if (inputType.getNumElements() != outputType.getNumElements()) {
return emitOpError(
"Input and output tensors must have the same number of elements");
return emitOpError() << "Input tensor number of elements "
<< inputType.getNumElements()
<< " and output tensor number of elements "
<< outputType.getNumElements() << " must be the same";
}

bool has_negative = false;
int64_t known_dim_product = 1;
bool hasNegative = false;
auto outputShape = outputType.getShape();

// Check that all dimensions are positive except for at most one -1
// Check that the non-negative dimensions match the output tensor shape
// Calculate the product of the known dimensions
for (int64_t i = 0; i < shape_size; i++) {
int64_t dim_value = mlir::cast<IntegerAttr>(shape[i]).getInt();
// Check that all dimensions are positive except for at most one -1.
// Check that the non-negative dimensions match the output tensor shape.
// Calculate the product of the known dimensions.
for (int64_t i = 0; i < shapeSize; i++) {
int64_t dimValue = mlir::cast<IntegerAttr>(shape[i]).getInt();

if (dim_value == -1) {
if (has_negative) {
if (dimValue == -1) {
if (hasNegative) {
return emitOpError("Shape attribute must have at most one -1 element");
}
has_negative = true;
hasNegative = true;
} else {
if (dim_value <= 0) {
if (dimValue <= 0) {
return emitOpError(
"All dimensions must be positive except the one with -1");
}

// Ensure that the non-negative dimensions match the output tensor shape
if (dim_value != outputShape[i]) {
return emitOpError("Shape attribute must match the output tensor shape "
"for dimensions that are not -1");
// Ensure that the non-negative dimensions match the output tensor shape.
if (dimValue != outputShape[i]) {
return emitOpError()
<< "Shape attribute " << dimValue
<< " must match the output tensor shape " << outputShape[i]
<< " at index " << i << " for dimension that is not -1";
}

known_dim_product *= dim_value;
}
}

// If there's a -1, ensure that it can be inferred correctly
if (has_negative && inputType.getNumElements() % known_dim_product != 0) {
return emitOpError("Invalid shape: the dimensions do not multiply to the "
"total number of elements in the tensor");
}

return success();
}

Expand Down
50 changes: 24 additions & 26 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,58 +647,56 @@ ::mlir::LogicalResult mlir::tt::ttnn::ReshapeOp::verify() {
auto shape = getShape();
int64_t shapeSize = static_cast<int64_t>(shape.size());

// Check that the shape size matches the rank of the output tensor
if (shapeSize != static_cast<int64_t>(outputType.getRank())) {
return emitOpError("Shape attribute size must match output tensor rank");
}
// Check that the shape attribute is non-empty
// Check that the shape attribute is non-empty.
if (shapeSize == 0) {
return emitOpError("Shape attribute must be non-empty");
}

// Cardinality of the input and output tensors must be the same
// Check that the shape size matches the rank of the output tensor.
if (shapeSize != static_cast<int64_t>(outputType.getRank())) {
return emitOpError() << "Shape attribute size " << shapeSize
<< " must match output tensor rank "
<< outputType.getRank();
}

// Cardinality of the input and output tensors must be the same.
if (inputType.getNumElements() != outputType.getNumElements()) {
return emitOpError(
"Input and output tensors must have the same number of elements");
return emitOpError() << "Input tensor number of elements "
<< inputType.getNumElements()
<< " and output tensor number of elements "
<< outputType.getNumElements() << " must be the same";
}

bool has_negative = false;
int64_t known_dim_product = 1;
bool hasNegative = false;
auto outputShape = outputType.getShape();

// Check that all dimensions are positive except for at most one -1
// Check that the non-negative dimensions match the output tensor shape
// Calculate the product of the known dimensions
for (int64_t i = 0; i < shapeSize; i++) {
int64_t dim_value = mlir::cast<IntegerAttr>(shape[i]).getInt();
int64_t dimValue = mlir::cast<IntegerAttr>(shape[i]).getInt();

if (dim_value == -1) {
if (has_negative) {
if (dimValue == -1) {
if (hasNegative) {
return emitOpError("Shape attribute must have at most one -1 element");
}
has_negative = true;
hasNegative = true;
} else {
if (dim_value <= 0) {
if (dimValue <= 0) {
return emitOpError(
"All dimensions must be positive except the one with -1");
}

// Ensure that the non-negative dimensions match the output tensor shape
if (dim_value != outputShape[i]) {
return emitOpError("Shape attribute must match the output tensor shape "
"for dimensions that are not -1");
if (dimValue != outputShape[i]) {
return emitOpError()
<< "Shape attribute " << dimValue
<< " must match the output tensor shape " << outputShape[i]
<< " at index " << i << " for dimension that is not -1";
}

known_dim_product *= dim_value;
}
}

// If there's a -1, ensure that it can be inferred correctly
if (has_negative && inputType.getNumElements() % known_dim_product != 0) {
return emitOpError("Invalid shape: the dimensions do not multiply to the "
"total number of elements in the tensor");
}

return success();
}

Expand Down
9 changes: 5 additions & 4 deletions lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ class TTNNAllReduceWorkarounds : public OpRewritePattern<ttnn::AllReduceOp> {

// Create a new reshape op.
ttnn::ReshapeOp preReshapeOp = rewriter.create<ttnn::ReshapeOp>(
loc, Type(reshapedInputType), op.getInput(), reshapedInputShapeAttr);
loc, Type(reshapedInputType), op.getInput(), reshapedInputShapeAttr,
/* memory_config */ nullptr);

// Determine new dimension since entire tensor shape got shifted.
dimension = dimension + requiredOnesInput;
Expand Down Expand Up @@ -424,9 +425,9 @@ class TTNNAllReduceWorkarounds : public OpRewritePattern<ttnn::AllReduceOp> {
loc, Type(reshapedOutputType), reduceScatterOp.getResult(),
deviceValue, dimension, clusterAxis);

rewriter.replaceOpWithNewOp<ttnn::ReshapeOp>(op, Type(outputType),
allGatherOp.getResult(),
reshapedOutputShapeAttr);
rewriter.replaceOpWithNewOp<ttnn::ReshapeOp>(
op, Type(outputType), allGatherOp.getResult(),
reshapedOutputShapeAttr, /* memory_config */ nullptr);
} else {
// TODO(wooseoklee): Once ttnn supports all_reduce op
// (https://github.com/tenstorrent/tt-metal/issues/13835), we can convert
Expand Down
11 changes: 10 additions & 1 deletion lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,16 @@ createReshapeOp(FlatbufferObjectCache &cache, ReshapeOp op) {
auto out = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedSize);

return ::tt::target::ttnn::CreateReshapeOp(*cache.fbb, in, out, shape);
std::optional<mlir::tt::ttnn::MemoryConfigAttr> memoryConfig =
op.getMemoryConfig();
auto tileShape = getTensorValueTileShape(op.getResult());
auto coreRangeSet = getTensorValueCoreRangeSet(cache, op.getResult());

return ::tt::target::ttnn::CreateReshapeOp(
*cache.fbb, in, out, shape,
memoryConfig ? memoryConfigToFlatbuffer(cache, memoryConfig.value(),
tileShape, coreRangeSet)
: 0);
}

template <typename RepeatOp>
Expand Down
Loading
Loading