Skip to content

Commit

Permalink
Merge pull request #232 from Xilinx/bump_to_6524838b
Browse files Browse the repository at this point in the history
Merge with fixes of 6524838 (8)
  • Loading branch information
cmcgirr-amd authored Aug 15, 2024
2 parents b457551 + 1ee6e12 commit bb40cfa
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 88 deletions.
65 changes: 15 additions & 50 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5609,12 +5609,10 @@ namespace {
// output size the kernelSize, stride and padding is calculated as follows:
// strideH = inH // outH
// strideW = inH // outH
// kernelH = inH - [(outH - 1) * strideH]
// kernelW = inW - [(outW - 1) * strideW]
// kernelH = inH - [(outH - 1) * strideH] = strideH
// kernelW = inW - [(outW - 1) * strideW] = strideW
// paddingH = 0, paddingW = 0
//
// For the special case, when the output size is one for all dimensions,
// the kernel size is same as the input size.
class DecomposeAtenAdaptiveAvgPool2dOp
: public OpRewritePattern<AtenAdaptiveAvgPool2dOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -5645,23 +5643,8 @@ class DecomposeAtenAdaptiveAvgPool2dOp
getListConstructElements(outputShape, outputShapeSizesTorchInt);

// TODO: Add support for cases other than:
// 1.) inH == outH and inW == outW.
// 2.) outH == outW == 1
bool unitOutputSize = true;
for (Value outShape : outputShapeSizesTorchInt) {
int64_t outShapeInt;
if (!matchPattern(outShape, m_TorchConstantInt(&outShapeInt))) {
return rewriter.notifyMatchFailure(
op, "output size is expected to be a constant");
}
if (outShapeInt != 1) {
unitOutputSize = false;
break;
}
}
// inH % outH != 0 or inW % outW != 0

Value constantOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Expand All @@ -5670,40 +5653,22 @@ class DecomposeAtenAdaptiveAvgPool2dOp
SmallVector<Value, 2> kernelSize;

for (unsigned i = 0; i < inputHW.size(); i++) {
if (unitOutputSize) {
BaseTensorType inputTensorType = input.getType().cast<BaseTensorType>();
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
kernelSize.push_back(inputShape[rank - 2 + i] == kUnknownSize
? inputHW[i]
: rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(
inputShape[rank - 2 + i])));
} else {
if (!isAssumingStrictSymbolicShapes(rewriter)) {
Value cond = rewriter.create<AtenEqIntOp>(
loc, inputHW[i], outputShapeSizesTorchInt[i]);
rewriter.create<RuntimeAssertOp>(loc, cond,
"unimplemented: only support cases "
"where input and output size are "
"equal for non-unit output size");
}
Value outMinusOne = rewriter.create<AtenSubIntOp>(
loc, outputShapeSizesTorchInt[i], constantOne);
kernelSize.push_back(
rewriter.create<AtenSubIntOp>(loc, inputHW[i], outMinusOne));
}
Value remainder = rewriter.create<AtenRemainderIntOp>(
loc, inputHW[i], outputShapeSizesTorchInt[i]);
Value cond = rewriter.create<AtenEqIntOp>(loc, remainder, constantZero);
rewriter.create<RuntimeAssertOp>(loc, cond,
"unimplemented: only support cases "
"input size is an integer multiple of "
"output size");
Value stride = rewriter.create<AtenFloordivIntOp>(
loc, inputHW[i], outputShapeSizesTorchInt[i]);
Value kernelSizeValue = stride;
kernelSize.push_back(kernelSizeValue);
}

Value kernelSizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
// Currently we only support cases where input size is equal to the output
// size or unit output size. For the former case, stride is always equal to
// one and for the latter the stride value doesn't matter, since the kernel
// size is same as the input size. Therfore, keeping the stride as one for
// the latter case as well for the ease of implementation.
Value strideList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne, constantOne});
Value strideList = kernelSizeList;
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantZero, constantZero});
Expand Down
5 changes: 4 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AddIntModule_basic",
"AliasModule_basic",
"AllBoolFalseModule_basic",
Expand Down Expand Up @@ -961,8 +962,9 @@
"ElementwiseSignIntModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"AddCDiv_Module_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AddCDivModule_basic",
"AddCDiv_Module_basic",
"AddCMul_Module_basic",
"AddCMulModule_basic",
"Add_Module_basic",
Expand Down Expand Up @@ -1740,6 +1742,7 @@
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic",
"AdaptiveMaxPool2dDynamicWithIndices_basic",
"AdaptiveMaxPool2dDynamic_basic",
"AdaptiveMaxPool2dStaticWithIndices_basic",
Expand Down
43 changes: 43 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,49 @@ def AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic(
module, tu: TestUtils):
module.forward(tu.rand(1, 512, 7, 7))


class AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule(torch.nn.Module):

def __init__(self):
super().__init__()
self.aap2d = torch.nn.AdaptiveAvgPool2d((5, 7))

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.aap2d(x)


@register_test_case(
module_factory=lambda: AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule())
def AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 512, 15, 28))


class AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()
self.aap2d = torch.nn.AdaptiveAvgPool2d((3, 7))

@export
@annotate_args([
None,
([1, 512, 15, 14], torch.float32, True),
])
def forward(self, x):
return self.aap2d(x)


@register_test_case(
module_factory=lambda: AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule())
def AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 512, 15, 14))


class AdaptiveAvgPool2dUnitOutputSizeStaticModule(torch.nn.Module):

def __init__(self):
Expand Down
49 changes: 12 additions & 37 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,29 @@ func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch
}

// -----
// CHECK-LABEL: func @torch.aten.adaptive_avg_pool2d$non_unit_output_size(
// CHECK-LABEL: func @torch.aten.adaptive_avg_pool2d$output_size_divisible_by_input(
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK-DAG: %[[CST0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[CST1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[CST2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[CST3:.*]] = torch.constant.int 3
// CHECK-DAG: %[[CST6:.*]] = torch.constant.int 6
// CHECK-DAG: %[[CST7:.*]] = torch.constant.int 7
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[COND1:.*]] = torch.aten.eq.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[COND1]], "unimplemented: only support cases where input and output size are equal for non-unit output size"
// CHECK: %[[T1:.*]] = torch.aten.sub.int %[[DIM2]], %[[CST6]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[COND2:.*]] = torch.aten.eq.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[COND2]], "unimplemented: only support cases where input and output size are equal for non-unit output size"
// CHECK: %[[T2:.*]] = torch.aten.sub.int %[[DIM3]], %[[CST6]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[REMAINER1:.*]] = torch.aten.remainder.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[COND1:.*]] = torch.aten.eq.int %[[REMAINER1]], %[[CST0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[COND1]], "unimplemented: only support cases input size is an integer multiple of output size"
// CHECK: %[[STRIDE1:.*]] = torch.aten.floordiv.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[REMAINER2:.*]] = torch.aten.remainder.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[COND2:.*]] = torch.aten.eq.int %[[REMAINER2]], %[[CST0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[COND2]], "unimplemented: only support cases input size is an integer multiple of output size"
// CHECK: %[[STRIDE2:.*]] = torch.aten.floordiv.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[STRIDE1]], %[[STRIDE2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[AVG_POOL:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.adaptive_avg_pool2d$non_unit_output_size(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[AVG_POOL:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[KERNEL_SIZE]], %[[PADDING]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.adaptive_avg_pool2d$output_size_divisible_by_input(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int7 = torch.constant.int 7
%output_size = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list<int>
%0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?],f32>
Expand All @@ -58,30 +57,6 @@ func.func @torch.aten.adaptive_avg_pool2d$non_unit_output_size(%arg0: !torch.vte

// -----

// CHECK-LABEL: func.func @torch.aten.adaptive_avg_pool2d$unit_output_size(
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK-DAG: %[[CST0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[CST1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[CST2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[CST3:.*]] = torch.constant.int 3
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[DIM2]], %[[DIM3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[AVG_POOL:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.adaptive_avg_pool2d$unit_output_size(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int1 = torch.constant.int 1
%output_size = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.acos$int_type(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,2],si32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2],si32>) -> !torch.vtensor<[2,2],si32> {
Expand Down

0 comments on commit bb40cfa

Please sign in to comment.