diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2de81537aa15..d6cd34b82091 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5609,12 +5609,10 @@ namespace { // output size the kernelSize, stride and padding is calculated as follows: // strideH = inH // outH // strideW = inH // outH -// kernelH = inH - [(outH - 1) * strideH] -// kernelW = inW - [(outW - 1) * strideW] +// kernelH = inH - [(outH - 1) * strideH] = strideH +// kernelW = inW - [(outW - 1) * strideW] = strideW // paddingH = 0, paddingW = 0 // -// For the special case, when the output size is one for all dimensions, -// the kernel size is same as the input size. class DecomposeAtenAdaptiveAvgPool2dOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -5645,23 +5643,8 @@ class DecomposeAtenAdaptiveAvgPool2dOp getListConstructElements(outputShape, outputShapeSizesTorchInt); // TODO: Add support for cases other than: - // 1.) inH == outH and inW == outW. - // 2.) outH == outW == 1 - bool unitOutputSize = true; - for (Value outShape : outputShapeSizesTorchInt) { - int64_t outShapeInt; - if (!matchPattern(outShape, m_TorchConstantInt(&outShapeInt))) { - return rewriter.notifyMatchFailure( - op, "output size is expected to be a constant"); - } - if (outShapeInt != 1) { - unitOutputSize = false; - break; - } - } + // inH % outH != 0 or inW % outW != 0 - Value constantOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); Value constantZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value constantFalse = rewriter.create(loc, false); @@ -5670,40 +5653,22 @@ class DecomposeAtenAdaptiveAvgPool2dOp SmallVector kernelSize; for (unsigned i = 0; i < inputHW.size(); i++) { - if (unitOutputSize) { - BaseTensorType inputTensorType = input.getType().cast(); - ArrayRef inputShape = inputTensorType.getSizes(); - kernelSize.push_back(inputShape[rank - 2 + i] == kUnknownSize - ? inputHW[i] - : rewriter.create( - loc, rewriter.getI64IntegerAttr( - inputShape[rank - 2 + i]))); - } else { - if (!isAssumingStrictSymbolicShapes(rewriter)) { - Value cond = rewriter.create( - loc, inputHW[i], outputShapeSizesTorchInt[i]); - rewriter.create(loc, cond, - "unimplemented: only support cases " - "where input and output size are " - "equal for non-unit output size"); - } - Value outMinusOne = rewriter.create( - loc, outputShapeSizesTorchInt[i], constantOne); - kernelSize.push_back( - rewriter.create(loc, inputHW[i], outMinusOne)); - } + Value remainder = rewriter.create( + loc, inputHW[i], outputShapeSizesTorchInt[i]); + Value cond = rewriter.create(loc, remainder, constantZero); + rewriter.create(loc, cond, + "unimplemented: only support cases " + "input size is an integer multiple of " + "output size"); + Value stride = rewriter.create( + loc, inputHW[i], outputShapeSizesTorchInt[i]); + Value kernelSizeValue = stride; + kernelSize.push_back(kernelSizeValue); } Value kernelSizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize); - // Currently we only support cases where input size is equal to the output - // size or unit output size. For the former case, stride is always equal to - // one and for the latter the stride value doesn't matter, since the kernel - // size is same as the input size. Therfore, keeping the stride as one for - // the latter case as well for the ease of implementation. - Value strideList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), - ValueRange{constantOne, constantOne}); + Value strideList = kernelSizeList; Value paddingSizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantZero, constantZero}); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a94b6befc765..c993d72d2bcb 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -407,6 +407,7 @@ "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", "AddIntModule_basic", "AliasModule_basic", "AllBoolFalseModule_basic", @@ -961,8 +962,9 @@ "ElementwiseSignIntModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", - "AddCDiv_Module_basic", + "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", "AddCDivModule_basic", + "AddCDiv_Module_basic", "AddCMul_Module_basic", "AddCMulModule_basic", "Add_Module_basic", @@ -1740,6 +1742,7 @@ "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic", "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic", "AdaptiveMaxPool2dDynamicWithIndices_basic", "AdaptiveMaxPool2dDynamic_basic", "AdaptiveMaxPool2dStaticWithIndices_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 8ab03ddeb019..8acb7778e0e3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -54,6 +54,49 @@ def AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic( module, tu: TestUtils): module.forward(tu.rand(1, 512, 7, 7)) + +class AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap2d = torch.nn.AdaptiveAvgPool2d((5, 7)) + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.aap2d(x) + + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule()) +def AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 512, 15, 28)) + + +class AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap2d = torch.nn.AdaptiveAvgPool2d((3, 7)) + + @export + @annotate_args([ + None, + ([1, 512, 15, 14], torch.float32, True), + ]) + def forward(self, x): + return self.aap2d(x) + + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule()) +def AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 512, 15, 14)) + + class AdaptiveAvgPool2dUnitOutputSizeStaticModule(torch.nn.Module): def __init__(self): diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 0e863ffdfe09..2e127fb769fa 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -26,30 +26,29 @@ func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch } // ----- -// CHECK-LABEL: func @torch.aten.adaptive_avg_pool2d$non_unit_output_size( +// CHECK-LABEL: func @torch.aten.adaptive_avg_pool2d$output_size_divisible_by_input( // CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { // CHECK-DAG: %[[CST0:.*]] = torch.constant.int 0 -// CHECK-DAG: %[[CST1:.*]] = torch.constant.int 1 // CHECK-DAG: %[[CST2:.*]] = torch.constant.int 2 // CHECK-DAG: %[[CST3:.*]] = torch.constant.int 3 -// CHECK-DAG: %[[CST6:.*]] = torch.constant.int 6 // CHECK-DAG: %[[CST7:.*]] = torch.constant.int 7 // CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false // CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true // CHECK-DAG: %[[NONE:.*]] = torch.constant.none // CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int // CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[COND1:.*]] = torch.aten.eq.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.bool -// CHECK: torch.runtime.assert %[[COND1]], "unimplemented: only support cases where input and output size are equal for non-unit output size" -// CHECK: %[[T1:.*]] = torch.aten.sub.int %[[DIM2]], %[[CST6]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[COND2:.*]] = torch.aten.eq.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.bool -// CHECK: torch.runtime.assert %[[COND2]], "unimplemented: only support cases where input and output size are equal for non-unit output size" -// CHECK: %[[T2:.*]] = torch.aten.sub.int %[[DIM3]], %[[CST6]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[REMAINER1:.*]] = torch.aten.remainder.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[COND1:.*]] = torch.aten.eq.int %[[REMAINER1]], %[[CST0]] : !torch.int, !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[COND1]], "unimplemented: only support cases input size is an integer multiple of output size" +// CHECK: %[[STRIDE1:.*]] = torch.aten.floordiv.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[REMAINER2:.*]] = torch.aten.remainder.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[COND2:.*]] = torch.aten.eq.int %[[REMAINER2]], %[[CST0]] : !torch.int, !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[COND2]], "unimplemented: only support cases input size is an integer multiple of output size" +// CHECK: %[[STRIDE2:.*]] = torch.aten.floordiv.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[STRIDE1]], %[[STRIDE2]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[AVG_POOL:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32> -func.func @torch.aten.adaptive_avg_pool2d$non_unit_output_size(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[AVG_POOL:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[KERNEL_SIZE]], %[[PADDING]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.adaptive_avg_pool2d$output_size_divisible_by_input(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %int7 = torch.constant.int 7 %output_size = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list %0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?,?],f32> @@ -58,30 +57,6 @@ func.func @torch.aten.adaptive_avg_pool2d$non_unit_output_size(%arg0: !torch.vte // ----- -// CHECK-LABEL: func.func @torch.aten.adaptive_avg_pool2d$unit_output_size( -// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { -// CHECK-DAG: %[[CST0:.*]] = torch.constant.int 0 -// CHECK-DAG: %[[CST1:.*]] = torch.constant.int 1 -// CHECK-DAG: %[[CST2:.*]] = torch.constant.int 2 -// CHECK-DAG: %[[CST3:.*]] = torch.constant.int 3 -// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false -// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true -// CHECK-DAG: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[DIM2]], %[[DIM3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[AVG_POOL:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32> -func.func @torch.aten.adaptive_avg_pool2d$unit_output_size(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { - %int1 = torch.constant.int 1 - %output_size = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list - %0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?,?],f32> -} - -// ----- - // CHECK-LABEL: func.func @torch.aten.acos$int_type( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,2],si32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2],si32>) -> !torch.vtensor<[2,2],si32> {