diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 6de6dde8d04f..fe6bad45ab44 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -16329,6 +16329,31 @@ def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [ }]; } +def Torch_AtenEqBoolOp : Torch_Op<"aten.eq.bool", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::eq.bool : (bool, bool) -> (bool)`"; + let arguments = (ins + Torch_BoolType:$a, + Torch_BoolType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEqBoolOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenEqBoolOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenNeBoolOp : Torch_Op<"aten.ne.bool", [ AllowsTypeRefinement, HasValueSemantics, @@ -16476,6 +16501,31 @@ def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [ let hasCanonicalizer = 1; } +def Torch_AtenMulLeftTOp : Torch_Op<"aten.mul.left_t", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mul.left_t : (t[], int) -> (t[])`"; + let arguments = (ins + AnyTorchListType:$l, + Torch_IntType:$n + ); + let results = (outs + AnyTorchListType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMulLeftTOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMulLeftTOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasCanonicalizer = 1; +} + def Torch_Aten__Getitem__TOp : Torch_Op<"aten.__getitem__.t", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 85dbfdac1961..d8517fbd156d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -7,12 +7,10 @@ // //===----------------------------------------------------------------------===// -#include "mlir/IR/DialectResourceBlobManager.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" -#include "llvm/Support/FormatVariadic.h" #include using namespace mlir; @@ -1292,6 +1290,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( }); patterns.onOp( "Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); Torch::ValueTensorType resultType; Value input, weight; int64_t group; @@ -1316,14 +1315,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "unsupported conversion: kernel_shape list size should have " "number of values equal to weight_rank - 2"); - } else { - for (unsigned i = 0; i < kernelShape.size(); i++) { - if (weightShape[i + 2] != kernelShape[i]) { - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: kernel_shape value " - "should be equal to the weight tensor shape"); - } - } } } @@ -1380,6 +1371,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( ArrayRef inputShape = inputTensorType.getSizes(); padding.resize_for_overwrite(2 * spatialRank); for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) { + if (weightShape[dimIdx + 2] == Torch::kUnknownSize || + inputShape[dimIdx + 2] == Torch::kUnknownSize) + return rewriter.notifyMatchFailure( + binder.op, + "expected weight and input tensor to have static shape"); const int64_t dilatedKernelSize = dilations[dimIdx] * (weightShape[dimIdx + 2] - 1) + 1; int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) / @@ -1405,10 +1401,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (padding.size() != 2 * (rank - 2)) { for (int64_t i : padding) { cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + loc, rewriter.getI64IntegerAttr(i))); } paddingList = rewriter.create( - binder.getLoc(), + loc, Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), cstPadding); @@ -1431,10 +1427,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (matchedPads) { for (unsigned i = 0; i < padding.size() / 2; i++) { cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + loc, rewriter.getI64IntegerAttr(padding[i]))); } paddingList = rewriter.create( - binder.getLoc(), + loc, Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), cstPadding); @@ -1443,40 +1439,40 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( SmallVector inputPaddingList; for (uint32_t i = 0; i < padding.size() / 2; i++) { padsRearrange.emplace_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr( - padding[padding.size() / 2 - i - 1]))); + loc, rewriter.getI64IntegerAttr( + padding[padding.size() / 2 - i - 1]))); padsRearrange.emplace_back(rewriter.create( - binder.getLoc(), + loc, rewriter.getI64IntegerAttr(padding[padding.size() - i - 1]))); inputPaddingList.emplace_back( rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0))); + loc, rewriter.getI64IntegerAttr(0))); } // The conv op itself will have no padding since the actual padding // is performed using the torch.pad preceding it. paddingList = rewriter.create( - binder.getLoc(), + loc, Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), inputPaddingList); Value padsSizeList = rewriter .create( - binder.getLoc(), + loc, Torch::ListType::get( rewriter.getType()), padsRearrange) .getResult(); Value modeVal = rewriter.create( - binder.getLoc(), rewriter.getStringAttr("constant")); + loc, rewriter.getStringAttr("constant")); Value constantValue; if (isa(inputTensorType.getDtype())) constantValue = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); + loc, rewriter.getI64IntegerAttr(0)); if (isa(inputTensorType.getDtype())) constantValue = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(0.0f)); + loc, rewriter.getF64FloatAttr(0.0f)); // Pad output shape must be computed explicitly from the pad values SmallVector newInputShape(inputTensorType.getSizes()); for (uint32_t i = 0; i < padding.size() / 2; i++) { @@ -1486,46 +1482,44 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( auto padTy = rewriter.getType( newInputShape, inputTensorType.getDtype()); paddedInput = rewriter.create( - binder.getLoc(), padTy, input, padsSizeList, modeVal, - constantValue); + loc, padTy, input, padsSizeList, modeVal, constantValue); } } for (int64_t i : dilations) { cstDilations.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + loc, rewriter.getI64IntegerAttr(i))); } for (int64_t i : strides) { cstStrides.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + loc, rewriter.getI64IntegerAttr(i))); } Value cstZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); + loc, rewriter.getI64IntegerAttr(0)); cstOutputPadding = {cstZero, cstZero}; Value dilationsList = rewriter.create( - binder.getLoc(), + loc, Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstDilations); Value stridesList = rewriter.create( - binder.getLoc(), + loc, Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); Value outputPaddingList = rewriter.create( - binder.getLoc(), + loc, Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstOutputPadding); - Value transposed = - rewriter.create(binder.getLoc(), false); + Value transposed = rewriter.create(loc, false); Value bias; if (binder.op->getNumOperands() == 3) { if (binder.tensorOperandAtIndex(bias, 2)) { return failure(); } } else { - bias = rewriter.create(binder.getLoc()); + bias = rewriter.create(loc); } Value cstGroup = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(group)); + loc, rewriter.getI64IntegerAttr(group)); rewriter.replaceOpWithNewOp( binder.op, resultType, paddedInput, weight, bias, stridesList, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 17df96c6ea5a..1d7af20a7b92 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -36,21 +36,24 @@ namespace { // we provide the original operand through storeResult, which will be modified // if the result will be passed onto another operation, and will be used for // noop_with_empty_axes handling before that. -LogicalResult reducedSumImpl(OpBinder binder, - ConversionPatternRewriter &rewriter, Value data, - Torch::ValueTensorType resultType, - Value &storeResult, int64_t keepDims, - int64_t noop_with_empty_axes, - bool isIntermediateOp) { - +template +LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter, + Value data, Torch::ValueTensorType resultType, + Value &storeResult, int64_t keepDims, + int64_t noop_with_empty_axes, + bool isIntermediateOp) { + + auto inputType = dyn_cast(data.getType()); + if (!inputType) + return failure(); SmallVector axesList; Value axesVal; if (!binder.tensorOperandAtIndex(axesVal, 1)) { - auto inputType = dyn_cast(data.getType()); - if (!inputType.hasSizes() || !resultType.hasSizes()) { - return rewriter.notifyMatchFailure( - binder.op, "unimplemented: expected input and result to have shapes"); - } + auto axesTy = dyn_cast(axesVal.getType()); + if (!axesTy || !axesTy.areAllSizesKnown() || axesTy.getSizes().size() > 1) + return failure(); + auto axesShape = axesTy.getSizes(); + uint64_t numAxes = (axesShape.empty()) ? 1 : axesShape.front(); if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) { SmallVector inputShape{inputType.getSizes()}; @@ -77,22 +80,25 @@ LogicalResult reducedSumImpl(OpBinder binder, } else { reduceDims.push_back(i); if (resultShapeCounter < resultShape.size() && - resultShape[resultShapeCounter] == 1) + resultShape[resultShapeCounter] == 1 && keepDims == 1) resultShapeCounter++; } } - for (auto i : reduceDims) { - axesList.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); - } + if (reduceDims.size() == numAxes) { + for (auto i : reduceDims) { + axesList.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + } else + binder.op->emitWarning( + "Number of inferred reduce dims, " + + std::to_string(reduceDims.size()) + + ", does not match the provided number of axes, " + + std::to_string(numAxes) + "."); } } if (axesList.empty()) { - Torch::BaseTensorType axesType = - cast(axesVal.getType()); - auto axesTy = dyn_cast(axesVal.getType()); - auto axesShape = axesTy.getSizes(); - if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) + if (axesTy.getSizes()[0] == Torch::kUnknownSize) return failure(); Value zero = rewriter.create( @@ -100,9 +106,8 @@ LogicalResult reducedSumImpl(OpBinder binder, rewriter.getI64IntegerAttr(0)); SmallVector selectSizes{1}; auto selType = rewriter.getType( - selectSizes, axesType.getOptionalDtype()); - int64_t numAxes = axesShape[0]; - for (int64_t i = 0; i < numAxes; ++i) { + selectSizes, axesTy.getOptionalDtype()); + for (uint64_t i = 0; i < numAxes; ++i) { Value iv = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getI64IntegerAttr(i)); @@ -117,38 +122,60 @@ LogicalResult reducedSumImpl(OpBinder binder, SmallVector axesInts; if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) { - for (int64_t i = 0, s = axesInts.size(); i < s; ++i) { - Value iv = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(axesInts[i])); - axesList.push_back(iv); + for (int64_t i : axesInts) { + axesList.push_back( + rewriter.create(binder.getLoc(), i)); } } // Do not include absolute value in the noop - if (axesList.empty() && noop_with_empty_axes) { - rewriter.replaceOp(binder.op, storeResult); + if (axesList.empty() && noop_with_empty_axes == 1) { + if (!isIntermediateOp) + rewriter.replaceOp(binder.op, data); + else + storeResult = data; return success(); } + // if the axes list is still empty, reduce everything. + if (axesList.empty()) { + if (keepDims == 0 && !resultType.getSizes().empty()) + return rewriter.notifyMatchFailure( + binder.op, + "no axes provided & no keepdim: expected result to be rank zero."); + if (keepDims == 1 && + (resultType.getSizes().size() != inputType.getSizes().size() || + llvm::any_of(resultType.getSizes(), + [](int64_t size) { return size != 1; }))) + return rewriter.notifyMatchFailure( + binder.op, "no axes provided & keepdim: expected result to have all " + "dimensions equal to 1."); + for (uint64_t i = 0; i < inputType.getSizes().size(); i++) { + axesList.push_back( + rewriter.create(binder.getLoc(), i)); + } + } + Value dimValueList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), axesList); Value keepDimBool = rewriter.create(binder.getLoc(), keepDims); - Value dType = rewriter.create(binder.getLoc()); - // If we are using the ReducedSum as an intermediate op to be passed into + // If we are using the reduction op as an intermediate op to be passed into // another operation, we might not want to replace the Op. So we create a new // Op and store the result in a variable. + SmallVector operands = {data, dimValueList, keepDimBool}; + if (llvm::is_one_of()) + operands.push_back( + /*dtype=*/rewriter.create(binder.getLoc())); if (!isIntermediateOp) { - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimBool, - /*dtype=*/dType); + rewriter.replaceOpWithNewOp(binder.op, resultType, + operands); } else { - storeResult = rewriter.create( - binder.getLoc(), resultType, data, dimValueList, keepDimBool, - /*dtype=*/dType); + storeResult = rewriter.create(binder.getLoc(), + resultType, operands); } return success(); } @@ -1039,25 +1066,25 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand, vAlpha, vScale, vInputScale); return success(); }); - patterns.onOp("ReduceL1", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - int64_t keepDims, noop_with_empty_axes; - Value operand; - if (binder.tensorOperandAtIndex(operand, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, - "noop_with_empty_axes", 0)) - return failure(); + patterns.onOp( + "ReduceL1", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + int64_t keepDims, noop_with_empty_axes; + Value operand; + if (binder.tensorOperandAtIndex(operand, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); - Value data = rewriter.create( - binder.getLoc(), operand.getType(), operand); + Value data = rewriter.create( + binder.getLoc(), operand.getType(), operand); - return reducedSumImpl(binder, rewriter, data, resultType, - /*storeValue=*/operand, keepDims, - noop_with_empty_axes, false); - }); + return reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/operand, keepDims, noop_with_empty_axes, false); + }); patterns.onOp( "ReduceL2", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -1075,9 +1102,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value squareOfOperand = rewriter.create( binder.getLoc(), operand.getType(), operand, operand); - auto reducedSum = - reducedSumImpl(binder, rewriter, squareOfOperand, resultType, - operand, keepDims, noop_with_empty_axes, true); + auto reducedSum = reduceOpImpl( + binder, rewriter, squareOfOperand, resultType, operand, keepDims, + noop_with_empty_axes, true); if (failed(reducedSum)) return rewriter.notifyMatchFailure( binder.op, @@ -1112,32 +1139,32 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*memory_format=*/noneVal); return success(); }); - patterns.onOp("ReduceLogSum", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - int64_t keepDims, noop_with_empty_axes; - if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, - "noop_with_empty_axes", 0)) - return failure(); + patterns.onOp( + "ReduceLogSum", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); - auto reducedSumBool = - reducedSumImpl(binder, rewriter, data, resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, true); + auto reducedSumBool = reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, true); - if (failed(reducedSumBool)) - return rewriter.notifyMatchFailure( - binder.op, - "Failed to perform sum operation on square of operand"); + if (failed(reducedSumBool)) + return rewriter.notifyMatchFailure( + binder.op, + "Failed to perform sum operation on square of operand"); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data); - return success(); - }); + rewriter.replaceOpWithNewOp(binder.op, resultType, + data); + return success(); + }); patterns.onOp( "ReduceLogSumExp", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -1169,7 +1196,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), f64ResultType, dataCast); auto f64ReduceType = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF64Type()); - auto reducedSumBool = reducedSumImpl( + auto reducedSumBool = reduceOpImpl( binder, rewriter, dataExp, f64ReduceType, /*storeValue=*/data, keepDims, noop_with_empty_axes, true); if (failed(reducedSumBool)) @@ -1186,7 +1213,23 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*memory_format=*/noneVal); return success(); }); - patterns.onOp("ReduceSum", 1, + patterns.onOp( + "ReduceSum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + + return reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, false); + }); + patterns.onOp("ReduceSumSquare", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data; @@ -1198,11 +1241,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "noop_with_empty_axes", 0)) return failure(); - return reducedSumImpl(binder, rewriter, data, resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, false); + Value dataSquare = rewriter.create( + binder.getLoc(), data.getType(), data, data); + + return reduceOpImpl( + binder, rewriter, dataSquare, resultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, + false); }); - patterns.onOp("ReduceSumSquare", 1, + patterns.onOp("ReduceMean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data; @@ -1214,140 +1261,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "noop_with_empty_axes", 0)) return failure(); - Value dataSquare = rewriter.create( - binder.getLoc(), data.getType(), data, data); - - return reducedSumImpl(binder, rewriter, dataSquare, - resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, false); + Value reduceSum = data; + return reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/reduceSum, keepDims, noop_with_empty_axes, + false); }); - patterns.onOp( - "ReduceMean", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - int64_t keepDims, noop_with_empty_axes; - if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", - 0)) - return failure(); - - SmallVector axesList; - - Value axesVal; - if (!binder.tensorOperandAtIndex(axesVal, 1)) { - auto inputType = dyn_cast(data.getType()); - if (!inputType.hasSizes() || !resultType.hasSizes()) { - return rewriter.notifyMatchFailure( - binder.op, - "unimplemented: expected input and result to have shapes"); - } - - // If the input shape and result shape is statically known then the - // list of dims to be squeezed can be derived from those shapes. As a - // result, we don't have to wait for the dim values to be known at - // runtime which is also expected by the downstream pipeline. - if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) { - SmallVector inputShape{inputType.getSizes()}; - SmallVector resultShape{resultType.getSizes()}; - if (llvm::equal(inputShape, resultShape)) { - // Case: none of the dimension is reduced. - rewriter.replaceOp(binder.op, data); - return success(); - } - if (areAllElementsDistinct(inputShape)) { - // The check for the input shape elements to be distinct is added - // for the cases like: - // Input: [3, 2, 2] -> Output: [3, 2] - // For the above case, from the input and output shape it can't be - // inferred whether the dim:1 is reduced or dim:2. To avoid these - // type of cases, the check has been placed. - SmallVector reduceDims; - unsigned resultShapeCounter = 0; - for (unsigned i = 0; i < inputShape.size(); i++) { - if (resultShapeCounter < resultShape.size() && - inputShape[i] == resultShape[resultShapeCounter]) { - resultShapeCounter++; - } else { - reduceDims.push_back(i); - if (resultShapeCounter < resultShape.size() && - resultShape[resultShapeCounter] == 1) - resultShapeCounter++; - } - } - for (auto i : reduceDims) { - axesList.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); - } - } - } - - if (axesList.empty()) { - Torch::BaseTensorType axesType = - cast(axesVal.getType()); - auto axesTy = dyn_cast(axesVal.getType()); - auto axesShape = axesTy.getSizes(); - if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) - return failure(); - - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(0)); - SmallVector selectSizes{1}; - auto selType = rewriter.getType( - selectSizes, axesType.getOptionalDtype()); - int64_t numAxes = axesShape[0]; - for (int64_t i = 0; i < numAxes; ++i) { - Value iv = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(i)); - Value extract = rewriter.create( - binder.getLoc(), selType, axesVal, zero, iv); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - axesList.push_back(dim); - } - } - } - - SmallVector axesInts; - if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) { - for (int64_t i = 0, s = axesInts.size(); i < s; ++i) { - Value iv = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(axesInts[i])); - axesList.push_back(iv); - } - } - - // deal with case when axes is empty - if (axesList.empty() && noop_with_empty_axes) { - rewriter.replaceOp(binder.op, data); - return success(); - } - - Value dimValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - axesList); - Value keepDimBool = - rewriter.create(binder.getLoc(), keepDims); - Value noneVal = rewriter.create(binder.getLoc()); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimBool, - /*dtype=*/noneVal); - return success(); - }); patterns.onOp( "ReduceMax", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // AtenAmaxOp allows us to pass a list of dims Torch::ValueTensorType resultType; Value data; - Value axes; int64_t keepDims; int64_t noop_with_empty_axes; if (binder.tensorOperandAtIndex(data, 0) || @@ -1412,87 +1337,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); } - // Previous version of the operation had the axes as an attribute: - SmallVector axesList; - llvm::SmallVector axesAttr; - if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) { - for (int i = 0, s = axesAttr.size(); i < s; ++i) { - axesList.push_back(rewriter.create( - binder.getLoc(), torchIntTy, - rewriter.getI64IntegerAttr(axesAttr[i]))); - } - } - - // Extract the axes values from the axes operand: - if (!binder.tensorOperandAtIndex(axes, 1)) { - Torch::BaseTensorType axesType = - cast(axes.getType()); - SmallVector selectSizes{1}; - Type selectResultType = axesType.getWithSizesAndDtype( - selectSizes, axesType.getOptionalDtype()); - auto sizes = axesType.getSizes(); - - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - - // Extract the value of each axes: - for (int i = 0; i < sizes[0]; i++) { - // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, axes, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - axesList.push_back(dim); - } - } - - // Handle the noop case: - if (axesList.empty() && noop_with_empty_axes) { - rewriter.replaceOp(binder.op, data); - return success(); - } - - // Deal with case when no axes arg is passed but not a noop: - if (axesList.empty()) { - int64_t numDims = dyn_cast(data.getType()) - .getSizes() - .size(); - for (int i = 0; i < numDims; i++) { - Value curr = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - axesList.push_back(curr); - } - } - - // Handle negative axis: - Value rankVal = rewriter.create(binder.getLoc(), - torchIntTy, data); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(0)); - for (Value &axes : axesList) { - Value isNegative = - rewriter.create(binder.getLoc(), axes, zero); - isNegative = rewriter.create(binder.getLoc(), - isNegative); - Value finalOffset = rewriter.create( - binder.getLoc(), isNegative, rankVal); - axes = rewriter.create(binder.getLoc(), axes, - finalOffset); - } - - Value dimValueList = rewriter.create( - binder.getLoc(), Torch::ListType::get(torchIntTy), axesList); - Value keepDimBool = - rewriter.create(binder.getLoc(), keepDims); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimBool); - return success(); + return reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, false); }); patterns.onOp( @@ -1501,7 +1348,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // AtenAminOp allows us to pass a list of dims Torch::ValueTensorType resultType; Value data; - Value axes; int64_t keepDims; int64_t noop_with_empty_axes; if (binder.tensorOperandAtIndex(data, 0) || @@ -1565,87 +1411,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); } - // Previous version of the operation had the axes as an attribute: - SmallVector axesList; - llvm::SmallVector axesAttr; - if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) { - for (int i = 0, s = axesAttr.size(); i < s; ++i) { - axesList.push_back(rewriter.create( - binder.getLoc(), torchIntTy, - rewriter.getI64IntegerAttr(axesAttr[i]))); - } - } - - // Extract the axes values from the axes operand: - if (!binder.tensorOperandAtIndex(axes, 1)) { - Torch::BaseTensorType axesType = - cast(axes.getType()); - SmallVector selectSizes{1}; - Type selectResultType = axesType.getWithSizesAndDtype( - selectSizes, axesType.getOptionalDtype()); - auto sizes = axesType.getSizes(); - - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - - // Extract the value of each axes: - for (int i = 0; i < sizes[0]; i++) { - // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, axes, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - axesList.push_back(dim); - } - } - - // Handle the noop case: - if (axesList.empty() && noop_with_empty_axes) { - rewriter.replaceOp(binder.op, data); - return success(); - } - - // Deal with case when no axes arg is passed but not a noop: - if (axesList.empty()) { - int64_t numDims = dyn_cast(data.getType()) - .getSizes() - .size(); - for (int i = 0; i < numDims; i++) { - Value curr = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - axesList.push_back(curr); - } - } - - // Handle negative axis: - Value rankVal = rewriter.create(binder.getLoc(), - torchIntTy, data); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(0)); - for (Value &axes : axesList) { - Value isNegative = - rewriter.create(binder.getLoc(), axes, zero); - isNegative = rewriter.create(binder.getLoc(), - isNegative); - Value finalOffset = rewriter.create( - binder.getLoc(), isNegative, rankVal); - axes = rewriter.create(binder.getLoc(), axes, - finalOffset); - } - - Value dimValueList = rewriter.create( - binder.getLoc(), Torch::ListType::get(torchIntTy), axesList); - Value keepDimBool = - rewriter.create(binder.getLoc(), keepDims); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimBool); - return success(); + return reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, false); }); patterns.onOp( @@ -3104,7 +2872,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( llvm::SmallVector operands; std::string mode, nearest_mode, coordTfMode; int64_t antialias, exclude_outside; - float extrapolation_value; + float extrapolation_value, cubic_coeff_a; Value noneVal = rewriter.create(binder.getLoc()); if (auto attr = binder.op->getAttr("torch.onnx.axes")) { @@ -3129,7 +2897,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.f32FloatAttr(extrapolation_value, "extrapolation_value", 0.0) || binder.customOpNameStringAttr(nearest_mode, "nearest_mode", - "round_prefer_floor")) + "round_prefer_floor") || + binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75)) return failure(); if (antialias != 0) { return rewriter.notifyMatchFailure( @@ -3158,6 +2927,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "except asymmetric and half_pixel"); } + if (mode == "cubic" && cubic_coeff_a != -0.75) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: cubic coeff must be -0.75"); + } + unsigned rank = dyn_cast(operands[0].getType()) .getSizes() .size(); @@ -3173,8 +2947,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value alignCorners = coordTfMode == "align_corners" ? cstTrue : cstFalse; if (mode == "cubic") { - return rewriter.notifyMatchFailure(binder.op, - "unimplemented: bicubic mode"); + std::string modeStr = "cubic"; + if (coordTfMode != "half_pixel") + modeStr = modeStr + "_" + coordTfMode; + modeStrValue = + rewriter.create(binder.getLoc(), modeStr); } // supported modes: // bilinear (half_pixel), bilinear with align_corners, diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 143b46694030..458ea31852ec 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -82,6 +82,25 @@ class ConvertAtenBinaryOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenNegIntOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenNegIntOp op, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = adaptor.getA(); + rewriter.replaceOpWithNewOp( + op, + rewriter.create(op.getLoc(), /*value=*/0, + /*bitwidth=*/64), + a); + return success(); + } +}; +} // namespace + namespace { template class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern { @@ -465,11 +484,14 @@ class ConvertTorchToArith target.addIllegalOp(); patterns.add(typeConverter, context); - + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); + AtenMulIntOp, AtenRemainderIntOp>(); patterns.add>( typeConverter, context); + patterns.add>( + typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 3991711081ea..d6b5aaf869c8 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2691,7 +2691,7 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { }; } // namespace -static Value NearestInterpolate(OpBuilder &b, Location loc, +static Value nearestInterpolate(OpBuilder &b, Location loc, SmallVector outputSizes, Value input, SmallVector inputSizes, SmallVector scaleValues, @@ -2779,12 +2779,12 @@ static Value NearestInterpolate(OpBuilder &b, Location loc, return retVal; } -static Value BilinearInterpolate(OpBuilder &b, - Aten__InterpolateSizeListScaleListOp op, - Location loc, SmallVector outputSizes, - Value input, SmallVector inputSizes, - SmallVector scaleValues, - std::string coordStr) { +static SmallVector coordinateTransform( + OpBuilder &b, Aten__InterpolateSizeListScaleListOp op, Location loc, + SmallVector outputSizes, Value input, SmallVector inputSizes, + SmallVector scaleValues, std::string coordStr, bool alignCornersBool, + SmallVector indices, bool clip) { + unsigned dimOffset = 2; auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); @@ -2793,15 +2793,7 @@ static Value BilinearInterpolate(OpBuilder &b, Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); Value zero = b.create(loc, b.getF32FloatAttr(0.0)); - bool alignCornersBool; - matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); - - SmallVector indices; - for (unsigned i = 0; i < inputRank; i++) { - indices.push_back(b.create(loc, i)); - } - - SmallVector proj, projEps, high, low, highFP, lowFP; + SmallVector proj; for (unsigned i = 0; i < inputRank - dimOffset; i++) { // length_original Value inputFP = @@ -2864,13 +2856,50 @@ static Value BilinearInterpolate(OpBuilder &b, outputSizeFP, cstOneFloat); preClip = b.create(loc, cmp, zero, preClip); } - // preClip is the fp position inside the input image to extract from. - // clip to [0,inf) - Value max = b.create(loc, preClip, zero); + if (clip) { + // preClip is the fp position inside the input image to extract from. + // clip to [0,inf) + Value max = b.create(loc, preClip, zero); + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + // clip to [0,length_original - 1]. + // proj is properly within the input image. + proj.push_back(b.create(loc, max, inputSubOne)); + } else { + proj.push_back(preClip); + } + } + return proj; +} + +static Value bilinearInterpolate(OpBuilder &b, + Aten__InterpolateSizeListScaleListOp op, + Location loc, SmallVector outputSizes, + Value input, SmallVector inputSizes, + SmallVector scaleValues, + std::string coordStr) { + unsigned dimOffset = 2; + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); + + bool alignCornersBool; + matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); + + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) { + indices.push_back(b.create(loc, i)); + } + + SmallVector proj, high, low, highFP, lowFP; + proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes, + scaleValues, coordStr, alignCornersBool, indices, + true); + for (unsigned i = 0; i < inputRank - dimOffset; i++) { + // length_original + Value inputFP = + b.create(loc, b.getF32Type(), inputSizes[i]); Value inputSubOne = b.create(loc, inputFP, cstOneFloat); - // clip to [0,length_original - 1]. - // proj is properly within the input image. - proj.push_back(b.create(loc, max, inputSubOne)); // for bilinear interpolation, we look for the nearest indices below and // above proj @@ -2934,6 +2963,176 @@ static Value BilinearInterpolate(OpBuilder &b, return b.create(loc, left, right); } +static Value bicubicInterpolate(OpBuilder &b, + Aten__InterpolateSizeListScaleListOp op, + Location loc, SmallVector outputSizes, + Value input, SmallVector inputSizes, + SmallVector scaleValues, + std::string coordStr) { + unsigned dimOffset = 2; + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + Value inputFPH = + b.create(loc, b.getF32Type(), inputSizes[0]); + Value inputFPW = + b.create(loc, b.getF32Type(), inputSizes[1]); + + Value a = b.create(loc, b.getF32FloatAttr(-0.75)); + Value zero = b.create(loc, b.getF32FloatAttr(0.0)); + Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); + Value cstTwoFloat = b.create(loc, b.getF32FloatAttr(2.0)); + Value cstThreeFloat = + b.create(loc, b.getF32FloatAttr(3.0)); + Value cstFourFloat = b.create(loc, b.getF32FloatAttr(4.0)); + Value cstFiveFloat = b.create(loc, b.getF32FloatAttr(5.0)); + Value cstEightFloat = + b.create(loc, b.getF32FloatAttr(8.0)); + + // (a+2)|x|^3 - (a+3)|x|^2 + 1 for xDistance (|x| <= 1) + auto WeightLessThanEqualOne = [&](Value xDistance) -> Value { + Value xDistanceSquared = b.create(loc, xDistance, xDistance); + Value xDistanceCubed = + b.create(loc, xDistanceSquared, xDistance); + + Value lessEqualOne = b.create(loc, a, cstTwoFloat); + lessEqualOne = b.create(loc, xDistanceCubed, lessEqualOne); + Value aPlusThree = b.create(loc, a, cstThreeFloat); + aPlusThree = b.create(loc, xDistanceSquared, aPlusThree); + lessEqualOne = b.create(loc, lessEqualOne, aPlusThree); + lessEqualOne = b.create(loc, lessEqualOne, cstOneFloat); + + return lessEqualOne; + }; + + // a|x|^3 - 5a|x|^2 + 8a|x| - 4a for xDistance (1 < |x| < 2) + auto WeightLessThanTwo = [&](Value xDistance) -> Value { + Value xDistanceSquared = b.create(loc, xDistance, xDistance); + Value xDistanceCubed = + b.create(loc, xDistanceSquared, xDistance); + // a|x|^3 + Value lessThanTwo = b.create(loc, xDistanceCubed, a); + + Value fiveA = b.create(loc, xDistanceSquared, a); + fiveA = b.create(loc, fiveA, cstFiveFloat); + // a|x|^3 - 5a|x|^2 + lessThanTwo = b.create(loc, lessThanTwo, fiveA); + + Value eightA = b.create(loc, a, xDistance); + eightA = b.create(loc, eightA, cstEightFloat); + // a|x|^3 - 5a|x|^2 + 8a|x| + lessThanTwo = b.create(loc, eightA, lessThanTwo); + + Value fourA = b.create(loc, a, cstFourFloat); + // a|x|^3 - 5a|x|^2 + 8a|x| - 4a + lessThanTwo = b.create(loc, lessThanTwo, fourA); + return lessThanTwo; + }; + + bool alignCornersBool; + matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); + + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) { + indices.push_back(b.create(loc, i)); + } + + SmallVector proj; + + proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes, + scaleValues, coordStr, alignCornersBool, indices, + false); + + // get the nearest neighbors of proj + Value x1 = b.create(loc, proj[1]); + Value x_1 = b.create(loc, x1, cstOneFloat); + Value x_2 = b.create(loc, x_1, cstOneFloat); + Value x2 = b.create(loc, x1, cstOneFloat); + + Value y1 = b.create(loc, proj[0]); + Value y_1 = b.create(loc, y1, cstOneFloat); + Value y_2 = b.create(loc, y_1, cstOneFloat); + Value y2 = b.create(loc, y1, cstOneFloat); + + // calculate the distance of nearest neighbors x and y to proj + Value y2Distance = b.create(loc, proj[0], y2); + y2Distance = b.create(loc, y2Distance); + Value y1Distance = b.create(loc, proj[0], y1); + y1Distance = b.create(loc, y1Distance); + Value y_1Distance = b.create(loc, proj[0], y_1); + y_1Distance = b.create(loc, y_1Distance); + Value y_2Distance = b.create(loc, proj[0], y_2); + y_2Distance = b.create(loc, y_2Distance); + + Value x2Distance = b.create(loc, proj[1], x2); + x2Distance = b.create(loc, x2Distance); + Value x1Distance = b.create(loc, proj[1], x1); + x1Distance = b.create(loc, x1Distance); + Value x_1Distance = b.create(loc, proj[1], x_1); + x_1Distance = b.create(loc, x_1Distance); + Value x_2Distance = b.create(loc, proj[1], x_2); + x_2Distance = b.create(loc, x_2Distance); + + SmallVector y{y_2, y_1, y1, y2}; + SmallVector x{x_2, x_1, x1, x2}; + + SmallVector wys{ + WeightLessThanTwo(y_2Distance), WeightLessThanEqualOne(y_1Distance), + WeightLessThanEqualOne(y1Distance), WeightLessThanTwo(y2Distance)}; + SmallVector wxs{ + WeightLessThanTwo(x_2Distance), WeightLessThanEqualOne(x_1Distance), + WeightLessThanEqualOne(x1Distance), WeightLessThanTwo(x2Distance)}; + + // clip the nearest neighbors points to inside the original image + for (int k = 0; k < 4; k++) { + Value yClipped = b.create(loc, y[k], zero); + Value inputHSubOne = b.create(loc, inputFPH, cstOneFloat); + yClipped = b.create(loc, yClipped, inputHSubOne); + Value yInt = b.create(loc, b.getI64Type(), yClipped); + y[k] = b.create(loc, b.getIndexType(), yInt); + + Value xClipped = b.create(loc, x[k], zero); + Value inputWSubOne = b.create(loc, inputFPW, cstOneFloat); + xClipped = b.create(loc, xClipped, inputWSubOne); + Value xInt = b.create(loc, b.getI64Type(), xClipped); + x[k] = b.create(loc, b.getIndexType(), xInt); + } + // 1. Compute x_original and y_original (proj) + // 2. Compute nearest x and y neighbors + // 3. Compute Wx Wy + // 4. Extract inputs at nearest neighbors (inputExtracts) + // 5. Compute weighted sum (yield this) + + // 4 nearest x neighbors : [x_2, x_1, x1, x2] of x_original + // 4 nearest y neighbors : [y_2, y_1, y1, y2] of y_original + // Sum_x is over 4 nearest x neighbors (similar for Sum_y) + // f(x_original, y_original) = Sum_y Sum_x W(x_original - x)*input[x,y] + // * W(y_original - y) + Value fxy = zero; + + for (int j = 0; j < 4; j++) { + Value wy = wys[j]; + Value xInterpy = zero; + + indices[dimOffset] = y[j]; + + for (int i = 0; i < 4; i++) { + Value wx = wxs[i]; + + indices[dimOffset + 1] = x[i]; + + Value p = b.create(loc, input, indices); + + Value wxp = b.create(loc, wx, p); + xInterpy = b.create(loc, xInterpy, wxp); + } + Value wyXInterpy = b.create(loc, wy, xInterpy); + fxy = b.create(loc, fxy, wyXInterpy); + } + + return fxy; +} + namespace { class ConvertInterpolateOp : public OpConversionPattern { @@ -2949,7 +3148,8 @@ class ConvertInterpolateOp // coordinate_transformation_mode="asymmetric" will lower to an interpolate // op with the non-standard mode="bilinear_asymmetric". matchPattern(op.getMode(), m_TorchConstantStr(mode)); - if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") { + if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest" && + mode.substr(0, 5) != "cubic") { return failure(); } @@ -3031,13 +3231,18 @@ class ConvertInterpolateOp (mode.find(",") == std::string::npos) ? "" : mode.substr(mode.find(",") + 1); - retVal = NearestInterpolate( + retVal = nearestInterpolate( b, loc, outputSizeIntValues, input, inputSizes, ScaleFactorFloatValues, coordTfMode, nearestMode); } else if (mode.substr(0, 8) == "bilinear") { - retVal = BilinearInterpolate( + retVal = bilinearInterpolate( b, op, loc, outputSizeIntValues, input, inputSizes, ScaleFactorFloatValues, mode.substr(8)); + } else if (mode.substr(0, 5) == "cubic") { + + retVal = bicubicInterpolate( + b, op, loc, outputSizeIntValues, input, inputSizes, + ScaleFactorFloatValues, mode.substr(5)); } b.create(loc, retVal); }) diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 18e8fb449ef5..cf41bbcd711b 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -578,6 +578,12 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op, int64_t inputRank = inType.getRank(); Type elementType = inType.getElementType(); + // Check for 0-D tensor. + if (inputRank == 0) { + result = input; + return success(); + } + // Check if the dimensions are a valid constants. int64_t numDimensions = dimensions.size(); if (inputRank != numDimensions) @@ -596,28 +602,10 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op, Value outVector = rewriter.create( loc, getAsOpFoldResult(outputDims), elementType); - SmallVector idExprs; - SmallVector swapExprs; - for (uint32_t i = 0; i < inputRank; i++) - idExprs.push_back(getAffineDimExpr(i, rewriter.getContext())); - for (uint32_t i = 0; i < inputRank; i++) - swapExprs.push_back(idExprs[dimensions[i]]); - - AffineMap inputMap = - AffineMap::get(inputRank, /*symbolCount=*/0, idExprs, op->getContext()); - AffineMap outputMap = - AffineMap::get(inputRank, /*symbolCount=*/0, swapExprs, op->getContext()); - SmallVector indexingMaps{inputMap, outputMap}; - SmallVector iteratorTypes(inputRank, - utils::IteratorType::parallel); - result = rewriter - .create( - loc, outVector.getType(), input, outVector, indexingMaps, - iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) - .getResult(0); + + result = + rewriter.create(loc, input, outVector, dimensions) + ->getResult(0); return success(); } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 38cf65bd2831..464dfab5b700 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -74,11 +74,16 @@ class ConvertAtenUnaryOp : public OpConversionPattern { LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, + auto self = adaptor.getSelf(); + + auto outType = dyn_cast( OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - adaptor.getSelf()); + op.getType())); + + self = tosa::promoteType(rewriter, self, outType); + + rewriter.replaceOpWithNewOp(op, outType, self); + return success(); } }; @@ -2184,30 +2189,30 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Rsub"); - if (!isa(selfTy.getElementType())) - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype legalization supported"); + auto resultTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto resultElemTy = resultTy.getElementType(); + + self = tosa::promoteType(rewriter, self, resultTy); Value otherTensor, alphaTensor; if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor, - selfTy.getElementType(), {}))) + resultElemTy, {}))) return rewriter.notifyMatchFailure( op, "Currently only scalar constants are supported for " "conversion in TOSA Rsub operation"); if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), alphaScalar, - alphaTensor, selfTy.getElementType(), + alphaTensor, resultElemTy, /*checkForUnity=*/true))) return failure(); - auto multTensor = rewriter.create( - op->getLoc(), getTypeConverter()->convertType(op.getType()), self, - alphaTensor, /*shift=*/0); + auto multTensor = rewriter.create(op->getLoc(), resultTy, self, + alphaTensor, /*shift=*/0); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), otherTensor, - multTensor); + rewriter.replaceOpWithNewOp(op, resultTy, otherTensor, + multTensor); return success(); } @@ -4922,6 +4927,108 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for aten.clamp.Tensor +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenClampTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // We are not using tosa.clamp to lower aten.clamp.Tensor, as + // aten.clamp.Tensor's min and max attributes are tensors that can have size + // greater than 1, which is not compatible with tosa.clamp. + // + // Instead, we use the following formula: + // yi = min(max(xi, min_valuei), max_valuei) + auto self = adaptor.getSelf(); + + // Not a tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + auto selfElemTy = selfType.getElementType(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + + // Get min tensor. If None, there is no lower bound. + Value min; + if (succeeded(checkNotNone(rewriter, op, adaptor.getMin()))) { + min = adaptor.getMin(); + } else { + min = + TypeSwitch(selfElemTy) + .Case([&](auto) { + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::lowest(), {}, + selfElemTy) + .value(); + }) + .Case([&](auto intType) { + switch (intType.getWidth()) { + case 8: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::min(), {}) + .value(); + case 32: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::min(), + {}) + .value(); + case 64: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::min(), + {}) + .value(); + } + llvm_unreachable("Invalid integer width"); + }); + } + + // Get max tensor. If None, there is no upper bound. + Value max; + if (succeeded(checkNotNone(rewriter, op, adaptor.getMax()))) { + max = adaptor.getMax(); + } else { + max = + TypeSwitch(selfElemTy) + .Case([&](auto) { + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::max(), {}, + selfElemTy) + .value(); + }) + .Case([&](auto intType) { + switch (intType.getWidth()) { + case 8: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::max(), {}) + .value(); + case 32: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::max(), + {}) + .value(); + case 64: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::max(), + {}) + .value(); + } + llvm_unreachable("Invalid integer width"); + }); + } + + // max(xi, min_valuei) + auto minThresholdCheck = tosa::createBinaryOpAndCast( + rewriter, op, resultType, self, min); + + // yi = min(max(xi, min_valuei), max_valuei) + auto result = tosa::createBinaryOpAndCast( + rewriter, op, resultType, minThresholdCheck, max); + + rewriter.replaceOp(op, result); + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenArangeStartStepOp op, OpAdaptor adaptor, @@ -5431,11 +5538,29 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { ConvertAtenPoolingBaseOp::transposePoolingOutputToChw( op, rewriter, pooledOutput); - rewriter.replaceOpWithNewOp( - op, + Value result = transposedOutput; + auto resultTy = dyn_cast( OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - transposedOutput); + op.getType())); + + if constexpr (std::is_same() || + std::is_same()) { + auto resultShape = resultTy.getShape(); + auto resultElemTy = resultTy.getElementType(); + + result = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(resultShape), + resultElemTy), + transposedOutput, + rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape))); + } + + rewriter.replaceOpWithNewOp( + op, resultTy, + // OpConversionPattern::getTypeConverter()->convertType( + // op.getType()), + result); return success(); } @@ -5582,6 +5707,12 @@ static LogicalResult getOutputTypeAndPoolingParameters( return rewriter.notifyMatchFailure( op, "Non-const kernel_size for pooling op unsupported"); + // Expand kernel size parameter to size 2 to be compatible with + // tosa::MaxPool2dOp or tosa::AvgPool2dOp + if constexpr (std::is_same() || + std::is_same()) + kernelSizeInts.push_back(1); + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) return rewriter.notifyMatchFailure( op, "Non-const stride for pooling op unsupported"); @@ -5589,13 +5720,46 @@ static LogicalResult getOutputTypeAndPoolingParameters( // list during import. For such a case, the stride value is the kernel size. // See: // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d - if (strideInts.empty()) + if (strideInts.empty()) { strideInts.assign(kernelSizeInts); + } else { + // Expand stride parameter to size 2 to be compatible with + // tosa::MaxPool2dOp or tosa::AvgPool2dOp + if constexpr (std::is_same() || + std::is_same()) + strideInts.push_back(1); + } if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts))) return rewriter.notifyMatchFailure( op, "Non-const padding factor for pooling op unsupported"); + // Expand padding parameter to size 2 to be compatible with + // tosa::MaxPool2dOp or tosa::AvgPool2dOp + if constexpr (std::is_same() || + std::is_same()) + paddingInts.push_back(0); + + if constexpr (std::is_same() || + std::is_same()) { + // Currently, we can not represent `count_include_pad` with the existing + // TOSA AvgPool2d specification. Without the below check, we produce silent + // wrong answer (SWA) when the `count_include_pad` value is `true.` + // + // Note: We need to check for `count_include_pad` only when the `padding` + // value is non-zero. + bool countIncludePad; + if ((paddingInts[0] != 0 || paddingInts[1] != 0) && + (!matchPattern(op.getCountIncludePad(), + m_TorchConstantBool(&countIncludePad)) || + + countIncludePad)) { + return rewriter.notifyMatchFailure( + op, "Unsupported `count_include_pad` value, for tosa AvgPool " + "`count_include_pad` value should be `False`."); + } + } + SmallVector padArr = {paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]}; kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts); @@ -5651,6 +5815,68 @@ class ConvertAtenMaxPool2dOp } }; +// Legalization for aten.max_pool1d +class ConvertAtenMaxPool1dOp + : public ConvertAtenPoolingBaseOp { +public: + using ConvertAtenPoolingBaseOp::ConvertAtenPoolingBaseOp; + LogicalResult processInputs(AtenMaxPool1dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Value &input, + DenseI64ArrayAttr &kernel, + DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, + Type &outputTy) const override { + auto self = adaptor.getSelf(); + + // Not a RankedTensorType + auto selfTy = dyn_cast(self.getType()); + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor type inputs are supported"); + auto selfShape = selfTy.getShape(); + + // Expected a rank 3 input tensor + if (selfTy.getRank() != 3) + return rewriter.notifyMatchFailure( + op, "Input tensor for MaxPool1d should have rank 3"); + + // Unsqueeze input tensor to rank 4 to be compatible with tosa::MaxPool2dOp + SmallVector rank4Shape(selfShape); + rank4Shape.push_back(1); + auto reshapedSelf = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(rank4Shape), + selfTy.getElementType()), + self, rewriter.getDenseI64ArrayAttr(rank4Shape)); + + SmallVector dilationArray; + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationArray))) + return rewriter.notifyMatchFailure( + op, "Non-const dilation for pooling op unsupported."); + // TOSA pooling only supports unit dilation. + if (dilationArray[0] > 1) + return rewriter.notifyMatchFailure( + op, "Cannot process non-unit pooling dilation."); + + // Expand dilation to size 2 to be compatible with tosa::MaxPool2dOp + dilationArray.push_back(1); + + if (failed(getOutputTypeAndPoolingParameters( + op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy, + kernel, stride, pad))) + return rewriter.notifyMatchFailure( + op, "invalid pooling parameters or input type"); + + // Transpose to xHWC + input = ConvertAtenPoolingBaseOp:: + transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult()); + + return success(); + } +}; + class ConvertAtenAvgPool2dOp : public ConvertAtenPoolingBaseOp { public: @@ -5662,18 +5888,6 @@ class ConvertAtenAvgPool2dOp DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, Type &outputTy) const override { - // Currently, we can not represent `count_include_pad` with the existing - // TOSA AvgPool2d specification. Without the below check, we produce silent - // wrong answers (SWA) when the `count_include_pad` value is `true.` - bool countIncludePad; - if (!matchPattern(op.getCountIncludePad(), - m_TorchConstantBool(&countIncludePad)) || - countIncludePad) { - return rewriter.notifyMatchFailure( - op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp " - "`count_include_pad` value should be `False`."); - } - // Currently, we can not represent `divisor_override` with the existing TOSA // AvgPool2d specification. Without the below check, we produce silent wrong // answers (SWA) when the `divisor_override` value is other than `None.` @@ -5699,6 +5913,56 @@ class ConvertAtenAvgPool2dOp } }; +// Legalization for aten.avg_pool1d +class ConvertAtenAvgPool1dOp + : public ConvertAtenPoolingBaseOp { +public: + using ConvertAtenPoolingBaseOp::ConvertAtenPoolingBaseOp; + LogicalResult processInputs(AtenAvgPool1dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Value &input, + DenseI64ArrayAttr &kernel, + DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, + Type &outputTy) const override { + auto self = adaptor.getSelf(); + + // Not a RankedTensorType + auto selfTy = dyn_cast(self.getType()); + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor type inputs are supported"); + auto selfShape = selfTy.getShape(); + + // Expected a rank 3 input tensor + if (selfTy.getRank() != 3) + return rewriter.notifyMatchFailure( + op, "Input tensor for AvgPool1d should have rank 3"); + + // Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp + SmallVector rank4Shape(selfShape); + rank4Shape.push_back(1); + auto reshapedSelf = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(rank4Shape), + selfTy.getElementType()), + self, rewriter.getDenseI64ArrayAttr(rank4Shape)); + + SmallVector dilationArray{1, 1}; + if (failed(getOutputTypeAndPoolingParameters( + op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy, + kernel, stride, pad))) + return rewriter.notifyMatchFailure( + op, "invalid pooling parameters or input type"); + + // Transpose to xHWC + input = ConvertAtenPoolingBaseOp:: + transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult()); + + return success(); + } +}; + // Ref: Error checking based on the Torch to LinAlg lowering template class ConvertAtenConstPatternOp : public OpConversionPattern { @@ -6019,8 +6283,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto builtinTensors = getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType); - for (auto &in : builtinTensors) - in = tosa::promoteType(rewriter, in, outType); + for (auto &tensor : builtinTensors) + tensor = tosa::promoteType(rewriter, tensor, outType); auto result = tosa::CreateOpAndInfer( rewriter, loc, outType, builtinTensors, rewriter.getI32IntegerAttr(dim)); @@ -7078,6 +7342,475 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for torch.prims.collapse +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + PrimsCollapseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getA(); + + // Not a tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultShape = resultType.getShape(); + + int64_t start, end; + if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) + return rewriter.notifyMatchFailure( + op, "Only constant int start value is supported"); + + if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) + return rewriter.notifyMatchFailure( + op, "Only constant int end value is supported"); + + // Identity case + if (start == end) { + rewriter.replaceOp(op, self); + return success(); + } + + // Technically, I should calculate the output shape based on the input shape, + // start value, and end value. However, that would just give the same result + // as me taking the result shape straight from resultType and applying + // tosa::ReshapeOp to the input. Therefore, I'm opting for the latter approach + // here, which is more simple and quicker. + rewriter.replaceOpWithNewOp( + op, resultType, self, + rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape))); + + return success(); +} + +// Legalization for aten.reflection_pad1d +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenReflectionPad1dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto selfShape = selfType.getShape(); + auto selfRank = selfType.getRank(); + auto selfElemTy = selfType.getElementType(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + + SmallVector paddingList; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList))) + return rewriter.notifyMatchFailure( + op, "Non-const padding lists are not supported"); + + int64_t paddingLeft = paddingList[0]; + int64_t paddingRight = paddingList[1]; + + if (paddingLeft >= selfShape[selfRank - 1] || + paddingRight >= selfShape[selfRank - 1]) + return rewriter.notifyMatchFailure( + op, "Padding should be less than input boundary size"); + + // Identity case + if (paddingLeft == 0 && paddingRight == 0) { + rewriter.replaceOp(op, self); + return success(); + } + + SmallVector resultTensors; + + // Use tosa.slice and tosa.reverse to get the reflection pads based on the + // padding size + if (paddingLeft > 0) { + SmallVector leftStartSlice(selfRank, 0); + SmallVector leftSizeSlice(selfShape); + + leftStartSlice[selfRank - 1] = 1; + leftSizeSlice[selfRank - 1] = paddingLeft; + + SmallVector leftPadShape(selfShape.begin(), selfShape.end() - 1); + leftPadShape.push_back(paddingLeft); + + auto leftPadType = RankedTensorType::get(leftPadShape, selfElemTy); + + auto leftPadSlice = rewriter.create( + op->getLoc(), leftPadType, self, + rewriter.getDenseI64ArrayAttr(leftStartSlice), + rewriter.getDenseI64ArrayAttr(leftSizeSlice)); + + auto leftPad = rewriter.create( + op->getLoc(), leftPadType, leftPadSlice.getResult(), + static_cast(selfRank - 1)); + + resultTensors.push_back(leftPad.getResult()); + } + + resultTensors.push_back(self); + + if (paddingRight > 0) { + SmallVector rightStartSlice(selfRank, 0); + SmallVector rightSizeSlice(selfShape); + + rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - paddingRight - 1; + rightSizeSlice[selfRank - 1] = paddingRight; + + SmallVector rightPadShape(selfShape.begin(), selfShape.end() - 1); + rightPadShape.push_back(paddingRight); + + auto rightPadType = RankedTensorType::get(rightPadShape, selfElemTy); + + auto rightPadSlice = rewriter.create( + op->getLoc(), rightPadType, self, + rewriter.getDenseI64ArrayAttr(rightStartSlice), + rewriter.getDenseI64ArrayAttr(rightSizeSlice)); + + auto rightPad = rewriter.create( + op->getLoc(), rightPadType, rightPadSlice.getResult(), + static_cast(selfRank - 1)); + + resultTensors.push_back(rightPad.getResult()); + } + + auto result = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), resultType, resultTensors, selfRank - 1); + + rewriter.replaceOp(op, result); + return success(); +} + +// Legalization for aten.reflection_pad2d +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenReflectionPad2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto selfShape = selfType.getShape(); + auto selfRank = selfType.getRank(); + auto selfElemTy = selfType.getElementType(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultShape = resultType.getShape(); + + SmallVector paddingList; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList))) + return rewriter.notifyMatchFailure( + op, "Non-const padding lists are not supported"); + + int64_t paddingLeft = paddingList[0]; + int64_t paddingRight = paddingList[1]; + int64_t paddingTop = paddingList[2]; + int64_t paddingBottom = paddingList[3]; + + if (paddingLeft >= selfShape[selfRank - 1] || + paddingRight >= selfShape[selfRank - 1] || + paddingTop >= selfShape[selfRank - 2] || + paddingBottom >= selfShape[selfRank - 2]) + return rewriter.notifyMatchFailure( + op, "Padding must be less than the corresponding input dimension"); + + // Identity case + if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 && + paddingBottom == 0) { + rewriter.replaceOp(op, self); + return success(); + } + + // Use tosa.slice and tosa.reverse to get the reflection pads based on the + // padding size + SmallVector sideTensors; + + if (paddingLeft > 0) { + SmallVector leftStartSlice(selfRank, 0); + SmallVector leftSizeSlice(selfShape); + + leftStartSlice[selfRank - 1] = 1; + leftSizeSlice[selfRank - 1] = paddingLeft; + + SmallVector leftPadShape(selfShape.begin(), selfShape.end() - 1); + leftPadShape.push_back(paddingLeft); + + auto leftPadType = RankedTensorType::get(leftPadShape, selfElemTy); + + auto leftPadSlice = rewriter.create( + op->getLoc(), leftPadType, self, + rewriter.getDenseI64ArrayAttr(leftStartSlice), + rewriter.getDenseI64ArrayAttr(leftSizeSlice)); + + auto leftPad = rewriter.create( + op->getLoc(), leftPadType, leftPadSlice.getResult(), + static_cast(selfRank - 1)); + + sideTensors.push_back(leftPad.getResult()); + } + + sideTensors.push_back(self); + + if (paddingRight > 0) { + SmallVector rightStartSlice(selfRank, 0); + SmallVector rightSizeSlice(selfShape); + + rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - paddingRight - 1; + rightSizeSlice[selfRank - 1] = paddingRight; + + SmallVector rightPadShape(selfShape.begin(), selfShape.end() - 1); + rightPadShape.push_back(paddingRight); + + auto rightPadType = RankedTensorType::get(rightPadShape, selfElemTy); + + auto rightPadSlice = rewriter.create( + op->getLoc(), rightPadType, self, + rewriter.getDenseI64ArrayAttr(rightStartSlice), + rewriter.getDenseI64ArrayAttr(rightSizeSlice)); + + auto rightPad = rewriter.create( + op->getLoc(), rightPadType, rightPadSlice.getResult(), + static_cast(selfRank - 1)); + + sideTensors.push_back(rightPad.getResult()); + } + + SmallVector selfSidePaddedShape(selfShape.begin(), + selfShape.end() - 1); + selfSidePaddedShape.push_back(resultShape.back()); + + auto selfSidePadded = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + RankedTensorType::get(selfSidePaddedShape, selfElemTy), sideTensors, + selfRank - 1); + + SmallVector resultTensors; + + if (paddingTop > 0) { + SmallVector topStartSlice(selfRank, 0); + SmallVector topSizeSlice(selfShape.begin(), selfShape.end() - 1); + topSizeSlice.push_back(resultShape.back()); + + topStartSlice[selfRank - 2] = 1; + topSizeSlice[selfRank - 2] = paddingTop; + + SmallVector topPadShape(selfShape.begin(), selfShape.end() - 2); + topPadShape.push_back(paddingTop); + topPadShape.push_back(resultShape.back()); + + auto topPadType = RankedTensorType::get(topPadShape, selfElemTy); + + auto topPadSlice = rewriter.create( + op->getLoc(), topPadType, selfSidePadded, + rewriter.getDenseI64ArrayAttr(topStartSlice), + rewriter.getDenseI64ArrayAttr(topSizeSlice)); + + auto topPad = rewriter.create( + op->getLoc(), topPadType, topPadSlice.getResult(), + static_cast(selfRank - 2)); + + resultTensors.push_back(topPad.getResult()); + } + + resultTensors.push_back(selfSidePadded.getResult()); + + if (paddingBottom > 0) { + SmallVector bottomStartSlice(selfRank, 0); + SmallVector bottomSizeSlice(selfShape.begin(), + selfShape.end() - 1); + bottomSizeSlice.push_back(resultShape.back()); + + bottomStartSlice[selfRank - 2] = + selfShape[selfRank - 2] - paddingBottom - 1; + bottomSizeSlice[selfRank - 2] = paddingBottom; + + SmallVector bottomPadShape(selfShape.begin(), selfShape.end() - 2); + bottomPadShape.push_back(paddingBottom); + bottomPadShape.push_back(resultShape.back()); + + auto bottomPadType = RankedTensorType::get(bottomPadShape, selfElemTy); + + auto bottomPadSlice = rewriter.create( + op->getLoc(), bottomPadType, selfSidePadded, + rewriter.getDenseI64ArrayAttr(bottomStartSlice), + rewriter.getDenseI64ArrayAttr(bottomSizeSlice)); + + auto bottomPad = rewriter.create( + op->getLoc(), bottomPadType, bottomPadSlice.getResult(), + static_cast(selfRank - 2)); + + resultTensors.push_back(bottomPad.getResult()); + } + + auto result = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), resultType, resultTensors, selfRank - 2); + + rewriter.replaceOp(op, result); + return success(); +} + +// Legalization for aten.replication_pad2d +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenReplicationPad2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto selfShape = selfType.getShape(); + auto selfRank = selfType.getRank(); + auto selfElemTy = selfType.getElementType(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultShape = resultType.getShape(); + + SmallVector paddingList; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList))) + return rewriter.notifyMatchFailure( + op, "Non-const padding lists are not supported"); + + int64_t paddingLeft = paddingList[0]; + int64_t paddingRight = paddingList[1]; + int64_t paddingTop = paddingList[2]; + int64_t paddingBottom = paddingList[3]; + + // Identity case + if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 && + paddingBottom == 0) { + rewriter.replaceOp(op, self); + return success(); + } + + // Use tosa.slice to get the reflection pads based on the padding size + SmallVector sideTensors; + + if (paddingLeft > 0) { + SmallVector leftStartSlice(selfRank, 0); + SmallVector leftSizeSlice(selfShape); + + leftStartSlice[selfRank - 1] = 0; + leftSizeSlice[selfRank - 1] = 1; + + SmallVector leftPadSliceShape(selfShape.begin(), + selfShape.end() - 1); + leftPadSliceShape.push_back(1); + + auto leftPadSliceType = + RankedTensorType::get(leftPadSliceShape, selfElemTy); + + auto leftPadSlice = rewriter.create( + op->getLoc(), leftPadSliceType, self, + rewriter.getDenseI64ArrayAttr(leftStartSlice), + rewriter.getDenseI64ArrayAttr(leftSizeSlice)); + + for (int64_t i = 0; i < paddingLeft; i++) + sideTensors.push_back(leftPadSlice.getResult()); + } + + sideTensors.push_back(self); + + if (paddingRight > 0) { + SmallVector rightStartSlice(selfRank, 0); + SmallVector rightSizeSlice(selfShape); + + rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - 1; + rightSizeSlice[selfRank - 1] = 1; + + SmallVector rightPadSliceShape(selfShape.begin(), + selfShape.end() - 1); + rightPadSliceShape.push_back(1); + + auto rightPadSliceType = + RankedTensorType::get(rightPadSliceShape, selfElemTy); + + auto rightPadSlice = rewriter.create( + op->getLoc(), rightPadSliceType, self, + rewriter.getDenseI64ArrayAttr(rightStartSlice), + rewriter.getDenseI64ArrayAttr(rightSizeSlice)); + + for (int64_t i = 0; i < paddingRight; i++) + sideTensors.push_back(rightPadSlice.getResult()); + } + + SmallVector selfSidePaddedShape(selfShape.begin(), + selfShape.end() - 1); + selfSidePaddedShape.push_back(resultShape.back()); + + auto selfSidePadded = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + RankedTensorType::get(selfSidePaddedShape, selfElemTy), sideTensors, + selfRank - 1); + + SmallVector resultTensors; + + if (paddingTop > 0) { + SmallVector topStartSlice(selfRank, 0); + SmallVector topSizeSlice(selfShape.begin(), selfShape.end() - 1); + topSizeSlice.push_back(resultShape.back()); + + topStartSlice[selfRank - 2] = 0; + topSizeSlice[selfRank - 2] = 1; + + SmallVector topPadSliceShape(selfShape.begin(), + selfShape.end() - 2); + topPadSliceShape.push_back(1); + topPadSliceShape.push_back(resultShape.back()); + + auto topPadSliceType = RankedTensorType::get(topPadSliceShape, selfElemTy); + + auto topPadSlice = rewriter.create( + op->getLoc(), topPadSliceType, selfSidePadded, + rewriter.getDenseI64ArrayAttr(topStartSlice), + rewriter.getDenseI64ArrayAttr(topSizeSlice)); + + for (int64_t i = 0; i < paddingTop; i++) + resultTensors.push_back(topPadSlice.getResult()); + } + + resultTensors.push_back(selfSidePadded.getResult()); + + if (paddingBottom > 0) { + SmallVector bottomStartSlice(selfRank, 0); + SmallVector bottomSizeSlice(selfShape.begin(), + selfShape.end() - 1); + bottomSizeSlice.push_back(resultShape.back()); + + bottomStartSlice[selfRank - 2] = selfShape[selfRank - 2] - 1; + bottomSizeSlice[selfRank - 2] = 1; + + SmallVector bottomPadSliceShape(selfShape.begin(), + selfShape.end() - 2); + bottomPadSliceShape.push_back(1); + bottomPadSliceShape.push_back(resultShape.back()); + + auto bottomPadSliceType = + RankedTensorType::get(bottomPadSliceShape, selfElemTy); + + auto bottomPadSlice = rewriter.create( + op->getLoc(), bottomPadSliceType, selfSidePadded, + rewriter.getDenseI64ArrayAttr(bottomStartSlice), + rewriter.getDenseI64ArrayAttr(bottomSizeSlice)); + + for (int64_t i = 0; i < paddingBottom; i++) + resultTensors.push_back(bottomPadSlice.getResult()); + } + + auto result = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), resultType, resultTensors, selfRank - 2); + + rewriter.replaceOp(op, result); + return success(); +} + } // namespace // ----------------------------------------------------------------------------- @@ -7304,9 +8037,15 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ @@ -7402,6 +8141,11 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenUniformOp); INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp); INSERT_ATENOP_PATTERN(AtenAsStridedOp); + INSERT_ATENOP_PATTERN(AtenClampTensorOp); + INSERT_ATENOP_PATTERN(PrimsCollapseOp); + INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp); + INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp); + INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 0ab2e5782633..a90a66431a09 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -769,6 +769,22 @@ OpFoldResult Aten__Or__BoolOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenEqBoolOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenEqBoolOp::fold(FoldAdaptor adaptor) { + if (getOperand(0) == getOperand(1)) + return IntegerAttr::get(IntegerType::get(getContext(), 1), true); + + auto intAttrA = dyn_cast_or_null(adaptor.getA()); + auto intAttrB = dyn_cast_or_null(adaptor.getB()); + if (!intAttrA || !intAttrB) + return nullptr; + return IntegerAttr::get(IntegerType::get(getContext(), 1), + intAttrA.getValue() == intAttrB.getValue()); +} + //===----------------------------------------------------------------------===// // AtenNeBoolOp //===----------------------------------------------------------------------===// @@ -777,12 +793,12 @@ OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) { if (getOperand(0) == getOperand(1)) return IntegerAttr::get(IntegerType::get(getContext(), 1), false); - bool a, b; - if (!matchPattern(getOperand(0), m_TorchConstantBool(&a))) - return nullptr; - if (!matchPattern(getOperand(1), m_TorchConstantBool(&b))) + auto intAttrA = dyn_cast_or_null(adaptor.getA()); + auto intAttrB = dyn_cast_or_null(adaptor.getB()); + if (!intAttrA || !intAttrB) return nullptr; - return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b); + return IntegerAttr::get(IntegerType::get(getContext(), 1), + intAttrA.getValue() != intAttrB.getValue()); } //===----------------------------------------------------------------------===// @@ -1131,6 +1147,35 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenMulLeftTOp +//===----------------------------------------------------------------------===// + +void AtenMulLeftTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + // `[1,2] * 3` -> `[1,2,1,2,1,2]`, if it is not mutated. + patterns.add(+[](AtenMulLeftTOp op, PatternRewriter &rewriter) { + auto listLiteral = op.getL().getDefiningOp(); + if (!listLiteral || isListPotentiallyMutated(listLiteral)) + return failure(); + + int64_t numReps; + if (!matchPattern(op.getN(), m_TorchConstantInt(&numReps))) + return failure(); + + SmallVector newListElements; + for (int rep = 0; rep < numReps; ++rep) { + for (auto operand : listLiteral.getOperands()) { + newListElements.push_back(operand); + } + } + + rewriter.replaceOpWithNewOp(op, op.getL().getType(), + newListElements); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenMinOtherOp //===----------------------------------------------------------------------===// @@ -4083,6 +4128,10 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { int64_t lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); + if (lConstant && lhs == 1) + return getOperand(1); + if (rConstant && rhs == 1) + return getOperand(0); if ((lConstant && lhs == 0) || (rConstant && rhs == 0)) return getI64IntegerAttr(getContext(), 0); if (lConstant && rConstant) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index d626fe6da4b2..de9db68b608f 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7758,7 +7758,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.If %2 -> (!torch.list) {\n" " %5 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list\n" " %6 = torch.aten.sub.int %1, %0 : !torch.int, !torch.int -> !torch.int\n" -" %7 = torch.operator \"aten.mul.left_t\"(%5, %6) : (!torch.list, !torch.int) -> !torch.list \n" +" %7 = torch.aten.mul.left_t %5, %6 : !torch.list, !torch.int -> !torch.list\n" " %8 = torch.aten.add.t %7, %arg1 : !torch.list, !torch.list -> !torch.list\n" " torch.prim.If.yield %8 : !torch.list\n" " } else {\n" @@ -8969,7 +8969,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %14 = call @__torch__.torch.jit._shape_functions.broadcast_three(%5, %6, %7) : (!torch.list, !torch.list, !torch.list) -> !torch.list\n" " %15 = torch.prim.ListConstruct %false : (!torch.bool) -> !torch.list\n" " %16 = torch.aten.len.t %14 : !torch.list -> !torch.int\n" -" %17 = torch.operator \"aten.mul.left_t\"(%15, %16) : (!torch.list, !torch.int) -> !torch.list \n" +" %17 = torch.aten.mul.left_t %15, %16 : !torch.list, !torch.int -> !torch.list\n" " %18 = torch.aten.len.t %arg6 : !torch.list -> !torch.int\n" " torch.prim.Loop %18, %true, init() {\n" " ^bb0(%arg8: !torch.int):\n" @@ -9833,7 +9833,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %76 = torch.aten.append.t %72, %75 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.Loop.condition %true, iter()\n" " } : (!torch.int, !torch.bool) -> ()\n" -" %74 = torch.operator \"aten.add_.t\"(%71, %72) : (!torch.list, !torch.list) -> !torch.list \n" +" %74 = torch.aten.add.t %71, %72 : !torch.list, !torch.list -> !torch.list\n" " return %74 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.topk\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple, list> {\n" @@ -10997,7 +10997,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" " %24 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.bool\n" -" %25 = torch.operator \"aten.eq.bool\"(%24, %true) : (!torch.bool, !torch.bool) -> !torch.bool \n" +" %25 = torch.aten.eq.bool %24, %true : !torch.bool, !torch.bool -> !torch.bool\n" " torch.prim.If.yield %25 : !torch.bool\n" " }\n" " torch.prim.If %17 -> () {\n" @@ -11016,7 +11016,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %22 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" " %23 = torch.prim.If %22 -> (!torch.bool) {\n" " %24 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" -" %25 = torch.operator \"aten.eq.bool\"(%24, %false) : (!torch.bool, !torch.bool) -> !torch.bool \n" +" %25 = torch.aten.eq.bool %24, %false : !torch.bool, !torch.bool -> !torch.bool\n" " torch.prim.If.yield %25 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 06c56ba14e3f..d371cf32d735 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4587,6 +4587,11 @@ class DecomposeAtenUnflattenIntOp if (!isValidDim(dimInt, inputRank)) return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + if (inputShape[dimInt] == Torch::kUnknownSize && + llvm::count(sizesInts, -1) > 0) + return rewriter.notifyMatchFailure( + op, "Unimplemented: dynamic unflatten dim with an inferred size."); + SmallVector sizesTorchInt; if (!getListConstructElements(op.getSizes(), sizesTorchInt)) return rewriter.notifyMatchFailure( diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 3d1a54de29f9..989057501957 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -714,7 +714,7 @@ class PropagateAtenItemPattern : public OpRewritePattern { ImplicitLocOpBuilder b(op.getLoc(), rewriter); // Rank 0 item op prop - if (selfTy.getSizes().size() == 0) { + if (selfTy.getSizes().empty()) { auto numToTensor = self.getDefiningOp(); auto squeezeDim = self.getDefiningOp(); if (!squeezeDim && !numToTensor) @@ -746,6 +746,109 @@ class PropagateAtenItemPattern : public OpRewritePattern { }; } // namespace +namespace { + +LogicalResult convertOpFoldResults(ImplicitLocOpBuilder &b, + SmallVector &converted, + SmallVector &elements, + Type inputDtype, Type resultDtype) { + auto inputIsInt = dyn_cast(inputDtype); + auto resultIsInt = dyn_cast(resultDtype); + if (!inputIsInt && !isa(inputDtype)) + return failure(); + if (!resultIsInt && !isa(resultDtype)) + return failure(); + + // if dtypes are both int or both float, no conversion needed + if (static_cast(inputIsInt) == static_cast(resultIsInt)) { + converted = elements; + return success(); + } + + if (resultIsInt) { + for (auto &e : elements) { + auto eValue = dyn_cast(e); + if (eValue) { + converted.push_back(b.createOrFold(eValue)); + continue; + } + auto eAttr = dyn_cast(e); + auto eFloatAttr = dyn_cast_or_null(eAttr); + if (!eFloatAttr) + return failure(); + + converted.push_back(IntegerAttr::get( + resultDtype, static_cast(eFloatAttr.getValueAsDouble()))); + } + return success(); + } + + // result is float + for (auto &e : elements) { + auto eValue = dyn_cast(e); + if (eValue) { + converted.push_back(b.createOrFold(eValue)); + continue; + } + auto eAttr = dyn_cast(e); + auto eIntAttr = dyn_cast(eAttr); + if (!eIntAttr) + return failure(); + + auto eInt = (inputIsInt.isSigned()) ? eIntAttr.getValue().getSExtValue() + : eIntAttr.getValue().getZExtValue(); + converted.push_back(FloatAttr::get(resultDtype, static_cast(eInt))); + } + return success(); +} + +class PropagateAtenToDtypePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenToDtypeOp op, + PatternRewriter &rewriter) const override { + bool nonBlocking, copyArg; + // The non_blocking arg must be `False`. + if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking)) || + nonBlocking) + return failure(); + // The copy arg must be `False`. + if (!matchPattern(op.getCopy(), m_TorchConstantBool(©Arg)) || copyArg) + return failure(); + // The memory_format arg must be `none`. + if (!isa(op.getMemoryFormat().getType())) + return failure(); + + auto inputType = dyn_cast(op.getSelf().getType()); + auto resultType = dyn_cast(op.getType()); + if (!inputType || !resultType || !inputType.hasDtype() || + !resultType.hasDtype()) + return failure(); + auto inputDtype = inputType.getDtype(); + auto resultDtype = resultType.getDtype(); + + SmallVector elements; + if (failed(getListFromTensor(op.getSelf(), elements))) + return failure(); + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + SmallVector converted; + if (failed(convertOpFoldResults(b, converted, elements, inputDtype, + resultDtype))) + return rewriter.notifyMatchFailure( + op, "Unhandled attribute type encountered."); + + SmallVector vals; + if (failed(materializeFolds(b, converted, vals))) + return failure(); + + Value result = constructAtenTensorOpFromList(b, op.getType(), vals); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + namespace { template class PropagateAtenViewLikePattern : public OpRewritePattern { @@ -828,7 +931,7 @@ class PropagateAtenArithmeticPattern : public OpRewritePattern { if (failed(materializeFolds(b, resultFolds, resultVals))) return failure(); - if (resultTy.getSizes().size() == 0) { + if (resultTy.getSizes().empty()) { rewriter.replaceOpWithNewOp( op, resultTy, resultVals.front()); return success(); @@ -841,6 +944,48 @@ class PropagateAtenArithmeticPattern : public OpRewritePattern { }; } // namespace +namespace { +template +class PropagateAtenUnaryPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Check type + auto resultTy = cast(op.getType()); + if (resultTy.getSizes().size() > 1) + return rewriter.notifyMatchFailure(op, "unsupported: rank > 1"); + if (!resultTy.hasDtype() || !isa(resultTy.getDtype())) + return rewriter.notifyMatchFailure(op, "not an int type"); + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + SmallVector selfFold; + if (failed(getListFromTensor(op.getSelf(), selfFold))) + return failure(); + SmallVector selfVals; + if (failed(materializeFolds(b, selfFold, selfVals))) + return failure(); + SmallVector resultFolds; + for (uint64_t i = 0; i < selfVals.size(); i++) { + resultFolds.push_back( + b.createOrFold(selfVals[i].getType(), selfVals[i])); + } + SmallVector resultVals; + if (failed(materializeFolds(b, resultFolds, resultVals))) + return failure(); + + if (resultTy.getSizes().size() == 0) { + rewriter.replaceOpWithNewOp( + op, resultTy, resultVals.front()); + return success(); + } + + Value result = constructAtenTensorOpFromList(b, resultTy, resultVals); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace /// ------ Fold Patterns ------ /// // These are shape-specific folding patterns @@ -915,6 +1060,11 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern { auto resultTy = cast(op.getType()); if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown()) return rewriter.notifyMatchFailure(op, "dynamic output shape"); + if (resultTy.getSizes().size() == 0) { + rewriter.replaceOpWithNewOp( + op, op.getType(), elements.front()); + return success(); + } auto loc = op.getLoc(); SmallVector sizes; @@ -922,12 +1072,10 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern { sizes.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(size))); - Value one = rewriter.create( - loc, rewriter.getType(), 1); Value sizeList = rewriter.create( loc, rewriter.getType(rewriter.getType()), - one); + sizes); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); @@ -1031,6 +1179,24 @@ class FoldAtenWhereSelf : public OpRewritePattern { }; } // namespace +namespace { +// fold ridiculous patterns like size.int -> float.scalar -> int.scalar +class FoldAtenIntScalarPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIntScalarOp op, + PatternRewriter &rewriter) const override { + auto floatScalarOp = op.getA().getDefiningOp(); + if (!floatScalarOp) + return failure(); + auto sizeOp = floatScalarOp.getA().getDefiningOp(); + if (!sizeOp) + return failure(); + rewriter.replaceOp(op, floatScalarOp.getA()); + return success(); + } +}; +} // namespace namespace { class FoldAtenUnsqueezePattern : public OpRewritePattern { public: @@ -1182,8 +1348,29 @@ class CanonicalizeAtenViewPattern : public OpRewritePattern { if (inputUnmatched == 1 && outputUnmatched > 1) { Value dimVal = rewriter.create(op.getLoc(), leftMatchEnd); - ArrayRef unflattenSizes(viewSizes.begin() + leftMatchEnd, - viewSizes.end() - rightMatchEnd); + SmallVector unflattenSizes(viewSizes.begin() + leftMatchEnd, + viewSizes.end() - rightMatchEnd); + // try to convert a single dynamic size input to -1 + int64_t dynCount = 0; + int64_t dynIdx = 0; + for (auto [i, v] : llvm::enumerate(unflattenSizes)) { + int64_t szeInt; + if (!matchPattern(v, m_TorchConstantInt(&szeInt))) { + dynCount++; + dynIdx = i; + continue; + } + // if we have a -1 already, make dynCount invalid and break + if (szeInt == -1) { + dynCount = -1; + break; + } + } + // if only one size is dynamic, make it -1 + if (dynCount == 1) + unflattenSizes[dynIdx] = + rewriter.create(op.getLoc(), -1); + Value unflattenList = rewriter.create( op.getLoc(), op.getSize().getType(), unflattenSizes); rewriter.replaceOpWithNewOp( @@ -1227,6 +1414,18 @@ template class RemoveUnusedPattern : public OpRewritePattern { namespace { +bool isItemForSliceOp(Operation *op) { + auto itemOp = dyn_cast_or_null(op); + if (!itemOp) + return false; + for (OpOperand &use : op->getUses()) { + Operation *userOp = use.getOwner(); + if (isa(userOp)) + return true; + } + return false; +} + bool isSourceOpForShapeScalarization(Operation *op) { return llvm::isa(op); @@ -1244,7 +1443,7 @@ bool isPrimListOfInts(Operation *op) { bool isAnchorOp(Operation *op) { return isa(op) || isa(op) || - isPrimListOfInts(op); + isPrimListOfInts(op) || isItemForSliceOp(op); } // The argument to this function, op, is the use of some source op, srcOp. If @@ -1278,9 +1477,9 @@ bool isInvalidValidViewConsumer(Operation *op, void populateScalarizationFoldPatterns(RewritePatternSet &patterns) { patterns.insert, FoldAtenSqueezePattern, - FoldAtenUnsqueezePattern, FoldAtenWhereSelf, - FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>( - patterns.getContext()); + FoldAtenIntScalarPattern, FoldAtenUnsqueezePattern, + FoldAtenWhereSelf, FoldAtenTensorSplatPattern, + FoldAtenEqIntPattern>(patterns.getContext()); } void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) { @@ -1303,10 +1502,12 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern, PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern, - PropagateAtenTransposeIntPattern, + PropagateAtenTransposeIntPattern, PropagateAtenToDtypePattern, + PropagateAtenUnaryPattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern, + PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern>( patterns.getContext()); } @@ -1314,6 +1515,7 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { patterns.insert, RemoveUnusedPattern, + RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, @@ -1321,6 +1523,8 @@ void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, RemoveUnusedPattern>( patterns.getContext()); } diff --git a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp index 1e6879530ce6..229b352094e8 100644 --- a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp +++ b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp @@ -104,7 +104,8 @@ class UnpackQuantizedMatmulWeights char mask = (1 << unpackedBitWidth) - 1; for (int b = 0; b < packRatio; b++) { newData[i * packRatio + b] = - APInt(unpackedBitWidth, (el & mask) >> (unpackedBitWidth * b)); + APInt(unpackedBitWidth, (el & mask) >> (unpackedBitWidth * b), + /*isSigned=*/false, /*implicitTrunc=*/true); mask = mask << unpackedBitWidth; } } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0b1694e3b5fb..03fa943f25b1 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1806,6 +1806,49 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "ReflectionPad1dModule2dInput_Right", + "ReflectionPad1dModule2dInput_basic", + "ReflectionPad1dModule3dInput_Left", + "ReflectionPad1dModule3dInput_basic", + "ReflectionPad2dModule_Bottom", + "ReflectionPad2dModule_Left", + "ReflectionPad2dModule_Right", + "ReflectionPad2dModule_Top", + "ReflectionPad2dModule_basic", + "ReplicationPad2dModule_basic", + "ReplicationPad2dModule_bottom0", + "ReplicationPad2dModule_left0", + "ReplicationPad2dModule_right0", + "ReplicationPad2dModule_top0", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", + "ElementwiseAtenLogicalNotOpPromoteModule_basic", + "ElementwiseCosIntModule_basic", + "ElementwiseReciprocalIntModule_basic", + "ElementwiseRsqrtIntModule_basic", + "ElementwiseSinIntModule_basic", + "FloatPowerTensorTensorStaticModule_basic", + "AdaptiveMaxPool1dDimOneStatic_basic", + "CollapseAllDimensionsModule_basic", + "CollapseRank1DynamicModule_basic", + "CollapseStaticModule_basic", + "ElementwiseClampMinTensorFloatModule_basic", + "ElementwiseClampMinTensorIntModule_basic", + "ElementwiseClampTensorFloatModule_basic", + "ElementwiseClampTensorIntModule_basic", + "ElementwiseFracModule_basic", + "ElementwiseLdexpModule_basic", + "ElementwiseSignbitIntModule_basic", + "Exp2StaticIntModule_basic", + "MaxPool1dEmptyStrideStaticModule_basic", + "MaxPool1dStaticCeilModeTrueModule_basic", + "MaxPool1dStaticModule_basic", + "RepeatInterleaveSelfIntModule_basic", + "RsubIntModule_noalpha_basic", "Aten_TrilinearModuleSumAllDims_basic", "Aten_TrilinearModuleSumdims_basic", "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", @@ -1831,7 +1874,6 @@ "SliceCopy_Module_basic", "Threshold1dIntModule_basic", "Threshold2dIntModule_basic", - "Threshold3dIntModule_basic", "EmptyModule_contiguous", "EmptyModule_defaultDtype", "EmptyModule_falsePinMemory", @@ -2364,6 +2406,7 @@ "ReshapeExpandModule_basic", "ReturnThreeTensorFloat32_basic", "ReturnTwoTensorF32I64_basic", + "ResNet18StaticModule_basic", "RsubFloatModule_basic", "RsubFloatModule_noalpha_basic", "RsubInt0d_NumToTensor_Module_basic", @@ -2480,6 +2523,8 @@ TOSA_PASS_SET | { ### Tests additionally passing in make_fx_tosa + "AdaptiveAvgPool1dStaticEvenMultiple_basic", + "IsInfiniteModule_basic", "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "ResNet18StaticModule_basic", @@ -2555,6 +2600,8 @@ } ) - { ### Test failing in make_fx_tosa but not in tosa + "AdaptiveMaxPool1dDimOneStatic_basic", + "FloatPowerTensorTensorStaticModule_basic", # Dynamic shape, has extra unsupported broadcast ops "Matmul_3d", # Unimplemented operator 'aten._index_put_impl_.hacked_twin' @@ -3442,21 +3489,17 @@ } FX_IMPORTER_TOSA_XFAIL_SET = { + "IsInfiniteModule_basic", + "LayerNormFwAndBwModule_basic", + "LayerNormManualFwAndBwModule_basic", + "SelfAttentionFwAndBwModule_basic", + "ElementwiseCopysignModule_basic", + "ElementwiseSignbitModule_basic", "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dStaticEvenMultiple_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic", - "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", - "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", - "AdaptiveMaxPool1dDimOneStatic_basic", "ElementwiseRreluWithNoiseTrainModule_basic", "ElementwiseRreluWithNoiseTrainStaticModule_basic", "MaxPool3dEmptyStrideStaticModule_basic", @@ -3584,11 +3627,6 @@ "BoolIntTrueModule_basic", "BroadcastDynamicDimModule_basic", "CeilFloatModule_basic", - "CollapseAllDimensionsModule_basic", - "CollapseFullDynamicModule_basic", - "CollapsePartialDynamicModule_basic", - "CollapseRank1DynamicModule_basic", - "CollapseStaticModule_basic", "ConstantBoolParameterModule_basic", "ContainsIntList_False", "ContainsIntList_True", @@ -3659,19 +3697,12 @@ "ElementwiseAtanTensorIntModule_basic", "ElementwiseAtanhIntModule_basic", "ElementwiseAtanhModule_basic", - "ElementwiseAtenLogicalNotOpPromoteModule_basic", - "ElementwiseClampMinTensorFloatModule_basic", - "ElementwiseClampMinTensorIntModule_basic", - "ElementwiseClampTensorFloatModule_basic", - "ElementwiseClampTensorIntModule_basic", "ElementwiseCopysignModule_basic", - "ElementwiseCosIntModule_basic", "ElementwiseCoshIntModule_basic", "ElementwiseCoshModule_basic", "ElementwiseCreateComplexModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", - "ElementwiseErfIntModule_basic", "ElementwiseExpIntModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", @@ -3688,12 +3719,9 @@ "ElementwiseMulTensorComplexModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", - "ElementwiseReciprocalIntModule_basic", "ElementwiseRreluWithNoiseTrainModule_basic", "ElementwiseRreluWithNoiseTrainStaticModule_basic", - "ElementwiseRsqrtIntModule_basic", "ElementwiseSigmoidIntModule_basic", - "ElementwiseSinIntModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", "ElementwiseTanIntModule_basic", @@ -3845,23 +3873,9 @@ "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", "ReduceSumDimIntListEmptyDimModule_basic", - "ReflectionPad1dModule2dInput_Right", - "ReflectionPad1dModule2dInput_basic", - "ReflectionPad1dModule3dInput_Left", - "ReflectionPad1dModule3dInput_basic", - "ReflectionPad2dModule_Bottom", - "ReflectionPad2dModule_Left", - "ReflectionPad2dModule_Right", - "ReflectionPad2dModule_Top", - "ReflectionPad2dModule_basic", "RepeatInterleaveFillModule_basic", "RepeatInterleaveModule_basic", "RepeatInterleaveStaticModule_basic", - "ReplicationPad2dModule_basic", - "ReplicationPad2dModule_bottom0", - "ReplicationPad2dModule_left0", - "ReplicationPad2dModule_right0", - "ReplicationPad2dModule_top0", "RollModule_basic", "ResNet18Module_basic", "ResNet18StaticModule_basic", @@ -3943,7 +3957,6 @@ # Unexpected failures due to new PyTorch version update "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AdaptiveAvgPool1dGeneralDynamic_basic", - "AdaptiveAvgPool1dStaticEvenMultiple_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic", "AdaptiveAvgPool2dDynamicNoBatch_basic", "AdaptiveAvgPool2dDynamic_basic", @@ -3963,24 +3976,15 @@ "ScaledDotProductAttentionSameCausalModule_basic", "ScaledDotProductAttentionSameDynamicModule_basic", "ScaledDotProductAttentionSameModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", - "AdaptiveMaxPool1dDimOneStatic_basic", "GridSamplerBasic1_basic", "GridSamplerBasic2_basic", "GridSamplerBasic3_basic", "GridSamplerBasic4_basic", - "MaxPool1dEmptyStrideStaticModule_basic", - "MaxPool1dStaticCeilModeTrueModule_basic", - "MaxPool1dStaticModule_basic", "MaxPool3dEmptyStrideStaticModule_basic", "MaxPool3dLargeDatadModule_basic", "MaxPool3dModuleRandomSimple_basic", "MaxPool3dModule_basic", "MaxPool3dStaticModule_basic", - "RepeatInterleaveSelfIntModule_basic", "Mlp1LayerModule_basic", "Mlp2LayerModuleNoBias_basic", "Mlp2LayerModule_basic", @@ -3993,6 +3997,7 @@ if torch_version_for_comparison() < version.parse("2.6.0.dev"): # Passing on stable but not on nightly FX_IMPORTER_TOSA_XFAIL_SET -= { + "AdaptiveAvgPool1dStaticEvenMultiple_basic", "AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic", @@ -4062,6 +4067,7 @@ "NumToTensorFloatModule_basic", "NumToTensorIntModule_basic", "RsubInt0d_NumToTensor_Module_basic", + "AdaptiveMaxPool1dDimOneStatic_basic", } ONNX_TOSA_CRASHING_SET = { @@ -4074,6 +4080,18 @@ } ONNX_TOSA_XFAIL_SET = { + "FloatPowerTensorTensorStaticModule_basic", + "IsInfiniteModule_basic", + "ElementwiseCopysignModule_basic", + "ElementwiseFracModule_basic", + "ElementwiseLdexpModule_basic", + "ElementwiseSignbitIntModule_basic", + "ElementwiseSignbitModule_basic", + "Exp2StaticIntModule_basic", + "NllLossStaticModule_basic", + "NllLossStaticModule_mean_basic", + "NllLossStaticModule_sum_basic", + "NllLossStaticModule_weight_basic", "Exp2StaticModule_basic", "ElementwiseRreluWithNoiseEvalModule_basic", "ElementwiseRreluWithNoiseEvalStaticModule_basic", @@ -4097,7 +4115,6 @@ "TriuIndicesAllZerosModule_basic", "ElementwiseCreateComplexModule_basic", "ReduceAllDimFloatModule_basic", - "AdaptiveMaxPool1dDimOneStatic_basic", "ScaledDotProductAttentionDifferentCausalModule_basic", "HstackBasicComplexModule_basic", "HstackBasicFloatModule_basic", @@ -4176,7 +4193,6 @@ "AdaptiveAvgPool1dStaticEvenMultiple_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic", "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dDynamicNoBatch_basic", "AdaptiveAvgPool2dDynamic_basic", "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", @@ -4309,7 +4325,6 @@ "ChunkListUnpackDynamic_Module_basic", "ChunkListUnpackUnevenDynamic_Module_basic", "ChunkListUnpackUneven_Module_basic", - "ChunkListUnpack_Module_basic", "CollapseAllDimensionsModule_basic", "CollapseFullDynamicModule_basic", "CollapsePartialDynamicModule_basic", @@ -4432,10 +4447,6 @@ "ElementwiseBitwiseRightShiftInt8Module_basic", "ElementwiseBitwiseXorModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic", - "ElementwiseClampMaxModule_basic", - "ElementwiseClampMinModule_basic", - "ElementwiseClampModule_basic", - "ElementwiseClampTensorInt8Module_basic", "ElementwiseCosIntModule_basic", "ElementwiseCoshIntModule_basic", "ElementwiseCoshModule_basic", @@ -4482,7 +4493,6 @@ "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", - "ElementwiseRelu6Module_basic", "ElementwiseRemainderScalarModule_Bool_basic", "ElementwiseRemainderScalarModule_Int_Float_basic", "ElementwiseRemainderTensorModule_Int_Float_basic", @@ -4561,8 +4571,6 @@ "GtFloatIntModule_basic", "GtIntModule_basic", "HBC_basic", - "HardTanhIntModule_basic", - "HardTanhModule_basic", "HardtanhBackward_basic", "IndexPut1DFloatAccumulateModule_basic", "IndexPut1DFloatNonAccumulateModule_basic", @@ -4610,7 +4618,6 @@ "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", "IndexTensorModule3dInput_basic", "IndexTensorModule_basic", - "IndexTensorMultiIndexStaticModule_basic", "IndexTensorMultiInputContiguousCenter_basic", "IndexTensorMultiInputContiguousOneDimDynamic_basic", "IndexTensorMultiInputNonContiguousDynamic_basic", @@ -4621,8 +4628,6 @@ "IndexTensorMultiInputThreeIndexers_basic", "IndexTensorMultiInput_basic", "IndexTensorSelectDimModule_basic", - "IndexTensorStaticContiguousWithNoneModule_basic", - "IndexTensorStaticNonContiguousWithNoneModule_basic", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", "InterpolateDynamicModule_scales_recompute_bilinear", @@ -4650,10 +4655,7 @@ "Matmul_matvec", "Matmul_vecmat", "MaxPool1dCeilModeTrueModule_basic", - "MaxPool1dEmptyStrideStaticModule_basic", "MaxPool1dModule_basic", - "MaxPool1dStaticCeilModeTrueModule_basic", - "MaxPool1dStaticModule_basic", "MaxPool2dCeilModeTrueModule_basic", "MaxPool2dModule_basic", "MaxPool2dWithIndicesAllNegativeValuesModule_basic", @@ -4697,7 +4699,6 @@ "MeanDimNoneDimModule_basic", "MeanDtypeModule_basic", "MeanDynamicSizesModule_basic", - "MeanModule_basic", "Mlp1LayerModule_basic", "Mlp2LayerModuleNoBias_basic", "Mlp2LayerModule_basic", @@ -4754,7 +4755,6 @@ "NormScalarOptDimKeepDimModule_basic", "NormScalarOptDimModule_basic", "NormalFunctionalModule_basic", - "NormalizeModule_basic", "NumToTensorFloatModule_basic", "NumToTensorIntModule_basic", "NumelModule_basic", @@ -4855,29 +4855,10 @@ "ReduceSumDimIntListDtypeFloatModule_basic", "ReduceSumDimIntListDtypeIntModule_basic", "ReduceSumDimIntListElementTypeBoolModule_basic", - "ReduceSumDimIntListEmptyDimModule_basic", "ReduceSumDtypeFloatModule_basic", "ReduceSumDtypeIntModule_basic", "ReduceSumElementTypeBoolModule_basic", - "ReduceSumFloatModule_basic", - "ReduceSumSignedIntModule_basic", - "ReduceSumUnsignedIntModule_basic", - "ReflectionPad1dModule2dInput_Right", - "ReflectionPad1dModule2dInput_basic", - "ReflectionPad1dModule3dInput_Left", - "ReflectionPad1dModule3dInput_basic", - "ReflectionPad2dModule_Bottom", - "ReflectionPad2dModule_Left", - "ReflectionPad2dModule_Right", - "ReflectionPad2dModule_Top", - "ReflectionPad2dModule_basic", - "ReplicationPad2dModule_basic", - "ReplicationPad2dModule_bottom0", - "ReplicationPad2dModule_left0", - "ReplicationPad2dModule_right0", - "ReplicationPad2dModule_top0", "ResNet18Module_basic", - "ResNet18StaticModule_basic", "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", "ReshapeCollapseModule_basic", @@ -5039,10 +5020,6 @@ "TypePromotionDifferentCategoryModule_basic", "TypePromotionSameCategoryDifferentWidthModule_basic", "TypePromotionZeroRankHigherCategoryModule_basic", - "UnflattenIntNegativeOneDimStaticModule_basic", - "UnflattenIntNegativeOneSizeStaticModule_basic", - "UnflattenIntStaticModule_basic", - "UnflattenStaticModule_basic", "UniformModule_basic", "UniformNoCorrelationModule_basic", "UniformStaticShapeModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 59c15f22582f..70ae7459eb2a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1764,8 +1764,7 @@ def aten〇col2im〡shape(self: List[int], output_size: List[int], kernel_size: # compute the shape of the output num_channels = n_input_plane // (kernel_size[0] * kernel_size[1]) - out: List[int] = [self[0], num_channels] if batch_dim == 0 else [num_channels] - out += [elem for elem in output_size] + out: List[int] = ([self[0], num_channels] if batch_dim == 0 else [num_channels]) + [elem for elem in output_size] return out diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 22418207271f..4c9968e6e9ea 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1143,12 +1143,14 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::gt.float_int : (float, int) -> (bool)") emit("aten::pow.int_float : (int, float) -> (float)", has_folder=True) emit("aten::__and__.bool : (bool, bool) -> (bool)") + emit("aten::eq.bool : (bool, bool) -> (bool)", has_folder=True) emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True) emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True) emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True) emit("aten::__not__ : (bool) -> (bool)", has_folder=True) emit("aten::__or__.bool : (bool, bool) -> (bool)", has_folder=True) emit("aten::len.t : (t[]) -> (int)", has_folder=True, has_canonicalizer=True) + emit("aten::mul.left_t : (t[], int) -> (t[])", has_canonicalizer=True) emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True) emit("aten::_set_item.t : (t[], int, t) -> (t[])") emit("aten::mul : (Scalar, Scalar) -> (Scalar)", has_folder=True) diff --git a/python/torch_mlir/tools/import_onnx/__main__.py b/python/torch_mlir/tools/import_onnx/__main__.py index fa0e2a89dbba..4f852d34bb0a 100644 --- a/python/torch_mlir/tools/import_onnx/__main__.py +++ b/python/torch_mlir/tools/import_onnx/__main__.py @@ -137,7 +137,7 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: # Load the temp file and the external data. inferred_model = onnx.load(temp_inferred_file, load_external_data=False) data_dir = Path(input_dir if args.temp_dir is None else args.data_dir) - onnx.load_external_data_for_model(inferred_model, data_dir) + onnx.load_external_data_for_model(inferred_model, str(data_dir)) # Remove the inferred shape file unless asked to keep it if not args.keep_temps: diff --git a/pytorch-hash.txt b/pytorch-hash.txt index dd4f3a19ad33..ad873201dbba 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -c787213d413e85c66bdad0d8c9cde1c5ced34b1b +0d5247caf3ffd618d31cf4cf880c47b7dbd323a7 diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index cfccf5d69871..ee5f9cecb5e1 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -707,17 +707,8 @@ func.func @test_reduce_max_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %a // CHECK-LABEL: func.func @test_reduce_max_bool_inputs func.func @test_reduce_max_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[IDX:.+]] = torch.constant.int 0 - // CHECK: %[[SZ:.+]] = torch.constant.int 0 - // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> // CHECK: return %[[AMAX]] : !torch.vtensor<[4,1],i1> @@ -729,17 +720,8 @@ func.func @test_reduce_max_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: ! // CHECK-LABEL: func.func @test_reduce_max_bool_inputs_nokeepdims func.func @test_reduce_max_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[IDX:.+]] = torch.constant.int 0 - // CHECK: %[[SZ:.+]] = torch.constant.int 0 - // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[AMAX]] : !torch.vtensor<[4],i1> @@ -751,19 +733,9 @@ func.func @test_reduce_max_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1 // CHECK-LABEL: func.func @test_reduce_max_all_dims_default func.func @test_reduce_max_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[I0:.+]] = torch.constant.int 0 - // CHECK: %[[I1:.+]] = torch.constant.int 1 - // CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]] + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[I0]], %[[I1]] // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[MAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[],i1> // CHECK: return %[[MAX]] : !torch.vtensor<[],i1> @@ -775,13 +747,7 @@ func.func @test_reduce_max_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[AMAX]] @@ -793,9 +759,12 @@ func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtens // CHECK-LABEL: func.func @test_reduce_l1_default_axes_keepdims_example func.func @test_reduce_l1_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[ABS:.+]] = torch.aten.abs %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> + // CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[ABS]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -845,8 +814,11 @@ func.func @test_reduce_l1_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f // CHECK-LABEL: func.func @test_reduce_l2_default_axes_keepdims_example func.func @test_reduce_l2_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE_0:.+]] = torch.constant.bool true // CHECK: %[[NONE_0:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -944,7 +916,10 @@ func.func @test_reduce_l2_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2 // CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -1000,7 +975,10 @@ func.func @test_reduce_log_sum_exp_default_axes_keepdims_example(%arg0: !torch.v // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE_1:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f64> @@ -1092,7 +1070,10 @@ func.func @test_reduce_log_sum_exp_keep_dims_int_input_example(%arg0: !torch.vte // CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -1177,7 +1158,10 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor< func.func @test_reduce_sum_square_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -1385,17 +1369,8 @@ func.func @test_reduce_min_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %a // CHECK-LABEL: func.func @test_reduce_min_bool_inputs func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[IDX:.+]] = torch.constant.int 0 - // CHECK: %[[SZ:.+]] = torch.constant.int 0 - // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> // CHECK: return %[[AMIN]] : !torch.vtensor<[4,1],i1> @@ -1407,17 +1382,8 @@ func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: ! // CHECK-LABEL: func.func @test_reduce_min_bool_inputs_nokeepdims func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[IDX:.+]] = torch.constant.int 0 - // CHECK: %[[SZ:.+]] = torch.constant.int 0 - // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[AMIN]] : !torch.vtensor<[4],i1> @@ -1431,17 +1397,7 @@ func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1 func.func @test_reduce_min_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[I0:.+]] = torch.constant.int 0 // CHECK: %[[I1:.+]] = torch.constant.int 1 - // CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]] + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[I0]], %[[I1]] // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[MIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[],i1> // CHECK: return %[[MIN]] : !torch.vtensor<[],i1> @@ -1453,13 +1409,7 @@ func.func @test_reduce_min_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> func.func @test_reduce_min_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[AMIN]] diff --git a/test/Conversion/TorchToLinalg/datamovement.mlir b/test/Conversion/TorchToLinalg/datamovement.mlir new file mode 100644 index 000000000000..dd5e5c553d31 --- /dev/null +++ b/test/Conversion/TorchToLinalg/datamovement.mlir @@ -0,0 +1,34 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.permute( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[64,32,16,8,4],f32>) -> !torch.vtensor<[64,8,4,32,16],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[64,32,16,8,4],f32> -> tensor<64x32x16x8x4xf32> +// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<64x8x4x32x16xf32> +// CHECK: %[[VAL_3:.*]] = linalg.transpose ins(%[[VAL_1]] : tensor<64x32x16x8x4xf32>) outs(%[[VAL_2]] : tensor<64x8x4x32x16xf32>) permutation = [0, 3, 4, 1, 2] +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<64x8x4x32x16xf32> -> !torch.vtensor<[64,8,4,32,16],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[64,8,4,32,16],f32> +// CHECK: } +func.func @torch.aten.permute(%arg0: !torch.vtensor<[64,32,16,8,4],f32>) -> !torch.vtensor<[64,8,4,32,16],f32> { + %int0 = torch.constant.int 0 + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int0, %int3, %int4, %int1, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[64,32,16,8,4],f32>, !torch.list -> !torch.vtensor<[64,8,4,32,16],f32> + return %1 : !torch.vtensor<[64,8,4,32,16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.permute$rank0( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[VAL_2]] : !torch.vtensor<[],f32> +// CHECK: } +func.func @torch.aten.permute$rank0(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { + %0 = torch.prim.ListConstruct : () -> !torch.list + %1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[],f32>, !torch.list -> !torch.vtensor<[],f32> + return %1 : !torch.vtensor<[],f32> +} diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir index 7976b1ad8b16..1dfe45492312 100644 --- a/test/Conversion/TorchToLinalg/resize.mlir +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -21,14 +21,14 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: // CHECK-DAG: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 // CHECK-DAG: %[[x26:.*]] = arith.subf %[[x25]], %[[cst_4]] : f32 // CHECK-DAG: %[[x27:.*]] = arith.maximumf %[[x26]], %[[cst_5]] : f32 - // CHECK-DAG: %[[x28:.*]] = arith.subf %[[x19]], %[[cst]] : f32 + // CHECK-DAG: %[[x28:.*]] = arith.subf %[[x19]], %cst_4 : f32 // CHECK-DAG: %[[x29:.*]] = arith.minimumf %[[x27]], %[[x28]] : f32 // CHECK-DAG: %[[x30:.*]] = math.floor %[[x29]] : f32 // CHECK-DAG: %[[x31:.*]] = arith.addf %[[cst]], %[[x29]] : f32 // CHECK-DAG: %[[x32:.*]] = math.floor %[[x31]] : f32 // CHECK-DAG: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64 // CHECK-DAG: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index - // CHECK-DAG: %[[x35:.*]] = arith.minimumf %[[x31]], %[[x28]] : f32 + // CHECK-DAG: %[[x35:.*]] = arith.minimumf %44, %42 : f32 // CHECK-DAG: %[[x36:.*]] = arith.fptosi %[[x35]] : f32 to i64 // CHECK-DAG: %[[x37:.*]] = arith.index_cast %[[x36]] : i64 to index // CHECK: %[[extracted:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x34]], %[[low:.*]]] : tensor<1x1x2x4xf32> @@ -304,4 +304,51 @@ func.func @test_resize_nearest_half_pixel_round_prefer_floor(%arg0: !torch.vtens return %5 : !torch.vtensor<[?,?,?],f32> } +// CHECK-LABEL: func.func @test_resize_sizes_cubic +func.func @test_resize_sizes_cubic(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4] +,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 +: si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[x1:.*]] = math.ceil %36 : f32 + // CHECK-DAG: %[[x_1:.*]] = arith.subf %[[x1]], %cst_5 : f32 + // CHECK-DAG: %[[x_2:.*]] = arith.subf %[[x_1]], %cst_5 : f32 + // CHECK-DAG: %[[x2:.*]] = arith.addf %[[x1]], %cst_5 : f32 + // CHECK-DAG: %[[y1:.*]] = math.ceil %28 : f32 + // CHECK-DAG: %[[y_1:.*]] = arith.subf %[[y1]], %cst_5 : f32 + // CHECK-DAG: %[[y_2:.*]] = arith.subf %[[y_1]], %cst_5 : f32 + // CHECK-DAG: %[[y2:.*]] = arith.addf %[[y1]], %cst_5 : f32 + // CHECK-DAG: %[[y2D:.*]] = arith.subf %28, %[[y2]] : f32 + // CHECK-DAG: %[[y2Dist:.*]] = math.absf %[[y2D]] : f32 + // CHECK-DAG: %[[y1D:.*]] = arith.subf %28, %[[y1]] : f32 + // CHECK-DAG: %[[y1Dist:.*]] = math.absf %[[y1D]] : f32 + // CHECK-DAG: %[[y_1D:.*]] = arith.subf %28, %[[y_1]] : f32 + // CHECK-DAG: %[[y_1Dist:.*]] = math.absf %[[y_1D]] : f32 + // CHECK-DAG: %[[y_2D:.*]] = arith.subf %28, %[[y_2]] : f32 + // CHECK-DAG: %[[y_2Dist:.*]] = math.absf %[[y_2D]] : f32 + // CHECK-DAG: %[[x2D:.*]] = arith.subf %36, %[[x2]] : f32 + // CHECK-DAG: %[[x2Dist:.*]] = math.absf %[[x2D]] : f32 + // CHECK-DAG: %[[x1D:.*]] = arith.subf %36, %[[x1]] : f32 + // CHECK-DAG: %[[x1Dist:.*]] = math.absf %[[x1D]] : f32 + // CHECK-DAG: %[[x_1D:.*]] = arith.subf %36, %[[x_1]] : f32 + // CHECK-DAG: %[[x_1Dist:.*]] = math.absf %[[x_1D]] : f32 + // CHECK-DAG: %[[x_2D:.*]] = arith.subf %36, %[[x_2]] : f32 + // CHECK-DAG: %[[x_2Dist:.*]] = math.absf %[[x_2D]] : f32 + // CHECK-DAG: %[[distSQ:.*]] = arith.mulf %52, %52 : f32 + // CHECK-DAG: %[[distCubed:.*]] = arith.mulf %[[distSQ]], %52 : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "cubic" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %int3 = torch.constant.int 3 + %2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?,?],f32> +} + // ----- diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 7c3394261558..59ce87a2b2d0 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2500,16 +2500,8 @@ func.func @torch.aten.logical_and$basic(%arg0: !torch.vtensor<[4,5],i1>, %arg1: // ----- -// CHECK-LABEL: func.func @torch.aten.uniform$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) { -// CHECK: %[[VAL_1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e+01 -// CHECK: %[[VAL_3:.*]] = torch.constant.none -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1.00007045, 2.18384027, 7.80044794, 5.12785149], [5.79490519, 2.97063255, 1.42340159, 7.10978221], [7.11366796, 9.41223621, 4.45151854, 5.67474747]]> : tensor<3x4xf32>}> : () -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf64> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf64> -> !torch.vtensor<[3,4],f64> -// CHECK: return %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64> -// CHECK: } +// CHECK-LABEL: torch.aten.uniform$basic +// CHECK: tosa.const func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) { %float1.000000e00 = torch.constant.float 1.000000e+00 %float1.000000e01 = torch.constant.float 1.000000e+01 @@ -2555,3 +2547,217 @@ func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !tor %2 = torch.aten.as_strided %arg0, %0, %1, %none : !torch.vtensor<[5,5],f32>, !torch.list, !torch.list, !torch.none -> !torch.vtensor<[3,3],f32> return %2 : !torch.vtensor<[3,3],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.max_pool1d$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,64,112],f32>) -> !torch.vtensor<[1,64,56],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,64,112],f32> -> tensor<1x64x112xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x64x112xf32>) -> tensor<1x64x112x1xf32> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_10]], %[[VAL_11]] : (tensor<1x64x112x1xf32>, tensor<4xi32>) -> tensor<1x112x1x64xf32> +// CHECK: %[[VAL_13:.*]] = tosa.max_pool2d %[[VAL_12]] {kernel = array, pad = array, stride = array} : (tensor<1x112x1x64xf32>) -> tensor<1x56x1x64xf32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<1x56x1x64xf32>, tensor<4xi32>) -> tensor<1x64x56x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<1x64x56x1xf32>) -> tensor<1x64x56xf32> +// CHECK: %[[VAL_17:.*]] = tensor.cast %[[VAL_16]] : tensor<1x64x56xf32> to tensor<1x64x56xf32> +// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<1x64x56xf32> -> !torch.vtensor<[1,64,56],f32> +// CHECK: return %[[VAL_18]] : !torch.vtensor<[1,64,56],f32> +// CHECK: } +func.func @torch.aten.max_pool1d$basic(%arg0: !torch.vtensor<[1,64,112],f32>) -> !torch.vtensor<[1,64,56],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %4 = torch.aten.max_pool1d %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[1,64,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,64,56],f32> + return %4 : !torch.vtensor<[1,64,56],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.avg_pool1d$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_4:.*]] = torch.constant.bool false +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x512x10xf32>) -> tensor<1x512x10x1xf32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_10:.*]] = tosa.transpose %[[VAL_8]], %[[VAL_9]] : (tensor<1x512x10x1xf32>, tensor<4xi32>) -> tensor<1x10x1x512xf32> +// CHECK: %[[VAL_11:.*]] = tosa.avg_pool2d %[[VAL_10]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x10x1x512xf32>) -> tensor<1x10x1x512xf32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_11]], %[[VAL_12]] : (tensor<1x10x1x512xf32>, tensor<4xi32>) -> tensor<1x512x10x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<1x512x10x1xf32>) -> tensor<1x512x10xf32> +// CHECK: %[[VAL_15:.*]] = tensor.cast %[[VAL_14]] : tensor<1x512x10xf32> to tensor<1x512x10xf32> +// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32> +// CHECK: return %[[VAL_16]] : !torch.vtensor<[1,512,10],f32> +// CHECK: } +func.func @torch.aten.avg_pool1d$basic(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list + %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %false : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32> + return %3 : !torch.vtensor<[1,512,10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.clamp.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],f32>) -> (!torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>) { +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[1],f32> -> tensor<1xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],f32> -> tensor<1xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32> +// CHECK: %[[VAL_6:.*]] = torch.constant.none +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<3.40282347E+38> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_4]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_9:.*]] = tosa.minimum %[[VAL_8]], %[[VAL_7]] : (tensor<3x5xf32>, tensor) -> tensor<3x5xf32> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<-3.40282347E+38> : tensor}> : () -> tensor +// CHECK: %[[VAL_12:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_11]] : (tensor<3x5xf32>, tensor) -> tensor<3x5xf32> +// CHECK: %[[VAL_13:.*]] = tosa.minimum %[[VAL_12]], %[[VAL_3]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: %[[VAL_15:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_4]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_16:.*]] = tosa.minimum %[[VAL_15]], %[[VAL_3]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: return %[[VAL_10]], %[[VAL_14]], %[[VAL_17]] : !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32> +// CHECK: } +func.func @torch.aten.clamp.Tensor$basic(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> (!torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>) { + %none = torch.constant.none + %0 = torch.aten.clamp.Tensor %arg0, %arg1, %none : !torch.vtensor<[3,5],f32>, !torch.vtensor<[1],f32>, !torch.none -> !torch.vtensor<[3,5],f32> + %1 = torch.aten.clamp.Tensor %arg0, %none, %arg2 : !torch.vtensor<[3,5],f32>, !torch.none, !torch.vtensor<[1],f32> -> !torch.vtensor<[3,5],f32> + %2 = torch.aten.clamp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[3,5],f32> + return %0, %1, %2 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.prims.collapse$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,12],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<2x3x4xf32>) -> tensor<2x12xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<2x12xf32> -> !torch.vtensor<[2,12],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[2,12],f32> +// CHECK: } +func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,12],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prims.collapse %arg0, %int1, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,12],f32> + return %0 : !torch.vtensor<[2,12],f32> +} + +// ----- + +func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %false = torch.constant.bool false + %count_include_pad = torch.constant.bool true + %0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}} + %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32> + return %3 : !torch.vtensor<[1,512,10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x2x4xf32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x2x4xf32>) -> tensor<1x2x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.reverse %[[VAL_7]] {axis = 2 : i32} : (tensor<1x2x1xf32>) -> tensor<1x2x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_6]], %[[VAL_1]], %[[VAL_8]] {axis = 2 : i32} : (tensor<1x2x3xf32>, tensor<1x2x4xf32>, tensor<1x2x1xf32>) -> tensor<1x2x8xf32> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<1x2x8xf32> -> !torch.vtensor<[1,2,8],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[1,2,8],f32> +// CHECK: } +func.func @torch.aten.reflection_pad1d$basic(%arg0: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> { + %int3 = torch.constant.int 3 + %int1 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.reflection_pad1d %arg0, %0 : !torch.vtensor<[1,2,4],f32>, !torch.list -> !torch.vtensor<[1,2,8],f32> + return %1 : !torch.vtensor<[1,2,8],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.reflection_pad2d$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,20,20],f32>) -> !torch.vtensor<[1,40,40],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,20,20],f32> -> tensor<1x20x20xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 10 +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x20x20xf32>) -> tensor<1x20x10xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_4]] {axis = 2 : i32} : (tensor<1x20x10xf32>) -> tensor<1x20x10xf32> +// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x20x20xf32>) -> tensor<1x20x10xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reverse %[[VAL_6]] {axis = 2 : i32} : (tensor<1x20x10xf32>) -> tensor<1x20x10xf32> +// CHECK: %[[VAL_8:.*]] = tosa.concat %[[VAL_5]], %[[VAL_1]], %[[VAL_7]] {axis = 2 : i32} : (tensor<1x20x10xf32>, tensor<1x20x20xf32>, tensor<1x20x10xf32>) -> tensor<1x20x40xf32> +// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array, start = array} : (tensor<1x20x40xf32>) -> tensor<1x10x40xf32> +// CHECK: %[[VAL_10:.*]] = tosa.reverse %[[VAL_9]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32> +// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_8]] {size = array, start = array} : (tensor<1x20x40xf32>) -> tensor<1x10x40xf32> +// CHECK: %[[VAL_12:.*]] = tosa.reverse %[[VAL_11]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32> +// CHECK: %[[VAL_13:.*]] = tosa.concat %[[VAL_10]], %[[VAL_8]], %[[VAL_12]] {axis = 1 : i32} : (tensor<1x10x40xf32>, tensor<1x20x40xf32>, tensor<1x10x40xf32>) -> tensor<1x40x40xf32> +// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<1x40x40xf32> -> !torch.vtensor<[1,40,40],f32> +// CHECK: return %[[VAL_14]] : !torch.vtensor<[1,40,40],f32> +// CHECK: } +func.func @torch.aten.reflection_pad2d$basic(%arg0: !torch.vtensor<[1,20,20],f32>) -> !torch.vtensor<[1,40,40],f32> { + %int10 = torch.constant.int 10 + %0 = torch.prim.ListConstruct %int10, %int10, %int10, %int10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.reflection_pad2d %arg0, %0 : !torch.vtensor<[1,20,20],f32>, !torch.list -> !torch.vtensor<[1,40,40],f32> + return %1 : !torch.vtensor<[1,40,40],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.replication_pad2d$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,10,6],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,3,3],f32> -> tensor<1x1x3x3xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 4 +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_7]], %[[VAL_1]], %[[VAL_8]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x1x3x1xf32>, tensor<1x1x3x3xf32>, tensor<1x1x3x1xf32>, tensor<1x1x3x1xf32>) -> tensor<1x1x3x6xf32> +// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_9]] {size = array, start = array} : (tensor<1x1x3x6xf32>) -> tensor<1x1x1x6xf32> +// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_9]] {size = array, start = array} : (tensor<1x1x3x6xf32>) -> tensor<1x1x1x6xf32> +// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_10]], %[[VAL_10]], %[[VAL_9]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] {axis = 2 : i32} : (tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x3x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>) -> tensor<1x1x10x6xf32> +// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x1x10x6xf32> -> !torch.vtensor<[1,1,10,6],f32> +// CHECK: return %[[VAL_13]] : !torch.vtensor<[1,1,10,6],f32> +// CHECK: } +func.func @torch.aten.replication_pad2d$basic(%arg0: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,10,6],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.replication_pad2d %arg0, %0 : !torch.vtensor<[1,1,3,3],f32>, !torch.list -> !torch.vtensor<[1,1,10,6],f32> + return %1 : !torch.vtensor<[1,1,10,6],f32> +} diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 263e69169cf3..ef478617d0d8 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -137,6 +137,46 @@ func.func @torch.aten.__isnot__$none_isnot_none(%arg0: !torch.none, %arg1: !torc return %0 : !torch.bool } +// CHECK-LABEL: func.func @torch.aten.eq.bool$same_value() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func.func @torch.aten.eq.bool$same_value() -> !torch.bool { + %a = torch.constant.bool false + %b = torch.constant.bool false + %0 = torch.aten.eq.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func.func @torch.aten.eq.bool$different_value() -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func.func @torch.aten.eq.bool$different_value() -> !torch.bool { + %a = torch.constant.bool true + %b = torch.constant.bool false + %0 = torch.aten.eq.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func.func @torch.aten.eq.bool$same_operand( +// CHECK-SAME: %[[ARG0:.*]]: !torch.bool) -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func.func @torch.aten.eq.bool$same_operand(%arg0: !torch.bool) -> !torch.bool { + %0 = torch.aten.eq.bool %arg0, %arg0: !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func.func @torch.aten.eq.bool$different_operand( +// CHECK-SAME: %[[ARG0:.*]]: !torch.bool) -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[RET:.*]] = torch.aten.eq.bool %[[ARG0]], %[[FALSE]] : !torch.bool, !torch.bool -> !torch.bool +// CHECK: return %[[RET]] : !torch.bool +func.func @torch.aten.eq.bool$different_operand(%a: !torch.bool) -> !torch.bool { + %b = torch.constant.bool false + %0 = torch.aten.eq.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + // CHECK-LABEL: func.func @torch.aten.ne.bool() -> !torch.bool { // CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: return %[[TRUE]] : !torch.bool @@ -698,6 +738,20 @@ func.func @torch.aten.len.t$no_fold_list_mutated() -> !torch.int { return %2 : !torch.int } +// CHECK-LABEL: func.func @torch.aten.mul.left_t( +// CHECK: %[[C4:.*]] = torch.constant.int 4 +// CHECK: %[[C5:.*]] = torch.constant.int 5 +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C4]], %[[C5]], %[[C4]], %[[C5]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: return %[[LIST]] : !torch.list +func.func @torch.aten.mul.left_t() -> !torch.list { + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.mul.left_t %0, %int2 : !torch.list, !torch.int -> !torch.list + return %1 : !torch.list +} + // CHECK-LABEL: func.func @torch.aten.__getitem__.t( // CHECK: %[[C5:.*]] = torch.constant.int 5 // CHECK: return %[[C5]] : !torch.int diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 0a01189df7d2..6f86f0f1a40b 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -145,9 +145,9 @@ func.func @torch.aten.fake_quantize_per_channel_affine_cachemask(%arg0: !torch.v // CHECK-LABEL: test_einsum_inner_prod func.func @test_einsum_inner_prod(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[5],f64>) -> !torch.vtensor<[],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64} { - // CHECK: %[[INT5:.+]] = torch.constant.int 5 - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[INT5:.+]] = torch.constant.int 5 + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[LHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]] // CHECK: %[[LHS_PERM:.+]] = torch.aten.permute %arg0, %[[LHS_LIST]] // CHECK: %[[RHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]] diff --git a/test/Dialect/Torch/scalarize-shapes.mlir b/test/Dialect/Torch/scalarize-shapes.mlir index 5ea715735c70..c7fc2c280a2b 100644 --- a/test/Dialect/Torch/scalarize-shapes.mlir +++ b/test/Dialect/Torch/scalarize-shapes.mlir @@ -27,12 +27,8 @@ func.func @shape_as_tensor(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtenso // CHECK-LABEL: @shape_as_tensor_dim func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si32> { // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]] - // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 - // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false - // CHECK-DAG: %[[NONE:.+]] = torch.constant.none - // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1_0]] - // CHECK: %[[TENSOR:.+]] = torch.aten.full %[[LIST]], %[[SZ]], %[[NONE]], %[[NONE]], %[[NONE]], %[[FALSE]] + // CHECK-DAG: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]] + // CHECK: %[[TENSOR:.+]] = torch.prim.NumToTensor.Scalar %[[SZ]] : !torch.int -> !torch.vtensor<[],si32> // CHECK: return %[[TENSOR]] : !torch.vtensor<[],si32> %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> %dim = torch.constant.int 0 @@ -43,6 +39,49 @@ func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vt return %select : !torch.vtensor<[],si32> } +// ----- + +// CHECK-LABEL: @cast_int_int +func.func @cast_int_int(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si64> { + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SZE]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[],si64> + %int4 = torch.constant.int 4 + %false = torch.constant.bool false + %none = torch.constant.none + %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> + %cast_shape = torch.aten.to.dtype %shape, %int4, %false, %false, %none : !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3],si64> + %dim = torch.constant.int 0 + %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> + %select = torch.aten.index_select %cast_shape, %dim, %idx : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si64> + %item = torch.aten.item %select : !torch.vtensor<[],si64> -> !torch.int + %list = torch.prim.ListConstruct %item : (!torch.int) -> !torch.list + return %select : !torch.vtensor<[],si64> +} + +// ----- + +// CHECK-LABEL: @cast_int_float +func.func @cast_int_float(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],f32> { + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[FLOAT:.*]] = torch.aten.Float.Scalar %[[SZE]] : !torch.int -> !torch.float + // CHECK: %[[TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[FLOAT]] : !torch.float -> !torch.vtensor<[],f32> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[],f32> + %int6 = torch.constant.int 6 + %false = torch.constant.bool false + %none = torch.constant.none + %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> + %cast_shape = torch.aten.to.dtype %shape, %int6, %false, %false, %none : !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3],f32> + %dim = torch.constant.int 0 + %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> + %select = torch.aten.index_select %cast_shape, %dim, %idx : !torch.vtensor<[3],f32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],f32> + %item = torch.aten.item %select : !torch.vtensor<[],f32> -> !torch.float + %item_int = torch.aten.Int.Scalar %item : !torch.float -> !torch.int + %list = torch.prim.ListConstruct %item_int : (!torch.int) -> !torch.list + return %select : !torch.vtensor<[],f32> +} // ----- @@ -89,14 +128,12 @@ func.func @arith_prop(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?] // CHECK: %[[x2:.*]] = torch.aten.floordiv.int %[[x0]], %[[int12]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[x3:.*]] = torch.aten.floordiv.int %[[x1]], %[[int1_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[int12_1:.*]] = torch.constant.int 12 - // CHECK: %[[int1_2:.*]] = torch.constant.int 1 // CHECK: %[[x4:.*]] = torch.aten.mul.int %[[x2]], %[[int12_1]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[x5:.*]] = torch.aten.mul.int %[[x3]], %[[int1_2]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[x6:.*]] = torch.aten.sub.int %[[x0]], %[[x4]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[x7:.*]] = torch.aten.sub.int %[[x1]], %[[x5]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[x8:.*]] = torch.prim.ListConstruct %[[x7]], %[[x6]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[x9:.*]] = torch.aten.constant_pad_nd %arg0, %[[x8]], %[[float0]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.float -> !torch.vtensor<[?,?],f32> - // CHECK: return %[[x9]] : !torch.vtensor<[?,?],f32> + // CHECK: %[[x5:.*]] = torch.aten.sub.int %[[x0]], %[[x4]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x6:.*]] = torch.aten.sub.int %[[x1]], %[[x3]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x7:.*]] = torch.prim.ListConstruct %[[x6]], %[[x5]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[x8:.*]] = torch.aten.constant_pad_nd %arg0, %[[x7]], %[[float0]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.float -> !torch.vtensor<[?,?],f32> + // CHECK: return %[[x8]] : !torch.vtensor<[?,?],f32> %0 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %1 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> %float0.000000e00 = torch.constant.float 0.000000e+00