From 8711d3ea87ea6172f9cdd0cbd3b80f6e61b7bbb0 Mon Sep 17 00:00:00 2001 From: Justin Ngo Date: Mon, 2 Dec 2024 10:14:42 -0800 Subject: [PATCH] [TOSA] Add upsample_nearest2d, split_dim, outer, GELU tanh mode and misc (#3886) - Add Torch to TOSA lowering for the following ops: + torch.aten.upsample_nearest2d + torch.aten.upsample_nearest2d.vec + torch.aten.outer + torch.prims.split_dim - Add Tanh approximation mode for GELU lowering - Add different types support for compare ops - Add different input and output types support for linalg vector norm lowering - Update xfail with new e2e results - Add new LIT tests to basic.mlir Change-Id: I7b1d44d94319cf94fcc9d234cc07708ef9ce321e Signed-off-by: Justin Ngo --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 438 +++++++++++++++++- .../TorchToTosa/TosaLegalizeCommon.cpp | 14 +- projects/pt1/e2e_testing/xfail_sets.py | 69 +-- test/Conversion/TorchToTosa/basic.mlir | 143 ++++++ 4 files changed, 599 insertions(+), 65 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index e75d358b068d..e9c7c2cc2e97 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -23,6 +23,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" #include "llvm/ADT/TypeSwitch.h" +#include #include #include #include @@ -405,6 +406,36 @@ class ConvertAtenCompareOp : public OpConversionPattern { "conversion in TOSA operation"); } auto rhsTensor = rhsTy ? rhs : rhsAsTensor; + auto rhsTensorTy = dyn_cast(rhsTensor.getType()); + auto rhsElemTy = rhsTensorTy.getElementType(); + + auto isLhsElemFloat = isa(lhsElemTy); + auto isRhsElemFloat = isa(rhsElemTy); + + // Support different types comparisons + if (lhsElemTy != rhsElemTy) { + if (isLhsElemFloat && !isRhsElemFloat) { + rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy); + } else if (!isLhsElemFloat && isRhsElemFloat) { + lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy); + } else if (isLhsElemFloat && isRhsElemFloat) { + auto lhsElemFloatTy = dyn_cast(lhsElemTy); + auto rhsElemFloatTy = dyn_cast(rhsElemTy); + if (lhsElemFloatTy.getWidth() > rhsElemFloatTy.getWidth()) { + rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy); + } else { + lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy); + } + } else { + auto lhsElemIntTy = dyn_cast(lhsElemTy); + auto rhsElemIntTy = dyn_cast(rhsElemTy); + if (lhsElemIntTy.getWidth() > rhsElemIntTy.getWidth()) { + rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy); + } else { + lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy); + } + } + } // There is no Lesser operator in TOSA. constexpr auto swapLhsRhs = (std::is_same() || std::is_same() || @@ -3196,9 +3227,10 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenGeluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto selfType = dyn_cast(self.getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -3209,21 +3241,104 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Only floating-point datatype legalization supported"); } - // TODO: Handle approximate. + auto resultType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + std::string approximate; - if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate)) || - approximate != "none") { - return rewriter.notifyMatchFailure(op, "Unsupported value of approximate"); + if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate))) { + return rewriter.notifyMatchFailure( + op, "Non-const approximate value not supported"); } - Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); - cdf = rewriter.createOrFold( - op->getLoc(), - cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); + if (approximate.compare("none") == 0) { + // GELU(x) = x * CDF(x) + Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); + cdf = rewriter.createOrFold( + op->getLoc(), + cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); + + rewriter.replaceOpWithNewOp(op, resultType, self, cdf, + /*shift=*/0); + } else if (approximate.compare("tanh") == 0) { + // "tanh" approximate + // GELU(x) = 0.5 * x * (1 + Tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) + // Formula taken from: + // https://pytorch.org/docs/stable/generated/torch.nn.GELU.html + auto selfShape = selfType.getShape(); + if (!selfType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Only static shape tensor types are currently supported for Tanh " + "approximation"); + + auto numElem = std::accumulate(selfShape.begin(), selfShape.end(), 1, + std::multiplies()); + + Value half = tosa::getConstTensor(rewriter, op, + SmallVector(numElem, 0.5), + selfShape, selfElemTy) + .value(); + Value one = tosa::getConstTensor(rewriter, op, + SmallVector(numElem, 1.0), + selfShape, selfElemTy) + .value(); + Value three = tosa::getConstTensor(rewriter, op, + SmallVector(numElem, 3.0), + selfShape, selfElemTy) + .value(); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf, - /*shift=*/0); + // 0.044715 + Value magicNumber = tosa::getConstTensor( + rewriter, op, SmallVector(numElem, 0.044715), + selfShape, selfElemTy) + .value(); + + // From header: M_2_PI = 2 / pi + Value twoOverPi = tosa::getConstTensor( + rewriter, op, SmallVector(numElem, M_2_PI), + selfShape, selfElemTy) + .value(); + + // 0.5 * x + auto halfInput = rewriter.create(op->getLoc(), resultType, + half, self, /*shift=*/0); + + // sqrt(2/pi) + auto sqrtTwoOverPi = + rewriter.create(op->getLoc(), resultType, twoOverPi, half); + + // x^3 + auto inputPowThree = + rewriter.create(op->getLoc(), resultType, self, three); + + // 0.044715 * x^3 + auto inputPowThreeMul = + rewriter.create(op->getLoc(), resultType, magicNumber, + inputPowThree.getResult(), /*shift=*/0); + + // x + 0.044715 * x^3 + auto inputPowThreeMulAdd = rewriter.create( + op->getLoc(), resultType, self, inputPowThreeMul.getResult()); + + // sqrt(2/pi) * (x + 0.044715 * x^3) + auto sqrtTwoOverPiMul = rewriter.create( + op->getLoc(), resultType, sqrtTwoOverPi.getResult(), + inputPowThreeMulAdd.getResult(), /*shift=*/0); + + // tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) + auto tanh = rewriter.create(op->getLoc(), resultType, + sqrtTwoOverPiMul.getResult()); + + // 1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) + auto tanhAdd = rewriter.create(op->getLoc(), resultType, one, + tanh.getResult()); + + rewriter.replaceOpWithNewOp( + op, resultType, halfInput.getResult(), tanhAdd.getResult(), + /*shift=*/0); + } else { + return rewriter.notifyMatchFailure(op, + "Unsupported approximation algorithm"); + } return success(); } @@ -7620,6 +7735,296 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for torch.prims.split_dim +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + PrimsSplitDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getA(); + + // Not a tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultShape = resultType.getShape(); + + int64_t dim, outerLength; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "Only constant int dim value is supported"); + + auto selfRank = selfType.getRank(); + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return rewriter.notifyMatchFailure(op, "Dim is invalid"); + + if (!matchPattern(op.getOuterLength(), m_TorchConstantInt(&outerLength))) + return rewriter.notifyMatchFailure( + op, "Only constant int outer length value is supported"); + + // Technically, I should calculate the output shape based on the dim and outer + // length values. However, that would just give the same result as me taking + // the result shape straight from resultType and applying tosa::ReshapeOp to + // the input. Therefore, I'm opting for the latter approach here, which is + // more simple and quicker. + rewriter.replaceOpWithNewOp( + op, resultType, self, + rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape))); + + return success(); +} + +// Legalization for aten.outer +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenOuterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + if (selfType.getRank() != 1) + return rewriter.notifyMatchFailure(op, "Only rank 1 vectors are supported"); + + auto vec2 = adaptor.getVec2(); + + auto vec2Type = dyn_cast(vec2.getType()); + if (!vec2Type) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + if (vec2Type.getRank() != 1) + return rewriter.notifyMatchFailure(op, "Only rank 1 vectors are supported"); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultShape = resultType.getShape(); + + self = tosa::promoteType(rewriter, self, resultType); + vec2 = tosa::promoteType(rewriter, vec2, resultType); + + SmallVector resultShapeIndex1Replaced({resultShape[0], 1}); + SmallVector resultShapeIndex0Replaced({1, resultShape[1]}); + + // Reshape and tile self to shape {selfShape[0], resultShape[1]} + auto selfReshaped = rewriter.create( + op->getLoc(), + RankedTensorType::get(resultShapeIndex1Replaced, + resultType.getElementType()), + self, rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced)); + + auto selfTiled = rewriter.create( + op->getLoc(), resultType, selfReshaped.getResult(), + rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced)); + + // Reshape and tile vec2 to shape {resultShape[0], vec2Shape[0]} + auto vec2Reshaped = rewriter.create( + op->getLoc(), + RankedTensorType::get(resultShapeIndex0Replaced, + resultType.getElementType()), + vec2, rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced)); + + auto vec2Tiled = rewriter.create( + op->getLoc(), resultType, vec2Reshaped.getResult(), + rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced)); + + auto result = + tosa::createMulOpAndCast(rewriter, op, resultType, selfTiled.getResult(), + vec2Tiled.getResult(), /*shift=*/0); + + rewriter.replaceOp(op, result); + return success(); +} + +// Legalization for aten.upsample_nearest2d +template +class ConvertUpsampleNearest2dForward : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // aten.upsample_nearest2d lowering process: + // 1. Reshape input: (N, C, H, W) -> (N, C, H x W) + // 2. Calculate PyTorch-styled gather op indices based on the following + // formula (based on Torch to Linalg UpsampleNearest2d lowering formula): + // for i in range(N x C): + // for heightIndex in range(scaledHeight): + // for widthIndex in range(scaledWidth): + // indices.append(int(heightIndex // scalesH * selfWidth + + // widthIndex // scalesW)) + // 3. Convert PyTorch-styled indices to TensorFlow-styled indices + // 4. Apply TensorFlow-styled ConverGatherOpNd to retrieve the output + // 5. Reshape output to desired output shape + Value self; + if constexpr (std::is_same()) { + self = adaptor.getSelf(); + } else if constexpr (std::is_same()) { + self = adaptor.getInput(); + } else { + return rewriter.notifyMatchFailure( + op, "Expected either AtenUpsampleNearest2dOp or " + "AtenUpsampleNearest2dVecOp"); + } + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto selfShape = selfType.getShape(); + auto selfRank = selfType.getRank(); + auto selfElemTy = selfType.getElementType(); + + auto selfHeight = selfShape[selfRank - 2]; + auto selfWidth = selfShape[selfRank - 1]; + + auto resultType = dyn_cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); + auto resultShape = resultType.getShape(); + auto resultElemTy = resultType.getElementType(); + + // Get op's parameters + SmallVector outputSize; + SmallVector scaleFactors; + double scalesH; + double scalesW; + int64_t outputHeight; + int64_t outputWidth; + if constexpr (std::is_same()) { + if (!matchPattern(op.getOutputSize(), + m_TorchListOfConstantInts(outputSize))) + return rewriter.notifyMatchFailure( + op, "Non-constant output size not supported"); + + outputHeight = outputSize[0]; + outputWidth = outputSize[1]; + + if (isa(op.getScalesH().getType())) { + scalesH = + static_cast(outputHeight) / static_cast(selfHeight); + } else { + if (!matchPattern(op.getScalesH(), m_TorchConstantFloat(&scalesH))) + return rewriter.notifyMatchFailure( + op, "Non-constant height scales not supported"); + + scalesH = std::ceil(scalesH); + } + + if (isa(op.getScalesW().getType())) { + scalesW = + static_cast(outputWidth) / static_cast(selfWidth); + } else { + if (!matchPattern(op.getScalesW(), m_TorchConstantFloat(&scalesW))) + return rewriter.notifyMatchFailure( + op, "Non-constant width scales not supported"); + + scalesW = std::ceil(scalesW); + } + } else if constexpr (std::is_same()) { + auto isOutputSizeNone = + isa(op.getOutputSize().getType()); + auto isScaleFactorsNone = + isa(op.getScaleFactors().getType()); + + if ((isOutputSizeNone && isScaleFactorsNone) || + (!isOutputSizeNone && !isScaleFactorsNone)) + return rewriter.notifyMatchFailure( + op, "Must specify exactly one of output size and scale factors"); + + if (!isOutputSizeNone) { + if (!matchPattern(op.getOutputSize(), + m_TorchListOfConstantInts(outputSize))) + return rewriter.notifyMatchFailure( + op, "Non-constant output size not supported"); + + outputHeight = outputSize[0]; + outputWidth = outputSize[1]; + + // Output size values being provided implies that scale values are not + // provided + scalesH = + static_cast(outputHeight) / static_cast(selfHeight); + scalesW = + static_cast(outputWidth) / static_cast(selfWidth); + } else { + if (!matchPattern(op.getScaleFactors(), + m_TorchListOfConstantFloats(scaleFactors))) + return rewriter.notifyMatchFailure( + op, "Non-constant output size not supported"); + + scalesH = std::ceil(scaleFactors[0]); + scalesW = std::ceil(scaleFactors[1]); + + // Scale values being provided implies that output size values are not + // provided + outputHeight = static_cast(scalesH * selfHeight); + outputWidth = static_cast(scalesW * selfWidth); + } + } + + // Reshape input + SmallVector reshapedSelfShape(selfShape.begin(), + selfShape.end() - 2); + reshapedSelfShape.push_back(selfHeight * selfWidth); + + auto reshapedSelf = rewriter.create( + op->getLoc(), RankedTensorType::get(reshapedSelfShape, selfElemTy), + self, rewriter.getDenseI64ArrayAttr(reshapedSelfShape)); + + // Calculate PyTorch-styled gather indices + SmallVector targetIndicesVec; + int64_t indexRepeat = std::accumulate( + selfShape.begin(), selfShape.end() - 2, 1, std::multiplies()); + for (int64_t i = 0; i < indexRepeat; i++) { + for (int64_t heightIndex = 0; heightIndex < outputHeight; heightIndex++) { + for (int64_t widthIndex = 0; widthIndex < outputWidth; widthIndex++) { + targetIndicesVec.push_back(static_cast( + std::floor(heightIndex / scalesH) * selfWidth + + std::floor(widthIndex / scalesW))); + } + } + } + + SmallVector targetIndicesShape(selfShape.begin(), + selfShape.end() - 2); + targetIndicesShape.push_back(outputHeight * outputWidth); + auto targetIndicesTorch = + tosa::getConstTensor(rewriter, op, targetIndicesVec, + targetIndicesShape) + .value(); + + // Convert PyTorch-styled indices to TensorFlow-styled indices + auto targetIndicesTF = tosa::convertTorchIndexToTfIndices( + rewriter, op, reshapedSelf.getResult(), targetIndicesTorch, + selfRank - 2); + if (!targetIndicesTF) + return rewriter.notifyMatchFailure( + op, "Convert PyTorch-styled indices and dim " + "to TensorFlow-styled indices failed"); + // Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve + // target elements + auto gatherOp = tosa::convertGatherNdOp( + rewriter, op, RankedTensorType::get(targetIndicesShape, resultElemTy), + reshapedSelf.getResult(), targetIndicesTF.value()); + if (!gatherOp) + return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed"); + + auto result = rewriter.create( + op->getLoc(), resultType, gatherOp.value(), + rewriter.getDenseI64ArrayAttr(resultShape)); + + rewriter.replaceOp(op, {result.getResult()}); + + return success(); + } +}; + } // namespace // ----------------------------------------------------------------------------- @@ -7891,6 +8296,13 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp); #undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN +#define INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dOp); + INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dVecOp); +#undef INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN + #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); @@ -7950,6 +8362,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp); INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp); INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp); + INSERT_ATENOP_PATTERN(PrimsSplitDimOp); + INSERT_ATENOP_PATTERN(AtenOuterOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 4df8a221d556..ee7f61becf4f 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -1031,11 +1031,17 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op, return std::nullopt; } - auto absVal = CreateOpAndInfer(rewriter, op->getLoc(), - input_type, input_value) + auto input_value_casted = + tosa::promoteType(rewriter, input_value, output_type); + auto absVal = CreateOpAndInfer( + rewriter, op->getLoc(), + RankedTensorType::get(input_type.getShape(), elemType), + input_value_casted) .getResult(); - auto powVal = CreateOpAndInfer(rewriter, op->getLoc(), - input_type, absVal, ordVal) + auto powVal = CreateOpAndInfer( + rewriter, op->getLoc(), + RankedTensorType::get(input_type.getShape(), elemType), + absVal, ordVal) .getResult(); std::optional result = convertReduceSumOp( rewriter, op, output_type, powVal, axes_elems, keep_dims); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index f3c8a9cd7837..e8494a148da2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1709,27 +1709,26 @@ "Aten_TrilinearModuleSumAllDims_basic", "Aten_TrilinearModuleSumdims_basic", "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", - "GridSamplerBasic1_basic", - "GridSamplerBasic2_basic", - "GridSamplerBasic3_basic", - "GridSamplerBasic4_basic", "ScatterSrcModule_basic", "ScatterSrcStaticModule_basic", "HBC_basic", - "InterpolateDynamicModule_scales_recompute_bilinear", - "InterpolateDynamicModule_sizes_bilinear", - "InterpolateDynamicModule_sizes_nearest", - "InterpolateStaticModule_scales_bilinear_align_corners", - "UpSampleNearest2d_basic", - "UpSampleNearest2dStaticSize_basic", - "UpSampleNearest2dDynamicSize_basic", - "UpSampleNearest2dDynamicFactor_basic", - "UpSampleNearest2dStaticFactor_basic", } # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "Deg2radModule_basic", + "ElementwiseIntTensorLtFloatTensorModule_basic", + "L1LossMeanReductionModule_basic", + "L1LossNoReductionModule_basic", + "L1LossSumReductionModule_basic", + "PixelShuffleModuleStaticRank3Int64_basic", + "PixelShuffleModuleStaticRank4Float32_basic", + "RandIntLowModule_basic", + "RandIntModule_basic", + "RandIntPinMemoryModule_basic", + "RenormModuleFloat16_basic", + "SplitDimStaticModule_basic", "ReflectionPad1dModule2dInput_Right", "ReflectionPad1dModule2dInput_basic", "ReflectionPad1dModule3dInput_Left", @@ -3461,8 +3460,6 @@ "Conv_Transpose2dStaticModule_basic", "Conv_Transpose3dModule_basic", "Conv_Transpose3dStaticModule_basic", - "ElementwiseFloatTensorGtIntTensorModule_basic", - "ElementwiseIntTensorLtFloatTensorModule_basic", "IndexPutWithNoneAndBroadcastModule_basic", "MaskedScatterStaticBasic_basic", "MaxUnpool3dModulePad0_basic", @@ -3470,7 +3467,6 @@ "MultinomialModule2D_F32", "MultinomialModule2D_basic", "MultinomialModule_basic", - "RenormModuleFloat16_basic", # REMOVE WHEN ENABLE_GQA IS ADDED "ScatterAddStaticModule_basic", "TensorsConcatComplex128FloatModule_basic", @@ -3634,7 +3630,6 @@ "ElementwiseExpIntModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", - "ElementwiseGeluApproximateTanhModule_basic", "ElementwiseIntTensorLtFloatScalarModule_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog10Module_basic", @@ -3690,8 +3685,6 @@ "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", "InterpolateDynamicModule_sizes_bilinear", - "InterpolateDynamicModule_sizes_nearest", - "InterpolateStaticModule_scales_bilinear_align_corners", "InterpolateDynamicModule_scales_recompute_bilinear", "IntFloatModule_basic", "IntImplicitModule_basic", @@ -3702,7 +3695,6 @@ "LinalgVectorNormComplexModule_basic", "LinspaceDtypeModule_basic", "LinspaceEmptyModule_basic", - "MaskedFillTensorFloatValueModule_basic", "MaskedScatterStaticBasic_basic", "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dModule_basic", @@ -3763,11 +3755,7 @@ "NumelModule_basic", "NumelZeroRankModule_basic", "OnesLikeModule_falsePinMemory", - "PixelShuffleModuleFullDynamic_basic", - "PixelShuffleModuleSpatiallyDynamic_basic", - "PixelShuffleModuleSpatiallyStatic_basic", - "PixelShuffleModuleStaticRank3Int64_basic", - "PixelShuffleModuleStaticRank4Float32_basic", + "PowIntFloatModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", @@ -3783,9 +3771,6 @@ "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", "QuantizedSingleLayer_basic", - "RandIntLowModule_basic", - "RandIntModule_basic", - "RandIntPinMemoryModule_basic", "RandnDtypeDeviceModule_basic", "RandnGeneratorF64Module_basic", "RandnGeneratorModule_basic", @@ -3795,26 +3780,11 @@ "ReduceAllDimEmpty_basic", "ReduceFrobeniusNormComplexModule_basic", "ReduceL1NormComplexModule_basic", - "ReduceL1NormWithDTypeModule_basic", "ReduceL2NormComplexModule_basic", "ReduceL3NormKeepDimComplexModule_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", "ReduceSumDimIntListEmptyDimModule_basic", - "ReflectionPad1dModule2dInput_Right", - "ReflectionPad1dModule2dInput_basic", - "ReflectionPad1dModule3dInput_Left", - "ReflectionPad1dModule3dInput_basic", - "ReflectionPad2dModule_Bottom", - "ReflectionPad2dModule_Left", - "ReflectionPad2dModule_Right", - "ReflectionPad2dModule_Top", - "ReflectionPad2dModule_basic", - "ReplicationPad2dModule_basic", - "ReplicationPad2dModule_bottom0", - "ReplicationPad2dModule_left0", - "ReplicationPad2dModule_right0", - "ReplicationPad2dModule_top0", "RollModule_basic", "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", @@ -3886,11 +3856,6 @@ "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "UpSampleNearest2dBackwardScalesNone_basic", "UpSampleNearest2dBackward_basic", - "UpSampleNearest2dDynamicFactor_basic", - "UpSampleNearest2dDynamicSize_basic", - "UpSampleNearest2dStaticFactor_basic", - "UpSampleNearest2dStaticSize_basic", - "UpSampleNearest2d_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", "VisionTransformerModule_basic", @@ -3937,6 +3902,13 @@ } ONNX_TOSA_XFAIL_SET = { + "ColumnStack0dModule_basic", + "ColumnStack1dModule_basic", + "ColumnStackBasicIntModule_basic", + "Deg2radModule_basic", + "L1LossMeanReductionModule_basic", + "L1LossNoReductionModule_basic", + "L1LossSumReductionModule_basic", "FloatPowerTensorTensorStaticModule_basic", "IsInfiniteModule_basic", "ElementwiseCopysignModule_basic", @@ -4645,7 +4617,6 @@ "QuantizedSingleLayer_basic", "RandIntDtypeModule_basic", "RandIntLowDtypeModule_basic", - "RandIntLowModule_basic", "RandIntModule_basic", "RandIntPinMemoryModule_basic", "RandLikeDtypeModule_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 4ea96a43249e..0463e0c3af92 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2519,3 +2519,146 @@ func.func @torch.aten.replication_pad2d$basic(%arg0: !torch.vtensor<[1,1,3,3],f3 %1 = torch.aten.replication_pad2d %arg0, %0 : !torch.vtensor<[1,1,3,3],f32>, !torch.list -> !torch.vtensor<[1,1,10,6],f32> return %1 : !torch.vtensor<[1,1,10,6],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.outer$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4],f32> -> tensor<4xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3],f32> -> tensor<3xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<3xf32>) -> tensor<3x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.tile %[[VAL_4]] {multiples = array} : (tensor<3x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<1x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array} : (tensor<1x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5]], %[[VAL_7]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.outer$basic(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.outer %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.prims.split_dim$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,8,3,3],si64>) -> !torch.vtensor<[1,2,2,2,3,3],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,8,3,3],si64> -> tensor<1x8x3x3xi64> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x8x3x3xi64>) -> tensor<1x2x4x3x3xi64> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1x2x4x3x3xi64>) -> tensor<1x2x2x2x3x3xi64> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x2x2x2x3x3xi64> -> !torch.vtensor<[1,2,2,2,3,3],si64> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,2,2,2,3,3],si64> +// CHECK: } +func.func @torch.prims.split_dim$basic(%arg0: !torch.vtensor<[1,8,3,3],si64>) -> !torch.vtensor<[1,2,2,2,3,3],si64> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prims.split_dim %arg0, %int1, %int2 : !torch.vtensor<[1,8,3,3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2,4,3,3],si64> + %1 = torch.prims.split_dim %0, %int2, %int2 : !torch.vtensor<[1,2,4,3,3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,3,3],si64> + return %1 : !torch.vtensor<[1,2,2,2,3,3],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.upsample_nearest2d$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,2,3],f64>) -> !torch.vtensor<[1,1,8,9],f64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,2,3],f64> -> tensor<1x1x2x3xf64> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 4.000000e+00 +// CHECK: %[[VAL_3:.*]] = torch.constant.float 3.000000e+00 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 8 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 9 +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x1x2x3xf64>) -> tensor<1x1x6xf64> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5]]]> : tensor<1x1x72xi32>}> : () -> tensor<1x1x72xi32> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<1x1x72xi32>) -> tensor<1x1x72x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x72x1xi32>}> : () -> tensor<1x1x72x1xi32> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x72x1xi32>}> : () -> tensor<1x1x72x1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_11]], %[[VAL_9]] {axis = 3 : i32} : (tensor<1x1x72x1xi32>, tensor<1x1x72x1xi32>, tensor<1x1x72x1xi32>) -> tensor<1x1x72x3xi32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<1x1x6xf64>) -> tensor<1x6x1xf64> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<1x1x72x3xi32>) -> tensor<72x3xi32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[6, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_14]], %[[VAL_15]] {shift = 0 : i8} : (tensor<72x3xi32>, tensor<3xi32>) -> tensor<72x3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<72x3xi32>) -> tensor<72x1xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<72x1xi32>) -> tensor<1x72xi32> +// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_13]], %[[VAL_18]] : (tensor<1x6x1xf64>, tensor<1x72xi32>) -> tensor<1x72x1xf64> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x72x1xf64>) -> tensor<1x1x72xf64> +// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x1x72xf64>) -> tensor<1x1x8x9xf64> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x1x8x9xf64> -> !torch.vtensor<[1,1,8,9],f64> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,1,8,9],f64> +// CHECK: } +func.func @torch.aten.upsample_nearest2d$basic(%arg0: !torch.vtensor<[1,1,2,3],f64>) -> !torch.vtensor<[1,1,8,9],f64> { + %float4.000000e00 = torch.constant.float 4.000000e+00 + %float3.000000e00 = torch.constant.float 3.000000e+00 + %int8 = torch.constant.int 8 + %int9 = torch.constant.int 9 + %0 = torch.prim.ListConstruct %int8, %int9 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.upsample_nearest2d %arg0, %0, %float4.000000e00, %float3.000000e00 : !torch.vtensor<[1,1,2,3],f64>, !torch.list, !torch.float, !torch.float -> !torch.vtensor<[1,1,8,9],f64> + return %1 : !torch.vtensor<[1,1,8,9],f64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.upsample_nearest2d.vec$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,4,5],f32>) -> !torch.vtensor<[1,1,2,7],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,4,5],f32> -> tensor<1x1x4x5xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 7 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x1x4x5xf32>) -> tensor<1x1x20xf32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0, 0, 1, 2, 2, 3, 4, 10, 10, 11, 12, 12, 13, 14]]]> : tensor<1x1x14xi32>}> : () -> tensor<1x1x14xi32> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<1x1x14xi32>) -> tensor<1x1x14x1xi32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x14x1xi32>}> : () -> tensor<1x1x14x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x14x1xi32>}> : () -> tensor<1x1x14x1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x1x14x1xi32>, tensor<1x1x14x1xi32>, tensor<1x1x14x1xi32>) -> tensor<1x1x14x3xi32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x1x20xf32>) -> tensor<1x20x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<1x1x14x3xi32>) -> tensor<14x3xi32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[20, 20, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<14x3xi32>, tensor<3xi32>) -> tensor<14x3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<14x3xi32>) -> tensor<14x1xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<14x1xi32>) -> tensor<1x14xi32> +// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x20x1xf32>, tensor<1x14xi32>) -> tensor<1x14x1xf32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x14x1xf32>) -> tensor<1x1x14xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x1x14xf32>) -> tensor<1x1x2x7xf32> +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<1x1x2x7xf32> -> !torch.vtensor<[1,1,2,7],f32> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[1,1,2,7],f32> +// CHECK: } +func.func @torch.aten.upsample_nearest2d.vec$basic(%arg0: !torch.vtensor<[1,1,4,5],f32>) -> !torch.vtensor<[1,1,2,7],f32> { + %none = torch.constant.none + %int2 = torch.constant.int 2 + %int7 = torch.constant.int 7 + %0 = torch.prim.ListConstruct %int2, %int7 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.upsample_nearest2d.vec %arg0, %0, %none : !torch.vtensor<[1,1,4,5],f32>, !torch.list, !torch.none -> !torch.vtensor<[1,1,2,7],f32> + return %1 : !torch.vtensor<[1,1,2,7],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.gelu$tanh( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,3],f32>) -> !torch.vtensor<[5,3],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,3],f32> -> tensor<5x3xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.str "tanh" +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<4.471500e-02> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.636619746> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_3]], %[[VAL_1]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_9:.*]] = tosa.pow %[[VAL_7]], %[[VAL_3]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_10:.*]] = tosa.pow %[[VAL_1]], %[[VAL_5]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_6]], %[[VAL_10]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_12:.*]] = tosa.add %[[VAL_1]], %[[VAL_11]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_9]], %[[VAL_12]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_14:.*]] = tosa.tanh %[[VAL_13]] : (tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_15:.*]] = tosa.add %[[VAL_4]], %[[VAL_14]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_8]], %[[VAL_15]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<5x3xf32> -> !torch.vtensor<[5,3],f32> +// CHECK: return %[[VAL_17]] : !torch.vtensor<[5,3],f32> +// CHECK: } +func.func @torch.aten.gelu$tanh(%arg0: !torch.vtensor<[5,3],f32>) -> !torch.vtensor<[5,3],f32> { + %str = torch.constant.str "tanh" + %0 = torch.aten.gelu %arg0, %str : !torch.vtensor<[5,3],f32>, !torch.str -> !torch.vtensor<[5,3],f32> + return %0 : !torch.vtensor<[5,3],f32> +}