Skip to content

Commit

Permalink
[AutoBump] Merge with fixes of 30c5193 (Nov 12)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Feb 3, 2025
2 parents b229dae + 30c5193 commit 25ee6b5
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 32 deletions.
46 changes: 21 additions & 25 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5737,6 +5737,26 @@ static LogicalResult getOutputTypeAndPoolingParameters(
std::is_same<AtenOpT, AtenAvgPool1dOp>())
paddingInts.push_back(0);

if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
// Currently, we can not represent `count_include_pad` with the existing
// TOSA AvgPool2d specification. Without the below check, we produce silent
// wrong answer (SWA) when the `count_include_pad` value is `true.`
//
// Note: We need to check for `count_include_pad` only when the `padding`
// value is non-zero.
bool countIncludePad;
if ((paddingInts[0] != 0 || paddingInts[1] != 0) &&
(!matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)) ||

countIncludePad)) {
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
"`count_include_pad` value should be `False`.");
}
}

SmallVector<int64_t, 4> padArr = {paddingInts[0], paddingInts[0],
paddingInts[1], paddingInts[1]};
kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts);
Expand Down Expand Up @@ -5865,18 +5885,6 @@ class ConvertAtenAvgPool2dOp
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
Type &outputTy) const override {

// Currently, we can not represent `count_include_pad` with the existing
// TOSA AvgPool2d specification. Without the below check, we produce silent
// wrong answers (SWA) when the `count_include_pad` value is `true.`
bool countIncludePad;
if (!matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)) ||
countIncludePad) {
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
"`count_include_pad` value should be `False`.");
}

// Currently, we can not represent `divisor_override` with the existing TOSA
// AvgPool2d specification. Without the below check, we produce silent wrong
// answers (SWA) when the `divisor_override` value is other than `None.`
Expand Down Expand Up @@ -5925,7 +5933,7 @@ class ConvertAtenAvgPool1dOp
// Expected a rank 3 input tensor
if (selfTy.getRank() != 3)
return rewriter.notifyMatchFailure(
op, "Input tensor for MaxPool1d should have rank 3");
op, "Input tensor for AvgPool1d should have rank 3");

// Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
SmallVector<int64_t> rank4Shape(selfShape);
Expand All @@ -5936,18 +5944,6 @@ class ConvertAtenAvgPool1dOp
selfTy.getElementType()),
self, rewriter.getDenseI64ArrayAttr(rank4Shape));

// Currently, we can not represent `count_include_pad` with the existing
// TOSA AvgPool2d specification. Without the below check, we produce silent
// wrong answers (SWA) when the `count_include_pad` value is `true.`
bool countIncludePad;
if (!matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)) ||
countIncludePad) {
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
"`count_include_pad` value should be `False`.");
}

SmallVector<int64_t, 2> dilationArray{1, 1};
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
tosa::AvgPool2dOp>(
Expand Down
14 changes: 7 additions & 7 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,12 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
Expand Down Expand Up @@ -2382,6 +2388,7 @@
"ReshapeExpandModule_basic",
"ReturnThreeTensorFloat32_basic",
"ReturnTwoTensorF32I64_basic",
"ResNet18StaticModule_basic",
"RsubFloatModule_basic",
"RsubFloatModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
Expand Down Expand Up @@ -3471,13 +3478,6 @@
"ElementwiseSignbitModule_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
Expand Down
15 changes: 15 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2666,3 +2666,18 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to
%0 = torch.prims.collapse %arg0, %int1, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,12],f32>
return %0 : !torch.vtensor<[2,12],f32>
}

// -----

func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%count_include_pad = torch.constant.bool true
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
return %3 : !torch.vtensor<[1,512,10],f32>
}

0 comments on commit 25ee6b5

Please sign in to comment.