From 8b398c630d3d862b3d405e56564849a4da2cb23d Mon Sep 17 00:00:00 2001 From: Stefan Djordjevic Date: Mon, 24 Feb 2025 09:46:05 +0000 Subject: [PATCH 1/2] Adding memory config to a reshape op. --- include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 7 ++++--- .../Decomposition/ReduceOpsRewritePattern.h | 2 +- include/ttmlir/Target/TTNN/program.fbs | 1 + lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 6 +++--- lib/Conversion/TTIRToTTNN/Utils.cpp | 5 +++-- .../TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp | 9 +++++---- lib/Target/TTNN/TTNNToFlatbuffer.cpp | 11 ++++++++++- runtime/lib/ttnn/operations/data_movement/reshape.cpp | 6 +++++- .../data_movement/reshape/reshape_folding_test.mlir | 7 ++++--- .../n150/{ => data_movement}/reshape/reshape.mlir | 2 +- 10 files changed, 37 insertions(+), 19 deletions(-) rename test/ttmlir/Silicon/TTNN/n150/{ => data_movement}/reshape/reshape.mlir (94%) diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 45f6311646..c95f200fe2 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -953,8 +953,8 @@ def TTNN_ConcatOp : TTNN_Op<"concat", [HasMemoryConfigTrait]> { let hasVerifier = 1; } -def TTNN_ReshapeOp : TTNN_Op<"reshape", - [DeclareOpInterfaceMethods] +def TTNN_ReshapeOp : TTNN_Op<"reshape", [HasMemoryConfigTrait, + DeclareOpInterfaceMethods] > { let summary = "Reshape op."; let description = [{ @@ -962,7 +962,8 @@ def TTNN_ReshapeOp : TTNN_Op<"reshape", }]; let arguments = (ins AnyRankedTensor:$input, - I32ArrayAttr:$shape); + I32ArrayAttr:$shape, + OptionalAttr:$memory_config); let results = (outs AnyRankedTensor:$result); diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h b/include/ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h index f80d6b80c2..f1bf5d45e2 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h +++ b/include/ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h @@ -95,7 +95,7 @@ class ReduceOpsKeepDimRewritePattern : public OpRewritePattern { llvm::SmallVector(outputType.getShape())); rewriter.replaceOpWithNewOp( - srcOp, outputType, newReduceOp, shapeAttr); + srcOp, outputType, newReduceOp, shapeAttr, /* memory_config */ nullptr); } // Determine if the workaround is required. diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 5ed126f1c5..a074610c22 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -267,6 +267,7 @@ table ReshapeOp { in: tt.target.ttnn.TensorRef; out: tt.target.ttnn.TensorRef; shape: [int32]; + memory_config: tt.target.ttnn.MemoryConfig; } table RepeatOp { diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 71db06d59b..c8c6e6ee27 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -671,7 +671,7 @@ class ReshapeOpConversionPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), - adaptor.getInput(), adaptor.getShape()); + adaptor.getInput(), adaptor.getShape(), /* memory_config */ nullptr); return success(); } }; @@ -731,7 +731,7 @@ class SqueezeOpConversionPattern : public OpConversionPattern { // Replace the SqueezeOp with a ReshapeOp rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), - adaptor.getInput(), shapeAttr); + adaptor.getInput(), shapeAttr, /* memory_config */ nullptr); return success(); } @@ -854,7 +854,7 @@ class UnsqueezeOpConversionPattern // Replace the UnsqueezeOp with a ReshapeOp rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), - adaptor.getInput(), shapeAttr); + adaptor.getInput(), shapeAttr, /* memory_config */ nullptr); return success(); } diff --git a/lib/Conversion/TTIRToTTNN/Utils.cpp b/lib/Conversion/TTIRToTTNN/Utils.cpp index 04fb629d40..9d4293741b 100644 --- a/lib/Conversion/TTIRToTTNN/Utils.cpp +++ b/lib/Conversion/TTIRToTTNN/Utils.cpp @@ -27,8 +27,9 @@ ttnn::ReshapeOp generateReshape(mlir::TypedValue input, newShape, inputType.getElementType(), outputLayoutAttr); llvm::SmallVector newShapeI32(newShape.begin(), newShape.end()); - return rewriter.create( - input.getLoc(), outputType, input, rewriter.getI32ArrayAttr(newShapeI32)); + return rewriter.create(input.getLoc(), outputType, input, + rewriter.getI32ArrayAttr(newShapeI32), + /* memory_config */ nullptr); } ttnn::ReshapeOp diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp index 02dd836f54..caea96bbad 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp @@ -386,7 +386,8 @@ class TTNNAllReduceWorkarounds : public OpRewritePattern { // Create a new reshape op. ttnn::ReshapeOp preReshapeOp = rewriter.create( - loc, Type(reshapedInputType), op.getInput(), reshapedInputShapeAttr); + loc, Type(reshapedInputType), op.getInput(), reshapedInputShapeAttr, + /* memory_config */ nullptr); // Determine new dimension since entire tensor shape got shifted. dimension = dimension + requiredOnesInput; @@ -424,9 +425,9 @@ class TTNNAllReduceWorkarounds : public OpRewritePattern { loc, Type(reshapedOutputType), reduceScatterOp.getResult(), deviceValue, dimension, clusterAxis); - rewriter.replaceOpWithNewOp(op, Type(outputType), - allGatherOp.getResult(), - reshapedOutputShapeAttr); + rewriter.replaceOpWithNewOp( + op, Type(outputType), allGatherOp.getResult(), + reshapedOutputShapeAttr, /* memory_config */ nullptr); } else { // TODO(wooseoklee): Once ttnn supports all_reduce op // (https://github.com/tenstorrent/tt-metal/issues/13835), we can convert diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 3d78258906..6bb9131c4a 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -1262,7 +1262,16 @@ createReshapeOp(FlatbufferObjectCache &cache, ReshapeOp op) { auto out = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, kHostAllocatedSize); - return ::tt::target::ttnn::CreateReshapeOp(*cache.fbb, in, out, shape); + std::optional memoryConfig = + op.getMemoryConfig(); + auto tileShape = getTensorValueTileShape(op.getResult()); + auto coreRangeSet = getTensorValueCoreRangeSet(cache, op.getResult()); + + return ::tt::target::ttnn::CreateReshapeOp( + *cache.fbb, in, out, shape, + memoryConfig ? memoryConfigToFlatbuffer(cache, memoryConfig.value(), + tileShape, coreRangeSet) + : 0); } template diff --git a/runtime/lib/ttnn/operations/data_movement/reshape.cpp b/runtime/lib/ttnn/operations/data_movement/reshape.cpp index 0be7114f00..ed66ca6906 100644 --- a/runtime/lib/ttnn/operations/data_movement/reshape.cpp +++ b/runtime/lib/ttnn/operations/data_movement/reshape.cpp @@ -5,6 +5,7 @@ #include "operations/data_movement/reshape.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/utils.h" namespace tt::runtime::ttnn::operations::data_movement { void run(const ::tt::target::ttnn::ReshapeOp *op, ProgramContext &context) { @@ -13,7 +14,10 @@ void run(const ::tt::target::ttnn::ReshapeOp *op, ProgramContext &context) { DEBUG_ASSERT(in.is_allocated()); const auto *fbShape = op->shape(); std::vector shape(fbShape->begin(), fbShape->end()); - ::ttnn::Tensor out = ::ttnn::reshape(in, shape); + std::optional<::ttnn::MemoryConfig> memoryConfig = + ::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded( + op->memory_config()); + ::ttnn::Tensor out = ::ttnn::reshape(in, shape, memoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); } } // namespace tt::runtime::ttnn::operations::data_movement diff --git a/test/ttmlir/Dialect/TTNN/data_movement/reshape/reshape_folding_test.mlir b/test/ttmlir/Dialect/TTNN/data_movement/reshape/reshape_folding_test.mlir index 15e41eaacf..142d33580a 100644 --- a/test/ttmlir/Dialect/TTNN/data_movement/reshape/reshape_folding_test.mlir +++ b/test/ttmlir/Dialect/TTNN/data_movement/reshape/reshape_folding_test.mlir @@ -1,11 +1,12 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s -// Tests if we fold when translating from "ttir.reshape" which is called on the two same shapes. + module @reshape_test { + // Test folding of "ttir.reshape" when called with identical shapes. func.func @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) { %0 = tensor.empty() : tensor<1xi32> %1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32]}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - // CHECK-NOT: = "ttnn.reshape"[C:.*]] - // CHECK: return %arg0 : tensor<1xsi32, #{{.*}}> + // CHECK-NOT: = "ttnn.reshape" return %1 : tensor<1xi32> + // CHECK: return %arg0 : tensor<1xsi32, #{{.*}}> } } diff --git a/test/ttmlir/Silicon/TTNN/n150/reshape/reshape.mlir b/test/ttmlir/Silicon/TTNN/n150/data_movement/reshape/reshape.mlir similarity index 94% rename from test/ttmlir/Silicon/TTNN/n150/reshape/reshape.mlir rename to test/ttmlir/Silicon/TTNN/n150/data_movement/reshape/reshape.mlir index ec7c9f25dd..3d3584dd8c 100644 --- a/test/ttmlir/Silicon/TTNN/n150/reshape/reshape.mlir +++ b/test/ttmlir/Silicon/TTNN/n150/data_movement/reshape/reshape.mlir @@ -4,7 +4,7 @@ func.func @reshape(%arg0: tensor<4x2x32x32xbf16>) -> tensor<2x4x32x32xbf16> { %0 = tensor.empty() : tensor<2x4x32x32xbf16> - // CHECK: = "ttnn.reshape" %1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 32: i32]}> : (tensor<4x2x32x32xbf16>, tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> + // CHECK: "ttnn.reshape" return %1 : tensor<2x4x32x32xbf16> } From d7343e1b725c02713d9f1a1b24f52498a2360729 Mon Sep 17 00:00:00 2001 From: Stefan Djordjevic Date: Mon, 24 Feb 2025 22:10:57 +0000 Subject: [PATCH 2/2] Adding tests to cover reshape op. --- lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp | 60 ++++++++++++----- lib/Dialect/TTIR/IR/TTIROps.cpp | 63 +++++++++-------- lib/Dialect/TTNN/IR/TTNNOps.cpp | 50 +++++++------- .../ttnn/operations/data_movement/concat.cpp | 7 +- .../ttnn/operations/data_movement/reshape.cpp | 7 +- .../reshape/reshape_tests_negative.mlir | 67 +++++++++++++++++++ .../reshape/reshape_tests_negative.mlir | 61 +++++++++++++++++ .../reshape/reshape_tests_positive.mlir | 22 ++++++ test/ttmlir/Dialect/TTNN/simple_reshape.mlir | 9 --- 9 files changed, 259 insertions(+), 87 deletions(-) create mode 100644 test/ttmlir/Dialect/TTIR/data_movement/reshape/reshape_tests_negative.mlir create mode 100644 test/ttmlir/Dialect/TTNN/data_movement/reshape/reshape_tests_negative.mlir create mode 100644 test/ttmlir/Dialect/TTNN/data_movement/reshape/reshape_tests_positive.mlir delete mode 100644 test/ttmlir/Dialect/TTNN/simple_reshape.mlir diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 6b958b96e1..0f74364829 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -511,19 +511,34 @@ class ReshapeOpConversionPattern ttnn::ReshapeOp>::TTNNToEmitCBaseOpConversionPattern; LogicalResult - matchAndRewrite(ttnn::ReshapeOp srcOp, OpAdaptor adaptor, + matchAndRewrite(ttnn::ReshapeOp reshapeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // Create operands vector + // + llvm::SmallVector operands{ + adaptor.getOperands()[0], // Input tensor + }; // emitc::CallOpaqueOp needs to know positions of operands vs attributes, so // an ArrayAttr object holding IndexTypes is created to denote this. // - ArrayAttr arrayAttrs = rewriter.getArrayAttr( - {rewriter.getIndexAttr(0), ttnn_to_emitc::utils::convertArrayAttrToSpan( - rewriter, srcOp.getShapeAttr())}); + ArrayAttr arrayAttrs = rewriter.getArrayAttr({ + rewriter.getIndexAttr(0), // Input tensor + ttnn_to_emitc::utils::convertArrayAttrToSpan( + rewriter, reshapeOp.getShapeAttr()), // Shape span + reshapeOp.getMemoryConfig() + ? (operands.append(1, ttnn_to_emitc::utils::createMemoryConfigOp( + rewriter, reshapeOp.getMemoryConfigAttr(), + reshapeOp.getLoc()) + ->getResult(0)), + mlir::cast(rewriter.getIndexAttr(1))) + : ttnn_to_emitc::utils::createStdNullopt( + rewriter) // ttnn::MemoryConfig + }); rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType()), - this->convertOpName(srcOp), arrayAttrs, nullptr, adaptor.getOperands()); + reshapeOp, this->getTypeConverter()->convertType(reshapeOp.getType()), + this->convertOpName(reshapeOp), arrayAttrs, nullptr, operands); return success(); } @@ -566,7 +581,7 @@ class ConcatOpConversionPattern ttnn::ConcatOp>::TTNNToEmitCBaseOpConversionPattern; LogicalResult - matchAndRewrite(ttnn::ConcatOp srcOp, OpAdaptor adaptor, + matchAndRewrite(ttnn::ConcatOp concatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // ttnn::concat op requires a `std::vector<>` of `Tensor` objects, but we @@ -575,23 +590,38 @@ class ConcatOpConversionPattern // by creating a utility function within the IR that converts a list of // `Tensor` objects into a `std::vector`. - ttnn_to_emitc::utils::insertVecCreateFnIfNotExists(rewriter, srcOp); + ttnn_to_emitc::utils::insertVecCreateFnIfNotExists(rewriter, concatOp); mlir::emitc::CallOpaqueOp vectorOp = rewriter.create( - srcOp.getLoc(), + concatOp.getLoc(), emitc::OpaqueType::get(rewriter.getContext(), "std::vector"), ttnn_to_emitc::utils::kCreateVectorFunctionName, nullptr, nullptr, adaptor.getInputs()); - ArrayAttr arrayAttrs = rewriter.getArrayAttr( - {mlir::IntegerAttr::get(rewriter.getIndexType(), 0), - srcOp.getDimAttr()}); + // Create operands vector + // + llvm::SmallVector operands{ + vectorOp->getResult(0), // Input vector of tensors + }; + + ArrayAttr arrayAttrs = rewriter.getArrayAttr({ + mlir::IntegerAttr::get(rewriter.getIndexType(), + 0), // Input vector of tensors + concatOp.getDimAttr(), // Concat dimension + concatOp.getMemoryConfig() + ? (operands.append(1, ttnn_to_emitc::utils::createMemoryConfigOp( + rewriter, concatOp.getMemoryConfigAttr(), + concatOp.getLoc()) + ->getResult(0)), + mlir::cast(rewriter.getIndexAttr(1))) + : ttnn_to_emitc::utils::createStdNullopt( + rewriter) // ttnn::MemoryConfig + }); rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType()), - this->convertOpName(srcOp), arrayAttrs, nullptr, - ValueRange(vectorOp->getResults())); + concatOp, this->getTypeConverter()->convertType(concatOp.getType()), + this->convertOpName(concatOp), arrayAttrs, nullptr, operands); return success(); } diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index ac113ecc25..2b4c5d6807 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -691,61 +691,58 @@ ::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() { ::mlir::RankedTensorType inputType = getInput().getType(); ::mlir::RankedTensorType outputType = getOutput().getType(); auto shape = getShape(); - int64_t shape_size = static_cast(shape.size()); + int64_t shapeSize = static_cast(shape.size()); - // Check that the shape size matches the rank of the output tensor - if (shape_size != static_cast(outputType.getRank())) { - return emitOpError("Shape attribute size must match output tensor rank"); + // Check that the shape attribute is non-empty. + if (shapeSize == 0) { + return emitOpError("Shape attribute must be non-empty"); } - // Check that the shape attribute is non-empty - if (shape_size == 0) { - return emitOpError("Shape attribute must be non-empty"); + // Check that the shape size matches the rank of the output tensor. + if (shapeSize != static_cast(outputType.getRank())) { + return emitOpError() << "Shape attribute size " << shapeSize + << " must match output tensor rank " + << outputType.getRank(); } - // Cardinality of the input and output tensors must be the same + // Cardinality of the input and output tensors must be the same. if (inputType.getNumElements() != outputType.getNumElements()) { - return emitOpError( - "Input and output tensors must have the same number of elements"); + return emitOpError() << "Input tensor number of elements " + << inputType.getNumElements() + << " and output tensor number of elements " + << outputType.getNumElements() << " must be the same"; } - bool has_negative = false; - int64_t known_dim_product = 1; + bool hasNegative = false; auto outputShape = outputType.getShape(); - // Check that all dimensions are positive except for at most one -1 - // Check that the non-negative dimensions match the output tensor shape - // Calculate the product of the known dimensions - for (int64_t i = 0; i < shape_size; i++) { - int64_t dim_value = mlir::cast(shape[i]).getInt(); + // Check that all dimensions are positive except for at most one -1. + // Check that the non-negative dimensions match the output tensor shape. + // Calculate the product of the known dimensions. + for (int64_t i = 0; i < shapeSize; i++) { + int64_t dimValue = mlir::cast(shape[i]).getInt(); - if (dim_value == -1) { - if (has_negative) { + if (dimValue == -1) { + if (hasNegative) { return emitOpError("Shape attribute must have at most one -1 element"); } - has_negative = true; + hasNegative = true; } else { - if (dim_value <= 0) { + if (dimValue <= 0) { return emitOpError( "All dimensions must be positive except the one with -1"); } - // Ensure that the non-negative dimensions match the output tensor shape - if (dim_value != outputShape[i]) { - return emitOpError("Shape attribute must match the output tensor shape " - "for dimensions that are not -1"); + // Ensure that the non-negative dimensions match the output tensor shape. + if (dimValue != outputShape[i]) { + return emitOpError() + << "Shape attribute " << dimValue + << " must match the output tensor shape " << outputShape[i] + << " at index " << i << " for dimension that is not -1"; } - - known_dim_product *= dim_value; } } - // If there's a -1, ensure that it can be inferred correctly - if (has_negative && inputType.getNumElements() % known_dim_product != 0) { - return emitOpError("Invalid shape: the dimensions do not multiply to the " - "total number of elements in the tensor"); - } - return success(); } diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index f42e103383..65bfb0b5aa 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -647,58 +647,56 @@ ::mlir::LogicalResult mlir::tt::ttnn::ReshapeOp::verify() { auto shape = getShape(); int64_t shapeSize = static_cast(shape.size()); - // Check that the shape size matches the rank of the output tensor - if (shapeSize != static_cast(outputType.getRank())) { - return emitOpError("Shape attribute size must match output tensor rank"); - } - // Check that the shape attribute is non-empty + // Check that the shape attribute is non-empty. if (shapeSize == 0) { return emitOpError("Shape attribute must be non-empty"); } - // Cardinality of the input and output tensors must be the same + // Check that the shape size matches the rank of the output tensor. + if (shapeSize != static_cast(outputType.getRank())) { + return emitOpError() << "Shape attribute size " << shapeSize + << " must match output tensor rank " + << outputType.getRank(); + } + + // Cardinality of the input and output tensors must be the same. if (inputType.getNumElements() != outputType.getNumElements()) { - return emitOpError( - "Input and output tensors must have the same number of elements"); + return emitOpError() << "Input tensor number of elements " + << inputType.getNumElements() + << " and output tensor number of elements " + << outputType.getNumElements() << " must be the same"; } - bool has_negative = false; - int64_t known_dim_product = 1; + bool hasNegative = false; auto outputShape = outputType.getShape(); // Check that all dimensions are positive except for at most one -1 // Check that the non-negative dimensions match the output tensor shape // Calculate the product of the known dimensions for (int64_t i = 0; i < shapeSize; i++) { - int64_t dim_value = mlir::cast(shape[i]).getInt(); + int64_t dimValue = mlir::cast(shape[i]).getInt(); - if (dim_value == -1) { - if (has_negative) { + if (dimValue == -1) { + if (hasNegative) { return emitOpError("Shape attribute must have at most one -1 element"); } - has_negative = true; + hasNegative = true; } else { - if (dim_value <= 0) { + if (dimValue <= 0) { return emitOpError( "All dimensions must be positive except the one with -1"); } // Ensure that the non-negative dimensions match the output tensor shape - if (dim_value != outputShape[i]) { - return emitOpError("Shape attribute must match the output tensor shape " - "for dimensions that are not -1"); + if (dimValue != outputShape[i]) { + return emitOpError() + << "Shape attribute " << dimValue + << " must match the output tensor shape " << outputShape[i] + << " at index " << i << " for dimension that is not -1"; } - - known_dim_product *= dim_value; } } - // If there's a -1, ensure that it can be inferred correctly - if (has_negative && inputType.getNumElements() % known_dim_product != 0) { - return emitOpError("Invalid shape: the dimensions do not multiply to the " - "total number of elements in the tensor"); - } - return success(); } diff --git a/runtime/lib/ttnn/operations/data_movement/concat.cpp b/runtime/lib/ttnn/operations/data_movement/concat.cpp index a7f97ed6f0..af61196316 100644 --- a/runtime/lib/ttnn/operations/data_movement/concat.cpp +++ b/runtime/lib/ttnn/operations/data_movement/concat.cpp @@ -19,8 +19,11 @@ void run(const ::tt::target::ttnn::ConcatOp *op, ProgramContext &context) { } int32_t dim = op->dim(); std::optional<::ttnn::MemoryConfig> memoryConfig = - ::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded( - op->memory_config()); + op->memory_config() == 0 + ? ::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded( + ::tt::runtime::ttnn::utils::getTensorRefMemoryConfig(op->out())) + : ::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded( + op->memory_config()); ::ttnn::Tensor out = ::ttnn::concat(inputs, dim, memoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); } diff --git a/runtime/lib/ttnn/operations/data_movement/reshape.cpp b/runtime/lib/ttnn/operations/data_movement/reshape.cpp index ed66ca6906..3f41bfc096 100644 --- a/runtime/lib/ttnn/operations/data_movement/reshape.cpp +++ b/runtime/lib/ttnn/operations/data_movement/reshape.cpp @@ -15,8 +15,11 @@ void run(const ::tt::target::ttnn::ReshapeOp *op, ProgramContext &context) { const auto *fbShape = op->shape(); std::vector shape(fbShape->begin(), fbShape->end()); std::optional<::ttnn::MemoryConfig> memoryConfig = - ::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded( - op->memory_config()); + op->memory_config() == 0 + ? ::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded( + ::tt::runtime::ttnn::utils::getTensorRefMemoryConfig(op->out())) + : ::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded( + op->memory_config()); ::ttnn::Tensor out = ::ttnn::reshape(in, shape, memoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); } diff --git a/test/ttmlir/Dialect/TTIR/data_movement/reshape/reshape_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/data_movement/reshape/reshape_tests_negative.mlir new file mode 100644 index 0000000000..de27fe6bb7 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/data_movement/reshape/reshape_tests_negative.mlir @@ -0,0 +1,67 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for reshape operation + +// Verify that verification fails when shape attribute is empty. +module { + func.func @reshape_shape_attribute_empty(%arg0: tensor<2x32x32xbf16>) -> tensor<32x2x32xbf16> { + %0 = tensor.empty() : tensor<32x2x32xbf16> + %1 = "ttir.reshape"(%arg0, %0) <{shape = []}> : (tensor<2x32x32xbf16>, tensor<32x2x32xbf16>) -> tensor<32x2x32xbf16> + // CHECK: error: 'ttir.reshape' op Shape attribute must be non-empty + return %1 : tensor<32x2x32xbf16> + } +} + +// ----- +// Verify that verification fails when shape size doesn't matches the rank of the output tensor +module { + func.func @reshape_shape_size_different_from_the_output_rank(%arg0: tensor<2x32x32xbf16>) -> tensor<1x32x2x32xbf16> { + %0 = tensor.empty() : tensor<1x32x2x32xbf16> + %1 = "ttir.reshape"(%arg0, %0) <{shape = [32: i32, 2: i32, 32: i32]}> : (tensor<2x32x32xbf16>, tensor<1x32x2x32xbf16>) -> tensor<1x32x2x32xbf16> + // CHECK: error: 'ttir.reshape' op Shape attribute size 3 must match output tensor rank 4 + return %1 : tensor<1x32x2x32xbf16> + } +} + +// ----- +// Verify that verification fails when input and output tensor have different number of elements. +module { + func.func @reshape_input_output_elements_mismatch(%arg0: tensor<2x32x32xbf16>) -> tensor<32x3x32xbf16> { + %0 = tensor.empty() : tensor<32x3x32xbf16> + %1 = "ttir.reshape"(%arg0, %0) <{shape = [32: i32, 3: i32, 32: i32]}> : (tensor<2x32x32xbf16>, tensor<32x3x32xbf16>) -> tensor<32x3x32xbf16> + // CHECK: error: 'ttir.reshape' op Input tensor number of elements 2048 and output tensor number of elements 3072 must be the same + return %1 : tensor<32x3x32xbf16> + } +} + +// ----- +// Verify that verification fails when shape attribute has more than one -1(infer) element. +module { + func.func @reshape_infer_dim_negative(%arg0: tensor<2x32x32xbf16>) -> tensor<32x2x32xbf16> { + %0 = tensor.empty() : tensor<32x2x32xbf16> + %1 = "ttir.reshape"(%arg0, %0) <{shape = [32: i32, -1: i32, -1: i32]}> : (tensor<2x32x32xbf16>, tensor<32x2x32xbf16>) -> tensor<32x2x32xbf16> + // CHECK: error: 'ttir.reshape' op Shape attribute must have at most one -1 element + return %1 : tensor<32x2x32xbf16> + } +} + +// ----- +// Verify that verification fails if the shape attribute has negative dimension which is not -1. +module { + func.func @reshape_infer_dim_negative(%arg0: tensor<2x32x32xbf16>) -> tensor<32x2x32xbf16> { + %0 = tensor.empty() : tensor<32x2x32xbf16> + %1 = "ttir.reshape"(%arg0, %0) <{shape = [32: i32, -1: i32, -32: i32]}> : (tensor<2x32x32xbf16>, tensor<32x2x32xbf16>) -> tensor<32x2x32xbf16> + // CHECK: error: 'ttir.reshape' op All dimensions must be positive except the one with -1 + return %1 : tensor<32x2x32xbf16> + } +} + +// ----- +// Verify that verification fails if the shape attribute is different from the output tensor shape. +module { + func.func @reshape_shape_mismatch(%arg0: tensor<2x32x32xbf16>) -> tensor<32x2x32xbf16> { + %0 = tensor.empty() : tensor<32x2x32xbf16> + %1 = "ttir.reshape"(%arg0, %0) <{shape = [32: i32, 3: i32, 32: i32]}> : (tensor<2x32x32xbf16>, tensor<32x2x32xbf16>) -> tensor<32x2x32xbf16> + // CHECK: error: 'ttir.reshape' op Shape attribute 3 must match the output tensor shape 2 at index 1 for dimension that is not -1 + return %1 : tensor<32x2x32xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/data_movement/reshape/reshape_tests_negative.mlir b/test/ttmlir/Dialect/TTNN/data_movement/reshape/reshape_tests_negative.mlir new file mode 100644 index 0000000000..76f4f91efd --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/data_movement/reshape/reshape_tests_negative.mlir @@ -0,0 +1,61 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for reshape operation + +// Verify that verification fails when shape attribute is empty. +module { + func.func @reshape_shape_attribute_empty(%arg0: tensor<2x32x32xbf16>) -> tensor<32x2x32xbf16> { + %0 = "ttnn.reshape"(%arg0) <{shape = []}> : (tensor<2x32x32xbf16>) -> tensor<32x2x32xbf16> + // CHECK: error: 'ttnn.reshape' op Shape attribute must be non-empty + return %0 : tensor<32x2x32xbf16> + } +} + +// ----- +// Verify that verification fails when shape size doesn't matches the rank of the output tensor +module { + func.func @reshape_shape_size_different_from_the_output_rank(%arg0: tensor<2x32x32xbf16>) -> tensor<1x32x2x32xbf16> { + %1 = "ttnn.reshape"(%arg0) <{shape = [32: i32, 2: i32, 32: i32]}> : (tensor<2x32x32xbf16>) -> tensor<1x32x2x32xbf16> + // CHECK: error: 'ttnn.reshape' op Shape attribute size 3 must match output tensor rank 4 + return %1 : tensor<1x32x2x32xbf16> + } +} + +// ----- +// Verify that verification fails when input and output tensor have different number of elements. +module { + func.func @reshape_input_output_elements_mismatch(%arg0: tensor<2x32x32xbf16>) -> tensor<32x3x32xbf16> { + %1 = "ttnn.reshape"(%arg0) <{shape = [32: i32, 3: i32, 32: i32]}> : (tensor<2x32x32xbf16>) -> tensor<32x3x32xbf16> + // CHECK: error: 'ttnn.reshape' op Input tensor number of elements 2048 and output tensor number of elements 3072 must be the same + return %1 : tensor<32x3x32xbf16> + } +} + +// ----- +// Verify that verification fails when shape attribute has more than one -1(infer) element. +module { + func.func @reshape_infer_dim_negative(%arg0: tensor<2x32x32xbf16>) -> tensor<32x2x32xbf16> { + %1 = "ttnn.reshape"(%arg0) <{shape = [32: i32, -1: i32, -1: i32]}> : (tensor<2x32x32xbf16>) -> tensor<32x2x32xbf16> + // CHECK: error: 'ttnn.reshape' op Shape attribute must have at most one -1 element + return %1 : tensor<32x2x32xbf16> + } +} + +// ----- +// Verify that verification fails if the shape attribute has negative dimension which is not -1. +module { + func.func @reshape_infer_dim_negative(%arg0: tensor<2x32x32xbf16>) -> tensor<32x2x32xbf16> { + %1 = "ttnn.reshape"(%arg0) <{shape = [32: i32, -1: i32, -32: i32]}> : (tensor<2x32x32xbf16>) -> tensor<32x2x32xbf16> + // CHECK: error: 'ttnn.reshape' op All dimensions must be positive except the one with -1 + return %1 : tensor<32x2x32xbf16> + } +} + +// ----- +// Verify that verification fails if the shape attribute is different from the output tensor shape. +module { + func.func @reshape_shape_mismatch(%arg0: tensor<2x32x32xbf16>) -> tensor<32x2x32xbf16> { + %1 = "ttnn.reshape"(%arg0) <{shape = [32: i32, 3: i32, 32: i32]}> : (tensor<2x32x32xbf16>) -> tensor<32x2x32xbf16> + // CHECK: error: 'ttnn.reshape' op Shape attribute 3 must match the output tensor shape 2 at index 1 for dimension that is not -1 + return %1 : tensor<32x2x32xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/data_movement/reshape/reshape_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/data_movement/reshape/reshape_tests_positive.mlir new file mode 100644 index 0000000000..226eaf7b12 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/data_movement/reshape/reshape_tests_positive.mlir @@ -0,0 +1,22 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s +module attributes {} { + func.func @reshape_positive(%arg0: tensor<4x2x32x32xbf16>) -> tensor<2x4x32x32xbf16> { + %0 = tensor.empty() : tensor<2x4x32x32xbf16> + %1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 32: i32]}> : (tensor<4x2x32x32xbf16>, tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> + // CHECK: "ttnn.reshape" + // CHECK-SAME: <{shape = [2 : i32, 4 : i32, 32 : i32, 32 : i32]}> + // CHECK-SAME: tensor<4x2x32x32xbf16 + // CHECK-SAME: -> tensor<2x4x32x32xbf16 + return %1 : tensor<2x4x32x32xbf16> + } + + func.func @reshape_with_minus_one(%arg0: tensor<4x2x32x32xbf16>) -> tensor<2x4x32x32xbf16> { + %0 = tensor.empty() : tensor<2x4x32x32xbf16> + %1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, -1: i32, 32: i32, 32: i32]}> : (tensor<4x2x32x32xbf16>, tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> + // CHECK: "ttnn.reshape" + // CHECK-SAME: <{shape = [2 : i32, -1 : i32, 32 : i32, 32 : i32]}> + // CHECK-SAME: tensor<4x2x32x32xbf16 + // CHECK-SAME: -> tensor<2x4x32x32xbf16 + return %1 : tensor<2x4x32x32xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/simple_reshape.mlir b/test/ttmlir/Dialect/TTNN/simple_reshape.mlir deleted file mode 100644 index 72eedd3cdb..0000000000 --- a/test/ttmlir/Dialect/TTNN/simple_reshape.mlir +++ /dev/null @@ -1,9 +0,0 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s -module attributes {} { - func.func @forward(%arg0: tensor<4x2x32x32xbf16>) -> tensor<2x4x32x32xbf16> { - %0 = tensor.empty() : tensor<2x4x32x32xbf16> - // CHECK: = "ttnn.reshape" - %1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 32: i32]}> : (tensor<4x2x32x32xbf16>, tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> - return %1 : tensor<2x4x32x32xbf16> - } -}