From 92b1ad908db7cc2128e1ce1edaf0a84231e47b45 Mon Sep 17 00:00:00 2001 From: Milan Topalovic Date: Wed, 26 Feb 2025 17:22:23 +0000 Subject: [PATCH] Remove dps interface from ttnn ops --- include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 61 +++---------- lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 59 +++++------- lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp | 6 +- lib/Dialect/TTNN/IR/TTNNOps.cpp | 12 +-- lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp | 5 -- .../Decomposition/CumSumOpRewritePattern.cpp | 24 +---- lib/Target/TTNN/TTNNToFlatbuffer.cpp | 55 ++++++------ runtime/lib/ttnn/operations/matmul/matmul.cpp | 10 +-- .../Workarounds/cumsum_workaround.mlir | 58 ++++-------- .../embedding_backward_workaround.mlir | 19 ++-- .../Workarounds/embedding_workaround.mlir | 33 ++----- .../Workarounds/max_pool2d_workaround.mlir | 21 ++--- .../Workarounds/slice_workaround.mlir | 15 +--- .../concat/concat_multiple_tensors.mlir | 4 - .../concat/concat_negative_dim.mlir | 1 - .../data_movement/concat/simple_concat.mlir | 1 - .../TTNN/embedding/embedding_non_tile.mlir | 1 - .../TTNN/embedding/gather_to_embedding.mlir | 3 - .../embedding/simple_embedding_backward.mlir | 1 - .../TTNN/linear/linear_tests_positive.mlir | 24 ----- .../Dialect/TTNN/linear/simple_linear.mlir | 3 - .../TTNN/matmul/matmul_tests_negative.mlir | 90 ++++++++----------- test/ttmlir/Dialect/TTNN/simple_cumsum.mlir | 2 - .../StableHLO/n150/Binary/concat_op.mlir | 17 +--- .../n150/dot_general/dot_general_op_2d.mlir | 2 - .../dot_general_op_batch_matmul.mlir | 2 - .../Silicon/StableHLO/n150/maxpool2d_op.mlir | 3 - .../StableHLO/n150/moreh_cumsum_op.mlir | 14 +-- .../Silicon/StableHLO/n150/slice_op.mlir | 14 --- test/ttmlir/Silicon/TTNN/n150/deallocate.mlir | 8 +- .../n150/eltwise/binary/concat/concat.mlir | 1 - .../n150/embedding/embedding_1d_tensor.mlir | 1 - .../n150/embedding/embedding_backward.mlir | 1 - .../n150/embedding/embedding_non_tile.mlir | 1 - .../TTNN/n150/embedding/simple_embedding.mlir | 1 - .../TTNN/n150/perf/test_perf_concat.mlir | 1 - .../n150/perf/test_perf_conv2d_config.mlir | 15 ++-- .../TTNN/n150/perf/test_perf_cumsum.mlir | 1 - .../TTNN/n150/perf/test_perf_embedding.mlir | 1 - .../TTNN/n150/perf/test_perf_linear.mlir | 3 - .../Silicon/TTNN/n150/simple_cumsum.mlir | 4 - .../Silicon/TTNN/n150/simple_linear.mlir | 9 -- .../OpModel/TTNN/Op/TestOpModelInterface.cpp | 43 +++++---- 43 files changed, 196 insertions(+), 454 deletions(-) diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 405d478a3c..40982d2296 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -753,20 +753,18 @@ def TTNN_ProdOp : TTNN_Op<"prod"> { let hasVerifier = 1; } -def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> { +def TTNN_EmbeddingOp : TTNN_Op<"embedding"> { let summary = "Embedding op."; let description = [{ Embedding operation. }]; let arguments = (ins AnyRankedTensor:$input, - AnyRankedTensor:$weight, - AnyRankedTensor:$output); + AnyRankedTensor:$weight); let results = (outs AnyRankedTensor:$result); let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } wa::TTNNOperandsWorkarounds getOperandsWorkarounds() { return wa::TTNNOperandsWorkaroundsFactory::createEmbeddingOpOperandsWorkarounds(); } @@ -809,7 +807,7 @@ def TTNN_FillCacheOp : TTNN_InplaceOp<"fill_cache"> { let hasVerifier = 1; } -def TTNN_EmbeddingBackwardOp : TTNN_NamedDPSOp<"embedding_bw"> { +def TTNN_EmbeddingBackwardOp : TTNN_Op<"embedding_bw"> { let summary = "Embedding backward op."; let description = [{ Embedding backward operation. Generates the gradient of the embedding operation with respect to the input. @@ -819,13 +817,11 @@ def TTNN_EmbeddingBackwardOp : TTNN_NamedDPSOp<"embedding_bw"> { AnyRankedTensor:$weight, AnyRankedTensor:$in_gradient, OptionalAttr:$dtype, - OptionalAttr:$memory_config, - AnyRankedTensor:$output); + OptionalAttr:$memory_config); let results = (outs AnyRankedTensor:$result); let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } wa::TTNNOperandsWorkarounds getOperandsWorkarounds() { return wa::TTNNOperandsWorkaroundsFactory::createEmbeddingBackwardOpOperandsWorkarounds(); } @@ -834,7 +830,7 @@ def TTNN_EmbeddingBackwardOp : TTNN_NamedDPSOp<"embedding_bw"> { let hasVerifier = 1; } -def TTNN_MorehCumSumOp : TTNN_NamedDPSOp<"moreh_cumsum"> { +def TTNN_MorehCumSumOp : TTNN_Op<"moreh_cumsum"> { let summary = "Moreh cummulative sum op."; let description = [{ Computes the cumulative sum of elements of a tensor along specified dimension. @@ -854,14 +850,11 @@ def TTNN_MorehCumSumOp : TTNN_NamedDPSOp<"moreh_cumsum"> { let arguments = (ins AnyRankedTensor:$input, I64Attr:$dim, - AnyRankedTensor:$output, OptionalAttr:$memory_config); let results = (outs AnyRankedTensor:$result); let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - wa::TTNNOperandsWorkarounds getOperandsWorkarounds() { RankedTensorType inputType = getInput().getType(); return wa::TTNNOperandsWorkaroundsFactory::createCumSumOpOperandsWorkarounds(inputType); @@ -924,22 +917,19 @@ def TTNN_RepeatInterleaveOp : TTNN_Op<"repeat_interleave"> { let hasVerifier = 1; } -def TTNN_ConcatOp : TTNN_NamedDPSOp<"concat", [HasMemoryConfigTrait]> { +def TTNN_ConcatOp : TTNN_Op<"concat", [HasMemoryConfigTrait]> { let summary = "Concat op."; let description = [{ Concat tensors along a given dimension. }]; let arguments = (ins Variadic:$inputs, - AnyRankedTensor:$output, SI32Attr:$dim, OptionalAttr:$memory_config); let results = (outs AnyRankedTensor:$result); let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - wa::TTNNOperandsWorkarounds getOperandsWorkarounds() { ::mlir::Operation::operand_range inputs = getInputs(); int64_t numOperands = getOperands().size(); @@ -1016,7 +1006,7 @@ def TTNN_PadOp: TTNN_Op<"pad"> { let hasVerifier = 1; } -def TTNN_SliceOp: TTNN_NamedDPSOp<"slice"> { +def TTNN_SliceOp: TTNN_Op<"slice"> { let summary = "Slice op."; let description = [{ Extract a portion of a tensor based on the specified start (`begins`), stop (`ends`), and step @@ -1024,7 +1014,6 @@ def TTNN_SliceOp: TTNN_NamedDPSOp<"slice"> { }]; let arguments = (ins AnyRankedTensor:$input, - AnyRankedTensor:$output, I32ArrayAttr:$begins, I32ArrayAttr:$ends, I32ArrayAttr:$step); @@ -1032,11 +1021,9 @@ def TTNN_SliceOp: TTNN_NamedDPSOp<"slice"> { let results = (outs AnyRankedTensor:$result); let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - wa::TTNNOperandsWorkarounds getOperandsWorkarounds() { ttnn::TTNNLayoutAttr layoutAttr = mlir::cast( - getOutput().getType().getEncoding()); + getResult().getType().getEncoding()); ::mlir::ArrayAttr begins = getBegins(); ::mlir::ArrayAttr step = getStep(); return wa::TTNNOperandsWorkaroundsFactory:: @@ -1047,7 +1034,7 @@ def TTNN_SliceOp: TTNN_NamedDPSOp<"slice"> { let hasVerifier = 1; } -def TTNN_LinearOp : TTNN_NamedDPSOp<"linear"> { +def TTNN_LinearOp : TTNN_Op<"linear"> { let summary = "Linear transformation of inputs."; let description = [{ @@ -1064,41 +1051,31 @@ def TTNN_LinearOp : TTNN_NamedDPSOp<"linear"> { let arguments = (ins AnyRankedTensor:$a, AnyRankedTensor:$b, Optional:$bias, - AnyRankedTensor:$output, DefaultValuedAttr:$transpose_a, DefaultValuedAttr:$transpose_b); let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } // ANCHOR: adding_an_op_matmul_ttnn -def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul", +def TTNN_MatmulOp : TTNN_Op<"matmul", [DeclareOpInterfaceMethods] > { let arguments = (ins AnyRankedTensor:$a, AnyRankedTensor:$b, - AnyRankedTensor:$output, DefaultValuedAttr:$transpose_a, DefaultValuedAttr:$transpose_b); let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } // ANCHOR_END: adding_an_op_matmul_ttnn -def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> { +def TTNN_Conv2dOp : TTNN_Op<"conv2d"> { let summary = "Conv2d operation."; let description = [{ Applies a 2D convolution over an input image composed of several input planes. @@ -1107,7 +1084,6 @@ def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> { let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$weight, Optional:$bias, - AnyRankedTensor:$output, TT_Device:$device, I32Attr:$in_channels, I32Attr:$out_channels, @@ -1123,14 +1099,10 @@ def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } -def TTNN_ConvTranspose2dOp : TTNN_NamedDPSOp<"conv_transpose2d"> { +def TTNN_ConvTranspose2dOp : TTNN_Op<"conv_transpose2d"> { let summary = "ConvTranspose2d operation."; let description = [{ Applies a 2D transposed convolution operator over an input image composed of several input planes. @@ -1191,7 +1163,6 @@ def TTNN_ConvTranspose2dOp : TTNN_NamedDPSOp<"conv_transpose2d"> { let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$weight, Optional:$bias, - AnyRankedTensor:$output, TT_Device:$device, I32Attr:$in_channels, I32Attr:$out_channels, @@ -1207,21 +1178,16 @@ def TTNN_ConvTranspose2dOp : TTNN_NamedDPSOp<"conv_transpose2d"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } -def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> { +def TTNN_MaxPool2dOp : TTNN_Op<"max_pool2d"> { let summary = "Applies a 2D max pooling over an input signal composed of several input planes."; let description = [{ Applies a 2D max pooling over an input signal composed of several input planes. }]; let arguments = (ins AnyRankedTensor:$input, - AnyRankedTensor:$output, TT_Device:$device, SI32Attr:$batch_size, SI32Attr:$input_height, @@ -1240,7 +1206,6 @@ def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> { let results = (outs AnyRankedTensor:$result); let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } wa::TTNNOperandsWorkarounds getOperandsWorkarounds() { return wa::TTNNOperandsWorkaroundsFactory::createMaxPool2DOpOperandsWorkarounds(); } diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 5da2b2d8a2..71db06d59b 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -399,7 +399,7 @@ class EmbeddingOpConversionPattern ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), - adaptor.getInput(), adaptor.getWeight(), adaptor.getOutput()); + adaptor.getInput(), adaptor.getWeight()); return success(); } @@ -455,7 +455,7 @@ class EmbeddingBackwardOpConversionPattern rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), adaptor.getInput(), adaptor.getWeight(), reshapedGrad, dTypeAttr, - memoryConfigAttr, adaptor.getOutput()); + memoryConfigAttr); return success(); } }; @@ -471,7 +471,7 @@ class CumSumOpConversionPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), - adaptor.getInput(), adaptor.getDim(), adaptor.getOutput(), nullptr); + adaptor.getInput(), adaptor.getDim(), nullptr); return success(); } }; @@ -654,7 +654,7 @@ class ConcatOpConversionPattern : public OpConversionPattern { } rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), - adaptor.getInputs(), adaptor.getOutput(), dim, + adaptor.getInputs(), dim, /* memory_config */ nullptr); return success(); } @@ -687,8 +687,8 @@ class SliceOpConversionPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), - adaptor.getInput(), adaptor.getOutput(), adaptor.getBegins(), - adaptor.getEnds(), adaptor.getStep()); + adaptor.getInput(), adaptor.getBegins(), adaptor.getEnds(), + adaptor.getStep()); return success(); } }; @@ -932,8 +932,8 @@ class LinearOpConversionPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), adaptor.getA(), - adaptor.getB(), adaptor.getBias(), adaptor.getOutput(), - adaptor.getTransposeA(), adaptor.getTransposeB()); + adaptor.getB(), adaptor.getBias(), adaptor.getTransposeA(), + adaptor.getTransposeB()); return success(); } }; @@ -950,8 +950,7 @@ class MatmulOpConversionPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), adaptor.getA(), - adaptor.getB(), adaptor.getOutput(), adaptor.getTransposeA(), - adaptor.getTransposeB()); + adaptor.getB(), adaptor.getTransposeA(), adaptor.getTransposeB()); return success(); } }; @@ -971,7 +970,7 @@ class Conv2dOpConversionPattern : public OpConversionPattern { auto inputTy = mlir::cast(adaptor.getInput().getType()); auto kernelTy = mlir::cast(adaptor.getWeight().getType()); - auto outputTy = mlir::cast(adaptor.getOutput().getType()); + auto outputTy = op.getResult().getType(); auto batchSizeAttr = rewriter.getI32IntegerAttr(inputTy.getDimSize(0)); auto inputHeightAttr = rewriter.getI32IntegerAttr(inputTy.getDimSize(1)); @@ -1028,12 +1027,11 @@ class Conv2dOpConversionPattern : public OpConversionPattern { outputTy.getElementType(), outputTy.getEncoding()); - ttnn::Conv2dOp newConv = ttmlir::utils::createDPSOp( - rewriter, op.getLoc(), outputTy, adaptor.getInput(), - adaptor.getWeight(), adaptor.getBias(), device, inChannelsAttr, - outChannelsAttr, batchSizeAttr, inputHeightAttr, inputWidthAttr, - kernelSizeAttr, *strideAttr, reducedPaddingAttr, *dilationAttr, - groupsAttr, nullptr); + ttnn::Conv2dOp newConv = rewriter.create( + op.getLoc(), outputTy, adaptor.getInput(), adaptor.getWeight(), + adaptor.getBias(), device, inChannelsAttr, outChannelsAttr, + batchSizeAttr, inputHeightAttr, inputWidthAttr, kernelSizeAttr, + *strideAttr, reducedPaddingAttr, *dilationAttr, groupsAttr, nullptr); Value output = ttir_to_ttnn::utils::generateReshape(newConv, outputShape, rewriter); @@ -1091,7 +1089,7 @@ class ConvTranspose2dOpConversionPattern auto inputTy = mlir::cast(adaptor.getInput().getType()); auto kernelTy = mlir::cast(adaptor.getWeight().getType()); - auto outputTy = mlir::cast(adaptor.getOutput().getType()); + auto outputTy = op.getResult().getType(); auto batchSizeAttr = rewriter.getI32IntegerAttr(inputTy.getDimSize(0)); auto inputHeightAttr = rewriter.getI32IntegerAttr(inputTy.getDimSize(1)); @@ -1151,21 +1149,12 @@ class ConvTranspose2dOpConversionPattern outputTy = mlir::cast(getTypeConverter()->convertType( outputTy.cloneWith(flattenedOutputShape, outputTy.getElementType()))); - // Using a tensor::EmptyOp so that the rewriter for EmptyOp can handle the - // attribute determination - auto convDPSOutput = rewriter.replaceOpWithNewOp( - adaptor.getOutput().getDefiningOp(), flattenedOutputShape, - outputTy.getElementType()); - - // Must set the type to the output type to maintain the layout attributes - convDPSOutput.getResult().setType(outputTy); - ttnn::ConvTranspose2dOp new_conv = rewriter.create( op.getLoc(), outputTy, adaptor.getInput(), adaptor.getWeight(), - adaptor.getBias(), convDPSOutput, device, inChannelsAttr, - outChannelsAttr, batchSizeAttr, inputHeightAttr, inputWidthAttr, - kernelSizeAttr, *strideAttr, reducedPaddingAttr, *outputPaddingAttr, - *dilationAttr, groupsAttr); + adaptor.getBias(), device, inChannelsAttr, outChannelsAttr, + batchSizeAttr, inputHeightAttr, inputWidthAttr, kernelSizeAttr, + *strideAttr, reducedPaddingAttr, *outputPaddingAttr, *dilationAttr, + groupsAttr); // Restore the normal shape (N x H x W x C) Value output = @@ -1239,8 +1228,7 @@ class MaxPool2dOpConversionPattern mlir::cast>(adaptor.getInput()), rewriter); - auto outputType = - mlir::cast(adaptor.getOutput().getType()); + auto outputType = op.getResult().getType(); llvm::ArrayRef outputShape = outputType.getShape(); llvm::SmallVector flattenedOutputShape{ @@ -1250,8 +1238,9 @@ class MaxPool2dOpConversionPattern outputType.getElementType(), outputType.getEncoding()); - auto newPool = ttmlir::utils::createDPSOp( - rewriter, op.getLoc(), outputType, flattenedInput, device, batchSize, + auto newPool = rewriter.create( + op.getLoc(), this->getTypeConverter()->convertType(outputType), + flattenedInput, device, batchSize, static_cast(inputShape[inputShape.size() - 3]), static_cast(inputShape[inputShape.size() - 2]), channels, adaptor.getKernelHeight(), adaptor.getKernelWidth(), diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index da17cb30fd..6b958b96e1 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -283,8 +283,7 @@ class LinearOpConversionPattern /*compute_kernel_config=*/ ttnn_to_emitc::utils::createStdNullopt(rewriter), /*core_grid=*/ttnn_to_emitc::utils::createStdNullopt(rewriter), - /*output_tile=*/ttnn_to_emitc::utils::createStdNullopt(rewriter), - rewriter.getIndexAttr(3)}); + /*output_tile=*/ttnn_to_emitc::utils::createStdNullopt(rewriter)}); rewriter.replaceOpWithNewOp( linearOp, this->getTypeConverter()->convertType(linearOp.getType()), @@ -327,8 +326,7 @@ class MatmulOpConversionPattern /*compute_kernel_config=*/ ttnn_to_emitc::utils::createStdNullopt(rewriter), /*core_grid=*/ttnn_to_emitc::utils::createStdNullopt(rewriter), - /*output_tile=*/ttnn_to_emitc::utils::createStdNullopt(rewriter), - rewriter.getIndexAttr(2)}); + /*output_tile=*/ttnn_to_emitc::utils::createStdNullopt(rewriter)}); // ANCHOR_END: adding_an_op_matmul_ttnn_to_emitc_array_attrs rewriter.replaceOpWithNewOp( diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index e172934fc9..f42e103383 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -57,7 +57,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::ClampOp::verify() { ::mlir::LogicalResult mlir::tt::ttnn::Conv2dOp::verify() { mlir::RankedTensorType inputType = getInput().getType(); mlir::RankedTensorType weightType = getWeight().getType(); - mlir::RankedTensorType outputType = getOutput().getType(); + mlir::RankedTensorType outputType = getResult().getType(); std::optional bias = getBias().getImpl() ? std::make_optional(getBias().getType()) : std::nullopt; @@ -256,7 +256,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::Conv2dOp::verify() { ::mlir::LogicalResult mlir::tt::ttnn::ConvTranspose2dOp::verify() { mlir::RankedTensorType inputType = getInput().getType(); mlir::RankedTensorType weightType = getWeight().getType(); - mlir::RankedTensorType outputType = getOutput().getType(); + mlir::RankedTensorType outputType = getResult().getType(); std::optional bias = getBias().getImpl() ? std::make_optional(getBias().getType()) : std::nullopt; @@ -749,7 +749,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::SliceOp::verify() { ::mlir::ArrayAttr begins = getBeginsAttr(); ::mlir::ArrayAttr ends = getEndsAttr(); ::mlir::ArrayAttr stepAttr = getStepAttr(); - ::mlir::RankedTensorType outputType = getOutput().getType(); + ::mlir::RankedTensorType outputType = getResult().getType(); // Verify that the input is at least 1D tensor if (inputType.getRank() < 1) { @@ -936,7 +936,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::EmbeddingBackwardOp::verify() { ::mlir::RankedTensorType inputType = getInput().getType(); ::mlir::RankedTensorType weightType = getWeight().getType(); ::mlir::RankedTensorType inputGradType = getInGradient().getType(); - ::mlir::RankedTensorType outputType = getOutput().getType(); + ::mlir::RankedTensorType outputType = getResult().getType(); // inputType checks: // 1. Last dimension must be divisible by TILE_WIDTH. @@ -1149,7 +1149,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::LinearOp::verify() { ::mlir::RankedTensorType inputBType = getB().getType(); std::optional<::mlir::RankedTensorType> biasType = getBias() ? std::make_optional(getBias().getType()) : std::nullopt; - ::mlir::RankedTensorType outputType = getOutput().getType(); + ::mlir::RankedTensorType outputType = getResult().getType(); llvm::ArrayRef outputShape = outputType.getShape(); llvm::SmallVector inputAShape(inputAType.getShape()); @@ -1298,7 +1298,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::LinearOp::verify() { ::mlir::LogicalResult mlir::tt::ttnn::MatmulOp::verify() { ::mlir::RankedTensorType inputAType = getA().getType(); ::mlir::RankedTensorType inputBType = getB().getType(); - ::mlir::RankedTensorType outputType = getOutput().getType(); + ::mlir::RankedTensorType outputType = getResult().getType(); llvm::ArrayRef outputShape = outputType.getShape(); llvm::SmallVector inputAShape(inputAType.getShape()); diff --git a/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp b/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp index 7518a1b1f7..a5dce00a1e 100644 --- a/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp +++ b/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp @@ -101,7 +101,6 @@ TTNNOperandsWorkaroundsFactory::createMaxPool2DOpOperandsWorkarounds() { rowMajorLayoutBF16Workaround.tensorLayoutWorkaround = Layout::RowMajor; rowMajorLayoutBF16Workaround.tensorDataTypeWorkaround = DataType::BFloat16; return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds() - .addInputOperandWorkaround(rowMajorLayoutBF16Workaround) .addInputOperandWorkaround(rowMajorLayoutBF16Workaround) .addOutputOperandWorkaround(rowMajorLayoutBF16Workaround); } @@ -128,7 +127,6 @@ TTNNOperandsWorkaroundsFactory::createEmbeddingOpOperandsWorkarounds() { return TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(0, 0) .addInputOperandWorkaround(inputRowMajorInt32Workaround) .addInputOperandWorkaround(bf16Workaround) - .addInputOperandWorkaround(bf16Workaround) .addOutputOperandWorkaround(bf16Workaround); } @@ -156,7 +154,6 @@ TTNNOperandsWorkaroundsFactory::createEmbeddingBackwardOpOperandsWorkarounds() { .addInputOperandWorkaround(inputRowMajorInt32Workaround) .addInputOperandWorkaround(bf16Workaround) .addInputOperandWorkaround(bf16Workaround) - .addInputOperandWorkaround(bf16Workaround) .addOutputOperandWorkaround(bf16Workaround); } @@ -190,7 +187,6 @@ TTNNOperandsWorkaroundsFactory::createCumSumOpOperandsWorkarounds( ? TTNNOperandWorkarounds(DataType::Float32) : TTNNOperandWorkarounds(); return TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds() - .addInputOperandWorkaround(typeWorkaround) .addInputOperandWorkaround(typeWorkaround) .addOutputOperandWorkaround(typeWorkaround); } @@ -324,7 +320,6 @@ TTNNOperandsWorkaroundsFactory::createSliceOpOperandsWorkarounds( rowMajorLayoutBF16Workaround.tensorLayoutWorkaround = Layout::RowMajor; } return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds() - .addInputOperandWorkaround(rowMajorLayoutBF16Workaround) .addInputOperandWorkaround(rowMajorLayoutBF16Workaround) .addOutputOperandWorkaround(rowMajorLayoutBF16Workaround); } diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/CumSumOpRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/CumSumOpRewritePattern.cpp index beaf590573..407783cd50 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/CumSumOpRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/CumSumOpRewritePattern.cpp @@ -47,31 +47,9 @@ CumSumOpRewritePattern::matchAndRewrite(ttnn::MorehCumSumOp srcOp, RankedTensorType::get(reshapeOutputType.getShape(), outputType.getElementType(), newOutputLayoutAttr); - DataTypeAttr dTypeAttr = DataTypeAttr::get(rewriter.getContext(), - newOutputLayoutAttr.getDataType()); - ttnn::LayoutAttr tensorLayoutAttr = - ttnn::LayoutAttr::get(getContext(), newOutputLayoutAttr.getLayout()); - - // Create MemoryConfigAttr - ttnn::ShapeAttr shapeAttr = - ttnn::ShapeAttr::get(rewriter.getContext(), newOutputType.getShape()); - - ttnn::BufferTypeAttr bufferTypeAttr = ttnn::BufferTypeAttr::get( - getContext(), newOutputLayoutAttr.getBufferType()); - ttnn::ShardSpecAttr shardSpecAttr = ttnn::ShardSpecAttr::get( - getContext(), - ttnn::ShapeAttr::get(getContext(), newOutputLayoutAttr.getShardShape())); - ttnn::MemoryConfigAttr memoryConfigAttr = - ttnn::MemoryConfigAttr::get(getContext(), bufferTypeAttr, shardSpecAttr, - newOutputLayoutAttr.getMemLayout()); - - EmptyOp emptyOp = rewriter.create( - srcOp->getLoc(), newOutputType, shapeAttr, dTypeAttr, tensorLayoutAttr, - ttnn::utils::getOrInsertDevice(rewriter, srcOp), memoryConfigAttr); - MorehCumSumOp cumsumOp = rewriter.create( srcOp->getLoc(), newOutputType, preReshapeOp->getResult(0), - srcOp.getDim(), emptyOp, nullptr); + srcOp.getDim(), nullptr); llvm::ArrayRef outputShapeAttr(inputTypeShape); mlir::TypedValue cumsumOutput = diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index a0a6b47cd0..7de2f54fb4 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -659,8 +659,8 @@ createOp(FlatbufferObjectCache &cache, LinearOp op) { ? cache.at<::tt::target::ttnn::TensorRef>( getOperandThroughDPSOps(op.getBias())) : flatbuffers::Offset<::tt::target::ttnn::TensorRef>(); - auto output = cache.at<::tt::target::ttnn::TensorRef>( - getOperandThroughDPSOps(op.getOutput())); + auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedSize); return ::tt::target::ttnn::CreateLinearOp( *cache.fbb, a, b, bias, output, op.getTransposeA(), op.getTransposeB()); } @@ -672,8 +672,8 @@ createOp(FlatbufferObjectCache &cache, MatmulOp op) { getOperandThroughDPSOps(op.getA())); auto b = cache.at<::tt::target::ttnn::TensorRef>( getOperandThroughDPSOps(op.getB())); - auto output = cache.at<::tt::target::ttnn::TensorRef>( - getOperandThroughDPSOps(op.getOutput())); + auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedSize); return ::tt::target::ttnn::CreateMatmulOp( *cache.fbb, a, b, output, op.getTransposeA(), op.getTransposeB()); } @@ -683,11 +683,12 @@ ::flatbuffers::Offset<::tt::target::ttnn::MorehCumSumOp> createOp(FlatbufferObjectCache &cache, MorehCumSumOp op) { auto in = cache.at<::tt::target::ttnn::TensorRef>( getOperandThroughDPSOps(op.getInput())); - auto dpsOutput = getOperandThroughDPSOps(op.getResult()); - auto output = cache.at<::tt::target::ttnn::TensorRef>(dpsOutput); + auto outputType = op.getResult(); + auto output = cache.getOrCreate(outputType, tensorValueToFlatbuffer, + kHostAllocatedSize); - auto tileShape = getTensorValueTileShape(dpsOutput); - auto coreRangeSet = getTensorValueCoreRangeSet(cache, dpsOutput); + auto tileShape = getTensorValueTileShape(outputType); + auto coreRangeSet = getTensorValueCoreRangeSet(cache, outputType); auto memoryConfig = op.getMemoryConfig() ? memoryConfigToFlatbuffer(cache, op.getMemoryConfig().value(), @@ -708,8 +709,8 @@ createOp(FlatbufferObjectCache &cache, Conv2dOp op) { ? flatbuffers::Offset<::tt::target::ttnn::TensorRef>() : cache.at<::tt::target::ttnn::TensorRef>( getOperandThroughDPSOps(op.getBias())); - auto output = cache.at<::tt::target::ttnn::TensorRef>( - getOperandThroughDPSOps(op.getResult())); + auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedSize); auto device = getOperandThroughDPSOps(op.getDevice()); @@ -746,8 +747,8 @@ createOp(FlatbufferObjectCache &cache, ConvTranspose2dOp op) { ? flatbuffers::Offset<::tt::target::ttnn::TensorRef>() : cache.at<::tt::target::ttnn::TensorRef>( getOperandThroughDPSOps(op.getBias())); - auto output = cache.at<::tt::target::ttnn::TensorRef>( - getOperandThroughDPSOps(op.getResult())); + auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedSize); auto device = getOperandThroughDPSOps(op.getDevice()); @@ -1190,15 +1191,16 @@ createConcatOp(FlatbufferObjectCache &cache, ConcatOp op) { getOperandThroughDPSOps(input))); } - auto dpsOutput = getOperandThroughDPSOps(op.getResult()); - auto out = cache.at<::tt::target::ttnn::TensorRef>(dpsOutput); + auto outputType = op.getResult(); + auto out = cache.getOrCreate(outputType, tensorValueToFlatbuffer, + kHostAllocatedSize); int32_t dim = op.getDim(); std::optional memoryConfig = op.getMemoryConfig(); - auto tileShape = getTensorValueTileShape(dpsOutput); - auto coreRangeSet = getTensorValueCoreRangeSet(cache, dpsOutput); + auto tileShape = getTensorValueTileShape(outputType); + auto coreRangeSet = getTensorValueCoreRangeSet(cache, outputType); return ::tt::target::ttnn::CreateConcatOpDirect( *cache.fbb, &ins, out, dim, memoryConfig ? memoryConfigToFlatbuffer(cache, memoryConfig.value(), @@ -1212,8 +1214,8 @@ createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) { getOperandThroughDPSOps(op.getInput())); auto in1 = cache.at<::tt::target::ttnn::TensorRef>( getOperandThroughDPSOps(op.getWeight())); - auto out = cache.at<::tt::target::ttnn::TensorRef>( - getOperandThroughDPSOps(op.getResult())); + auto out = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedSize); return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, out); } @@ -1231,11 +1233,12 @@ createEmbeddingBackwardOp(FlatbufferObjectCache &cache, std::optional<::mlir::tt::ttnn::MemoryConfigAttr> memoryConfig = op.getMemoryConfig(); - auto dpsOutput = getOperandThroughDPSOps(op.getResult()); - auto out = cache.at<::tt::target::ttnn::TensorRef>(dpsOutput); + auto outputType = op.getResult(); + auto out = cache.getOrCreate(outputType, tensorValueToFlatbuffer, + kHostAllocatedSize); - auto tileShape = getTensorValueTileShape(dpsOutput); - auto coreRangeSet = getTensorValueCoreRangeSet(cache, dpsOutput); + auto tileShape = getTensorValueTileShape(outputType); + auto coreRangeSet = getTensorValueCoreRangeSet(cache, outputType); return ::tt::target::ttnn::CreateEmbeddingBackwardOp( *cache.fbb, in0, in1, in2, dtype.has_value() @@ -1299,8 +1302,8 @@ ::flatbuffers::Offset<::tt::target::ttnn::SliceOp> createSliceOp(FlatbufferObjectCache &cache, SliceOp op) { auto in = cache.at<::tt::target::ttnn::TensorRef>( getOperandThroughDPSOps(op.getInput())); - auto out = cache.at<::tt::target::ttnn::TensorRef>( - getOperandThroughDPSOps(op.getResult())); + auto out = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedSize); auto begins = arrayAttrToFlatbuffer(cache, op.getBegins()); auto ends = @@ -1316,8 +1319,8 @@ ::flatbuffers::Offset<::tt::target::ttnn::MaxPool2dOp> createMaxPool2dOp(FlatbufferObjectCache &cache, MaxPool2dOp op) { auto in = cache.at<::tt::target::ttnn::TensorRef>( getOperandThroughDPSOps(op.getInput())); - auto out = cache.at<::tt::target::ttnn::TensorRef>( - getOperandThroughDPSOps(op.getResult())); + auto out = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedSize); auto device = getOperandThroughDPSOps(op.getDevice()); return ::tt::target::ttnn::CreateMaxPool2dOp( diff --git a/runtime/lib/ttnn/operations/matmul/matmul.cpp b/runtime/lib/ttnn/operations/matmul/matmul.cpp index 5537a18d41..567c9a7932 100644 --- a/runtime/lib/ttnn/operations/matmul/matmul.cpp +++ b/runtime/lib/ttnn/operations/matmul/matmul.cpp @@ -17,10 +17,8 @@ void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &lhs = tensorPool.at(op->a()->global_id()); const ::ttnn::Tensor &rhs = tensorPool.at(op->b()->global_id()); - const ::ttnn::Tensor &out = tensorPool.at(op->out()->global_id()); DEBUG_ASSERT(lhs.is_allocated()); DEBUG_ASSERT(rhs.is_allocated()); - DEBUG_ASSERT(out.is_allocated()); auto outputMemoryConfig = ::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded( @@ -35,7 +33,8 @@ void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { lhs, rhs, op->transpose_a(), op->transpose_b(), outputMemoryConfig, outputDataType, /*program_config=*/std::nullopt, /*activation=*/std::nullopt, /*compute_kernel_config=*/std::nullopt, - /*core_grid=*/std::nullopt, /*output_tile=*/std::nullopt, out); + /*core_grid=*/std::nullopt, /*output_tile=*/std::nullopt, + /* optional_output_tensor=*/std::nullopt); tensorPool.insert_or_assign(op->out()->global_id(), output); } @@ -48,11 +47,9 @@ void run(const ::tt::target::ttnn::LinearOp *op, ProgramContext &context) { std::optional<::ttnn::Tensor> bias = op->bias() ? std::make_optional(tensorPool.at(op->bias()->global_id())) : std::nullopt; - const ::ttnn::Tensor &out = tensorPool.at(op->out()->global_id()); DEBUG_ASSERT(lhs.is_allocated()); DEBUG_ASSERT(rhs.is_allocated()); DEBUG_ASSERT(!bias || bias->is_allocated()); - DEBUG_ASSERT(out.is_allocated()); auto outputMemoryConfig = ::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded( @@ -67,7 +64,8 @@ void run(const ::tt::target::ttnn::LinearOp *op, ProgramContext &context) { lhs, rhs, bias, op->transpose_a(), op->transpose_b(), outputMemoryConfig, outputDataType, /*program_config=*/std::nullopt, /*activation=*/std::nullopt, /*compute_kernel_config=*/std::nullopt, - /*core_grid=*/std::nullopt, /*output_tile=*/std::nullopt, out); + /*core_grid=*/std::nullopt, /*output_tile=*/std::nullopt, + /* optional_output_tensor=*/std::nullopt); tensorPool.insert_or_assign(op->out()->global_id(), output); } diff --git a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/cumsum_workaround.mlir b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/cumsum_workaround.mlir index 00b701f5d4..3ecb3fa5f8 100644 --- a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/cumsum_workaround.mlir +++ b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/cumsum_workaround.mlir @@ -20,28 +20,19 @@ module @moreh_cumsum attributes {tt.device = #device, tt.system_desc = #system_d %1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<1x32xui32, #ttnn_layout>) -> tensor<1x32xui32, #ttnn_layout1> %2 = "ttnn.to_device"(%1, %0) <{memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<1x32xui32, #ttnn_layout1>, !tt.device<#device>) -> tensor<1x32xui32, #ttnn_layout2> "ttnn.deallocate"(%1) <{force = false}> : (tensor<1x32xui32, #ttnn_layout1>) -> () - %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x1>>, >, shape = #ttnn.shape<1x32>}> : (!tt.device<#device>) -> tensor<1x32xui32, #ttnn_layout2> // CHECK: %[[RESHAPE:[0-9]+]] = "ttnn.reshape" // CHECK-SAME: {shape = [1 : i32, 32 : i32, 1 : i32, 1 : i32]} // CHECK-SAME: tensor<1x32xui32 // CHECK-SAME: -> tensor<1x32x1x1xui32 - // CHECK: %[[EMPTY:[0-9]+]] = "ttnn.empty" - // CHECK-SAME: dtype = #tt.supportedDataTypes - // CHECK-SAME: -> tensor<1x32x1x1xui32 // CHECK: %[[ARG1:[0-9]+]] = "ttnn.to_layout"(%[[RESHAPE]] // CHECK-SAME: dtype = #tt.supportedDataTypes // CHECK-SAME: tensor<1x32x1x1xui32 // CHECK-SAME: -> tensor<1x32x1x1xf32 - // CHECK: %[[ARG2:[0-9]+]] = "ttnn.to_layout"(%[[EMPTY]] - // CHECK-SAME: dtype = #tt.supportedDataTypes - // CHECK-SAME: tensor<1x32x1x1xui32 - // CHECK-SAME: -> tensor<1x32x1x1xf32 - // CHECK: %[[CUMSUM:[0-9]+]] = "ttnn.moreh_cumsum"(%[[ARG1]], %[[ARG2]]) + // CHECK: %[[CUMSUM:[0-9]+]] = "ttnn.moreh_cumsum"(%[[ARG1]]) // CHECK-SAME: {dim = 0 : i64} // CHECK-SAME: tensor<1x32x1x1xf32 - // CHECK-SAME: tensor<1x32x1x1xf32 // CHECK-SAME: -> tensor<1x32x1x1xf32 - %4 = "ttnn.moreh_cumsum"(%2, %3) <{dim = 0 : i64}> : (tensor<1x32xui32, #ttnn_layout2>, tensor<1x32xui32, #ttnn_layout2>) -> tensor<1x32xui32, #ttnn_layout2> + %3 = "ttnn.moreh_cumsum"(%2) <{dim = 0 : i64}> : (tensor<1x32xui32, #ttnn_layout2>) -> tensor<1x32xui32, #ttnn_layout2> // CHECK: %[[POSTLAYOUT:[0-9]+]] = "ttnn.to_layout"(%[[CUMSUM]] // CHECK-SAME: dtype = #tt.supportedDataTypes // CHECK-SAME: tensor<1x32x1x1xf32 @@ -51,11 +42,11 @@ module @moreh_cumsum attributes {tt.device = #device, tt.system_desc = #system_d // CHECK-SAME: tensor<1x32x1x1xui32 // CHECK-SAME: -> tensor<1x32xui32 "ttnn.deallocate"(%2) <{force = false}> : (tensor<1x32xui32, #ttnn_layout2>) -> () - %5 = "ttnn.from_device"(%4) : (tensor<1x32xui32, #ttnn_layout2>) -> tensor<1x32xui32, #ttnn_layout1> + %4 = "ttnn.from_device"(%3) : (tensor<1x32xui32, #ttnn_layout2>) -> tensor<1x32xui32, #ttnn_layout1> "ttnn.deallocate"(%3) <{force = false}> : (tensor<1x32xui32, #ttnn_layout2>) -> () - %6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout}> : (tensor<1x32xui32, #ttnn_layout1>) -> tensor<1x32xui32, #ttnn_layout> - "ttnn.deallocate"(%5) <{force = false}> : (tensor<1x32xui32, #ttnn_layout1>) -> () - return %6 : tensor<1x32xui32, #ttnn_layout> + %5 = "ttnn.to_layout"(%4) <{layout = #ttnn.layout}> : (tensor<1x32xui32, #ttnn_layout1>) -> tensor<1x32xui32, #ttnn_layout> + "ttnn.deallocate"(%4) <{force = false}> : (tensor<1x32xui32, #ttnn_layout1>) -> () + return %5 : tensor<1x32xui32, #ttnn_layout> } func.func public @test_cumsum_reshape(%arg0: tensor<1x32xf32, #ttnn_layout3>) -> tensor<1x32xf32, #ttnn_layout3> { @@ -64,63 +55,52 @@ module @moreh_cumsum attributes {tt.device = #device, tt.system_desc = #system_d %1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<1x32xf32, #ttnn_layout3>) -> tensor<1x32xf32, #ttnn_layout4> %2 = "ttnn.to_device"(%1, %0) <{memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<1x32xf32, #ttnn_layout4>, !tt.device<#device>) -> tensor<1x32xf32, #ttnn_layout5> "ttnn.deallocate"(%1) <{force = false}> : (tensor<1x32xf32, #ttnn_layout4>) -> () - %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x1>>, >, shape = #ttnn.shape<1x32>}> : (!tt.device<#device>) -> tensor<1x32xf32, #ttnn_layout5> // CHECK: %[[RESHAPE:[0-9]+]] = "ttnn.reshape" // CHECK-SAME: {shape = [1 : i32, 32 : i32, 1 : i32, 1 : i32]} // CHECK-SAME: tensor<1x32xf32 // CHECK-SAME: -> tensor<1x32x1x1xf32 - // CHECK: %[[EMPTY:[0-9]+]] = "ttnn.empty" - // CHECK-SAME: dtype = #tt.supportedDataTypes - // CHECK-SAME: -> tensor<1x32x1x1xf32 - // CHECK: %[[CUMSUM:[0-9]+]] = "ttnn.moreh_cumsum"(%[[RESHAPE]], %[[EMPTY]]) + // CHECK: %[[CUMSUM:[0-9]+]] = "ttnn.moreh_cumsum"(%[[RESHAPE]]) // CHECK-SAME: {dim = 1 : i64} // CHECK-SAME: tensor<1x32x1x1xf32 // CHECK-SAME: -> tensor<1x32x1x1xf32 - %4 = "ttnn.moreh_cumsum"(%2, %3) <{dim = 1 : i64}> : (tensor<1x32xf32, #ttnn_layout5>, tensor<1x32xf32, #ttnn_layout5>) -> tensor<1x32xf32, #ttnn_layout5> + %3 = "ttnn.moreh_cumsum"(%2) <{dim = 1 : i64}> : (tensor<1x32xf32, #ttnn_layout5>) -> tensor<1x32xf32, #ttnn_layout5> // CHECK: "ttnn.reshape"(%[[CUMSUM]]) // CHECK-SAME: {shape = [1 : i32, 32 : i32]} // CHECK-SAME: tensor<1x32x1x1xf32 // CHECK-SAME: -> tensor<1x32xf32 "ttnn.deallocate"(%2) <{force = false}> : (tensor<1x32xf32, #ttnn_layout5>) -> () - %5 = "ttnn.from_device"(%4) : (tensor<1x32xf32, #ttnn_layout5>) -> tensor<1x32xf32, #ttnn_layout4> + %4 = "ttnn.from_device"(%3) : (tensor<1x32xf32, #ttnn_layout5>) -> tensor<1x32xf32, #ttnn_layout4> "ttnn.deallocate"(%3) <{force = false}> : (tensor<1x32xf32, #ttnn_layout5>) -> () - %6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout}> : (tensor<1x32xf32, #ttnn_layout4>) -> tensor<1x32xf32, #ttnn_layout3> - "ttnn.deallocate"(%5) <{force = false}> : (tensor<1x32xf32, #ttnn_layout4>) -> () - return %6 : tensor<1x32xf32, #ttnn_layout3> + %5 = "ttnn.to_layout"(%4) <{layout = #ttnn.layout}> : (tensor<1x32xf32, #ttnn_layout4>) -> tensor<1x32xf32, #ttnn_layout3> + "ttnn.deallocate"(%4) <{force = false}> : (tensor<1x32xf32, #ttnn_layout4>) -> () + return %5 : tensor<1x32xf32, #ttnn_layout3> } func.func public @test_cumsum_layout(%arg0: tensor<1x32x64x64xui32, #ttnn_layout6>) -> tensor<1x32x64x64xui32, #ttnn_layout6> { // CHECK-LABEL: func.func public @test_cumsum_layout( %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> %1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<1x32x64x64xui32, #ttnn_layout6>) -> tensor<1x32x64x64xui32, #ttnn_layout7> + // CHECK: "ttnn.to_layout" %2 = "ttnn.to_device"(%1, %0) <{memory_config = #ttnn.memory_config<#dram, <<64x2>>, >}> : (tensor<1x32x64x64xui32, #ttnn_layout7>, !tt.device<#device>) -> tensor<1x32x64x64xui32, #ttnn_layout8> "ttnn.deallocate"(%1) <{force = false}> : (tensor<1x32x64x64xui32, #ttnn_layout7>) -> () - // CHECK: %[[EMPTY:[0-9]]] = "ttnn.empty" - // CHECK-SAME: -> tensor<1x32x64x64xui32 - %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x2>>, >, shape = #ttnn.shape<1x32x64x64>}> : (!tt.device<#device>) -> tensor<1x32x64x64xui32, #ttnn_layout8> // CHECK: %[[ARG1:[0-9]+]] = "ttnn.to_layout" // CHECK-SAME: dtype = #tt.supportedDataTypes // CHECK-SAME: tensor<1x32x64x64xui32 // CHECK-SAME: -> tensor<1x32x64x64xf32 - // CHECK: %[[ARG2:[0-9]+]] = "ttnn.to_layout"(%[[EMPTY]] - // CHECK-SAME: dtype = #tt.supportedDataTypes - // CHECK-SAME: tensor<1x32x64x64xui32 - // CHECK-SAME: -> tensor<1x32x64x64xf32 - // CHECK: %[[CUMSUM:[0-9]+]] = "ttnn.moreh_cumsum"(%[[ARG1]], %[[ARG2]]) + // CHECK: %[[CUMSUM:[0-9]+]] = "ttnn.moreh_cumsum"(%[[ARG1]]) // CHECK-SAME: <{dim = 1 : i64}> // CHECK-SAME: tensor<1x32x64x64xf32 - // CHECK-SAME: tensor<1x32x64x64xf32 // CHECK-SAME: -> tensor<1x32x64x64xf32 - %4 = "ttnn.moreh_cumsum"(%2, %3) <{dim = 1 : i64}> : (tensor<1x32x64x64xui32, #ttnn_layout8>, tensor<1x32x64x64xui32, #ttnn_layout8>) -> tensor<1x32x64x64xui32, #ttnn_layout8> + %3 = "ttnn.moreh_cumsum"(%2) <{dim = 1 : i64}> : (tensor<1x32x64x64xui32, #ttnn_layout8>) -> tensor<1x32x64x64xui32, #ttnn_layout8> // CHECK: "ttnn.to_layout"(%[[CUMSUM]], // CHECK-SAME: dtype = #tt.supportedDataTypes tensor<1x32x64x64xui32 "ttnn.deallocate"(%2) <{force = false}> : (tensor<1x32x64x64xui32, #ttnn_layout8>) -> () - %5 = "ttnn.from_device"(%4) : (tensor<1x32x64x64xui32, #ttnn_layout8>) -> tensor<1x32x64x64xui32, #ttnn_layout7> + %4 = "ttnn.from_device"(%3) : (tensor<1x32x64x64xui32, #ttnn_layout8>) -> tensor<1x32x64x64xui32, #ttnn_layout7> "ttnn.deallocate"(%3) <{force = false}> : (tensor<1x32x64x64xui32, #ttnn_layout8>) -> () - %6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout}> : (tensor<1x32x64x64xui32, #ttnn_layout7>) -> tensor<1x32x64x64xui32, #ttnn_layout6> - "ttnn.deallocate"(%5) <{force = false}> : (tensor<1x32x64x64xui32, #ttnn_layout7>) -> () - return %6 : tensor<1x32x64x64xui32, #ttnn_layout6> + %5 = "ttnn.to_layout"(%4) <{layout = #ttnn.layout}> : (tensor<1x32x64x64xui32, #ttnn_layout7>) -> tensor<1x32x64x64xui32, #ttnn_layout6> + "ttnn.deallocate"(%4) <{force = false}> : (tensor<1x32x64x64xui32, #ttnn_layout7>) -> () + return %5 : tensor<1x32x64x64xui32, #ttnn_layout6> } } diff --git a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/embedding_backward_workaround.mlir b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/embedding_backward_workaround.mlir index e6ee5b5786..54a2913e58 100644 --- a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/embedding_backward_workaround.mlir +++ b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/embedding_backward_workaround.mlir @@ -16,10 +16,8 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} { %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<1x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<1x32xf32, #ttnn_layout3> %2 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<16x4>>, >}> : (tensor<512x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<512x128xf32, #ttnn_layout4> %3 = "ttnn.to_layout"(%arg2, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x4>>, >}> : (tensor<1x32x128xf32, #ttnn_layout2>, !tt.device<#device>) -> tensor<1x32x128xf32, #ttnn_layout5> - %4 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<16x4>>, >, shape = #ttnn.shape<512x128>}> : (!tt.device<#device>) -> tensor<512x128xf32, #ttnn_layout4> - // CHECK: %[[EMPTY_OP:.*]] = "ttnn.empty"(%[[DEVICE_OP]]) - // CHECK-NEXT: "ttnn.reshape" - %5 = "ttnn.reshape"(%3) <{shape = [1 : i32, 1 : i32, 32 : i32, 128 : i32]}> : (tensor<1x32x128xf32, #ttnn_layout5>) -> tensor<1x1x32x128xf32, #ttnn_layout5> + // CHECK: "ttnn.reshape" + %4 = "ttnn.reshape"(%3) <{shape = [1 : i32, 1 : i32, 32 : i32, 128 : i32]}> : (tensor<1x32x128xf32, #ttnn_layout5>) -> tensor<1x1x32x128xf32, #ttnn_layout5> // Check that the input operand is transformed into the row major layout. // CHECK-NEXT: %[[TO_LAYOUT_INPUT:.*]] = "ttnn.to_layout" // CHECK-SAME: dtype = #tt.supportedDataTypes @@ -39,19 +37,14 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} { // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<1x4>>, > // CHECK-SAME: -> tensor<1x1x32x128xbf16 // Check that the data type of the output operand is transformed in bf16. - // CHECK-NEXT: %[[TO_LAYOUT_OUTPUT_DPS:.*]] = "ttnn.to_layout"(%[[EMPTY_OP]], %[[DEVICE_OP]]) - // CHECK-SAME: dtype = #tt.supportedDataTypes - // CHECK-SAME: layout = #ttnn.layout - // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<16x4>>, > - // CHECK-SAME: -> tensor<512x128xbf16 - %6 = "ttnn.embedding_bw"(%1, %2, %5, %4) <{dtype = #tt.supportedDataTypes, memory_config = #ttnn.memory_config<#dram, <<16x4>>, >}> : (tensor<1x32xf32, #ttnn_layout3>, tensor<512x128xf32, #ttnn_layout4>, tensor<1x1x32x128xf32, #ttnn_layout5>, tensor<512x128xf32, #ttnn_layout4>) -> tensor<512x128xf32, #ttnn_layout4> - // CHECK-NEXT: %[[EMBEDDING_BW_OP:.*]] = "ttnn.embedding_bw"(%[[TO_LAYOUT_INPUT]], %[[TO_LAYOUT_WEIGHTS]], %[[TO_LAYOUT_IN_GRADIENT]], %[[TO_LAYOUT_OUTPUT_DPS]]) + %5 = "ttnn.embedding_bw"(%1, %2, %4) <{dtype = #tt.supportedDataTypes, memory_config = #ttnn.memory_config<#dram, <<16x4>>, >}> : (tensor<1x32xf32, #ttnn_layout3>, tensor<512x128xf32, #ttnn_layout4>, tensor<1x1x32x128xf32, #ttnn_layout5>) -> tensor<512x128xf32, #ttnn_layout4> + // CHECK-NEXT: %[[EMBEDDING_BW_OP:.*]] = "ttnn.embedding_bw"(%[[TO_LAYOUT_INPUT]], %[[TO_LAYOUT_WEIGHTS]], %[[TO_LAYOUT_IN_GRADIENT]]) // Check that the output operand is transformed back into the f32 data type. // CHECK-NEXT: "ttnn.to_layout"(%[[EMBEDDING_BW_OP]]) // CHECK-SAME: dtype = #tt.supportedDataTypes // CHECK-SAME: layout = #ttnn.layout // CHECK-SAME: memory_config = #ttnn.memory_config<#system_memory, <<512x128>>> - %7 = "ttnn.to_layout"(%6) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<512x128>>>}> : (tensor<512x128xf32, #ttnn_layout4>) -> tensor<512x128xf32, #ttnn_layout1> - return %7 : tensor<512x128xf32, #ttnn_layout1> + %6 = "ttnn.to_layout"(%5) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<512x128>>>}> : (tensor<512x128xf32, #ttnn_layout4>) -> tensor<512x128xf32, #ttnn_layout1> + return %6 : tensor<512x128xf32, #ttnn_layout1> } } diff --git a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/embedding_workaround.mlir b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/embedding_workaround.mlir index a3a3860e30..173cb744cb 100644 --- a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/embedding_workaround.mlir +++ b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/embedding_workaround.mlir @@ -3,20 +3,12 @@ #dram = #ttnn.buffer_type #system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux-gnu"}], [{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]> #system_memory = #ttnn.buffer_type -#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<32x32xf32, #system_memory>> -#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<512x128xf32, #system_memory>> -#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<1024x128xf32, #system_memory>> -#ttnn_layout3 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #dram>, > -#ttnn_layout4 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<16x4x!tt.tile<32x32, f32>, #dram>, > -#ttnn_layout5 = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x4x!tt.tile<32x32, f32>, #dram>, > +#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #dram>, > +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<16x4x!tt.tile<32x32, f32>, #dram>, > +#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x4x!tt.tile<32x32, f32>, #dram>, > module attributes {tt.device = #device, tt.system_desc = #system_desc} { func.func @forward(%arg0: tensor<32x32xf32, #ttnn_layout>, %arg1: tensor<512x128xf32, #ttnn_layout1>) -> tensor<32x32x128xf32, #ttnn_layout2> { - %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> // CHECK: %[[DEVICE_OP:.*]] = "ttnn.get_device" - %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout3> - %2 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<16x4>>, >}> : (tensor<512x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<512x128xf32, #ttnn_layout4> - %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<32x4>>, >, shape = #ttnn.shape<32x32x128>}> : (!tt.device<#device>) -> tensor<32x32x128xf32, #ttnn_layout5> - // CHECK: %[[EMPTY_OP:.*]] = "ttnn.empty"(%[[DEVICE_OP]]) // Check that the input operand is transformed into the row major layout. // CHECK-NEXT: %[[TO_LAYOUT_INPUT:.*]] = "ttnn.to_layout" // CHECK-SAME: dtype = #tt.supportedDataTypes @@ -29,20 +21,13 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} { // CHECK-SAME: layout = #ttnn.layout // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<16x4>>, > // CHECK-SAME: -> tensor<512x128xbf16 - // Check that the data type of the output operand is transformed in bf16. - // CHECK-NEXT: %[[TO_LAYOUT_OUTPUT_DPS:.*]] = "ttnn.to_layout"(%[[EMPTY_OP]], %[[DEVICE_OP]]) - // CHECK-SAME: dtype = #tt.supportedDataTypes - // CHECK-SAME: layout = #ttnn.layout - // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<32x4>>, > - // CHECK-SAME: -> tensor<32x32x128xbf16 - %4 = "ttnn.embedding"(%1, %2, %3) : (tensor<32x32xf32, #ttnn_layout3>, tensor<512x128xf32, #ttnn_layout4>, tensor<32x32x128xf32, #ttnn_layout5>) -> tensor<32x32x128xf32, #ttnn_layout5> - // CHECK-NEXT: %[[EMBEDDING_OP:.*]] = "ttnn.embedding"(%[[TO_LAYOUT_INPUT]], %[[TO_LAYOUT_WEIGHTS]], %[[TO_LAYOUT_OUTPUT_DPS]]) + %0 = "ttnn.embedding"(%arg0, %arg1) : (tensor<32x32xf32, #ttnn_layout>, tensor<512x128xf32, #ttnn_layout1>) -> tensor<32x32x128xf32, #ttnn_layout2> + // CHECK-NEXT: %[[EMBEDDING_OP:.*]] = "ttnn.embedding"(%[[TO_LAYOUT_INPUT]], %[[TO_LAYOUT_WEIGHTS]]) // Check that the output operand is transformed back into the f32 data type. - // CHECK-NEXT: "ttnn.to_layout"(%[[EMBEDDING_OP]]) + // CHECK-NEXT: "ttnn.to_layout"(%[[EMBEDDING_OP]], %[[DEVICE_OP]]) // CHECK-SAME: dtype = #tt.supportedDataTypes - // CHECK-SAME: layout = #ttnn.layout - // CHECK-SAME: memory_config = #ttnn.memory_config<#system_memory, <<1024x128>>> - %5 = "ttnn.to_layout"(%4) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<1024x128>>>}> : (tensor<32x32x128xf32, #ttnn_layout5>) -> tensor<32x32x128xf32, #ttnn_layout2> - return %5 : tensor<32x32x128xf32, #ttnn_layout2> + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<32x4>>, > + return %0 : tensor<32x32x128xf32, #ttnn_layout2> } } diff --git a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/max_pool2d_workaround.mlir b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/max_pool2d_workaround.mlir index 929eaca4e5..259e912922 100644 --- a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/max_pool2d_workaround.mlir +++ b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/max_pool2d_workaround.mlir @@ -12,32 +12,25 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} { %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> // CHECK: %[[DEVICE_OP:.*]] = "ttnn.get_device" %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<512x1>>, >}> : (tensor<1x128x128x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<1x128x128x32xf32, #ttnn_layout2> + // CHECK: "ttnn.to_layout" %2 = "ttnn.reshape"(%1) <{shape = [1 : i32, 1 : i32, 16384 : i32, 32 : i32]}> : (tensor<1x128x128x32xf32, #ttnn_layout2>) -> tensor<1x1x16384x32xf32, #ttnn_layout2> - %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<128x1>>, >, shape = #ttnn.shape<1x1x4096x32>}> : (!tt.device<#device>) -> tensor<1x1x4096x32xf32, #ttnn_layout3> - // CHECK: %[[EMPTY_OP:.*]] = "ttnn.empty"(%[[DEVICE_OP]]) // Check that the input operand is transformed into the row major layout. - // CHECK-NEXT: %[[TO_LAYOUT_INPUT:.*]] = "ttnn.to_layout" + // CHECK: %[[TO_LAYOUT_INPUT:.*]] = "ttnn.to_layout" // CHECK-SAME: dtype = #tt.supportedDataTypes // CHECK-SAME: layout = #ttnn.layout // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<16384x32>>, > // CHECK-SAME: -> tensor<1x1x16384x32xbf16, - // Check that the output operand is transformed into the row major layout. - // CHECK-NEXT: %[[TO_LAYOUT_OUTPUT_DPS:.*]] = "ttnn.to_layout"(%[[EMPTY_OP]], %[[DEVICE_OP]]) - // CHECK-SAME: dtype = #tt.supportedDataTypes - // CHECK-SAME: layout = #ttnn.layout - // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<4096x32>>, > - // CHECK-SAME: -> tensor<1x1x4096x32xbf16, - %4 = "ttnn.max_pool2d"(%2, %3, %0) <{batch_size = 1 : si32, ceil_mode = false, channels = 32 : si32, dilation_height = 1 : si32, dilation_width = 1 : si32, input_height = 128 : si32, input_width = 128 : si32, kernel_height = 2 : si32, kernel_width = 2 : si32, padding_height = 0 : si32, padding_width = 0 : si32, stride_height = 2 : si32, stride_width = 2 : si32}> : (tensor<1x1x16384x32xf32, #ttnn_layout2>, tensor<1x1x4096x32xf32, #ttnn_layout3>, !tt.device<#device>) -> tensor<1x1x4096x32xf32, #ttnn_layout3> - // CHECK-NEXT: %[[MAX_POOL_2D_OP:.*]] = "ttnn.max_pool2d"(%[[TO_LAYOUT_INPUT]], %[[TO_LAYOUT_OUTPUT_DPS]], %[[DEVICE_OP]]) + %3 = "ttnn.max_pool2d"(%2, %0) <{batch_size = 1 : si32, ceil_mode = false, channels = 32 : si32, dilation_height = 1 : si32, dilation_width = 1 : si32, input_height = 128 : si32, input_width = 128 : si32, kernel_height = 2 : si32, kernel_width = 2 : si32, padding_height = 0 : si32, padding_width = 0 : si32, stride_height = 2 : si32, stride_width = 2 : si32}> : (tensor<1x1x16384x32xf32, #ttnn_layout2>, !tt.device<#device>) -> tensor<1x1x4096x32xf32, #ttnn_layout3> + // CHECK-NEXT: %[[MAX_POOL_2D_OP:.*]] = "ttnn.max_pool2d"(%[[TO_LAYOUT_INPUT]], %[[DEVICE_OP]]) // Check that the output operand is transformed back into the tile and f32 data type. // CHECK-NEXT: %[[TO_LAYOUT_OUTPUT:.*]] = "ttnn.to_layout"(%[[MAX_POOL_2D_OP]], %[[DEVICE_OP]]) // CHECK-SAME: dtype = #tt.supportedDataTypes // CHECK-SAME: layout = #ttnn.layout // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<128x1>>, > // CHECK-SAME: -> tensor<1x1x4096x32xf32 - %5 = "ttnn.reshape"(%4) <{shape = [1 : i32, 64 : i32, 64 : i32, 32 : i32]}> : (tensor<1x1x4096x32xf32, #ttnn_layout3>) -> tensor<1x64x64x32xf32, #ttnn_layout3> + %4 = "ttnn.reshape"(%3) <{shape = [1 : i32, 64 : i32, 64 : i32, 32 : i32]}> : (tensor<1x1x4096x32xf32, #ttnn_layout3>) -> tensor<1x64x64x32xf32, #ttnn_layout3> // CHECK-NEXT: ttnn.reshape - %6 = "ttnn.to_layout"(%5) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<4096x32>>>}> : (tensor<1x64x64x32xf32, #ttnn_layout3>) -> tensor<1x64x64x32xf32, #ttnn_layout1> - return %6 : tensor<1x64x64x32xf32, #ttnn_layout1> + %5 = "ttnn.to_layout"(%4) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<4096x32>>>}> : (tensor<1x64x64x32xf32, #ttnn_layout3>) -> tensor<1x64x64x32xf32, #ttnn_layout1> + return %5 : tensor<1x64x64x32xf32, #ttnn_layout1> } } diff --git a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/slice_workaround.mlir b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/slice_workaround.mlir index 36eb4d9c6d..3981c38851 100644 --- a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/slice_workaround.mlir +++ b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/slice_workaround.mlir @@ -8,30 +8,21 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} { func.func @test_strided_slice_workaround(%arg0: tensor<4x32x32xf32, #ttnn_layout>) -> tensor<2x16x8xf32, #ttnn_layout1> { // CHECK-LABEL: @test_strided_slice_workaround( - %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> - // CHECK: %[[EMPTY:[0-9]+]] = "ttnn.empty" - // CHECK-SAME: layout = #ttnn.layout - %1 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x1>>, >, shape = #ttnn.shape<2x16x8>}> : (!tt.device<#device>) -> tensor<2x16x8xf32, #ttnn_layout1> // CHECK: %[[ARG0:[0-9]+]] = "ttnn.to_layout"(%arg0, // CHECK-SAME: dtype = #tt.supportedDataTypes // CHECK-SAME: tensor<4x32x32xf32 // CHECK-SAME: -> tensor<4x32x32xbf16 - // CHECK: %[[ARG1:[0-9]+]] = "ttnn.to_layout"(%[[EMPTY]], - // CHECK-SAME: dtype = #tt.supportedDataTypes - // CHECK-SAME: tensor<2x16x8xf32 - // CHECK-SAME: -> tensor<2x16x8xbf16 - // CHECK: %[[SLICE:[0-9]+]] = "ttnn.slice"(%[[ARG0]], %[[ARG1]]) + // CHECK: %[[SLICE:[0-9]+]] = "ttnn.slice"(%[[ARG0]]) // CHECK-SAME: begins = [0 : i32, 0 : i32, 0 : i32] // CHECK-SAME: ends = [2 : i32, 16 : i32, 16 : i32] // CHECK-SAME: step = [1 : i32, 1 : i32, 2 : i32]} // CHECK-SAME: tensor<4x32x32xbf16 - // CHECK-SAME: tensor<2x16x8xbf16 // CHECK-SAME: -> tensor<2x16x8xbf16 - %2 = "ttnn.slice"(%arg0, %1) <{begins = [0 : i32, 0 : i32, 0 : i32], ends = [2 : i32, 16 : i32, 16 : i32], step = [1 : i32, 1 : i32, 2 : i32]}> : (tensor<4x32x32xf32, #ttnn_layout>, tensor<2x16x8xf32, #ttnn_layout1>) -> tensor<2x16x8xf32, #ttnn_layout1> + %1 = "ttnn.slice"(%arg0) <{begins = [0 : i32, 0 : i32, 0 : i32], ends = [2 : i32, 16 : i32, 16 : i32], step = [1 : i32, 1 : i32, 2 : i32]}> : (tensor<4x32x32xf32, #ttnn_layout>) -> tensor<2x16x8xf32, #ttnn_layout1> // CHECK: "ttnn.to_layout"(%[[SLICE]] // CHECK-SAME: dtype = #tt.supportedDataTypes // CHECK-SAME: tensor<2x16x8xbf16 // CHECK-SAME: -> tensor<2x16x8xf32 - return %2 : tensor<2x16x8xf32, #ttnn_layout1> + return %1 : tensor<2x16x8xf32, #ttnn_layout1> } } diff --git a/test/ttmlir/Dialect/TTNN/data_movement/concat/concat_multiple_tensors.mlir b/test/ttmlir/Dialect/TTNN/data_movement/concat/concat_multiple_tensors.mlir index eec9c91028..09ebb3ecaf 100644 --- a/test/ttmlir/Dialect/TTNN/data_movement/concat/concat_multiple_tensors.mlir +++ b/test/ttmlir/Dialect/TTNN/data_movement/concat/concat_multiple_tensors.mlir @@ -1,13 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s module attributes {} { func.func @forward() -> tensor<32x224xf32> { - // CHECK: = "ttnn.empty" %0 = tensor.empty() : tensor<32x32xf32> - // CHECK: = "ttnn.empty" %1 = tensor.empty() : tensor<32x64xf32> - // CHECK: = "ttnn.empty" %2 = tensor.empty() : tensor<32x128xf32> - // CHECK: = "ttnn.empty" %3 = tensor.empty() : tensor<32x224xf32> // CHECK: = "ttnn.concat" %4 = "ttir.concat"(%0, %1, %2, %3) <{dim = 1 : si32}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x128xf32>, tensor<32x224xf32>) -> tensor<32x224xf32> diff --git a/test/ttmlir/Dialect/TTNN/data_movement/concat/concat_negative_dim.mlir b/test/ttmlir/Dialect/TTNN/data_movement/concat/concat_negative_dim.mlir index 0798f29714..c47f0c6e85 100644 --- a/test/ttmlir/Dialect/TTNN/data_movement/concat/concat_negative_dim.mlir +++ b/test/ttmlir/Dialect/TTNN/data_movement/concat/concat_negative_dim.mlir @@ -1,7 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s module attributes {} { func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { - // CHECK: = "ttnn.empty" %0 = tensor.empty() : tensor<32x96xf32> // CHECK: = "ttnn.concat" %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = -1 : si32}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> diff --git a/test/ttmlir/Dialect/TTNN/data_movement/concat/simple_concat.mlir b/test/ttmlir/Dialect/TTNN/data_movement/concat/simple_concat.mlir index 4ce27a0da6..5165c60f34 100644 --- a/test/ttmlir/Dialect/TTNN/data_movement/concat/simple_concat.mlir +++ b/test/ttmlir/Dialect/TTNN/data_movement/concat/simple_concat.mlir @@ -1,7 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s module attributes {} { func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { - // CHECK: = "ttnn.empty" %0 = tensor.empty() : tensor<32x96xf32> // CHECK: = "ttnn.concat" %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> diff --git a/test/ttmlir/Dialect/TTNN/embedding/embedding_non_tile.mlir b/test/ttmlir/Dialect/TTNN/embedding/embedding_non_tile.mlir index 4c792322bb..d8ed061a5c 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/embedding_non_tile.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/embedding_non_tile.mlir @@ -1,7 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s module attributes {} { func.func @forward(%arg0: tensor<1x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<1x32x128xbf16> { - // CHECK: = "ttnn.empty" %0 = tensor.empty() : tensor<1x32x128xbf16> // CHECK: = "ttnn.embedding" %1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<1x32xbf16>, tensor<512x128xbf16>, tensor<1x32x128xbf16>) -> tensor<1x32x128xbf16> diff --git a/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir b/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir index 2b82396cea..836396149c 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir @@ -1,7 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s module attributes {} { func.func @gather_0(%operand: tensor<32000x1024xbf16>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xbf16> { - // CHECK: = "ttnn.empty" %0 = tensor.empty() : tensor<1x32x1024xbf16> // CHECK: = "ttnn.embedding" %1 = "ttir.gather"(%operand, %start_indices, %0) { @@ -18,7 +17,6 @@ module attributes {} { } func.func @gather_1(%operand: tensor<448x384xbf16>, %start_indices: tensor<1x2x1xi32>) -> tensor<1x2x384xbf16> { - // CHECK: = "ttnn.empty" %0 = tensor.empty() : tensor<1x2x384xbf16> // CHECK: = "ttnn.embedding" %1 = "ttir.gather"(%operand, %start_indices, %0) <{ @@ -35,7 +33,6 @@ module attributes {} { } func.func @gather_2(%operand: tensor<51864x384xbf16>, %start_indices: tensor<1x2xi32>) -> tensor<1x2x384xbf16> { - // CHECK: = "ttnn.empty" %0 = tensor.empty() : tensor<1x2x384xbf16> // CHECK: = "ttnn.embedding" %1 = "ttir.gather"(%operand, %start_indices, %0) <{ diff --git a/test/ttmlir/Dialect/TTNN/embedding/simple_embedding_backward.mlir b/test/ttmlir/Dialect/TTNN/embedding/simple_embedding_backward.mlir index 28c3d98ac8..aebc17084c 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/simple_embedding_backward.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/simple_embedding_backward.mlir @@ -4,7 +4,6 @@ module attributes {} { // Capture reshape output layout for validation // CHECK: [[RESHAPE_OUTPUT_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 32 + d1 * 32 + d2, d3), <1x1>, memref<1x4x!tt.tile<32x32, bf16>, #dram>, > %0 = tensor.empty() : tensor<512x128xbf16> - // CHECK: "ttnn.empty" // Verify inserted reshape op // CHECK: "ttnn.reshape" // CHECK-SAME: <{shape = [1 : i32, 1 : i32, 32 : i32, 128 : i32]}> diff --git a/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir index caf9ee5177..b2576f3c92 100644 --- a/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir +++ b/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir @@ -1,79 +1,62 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s module { func.func @linear_1d_1d_bias(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>, %bias: tensor<1xbf16>) -> tensor<1xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<1xbf16 %0 = tensor.empty() : tensor<1xbf16> // CHECK: "ttnn.linear" // CHECK-SAME: tensor<128xbf16 // CHECK-SAME: tensor<128xbf16 // CHECK-SAME: tensor<1xbf16 // CHECK-SAME: tensor<1xbf16 - // CHECK-SAME: tensor<1xbf16 %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>, tensor<1xbf16>) -> tensor<1xbf16> return %1 : tensor<1xbf16> } func.func @linear_1d_1d_bias_broadcast(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>, %bias: tensor<128xbf16>) -> tensor<128xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<128xbf16 %0 = tensor.empty() : tensor<128xbf16> // CHECK: "ttnn.linear" // CHECK-SAME: tensor<128xbf16 // CHECK-SAME: tensor<128xbf16 // CHECK-SAME: tensor<128xbf16 // CHECK-SAME: tensor<128xbf16 - // CHECK-SAME: tensor<128xbf16 %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor<128xbf16>, tensor<128xbf16>) -> tensor<128xbf16> return %1 : tensor<128xbf16> } func.func @linear_2d_2d_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<64x64xbf16 %0 = tensor.empty() : tensor<64x64xbf16> // CHECK: "ttnn.linear" // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: tensor<128x64xbf16 // CHECK-SAME: tensor<64x64xbf16 // CHECK-SAME: tensor<64x64xbf16 - // CHECK-SAME: tensor<64x64xbf16 %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } // linear nd - nd tests func.func @linear_nd_nd_bias_broadcast_bias(%arg0: tensor<14x7x32x32xbf16>, %arg1:tensor<14x1x32x64xbf16>, %bias: tensor<64xbf16>) -> tensor<14x7x32x64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<14x7x32x64xbf16 %0 = tensor.empty() : tensor<14x7x32x64xbf16> // CHECK: "ttnn.linear" // CHECK-SAME: tensor<14x7x32x32xbf16 // CHECK-SAME: tensor<14x1x32x64xbf16 // CHECK-SAME: tensor<64xbf16 // CHECK-SAME: tensor<14x7x32x64xbf16 - // CHECK-SAME: tensor<14x7x32x64xbf16 %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<14x7x32x32xbf16>, tensor<14x1x32x64xbf16>, tensor<64xbf16>, tensor<14x7x32x64xbf16>) -> tensor<14x7x32x64xbf16> return %1 : tensor<14x7x32x64xbf16> } func.func @linear_nd_nd_bias_broadcast_matmul(%arg0: tensor<3x64x128xbf16>, %arg1: tensor<4x3x128x32xbf16>, %bias: tensor<14x4x3x64x32xbf16>) -> tensor<14x4x3x64x32xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<14x4x3x64x32xbf16 %0 = tensor.empty() : tensor<14x4x3x64x32xbf16> // CHECK: "ttnn.linear" // CHECK-SAME: tensor<3x64x128xbf16 // CHECK-SAME: tensor<4x3x128x32xbf16 // CHECK-SAME: tensor<14x4x3x64x32xbf16 - // CHECK-SAME: tensor<14x4x3x64x32xbf16 %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<3x64x128xbf16>, tensor<4x3x128x32xbf16>, tensor<14x4x3x64x32xbf16>, tensor<14x4x3x64x32xbf16>) -> tensor<14x4x3x64x32xbf16> return %1 : tensor<14x4x3x64x32xbf16> } // Linear with transposed inputs tests. func.func @linear_2d_tranpose_2d_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>, %bias: tensor<128x128xbf16>) -> tensor<128x128xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<128x128xbf16 %0 = tensor.empty() : tensor<128x128xbf16> // CHECK: "ttnn.linear" // CHECK-SAME: transpose_a = true @@ -81,13 +64,10 @@ module { // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: tensor<128x128xbf16 - // CHECK-SAME: tensor<128x128xbf16 %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_a = true}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> return %1 : tensor<128x128xbf16> } func.func @linear_2d_2d_transpose_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<64x64xbf16 %0 = tensor.empty() : tensor<64x64xbf16> // CHECK: "ttnn.linear" // CHECK-SAME: transpose_a = false @@ -95,13 +75,10 @@ module { // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: tensor<64x64xbf16 - // CHECK-SAME: tensor<64x64xbf16 %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_b = true}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } func.func @linear_2d_tranpose_2d_transpose(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<128x128xbf16>) -> tensor<128x128xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<128x128xbf16 %0 = tensor.empty() : tensor<128x128xbf16> // CHECK: "ttnn.linear" // CHECK-SAME: transpose_a = true @@ -109,7 +86,6 @@ module { // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: tensor<128x64xbf16 // CHECK-SAME: tensor<128x128xbf16 - // CHECK-SAME: tensor<128x128xbf16 %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_a = true, transpose_b = true}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<128x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> return %1 : tensor<128x128xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir b/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir index 579bb2e82d..62bf9d5650 100644 --- a/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir +++ b/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir @@ -2,15 +2,12 @@ module { func.func @simple_linear_with_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<64x64xbf16 %0 = tensor.empty() : tensor<64x64xbf16> // CHECK: "ttnn.linear" // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: tensor<128x64xbf16 // CHECK-SAME: tensor<64x64xbf16 // CHECK-SAME: tensor<64x64xbf16 - // CHECK-SAME: tensor<64x64xbf16 %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_negative.mlir b/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_negative.mlir index 576777cecf..b3b4195674 100644 --- a/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_negative.mlir +++ b/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_negative.mlir @@ -5,9 +5,8 @@ module { func.func @matmul_negative_0d_1d_input_scalar(%arg0: tensor, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { // CHECK: error: 'ttnn.matmul' op Input A must be at least a 1D tensor - %0 = tensor.empty() : tensor<1xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> - return %1 : tensor<1xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) : (tensor, tensor<64xbf16>) -> tensor<1xbf16> + return %0 : tensor<1xbf16> } } @@ -15,9 +14,8 @@ module { module { func.func @matmul_negative_1d_0d_input_scalar(%arg0: tensor<128xbf16>, %arg1: tensor) -> tensor<1xbf16> { // CHECK: error: 'ttnn.matmul' op Input B must be at least a 1D tensor - %0 = tensor.empty() : tensor<1xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor, tensor<1xbf16>) -> tensor<1xbf16> - return %1 : tensor<1xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) : (tensor<128xbf16>, tensor) -> tensor<1xbf16> + return %0 : tensor<1xbf16> } } @@ -26,9 +24,8 @@ module { module { func.func @matmul_negative_1d_1d_output_scalar(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor { // CHECK: error: 'ttnn.matmul' op Scalar output is not supported, output must be at least a 1D tensor - %0 = tensor.empty() : tensor - %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor) -> tensor - return %1 : tensor + %0 = "ttnn.matmul"(%arg0, %arg1) : (tensor<128xbf16>, tensor<128xbf16>) -> tensor + return %0 : tensor } } @@ -36,9 +33,8 @@ module { module { func.func @matmul_negative_1d_1d_nonone_output(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<2xbf16> { // CHECK: error: 'ttnn.matmul' op Scalar output must be a 1D tensor of size 1 - %0 = tensor.empty() : tensor<2xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor<2xbf16>) -> tensor<2xbf16> - return %1 : tensor<2xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) : (tensor<128xbf16>, tensor<128xbf16>) -> tensor<2xbf16> + return %0 : tensor<2xbf16> } } @@ -47,9 +43,8 @@ module { module { func.func @matmul_negative_1d_1d_inner_dimension_mismatch(%arg0: tensor<128xbf16>, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions - %0 = tensor.empty() : tensor<1xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> - return %1 : tensor<1xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) : (tensor<128xbf16>, tensor<64xbf16>) -> tensor<1xbf16> + return %0 : tensor<1xbf16> } } @@ -57,9 +52,8 @@ module { module { func.func @matmul_negative_1d_2d_inner_dimension_mismatch(%arg0: tensor<64xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](64) and B[-2](128) must have matching inner dimensions - %0 = tensor.empty() : tensor<64xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<64xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> - return %1 : tensor<64xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) : (tensor<64xbf16>, tensor<128x64xbf16>) -> tensor<64xbf16> + return %0 : tensor<64xbf16> } } @@ -67,9 +61,8 @@ func.func @matmul_negative_1d_2d_inner_dimension_mismatch(%arg0: tensor<64xbf16> module { func.func @matmul_negative_2d_1d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions - %0 = tensor.empty() : tensor<64xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> - return %1 : tensor<64xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) : (tensor<64x128xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + return %0 : tensor<64xbf16> } } @@ -77,9 +70,8 @@ module { module { func.func @matmul_negative_2d_2d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x64xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions - %0 = tensor.empty() : tensor<64x64xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> - return %1 : tensor<64x64xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x64xbf16> + return %0 : tensor<64x64xbf16> } } @@ -87,9 +79,8 @@ module { module { func.func @matmul_negative_2d_transpose_2d_inner_dimension_mismatch(%arg0: tensor<128x64xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<128x128xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions - %0 = tensor.empty() : tensor<128x128xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) <{transpose_a = true}> : (tensor<128x64xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> - return %1 : tensor<128x128xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) <{transpose_a = true}> : (tensor<128x64xbf16>, tensor<64x128xbf16>) -> tensor<128x128xbf16> + return %0 : tensor<128x128xbf16> } } @@ -97,9 +88,8 @@ module { module { func.func @matmul_negative_2d_2d_transpose_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions - %0 = tensor.empty() : tensor<64x64xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) <{transpose_b = true}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> - return %1 : tensor<64x64xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) <{transpose_b = true}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<64x64xbf16> + return %0 : tensor<64x64xbf16> } } @@ -107,9 +97,8 @@ module { module { func.func @matmul_negative_nd_nd_inner_dimension_mismatch(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x64x128xbf16>) -> tensor<7x64x64xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions - %0 = tensor.empty() : tensor<7x64x64xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<1x64x128xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> - return %1 : tensor<7x64x64xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) : (tensor<7x64x128xbf16>, tensor<1x64x128xbf16>) -> tensor<7x64x64xbf16> + return %0 : tensor<7x64x64xbf16> } } @@ -118,9 +107,8 @@ module { module { func.func @matmul_negative_nd_nd_same_rank_batch_broadcast_incompatible_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<2x128x64xbf16>) -> tensor<7x64x64xbf16> { // CHECK: error: 'ttnn.matmul' op Batch dimensions of input A(7) and B(2) are not broadcast compatible - %0 = tensor.empty() : tensor<7x64x64xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<2x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> - return %1 : tensor<7x64x64xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) : (tensor<7x64x128xbf16>, tensor<2x128x64xbf16>) -> tensor<7x64x64xbf16> + return %0 : tensor<7x64x64xbf16> } } @@ -128,9 +116,8 @@ module { module { func.func @matmul_negative_nd_nd_same_rank_batch_broadcast_incompatible_2(%arg0: tensor<2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { // CHECK: error: 'ttnn.matmul' op Batch dimensions of input A(2,7) and B(7,1) are not broadcast compatible - %0 = tensor.empty() : tensor<7x64x64xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<2x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x7x64x64xbf16> - return %1 : tensor<7x7x64x64xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) : (tensor<2x7x64x128xbf16>, tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> + return %0 : tensor<7x7x64x64xbf16> } } @@ -138,9 +125,8 @@ module { module { func.func @matmul_negative_nd_nd_different_rank_batch_broadcast_incompatible(%arg0: tensor<12x2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { // CHECK: error: 'ttnn.matmul' op Batch dimensions of input A(12,2,7) and B(7,1) are not broadcast compatible - %0 = tensor.empty() : tensor<12x7x7x64x64xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<12x2x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> - return %1 : tensor<12x7x7x64x64xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) : (tensor<12x2x7x64x128xbf16>, tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> + return %0 : tensor<12x7x7x64x64xbf16> } } @@ -149,9 +135,8 @@ module { module { func.func @matmul_negative_2d_2d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttnn.matmul' op Output shape rank(1) must match the expected output shape rank(2) - %0 = tensor.empty() : tensor<64xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> - return %1 : tensor<64xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<64xbf16> + return %0 : tensor<64xbf16> } } @@ -159,9 +144,8 @@ module { module { func.func @matmul_negative_2d_2d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x128xbf16> { // CHECK: error: 'ttnn.matmul' op Output shape dimension[1](128) doesn't match the expected output shape dimension[1](64) - %0 = tensor.empty() : tensor<64x128xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> - return %1 : tensor<64x128xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<64x128xbf16> + return %0 : tensor<64x128xbf16> } } @@ -170,9 +154,8 @@ module { module { func.func @matmul_negative_2d_transpose_2d_output_shape_mismatch(%arg0: tensor<128x64xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<128x128xbf16> { // CHECK: error: 'ttnn.matmul' op Output shape dimension[0](128) doesn't match the expected output shape dimension[0](64) - %0 = tensor.empty() : tensor<128x128xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) <{transpose_a = true}> : (tensor<128x64xbf16>, tensor<128x64xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> - return %1 : tensor<128x128xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) <{transpose_a = true}> : (tensor<128x64xbf16>, tensor<128x64xbf16>) -> tensor<128x128xbf16> + return %0 : tensor<128x128xbf16> } } @@ -180,8 +163,7 @@ module { module { func.func @matmul_negative_2d_2d_transpose_output_shape_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<128x128xbf16> { // CHECK: error: 'ttnn.matmul' op Output shape dimension[0](128) doesn't match the expected output shape dimension[0](64) - %0 = tensor.empty() : tensor<128x128xbf16> - %1 = "ttnn.matmul"(%arg0, %arg1, %0) <{transpose_b = true}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> - return %1 : tensor<128x128xbf16> + %0 = "ttnn.matmul"(%arg0, %arg1) <{transpose_b = true}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<128x128xbf16> + return %0 : tensor<128x128xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_cumsum.mlir b/test/ttmlir/Dialect/TTNN/simple_cumsum.mlir index 8127692bd0..d88b157bdd 100644 --- a/test/ttmlir/Dialect/TTNN/simple_cumsum.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_cumsum.mlir @@ -7,7 +7,6 @@ module @moreh_cumsum attributes {} { // CHECK: ttnn.moreh_cumsum // CHECK-SAME: dim = 0 : i64 // CHECK-SAME: tensor<1x32x128x128xbf16, - // CHECK-SAME: tensor<1x32x128x128xbf16, // CHECK-SAME: -> tensor<1x32x128x128xbf16, %1 = "ttir.cumsum"(%arg0, %0) <{dim = 0 : i64}> : (tensor<1x32x128x128xbf16>, tensor<1x32x128x128xbf16>) -> tensor<1x32x128x128xbf16> return %1 : tensor<1x32x128x128xbf16> @@ -19,7 +18,6 @@ module @moreh_cumsum attributes {} { // CHECK: ttnn.moreh_cumsum // CHECK-SAME: dim = 1 : i64 // CHECK-SAME: tensor<4x4x128x128xf32, - // CHECK-SAME: tensor<4x4x128x128xf32, // CHECK-SAME: -> tensor<4x4x128x128xf32, %1 = "ttir.cumsum"(%arg0, %0) <{dim = 1 : i64}> : (tensor<4x4x128x128xf32>, tensor<4x4x128x128xf32>) -> tensor<4x4x128x128xf32> return %1 : tensor<4x4x128x128xf32> diff --git a/test/ttmlir/Silicon/StableHLO/n150/Binary/concat_op.mlir b/test/ttmlir/Silicon/StableHLO/n150/Binary/concat_op.mlir index 8344448c00..98d1823cfa 100644 --- a/test/ttmlir/Silicon/StableHLO/n150/Binary/concat_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/n150/Binary/concat_op.mlir @@ -9,12 +9,10 @@ module @jit_concat attributes {} { func.func public @test_concat_0(%arg0: tensor<32x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<96x32xf32> { // CHECK-LABEL: func.func public @test_concat_0 - // CHECK: ttnn.empty // CHECK: ttnn.concat // CHECK-SAME: dim = 0 // CHECK-SAME: tensor<32x32xf32, // CHECK-SAME: tensor<64x32xf32, - // CHECK-SAME: tensor<96x32xf32, // CHECK-SAME: -> tensor<96x32xf32, %0 = "stablehlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 @@ -24,12 +22,10 @@ module @jit_concat attributes {} { func.func public @test_concat_1(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { // CHECK-LABEL: func.func public @test_concat_1 - // CHECK: ttnn.empty // CHECK: ttnn.concat // CHECK-SAME: dim = 1 // CHECK-SAME: tensor<32x32xf32, // CHECK-SAME: tensor<32x64xf32, - // CHECK-SAME: tensor<32x96xf32, // CHECK-SAME: -> tensor<32x96xf32, %0 = "stablehlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 @@ -40,12 +36,10 @@ module @jit_concat attributes {} { func.func public @test_concat_2(%arg0: tensor<128x64xf32>, %arg1: tensor<128x96xf32>) -> tensor<128x160xf32> { // CHECK-LABEL: func.func public @test_concat_2 - // CHECK: ttnn.empty // CHECK: ttnn.concat // CHECK-SAME: dim = 1 // CHECK-SAME: tensor<128x64xf32, // CHECK-SAME: tensor<128x96xf32, - // CHECK-SAME: tensor<128x160xf32, // CHECK-SAME: -> tensor<128x160xf32, %0 = "stablehlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 @@ -55,12 +49,10 @@ module @jit_concat attributes {} { func.func public @test_concat_3(%arg0: tensor<64x32xf32>, %arg1: tensor<64x64xf32>) -> tensor<64x96xf32> { // CHECK-LABEL: func.func public @test_concat_3 - // CHECK: ttnn.empty // CHECK: ttnn.concat // CHECK-SAME: dim = 1 // CHECK-SAME: tensor<64x32xf32, // CHECK-SAME: tensor<64x64xf32, - // CHECK-SAME: tensor<64x96xf32, // CHECK-SAME: -> tensor<64x96xf32, %0 = "stablehlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 @@ -70,12 +62,10 @@ module @jit_concat attributes {} { func.func public @test_concat_4(%arg0: tensor<32x32x32x32xf32>, %arg1: tensor<32x32x32x64xf32>) -> tensor<32x32x32x96xf32> { // CHECK-LABEL: func.func public @test_concat_4 - // CHECK: ttnn.empty // CHECK: ttnn.concat // CHECK-SAME: dim = 3 // CHECK-SAME: tensor<32x32x32x32xf32, // CHECK-SAME: tensor<32x32x32x64xf32, - // CHECK-SAME: tensor<32x32x32x96xf32, // CHECK-SAME: -> tensor<32x32x32x96xf32, %0 = "stablehlo.concatenate"(%arg0, %arg1) { dimension = 3 : i64 @@ -93,15 +83,10 @@ module @jit_concat attributes {} { // CHECK-SAME: dtype = #tt.supportedDataTypes // CHECK-SAME: tensor<1x1xsi32 // CHECK-SAME: -> tensor<1x1xbf16 - // CHECK: %[[ARG2:[0-9]+]] = "ttnn.typecast" - // CHECK-SAME: dtype = #tt.supportedDataTypes - // CHECK-SAME: tensor<1x54xsi32 - // CHECK-SAME: -> tensor<1x54xbf16 - // CHECK: %[[CONCAT:[0-9]+]] = "ttnn.concat"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) + // CHECK: %[[CONCAT:[0-9]+]] = "ttnn.concat"(%[[ARG0]], %[[ARG1]]) // CHECK-SAME: dim = 1 : si32 // CHECK-SAME: tensor<1x53xbf16 // CHECK-SAME: tensor<1x1xbf16 - // CHECK-SAME: tensor<1x54xbf16 // CHECK-SAME: -> tensor<1x54xbf16 %0 = stablehlo.concatenate %arg0, %arg1, dim = 1 : (tensor<1x53xi64>, tensor<1x1xi64>) -> tensor<1x54xi64> // CHECK: "ttnn.typecast"(%[[CONCAT]]) diff --git a/test/ttmlir/Silicon/StableHLO/n150/dot_general/dot_general_op_2d.mlir b/test/ttmlir/Silicon/StableHLO/n150/dot_general/dot_general_op_2d.mlir index 179f112b49..96692aa0c9 100644 --- a/test/ttmlir/Silicon/StableHLO/n150/dot_general/dot_general_op_2d.mlir +++ b/test/ttmlir/Silicon/StableHLO/n150/dot_general/dot_general_op_2d.mlir @@ -9,11 +9,9 @@ module @jit_dot_general_2d attributes {} { func.func public @test_dot_general_2d(%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x8xf32>) -> tensor<16x8xf32> { // CHECK-LABEL: func.func public @test_dot_general - // CHECK: ttnn.empty // CHECK: ttnn.matmul // CHECK-SAME: tensor<16x32xf32, // CHECK-SAME: tensor<32x8xf32, - // CHECK-SAME: tensor<16x8xf32, // CHECK-SAME: -> tensor<16x8xf32 %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<16x32xf32>, tensor<32x8xf32>) -> tensor<16x8xf32> return %0 : tensor<16x8xf32> diff --git a/test/ttmlir/Silicon/StableHLO/n150/dot_general/dot_general_op_batch_matmul.mlir b/test/ttmlir/Silicon/StableHLO/n150/dot_general/dot_general_op_batch_matmul.mlir index f23ece73ff..e0e3c2d66a 100644 --- a/test/ttmlir/Silicon/StableHLO/n150/dot_general/dot_general_op_batch_matmul.mlir +++ b/test/ttmlir/Silicon/StableHLO/n150/dot_general/dot_general_op_batch_matmul.mlir @@ -9,11 +9,9 @@ module @jit_dot_general_4d attributes {} { func.func public @test_dot_general_4d(%arg0 : tensor<1x128x16x32xf32>, %arg1 : tensor<1x128x32x8xf32>) -> tensor<1x128x16x8xf32> { // CHECK-LABEL: func.func public @test_dot_general - // CHECK: ttnn.empty // CHECK: ttnn.matmul // CHECK-SAME: tensor<1x128x16x32xf32, // CHECK-SAME: tensor<1x128x32x8xf32, - // CHECK-SAME: tensor<1x128x16x8xf32, // CHECK-SAME: -> tensor<1x128x16x8xf32 %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] : (tensor<1x128x16x32xf32>, tensor<1x128x32x8xf32>) -> tensor<1x128x16x8xf32> return %0 : tensor<1x128x16x8xf32> diff --git a/test/ttmlir/Silicon/StableHLO/n150/maxpool2d_op.mlir b/test/ttmlir/Silicon/StableHLO/n150/maxpool2d_op.mlir index 7b2c50f967..768d125922 100644 --- a/test/ttmlir/Silicon/StableHLO/n150/maxpool2d_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/n150/maxpool2d_op.mlir @@ -29,7 +29,6 @@ module @max_pool2d attributes {} { // CHECK-SAME: padding_height = 1 : si32, padding_width = 1 : si32, // CHECK-SAME: stride_height = 3 : si32, stride_width = 3 : si32} // CHECK-SAME: tensor<1x1x1024x128xbf16 - // CHECK-SAME: tensor<1x1x121x128xbf16 // CHECK-SAME: -> tensor<1x1x121x128xbf16 %0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ ^bb0(%arg1: tensor, %arg2: tensor): @@ -70,7 +69,6 @@ module @max_pool2d attributes {} { // CHECK-SAME: padding_height = 0 : si32, padding_width = 0 : si32, // CHECK-SAME: stride_height = 1 : si32, stride_width = 1 : si32} // CHECK-SAME: tensor<1x1x784x192xbf16 - // CHECK-SAME: tensor<1x1x784x192xbf16 // CHECK-SAME: -> tensor<1x1x784x192xbf16 %0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ ^bb0(%arg1: tensor, %arg2: tensor): @@ -111,7 +109,6 @@ module @max_pool2d attributes {} { // CHECK-SAME: padding_height = 0 : si32, padding_width = 0 : si32, // CHECK-SAME: stride_height = 3 : si32, stride_width = 1 : si32} // CHECK-SAME: tensor<1x1x784x192xbf16 - // CHECK-SAME: tensor<1x1x270x192xbf16 // CHECK-SAME: -> tensor<1x1x270x192xbf16 %0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ ^bb0(%arg1: tensor, %arg2: tensor): diff --git a/test/ttmlir/Silicon/StableHLO/n150/moreh_cumsum_op.mlir b/test/ttmlir/Silicon/StableHLO/n150/moreh_cumsum_op.mlir index 00cf8c8857..659d590c12 100644 --- a/test/ttmlir/Silicon/StableHLO/n150/moreh_cumsum_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/n150/moreh_cumsum_op.mlir @@ -13,7 +13,6 @@ module @moreh_cumsum attributes {} { // CHECK: ttnn.moreh_cumsum // CHECK-SAME: dim = 0 : i64 // CHECK-SAME: tensor<8x2x4x16xbf16, - // CHECK-SAME: tensor<8x2x4x16xbf16, // CHECK-SAME: -> tensor<8x2x4x16xbf16, %0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<[[7, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ ^bb0(%arg1: tensor, %arg2: tensor): @@ -29,7 +28,6 @@ module @moreh_cumsum attributes {} { // CHECK: ttnn.moreh_cumsum // CHECK-SAME: dim = 1 : i64 // CHECK-SAME: tensor<8x2x4x16xbf16, - // CHECK-SAME: tensor<8x2x4x16xbf16, // CHECK-SAME: -> tensor<8x2x4x16xbf16, %0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<[[0, 0], [1, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ ^bb0(%arg1: tensor, %arg2: tensor): @@ -45,7 +43,6 @@ module @moreh_cumsum attributes {} { // CHECK: ttnn.moreh_cumsum // CHECK-SAME: dim = 2 : i64 // CHECK-SAME: tensor<8x2x4x16xbf16, - // CHECK-SAME: tensor<8x2x4x16xbf16, // CHECK-SAME: -> tensor<8x2x4x16xbf16, %0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<[[0, 0], [0, 0], [3, 0], [0, 0]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ ^bb0(%arg1: tensor, %arg2: tensor): @@ -61,7 +58,6 @@ module @moreh_cumsum attributes {} { // CHECK: ttnn.moreh_cumsum // CHECK-SAME: dim = 3 : i64 // CHECK-SAME: tensor<8x2x4x16xbf16, - // CHECK-SAME: tensor<8x2x4x16xbf16, // CHECK-SAME: -> tensor<8x2x4x16xbf16, %0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<[[0, 0], [0, 0], [0, 0], [15, 0]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ ^bb0(%arg1: tensor, %arg2: tensor): @@ -78,21 +74,13 @@ module @moreh_cumsum attributes {} { // CHECK-SAME: {shape = [1 : i32, 10 : i32, 1 : i32, 1 : i32]} // CHECK-SAME: tensor<1x10xui32 // CHECK-SAME: -> tensor<1x10x1x1xui32 - // CHECK: %[[EMPTY:[0-9]+]] = "ttnn.empty" - // CHECK-SAME: {dtype = #tt.supportedDataTypes - // CHECK-SAME: -> tensor<1x10x1x1xsi32 // CHECK: %[[ARG0:[0-9]+]] = "ttnn.typecast"(%[[RESHAPE]]) // CHECK-SAME: {dtype = #tt.supportedDataTypes} // CHECK-SAME: tensor<1x10x1x1xui32 // CHECK-SAME: -> tensor<1x10x1x1xf32 - // CHECK: %[[ARG1:[0-9]+]] = "ttnn.typecast"(%[[EMPTY]]) - // CHECK-SAME: {dtype = #tt.supportedDataTypes} - // CHECK-SAME: tensor<1x10x1x1xsi32 - // CHECK-SAME: -> tensor<1x10x1x1xf32 - // CHECK: %[[CUMSUM:[0-9]+]] = "ttnn.moreh_cumsum"(%[[ARG0]], %[[ARG1]]) + // CHECK: %[[CUMSUM:[0-9]+]] = "ttnn.moreh_cumsum"(%[[ARG0]]) // CHECK-SAME: <{dim = 1 : i64}> // CHECK-SAME: tensor<1x10x1x1xf32 - // CHECK-SAME: tensor<1x10x1x1xf32 // CHECK-SAME: -> tensor<1x10x1x1xf32 // CHECK: %[[TYPECAST:[0-9]+]] = "ttnn.typecast"(%[[CUMSUM]]) // CHECK-SAME: {dtype = #tt.supportedDataTypes} diff --git a/test/ttmlir/Silicon/StableHLO/n150/slice_op.mlir b/test/ttmlir/Silicon/StableHLO/n150/slice_op.mlir index 71497db195..dada11fd9c 100644 --- a/test/ttmlir/Silicon/StableHLO/n150/slice_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/n150/slice_op.mlir @@ -9,13 +9,11 @@ module @mod_slice attributes {} { func.func public @test_slice(%arg0: tensor<32x64xbf16>) -> tensor<8x8xbf16> { // CHECK-LABEL: func.func public @test_slice - // CHECK: ttnn.empty // CHECK: ttnn.slice // CHECK-SAME: begins = [0 : i32, 16 : i32], // CHECK-SAME: ends = [16 : i32, 32 : i32], // CHECK-SAME: step = [2 : i32, 2 : i32] // CHECK-SAME: tensor<32x64xbf16, - // CHECK-SAME: tensor<8x8xbf16, // CHECK-SAME: -> tensor<8x8xbf16 %result = "stablehlo.slice"(%arg0) { start_indices = array, @@ -27,13 +25,11 @@ module @mod_slice attributes {} { func.func public @test_slice_f32(%arg0: tensor<32x64xf32>) -> (tensor<16x16xf32>) { // CHECK-LABEL: @test_slice_f32( - // CHECK: ttnn.empty // CHECK: ttnn.slice // CHECK-SAME: begins = [0 : i32, 16 : i32], // CHECK-SAME: ends = [16 : i32, 32 : i32], // CHECK-SAME: step = [1 : i32, 1 : i32] // CHECK-SAME: tensor<32x64xf32 - // CHECK-SAME: tensor<16x16xf32 // CHECK-SAME: -> tensor<16x16xf32 %0 = stablehlo.slice %arg0 [0:16, 16:32] : (tensor<32x64xf32>) -> tensor<16x16xf32> return %0 : tensor<16x16xf32> @@ -41,13 +37,11 @@ module @mod_slice attributes {} { func.func public @test_slice_non_tilize(%arg0: tensor<32x64xf32>) -> (tensor<14x14xf32>) { // CHECK-LABEL: @test_slice_non_tilize( - // CHECK: ttnn.empty // CHECK: ttnn.slice // CHECK-SAME: begins = [0 : i32, 16 : i32], // CHECK-SAME: ends = [14 : i32, 30 : i32], // CHECK-SAME: step = [1 : i32, 1 : i32] // CHECK-SAME: tensor<32x64xf32 - // CHECK-SAME: tensor<14x14xf32 // CHECK-SAME: -> tensor<14x14xf32 %0 = stablehlo.slice %arg0 [0:14, 16:30] : (tensor<32x64xf32>) -> tensor<14x14xf32> return %0 : tensor<14x14xf32> @@ -55,13 +49,11 @@ module @mod_slice attributes {} { func.func public @test_slice_strided(%arg0: tensor<32x64xf32>) -> (tensor<8x8xf32>) { // CHECK-LABEL: @test_slice_strided( - // CHECK: ttnn.empty // CHECK: ttnn.slice // CHECK-SAME: begins = [0 : i32, 16 : i32], // CHECK-SAME: ends = [16 : i32, 32 : i32], // CHECK-SAME: step = [2 : i32, 2 : i32] // CHECK-SAME: tensor<32x64xbf16 - // CHECK-SAME: tensor<8x8xbf16 // CHECK-SAME: -> tensor<8x8xbf16 %0 = stablehlo.slice %arg0 [0:16:2, 16:32:2] : (tensor<32x64xf32>) -> tensor<8x8xf32> return %0 : tensor<8x8xf32> @@ -69,21 +61,15 @@ module @mod_slice attributes {} { func.func @test_slice_strided_f32(%arg0: tensor<1x128x128x192xf32>) -> tensor<1x64x128x192xf32> { // CHECK-LABEL: @test_slice_strided_f32( - // CHECK: ttnn.empty // CHECK: ttnn.typecast // CHECK-SAME: dtype = #tt.supportedDataTypes // CHECK-SAME: tensor<1x128x128x192xf32 // CHECK-SAME:-> tensor<1x128x128x192xbf16 - // CHECK: ttnn.typecast - // CHECK-SAME: dtype = #tt.supportedDataTypes - // CHECK-SAME: tensor<1x64x128x192xf32 - // CHECK-SAME: -> tensor<1x64x128x192xbf16 // CHECK: ttnn.slice // CHECK-SAME: begins = [0 : i32, 0 : i32, 0 : i32, 0 : i32], // CHECK-SAME: ends = [1 : i32, 128 : i32, 128 : i32, 192 : i32], // CHECK-SAME: step = [1 : i32, 2 : i32, 1 : i32, 1 : i32] // CHECK-SAME: tensor<1x128x128x192xbf16 - // CHECK-SAME: tensor<1x64x128x192xbf16 // CHECK-SAME: -> tensor<1x64x128x192xbf16 %0 = stablehlo.slice %arg0 [0:1, 0:128:2, 0:128, 0:192] : (tensor<1x128x128x192xf32>) -> tensor<1x64x128x192xf32> // CHECK: ttnn.typecast diff --git a/test/ttmlir/Silicon/TTNN/n150/deallocate.mlir b/test/ttmlir/Silicon/TTNN/n150/deallocate.mlir index a26d8ac8e3..82c2acda18 100644 --- a/test/ttmlir/Silicon/TTNN/n150/deallocate.mlir +++ b/test/ttmlir/Silicon/TTNN/n150/deallocate.mlir @@ -6,21 +6,21 @@ module @"dealloc_test" attributes {} { func.func @main(%arg0: tensor<1x784xf32> loc("Dealloc":4294967295:0), %arg1: tensor<1x10xf32> loc("Dealloc":4294967295:0), %arg2: tensor<256x10xf32> loc("Dealloc":4294967295:0), %arg3: tensor<1x256xf32> loc("Dealloc":4294967295:0), %arg4: tensor<784x256xf32> loc("Dealloc":4294967295:0)) -> tensor<1x10xf32> { %0 = tensor.empty() : tensor<1x256xf32> loc(#loc8) %1 = "ttir.matmul"(%arg0, %arg4, %0) : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) - // CHECK: %{{.+}} = "ttnn.matmul"([[I1:%.+]], [[I2:%.+]], [[O1:%.+]]) {{.+}} -> tensor<1x256xf32, {{.+}}> + // CHECK: %[[MATMUL1:.*]] = "ttnn.matmul"([[I1:%.+]], [[I2:%.+]]) {{.+}} -> tensor<1x256xf32, {{.+}}> %2 = tensor.empty() : tensor<1x256xf32> loc(#loc9) %3 = "ttir.add"(%1, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<1x256xf32>, tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc9) // CHECK: %{{.+}} = "ttnn.add"([[I1:%.+]], [[I2:%.+]]) {{.+}} -> tensor<1x256xf32, {{.+}}> - // CHECK: "ttnn.deallocate"([[O1]]) {{.+}} : (tensor<1x256xf32, {{.+}}>) -> () + // CHECK: "ttnn.deallocate"(%[[MATMUL1]]) {{.+}} : (tensor<1x256xf32, {{.+}}>) -> () %4 = tensor.empty() : tensor<1x256xf32> loc(#loc10) %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array}> : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc10) // CHECK: %{{.+}} = "ttnn.relu"([[I1:%.+]]) {{.+}} -> tensor<1x256xf32, {{.+}}> %6 = tensor.empty() : tensor<1x10xf32> loc(#loc11) %7 = "ttir.matmul"(%5, %arg2, %6) : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc11) - // CHECK: %{{.+}} = "ttnn.matmul"([[I1:%.+]], [[I2:%.+]], [[O4:%.+]]) {{.+}} -> tensor<1x10xf32, {{.+}}> + // CHECK: %[[MATMUL2:.*]] = "ttnn.matmul"([[I1:%.+]], [[I2:%.+]]) {{.+}} -> tensor<1x10xf32, {{.+}}> %8 = tensor.empty() : tensor<1x10xf32> loc(#loc12) %9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc12) // CHECK: %{{.+}} = "ttnn.add"([[I1:%.+]], [[I2:%.+]]) {{.+}} -> tensor<1x10xf32,{{.+}}> - // CHECK: "ttnn.deallocate"([[O4]]) {{.+}} : (tensor<1x10xf32, {{.+}}>) -> () + // CHECK: "ttnn.deallocate"(%[[MATMUL2]]) {{.+}} : (tensor<1x10xf32, {{.+}}>) -> () %10 = tensor.empty() : tensor<1x10xf32> loc(#loc13) %11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc13) // CHECK: %{{.+}} = "ttnn.softmax"([[I1:%.+]]) {{.+}} -> tensor<1x10xf32, {{.+}}> diff --git a/test/ttmlir/Silicon/TTNN/n150/eltwise/binary/concat/concat.mlir b/test/ttmlir/Silicon/TTNN/n150/eltwise/binary/concat/concat.mlir index 2fd1a0c59c..0349171341 100644 --- a/test/ttmlir/Silicon/TTNN/n150/eltwise/binary/concat/concat.mlir +++ b/test/ttmlir/Silicon/TTNN/n150/eltwise/binary/concat/concat.mlir @@ -3,7 +3,6 @@ // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn func.func @concat(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { - // CHECK: = "ttnn.empty" %0 = tensor.empty() : tensor<32x96xf32> // CHECK: = "ttnn.concat" %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> diff --git a/test/ttmlir/Silicon/TTNN/n150/embedding/embedding_1d_tensor.mlir b/test/ttmlir/Silicon/TTNN/n150/embedding/embedding_1d_tensor.mlir index aadc76d55a..dcdada99b3 100644 --- a/test/ttmlir/Silicon/TTNN/n150/embedding/embedding_1d_tensor.mlir +++ b/test/ttmlir/Silicon/TTNN/n150/embedding/embedding_1d_tensor.mlir @@ -3,7 +3,6 @@ // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn module attributes {} { func.func @forward(%arg0: tensor<32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x128xbf16> { - // CHECK: = "ttnn.empty" %0 = tensor.empty() : tensor<32x128xbf16> // CHECK: = "ttnn.embedding" %1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<32xbf16>, tensor<512x128xbf16>, tensor<32x128xbf16>) -> tensor<32x128xbf16> diff --git a/test/ttmlir/Silicon/TTNN/n150/embedding/embedding_backward.mlir b/test/ttmlir/Silicon/TTNN/n150/embedding/embedding_backward.mlir index 4686dca728..54b75b95d5 100644 --- a/test/ttmlir/Silicon/TTNN/n150/embedding/embedding_backward.mlir +++ b/test/ttmlir/Silicon/TTNN/n150/embedding/embedding_backward.mlir @@ -3,7 +3,6 @@ // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn module attributes {} { func.func @backward(%arg0: tensor<1x32xf32>, %arg1: tensor<512x128xf32>, %arg2: tensor<1x32x128xf32>) -> tensor<512x128xf32> { - // CHECK: %{{[0-9]+}} = "ttnn.empty" %0 = tensor.empty() : tensor<512x128xf32> // CHECK: %{{[0-9]+}} = "ttnn.embedding_bw" %1 = "ttir.embedding_backward"(%arg0, %arg1, %arg2, %0) : (tensor<1x32xf32>, tensor<512x128xf32>, tensor<1x32x128xf32>, tensor<512x128xf32>) -> tensor<512x128xf32> diff --git a/test/ttmlir/Silicon/TTNN/n150/embedding/embedding_non_tile.mlir b/test/ttmlir/Silicon/TTNN/n150/embedding/embedding_non_tile.mlir index 80c373981a..e7512ec987 100644 --- a/test/ttmlir/Silicon/TTNN/n150/embedding/embedding_non_tile.mlir +++ b/test/ttmlir/Silicon/TTNN/n150/embedding/embedding_non_tile.mlir @@ -3,7 +3,6 @@ // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn module attributes {} { func.func @forward(%arg0: tensor<1x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<1x32x128xbf16> { - // CHECK: = "ttnn.empty" %0 = tensor.empty() : tensor<1x32x128xbf16> // CHECK: = "ttnn.embedding" %1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<1x32xbf16>, tensor<512x128xbf16>, tensor<1x32x128xbf16>) -> tensor<1x32x128xbf16> diff --git a/test/ttmlir/Silicon/TTNN/n150/embedding/simple_embedding.mlir b/test/ttmlir/Silicon/TTNN/n150/embedding/simple_embedding.mlir index 38f2f1a353..cf7d572c80 100644 --- a/test/ttmlir/Silicon/TTNN/n150/embedding/simple_embedding.mlir +++ b/test/ttmlir/Silicon/TTNN/n150/embedding/simple_embedding.mlir @@ -3,7 +3,6 @@ // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn module attributes {} { func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x32x128xf32> { - // CHECK: = "ttnn.empty" %0 = tensor.empty() : tensor<32x32x128xf32> // CHECK: = "ttnn.embedding" %1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<32x32xf32>, tensor<512x128xf32>, tensor<32x32x128xf32>) -> tensor<32x32x128xf32> diff --git a/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_concat.mlir b/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_concat.mlir index 8716cb104b..034fa469a0 100644 --- a/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_concat.mlir +++ b/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_concat.mlir @@ -2,7 +2,6 @@ // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn func.func @concat(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { - // CHECK: = "ttnn.empty" %0 = tensor.empty() : tensor<32x96xf32> // CHECK: = "ttnn.concat" %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> diff --git a/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_conv2d_config.mlir b/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_conv2d_config.mlir index 1776705e7e..038111da77 100644 --- a/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_conv2d_config.mlir +++ b/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_conv2d_config.mlir @@ -35,14 +35,13 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} { func.func @forward(%arg0: tensor<1x32x32x64xbf16, #ttnn_layout>, %arg1: tensor<64x64x3x3xbf16, #ttnn_layout1>, %arg2: tensor<1x1x1x64xbf16, #ttnn_layout2>) -> tensor<1x30x30x64xbf16, #ttnn_layout3> { %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> - %1 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<29x2>>, >, shape = #ttnn.shape<1x1x900x64>}> : (!tt.device<#device>) -> tensor<1x1x900x64xbf16, #ttnn_layout4> - %2 = "ttnn.conv2d"(%arg0, %arg1, %arg2, %1, %0) <{batch_size = 1 : i32, conv2d_config = #conv2d_config, dilation = array, groups = 1 : i32, in_channels = 64 : i32, input_height = 32 : i32, input_width = 32 : i32, kernel_size = array, out_channels = 64 : i32, padding = array, stride = array}> : (tensor<1x32x32x64xbf16, #ttnn_layout>, tensor<64x64x3x3xbf16, #ttnn_layout1>, tensor<1x1x1x64xbf16, #ttnn_layout2>, tensor<1x1x900x64xbf16, #ttnn_layout4>, !tt.device<#device>) -> tensor<1x1x900x64xbf16, #ttnn_layout4> - %3 = "ttnn.reshape"(%2) <{shape = [1 : i32, 30 : i32, 30 : i32, 64 : i32]}> : (tensor<1x1x900x64xbf16, #ttnn_layout4>) -> tensor<1x30x30x64xbf16, #ttnn_layout4> + %1 = "ttnn.conv2d"(%arg0, %arg1, %arg2, %0) <{batch_size = 1 : i32, conv2d_config = #conv2d_config, dilation = array, groups = 1 : i32, in_channels = 64 : i32, input_height = 32 : i32, input_width = 32 : i32, kernel_size = array, out_channels = 64 : i32, padding = array, stride = array}> : (tensor<1x32x32x64xbf16, #ttnn_layout>, tensor<64x64x3x3xbf16, #ttnn_layout1>, tensor<1x1x1x64xbf16, #ttnn_layout2>, !tt.device<#device>) -> tensor<1x1x900x64xbf16, #ttnn_layout4> + %2 = "ttnn.reshape"(%1) <{shape = [1 : i32, 30 : i32, 30 : i32, 64 : i32]}> : (tensor<1x1x900x64xbf16, #ttnn_layout4>) -> tensor<1x30x30x64xbf16, #ttnn_layout4> "ttnn.deallocate"(%1) <{force = false}> : (tensor<1x1x900x64xbf16, #ttnn_layout4>) -> () - %4 = "ttnn.from_device"(%3) : (tensor<1x30x30x64xbf16, #ttnn_layout4>) -> tensor<1x30x30x64xbf16, #ttnn_layout5> - "ttnn.deallocate"(%3) <{force = false}> : (tensor<1x30x30x64xbf16, #ttnn_layout4>) -> () - %5 = "ttnn.to_layout"(%4) <{layout = #ttnn.layout}> : (tensor<1x30x30x64xbf16, #ttnn_layout5>) -> tensor<1x30x30x64xbf16, #ttnn_layout3> - "ttnn.deallocate"(%4) <{force = false}> : (tensor<1x30x30x64xbf16, #ttnn_layout5>) -> () - return %5 : tensor<1x30x30x64xbf16, #ttnn_layout3> + %3 = "ttnn.from_device"(%2) : (tensor<1x30x30x64xbf16, #ttnn_layout4>) -> tensor<1x30x30x64xbf16, #ttnn_layout5> + "ttnn.deallocate"(%2) <{force = false}> : (tensor<1x30x30x64xbf16, #ttnn_layout4>) -> () + %4 = "ttnn.to_layout"(%3) <{layout = #ttnn.layout}> : (tensor<1x30x30x64xbf16, #ttnn_layout5>) -> tensor<1x30x30x64xbf16, #ttnn_layout3> + "ttnn.deallocate"(%3) <{force = false}> : (tensor<1x30x30x64xbf16, #ttnn_layout5>) -> () + return %4 : tensor<1x30x30x64xbf16, #ttnn_layout3> } } diff --git a/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_cumsum.mlir b/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_cumsum.mlir index 9a0120a917..4ec0cb7167 100644 --- a/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_cumsum.mlir +++ b/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_cumsum.mlir @@ -9,7 +9,6 @@ module @moreh_cumsum attributes {} { // CHECK: ttnn.moreh_cumsum // CHECK-SAME: dim = 0 : i64 // CHECK-SAME: tensor<1x32x128x128xf32, - // CHECK-SAME: tensor<1x32x128x128xf32, // CHECK-SAME: -> tensor<1x32x128x128xf32, %1 = "ttir.cumsum"(%arg0, %0) <{dim = 0 : i64}> : (tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> return %1 : tensor<1x32x128x128xf32> diff --git a/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_embedding.mlir b/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_embedding.mlir index 1d1457602f..28d7acbfbf 100644 --- a/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_embedding.mlir +++ b/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_embedding.mlir @@ -3,7 +3,6 @@ // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn module attributes {} { func.func @forward(%arg0: tensor<32x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x32x128xbf16> { - // CHECK: = "ttnn.empty" %0 = tensor.empty() : tensor<32x32x128xbf16> // CHECK: = "ttnn.embedding" %1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<32x32xbf16>, tensor<512x128xbf16>, tensor<32x32x128xbf16>) -> tensor<32x32x128xbf16> diff --git a/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_linear.mlir b/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_linear.mlir index ab073ef75d..bbca51f90e 100644 --- a/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_linear.mlir +++ b/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_linear.mlir @@ -4,15 +4,12 @@ module { func.func @linear(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<64x64xbf16 %0 = tensor.empty() : tensor<64x64xbf16> // CHECK: "ttnn.linear" // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: tensor<128x64xbf16 // CHECK-SAME: tensor<64x64xbf16 // CHECK-SAME: tensor<64x64xbf16 - // CHECK-SAME: tensor<64x64xbf16 %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/n150/simple_cumsum.mlir b/test/ttmlir/Silicon/TTNN/n150/simple_cumsum.mlir index 6cb7ef69b2..73d4f7bf37 100644 --- a/test/ttmlir/Silicon/TTNN/n150/simple_cumsum.mlir +++ b/test/ttmlir/Silicon/TTNN/n150/simple_cumsum.mlir @@ -11,7 +11,6 @@ module @moreh_cumsum attributes {} { // CHECK: ttnn.moreh_cumsum // CHECK-SAME: dim = 0 : i64 // CHECK-SAME: tensor<1x32x128x128xf32, - // CHECK-SAME: tensor<1x32x128x128xf32, // CHECK-SAME: -> tensor<1x32x128x128xf32, %1 = "ttir.cumsum"(%arg0, %0) <{dim = 0 : i64}> : (tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> return %1 : tensor<1x32x128x128xf32> @@ -23,7 +22,6 @@ module @moreh_cumsum attributes {} { // CHECK: ttnn.moreh_cumsum // CHECK-SAME: dim = 1 : i64 // CHECK-SAME: tensor<4x4x128x128xf32, - // CHECK-SAME: tensor<4x4x128x128xf32, // CHECK-SAME: -> tensor<4x4x128x128xf32, %1 = "ttir.cumsum"(%arg0, %0) <{dim = 1 : i64}> : (tensor<4x4x128x128xf32>, tensor<4x4x128x128xf32>) -> tensor<4x4x128x128xf32> return %1 : tensor<4x4x128x128xf32> @@ -35,7 +33,6 @@ module @moreh_cumsum attributes {} { // CHECK: ttnn.moreh_cumsum // CHECK-SAME: dim = 2 : i64 // CHECK-SAME: tensor<4x4x128x128xf32, - // CHECK-SAME: tensor<4x4x128x128xf32, // CHECK-SAME: -> tensor<4x4x128x128xf32, %1 = "ttir.cumsum"(%arg0, %0) <{dim = 2 : i64}> : (tensor<4x4x128x128xf32>, tensor<4x4x128x128xf32>) -> tensor<4x4x128x128xf32> return %1 : tensor<4x4x128x128xf32> @@ -47,7 +44,6 @@ module @moreh_cumsum attributes {} { // CHECK: ttnn.moreh_cumsum // CHECK-SAME: dim = 3 : i64 // CHECK-SAME: tensor<4x4x128x128xf32, - // CHECK-SAME: tensor<4x4x128x128xf32, // CHECK-SAME: -> tensor<4x4x128x128xf32, %1 = "ttir.cumsum"(%arg0, %0) <{dim = 3 : i64}> : (tensor<4x4x128x128xf32>, tensor<4x4x128x128xf32>) -> tensor<4x4x128x128xf32> return %1 : tensor<4x4x128x128xf32> diff --git a/test/ttmlir/Silicon/TTNN/n150/simple_linear.mlir b/test/ttmlir/Silicon/TTNN/n150/simple_linear.mlir index a26c06c2f3..6357e329c7 100644 --- a/test/ttmlir/Silicon/TTNN/n150/simple_linear.mlir +++ b/test/ttmlir/Silicon/TTNN/n150/simple_linear.mlir @@ -4,23 +4,18 @@ module { func.func @simple_linear_with_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<64x64xbf16 %0 = tensor.empty() : tensor<64x64xbf16> // CHECK: "ttnn.linear" // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: tensor<128x64xbf16 // CHECK-SAME: tensor<64x64xbf16 // CHECK-SAME: tensor<64x64xbf16 - // CHECK-SAME: tensor<64x64xbf16 %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } func.func @linear_transpose_lhs(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>, %bias: tensor<128x128xbf16>) -> tensor<128x128xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<128x128xbf16 %0 = tensor.empty() : tensor<128x128xbf16> // CHECK: "ttnn.linear" // CHECK-SAME: transpose_a = true @@ -28,14 +23,11 @@ module { // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: tensor<128x128xbf16 - // CHECK-SAME: tensor<128x128xbf16 %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_a = true}>: (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16> return %1 : tensor<128x128xbf16> } func.func @linear_transpose_second(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<64x64xbf16 %0 = tensor.empty() : tensor<64x64xbf16> // CHECK: "ttnn.linear" // CHECK-SAME: transpose_a = false @@ -43,7 +35,6 @@ module { // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: tensor<64x64xbf16 - // CHECK-SAME: tensor<64x64xbf16 %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_b = true}>: (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } diff --git a/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp b/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp index 27536e8bc6..d6961e5504 100644 --- a/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp +++ b/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp @@ -36,10 +36,11 @@ class OpModelBase : public OpModelFixture { std::vector getInputLayouts(Operation *op) { std::vector inputs; - // TODO(odjuricic): check for DPS explicitly. - auto numOperand = op->getNumOperands(); - // some ops have multiple operands - auto limit = (numOperand > 1) ? numOperand - 1 : numOperand; + auto limit = op->getNumOperands(); + if (isa(op)) { + limit--; + } + for (size_t i = 0; i < limit; i++) { auto operand = op->getOperand(i); auto inputShape = @@ -74,10 +75,15 @@ class OpModelBase : public OpModelFixture { return DeviceAttr::get(&context, workerGrid, map4, map4, {1}, {0}); } - mlir::Value createEmptyTensor(llvm::ArrayRef tensorShape) { + mlir::RankedTensorType createRankedTensorType(llvm::ArrayRef shape) { Type elementType = builder.getBF16Type(); RankedTensorType rankedTensorType = - RankedTensorType::get(tensorShape, elementType); + RankedTensorType::get(shape, elementType); + return rankedTensorType; + } + + mlir::Value createEmptyTensor(llvm::ArrayRef tensorShape) { + RankedTensorType rankedTensorType = createRankedTensorType(tensorShape); return builder.create(builder.getUnknownLoc(), rankedTensorType, ShapeAttr::get(&context, tensorShape), nullptr, nullptr, nullptr, nullptr); @@ -89,10 +95,10 @@ TEST_F(OpModelBase, ReluInterface) { llvm::SmallVector tensorShape = {workerCoresN300, 1024}; auto input = createEmptyTensor(tensorShape); - auto output = createEmptyTensor(tensorShape); + auto outputType = createRankedTensorType(tensorShape); - auto relu = builder.create(builder.getUnknownLoc(), output.getType(), - ::mlir::ValueRange{input, output}); + auto relu = builder.create(builder.getUnknownLoc(), outputType, + ::mlir::ValueRange{input}); relu->setAttr(DeviceAttr::name, getFakeDeviceAttr()); // test ReluOp interface @@ -120,10 +126,10 @@ TEST_F(OpModelBase, SoftmaxInterface) { llvm::SmallVector tensorShape = {workerCoresN300, 1024}; auto input = createEmptyTensor(tensorShape); - auto output = createEmptyTensor(tensorShape); + auto output = createRankedTensorType(tensorShape); - auto softmax = builder.create(builder.getUnknownLoc(), - output.getType(), input, -1); + auto softmax = + builder.create(builder.getUnknownLoc(), output, input, -1); softmax->setAttr(DeviceAttr::name, getFakeDeviceAttr()); // test SoftmaxOp interface @@ -153,10 +159,10 @@ TEST_F(OpModelBase, AddInterface) { auto input1 = createEmptyTensor(tensorShape); auto input2 = createEmptyTensor(tensorShape); - auto output = createEmptyTensor(tensorShape); + auto outputType = createRankedTensorType(tensorShape); - auto add = builder.create(builder.getUnknownLoc(), output.getType(), - ::mlir::ValueRange{input1, input2, output}); + auto add = builder.create(builder.getUnknownLoc(), outputType, + ::mlir::ValueRange{input1, input2}); add->setAttr(DeviceAttr::name, getFakeDeviceAttr()); // test AddOp interface @@ -188,11 +194,10 @@ TEST_F(OpModelBase, MatmulInterface) { auto inputA = createEmptyTensor(tensorShapeA); auto inputB = createEmptyTensor(tensorShapeB); - auto output = createEmptyTensor(tensorShapeO); + auto outputType = createRankedTensorType(tensorShapeO); - auto matmul = - builder.create(builder.getUnknownLoc(), output.getType(), - ::mlir::ValueRange{inputA, inputB, output}); + auto matmul = builder.create(builder.getUnknownLoc(), outputType, + ::mlir::ValueRange{inputA, inputB}); matmul->setAttr(DeviceAttr::name, getFakeDeviceAttr()); // test MatmulOp interface