From ee08942c8fa51ae6fcdc0c231b138be16ec5a7ae Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 2 Dec 2024 09:14:33 +0100 Subject: [PATCH 1/2] aten.pow / onnx.Pow: Fix (float,int) / (int, float) accuracy (#3894) Fixes `onnx.Pow(float,int)` and `Pow(int,float)` accuracy. Torch uses `double` internally to compute pow if one argument is integer and the other one is floating point (due to C++ promotion rules). This PR keeps `onnx.Pow(int,int)` as is, which still produces numeric mismatches for values that overflow. torch uses a pure-integer implementation, where torch-mlir currently maps it to `Pow(float,float)` --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 26 +----- .../TorchToLinalg/Uncategorized.cpp | 14 ++- projects/pt1/e2e_testing/xfail_sets.py | 12 +-- .../torch_mlir_e2e_test/test_suite/basic.py | 89 +++++++++++++++++-- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 7 +- 5 files changed, 105 insertions(+), 43 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 206f1eecbf53..7446b7faaa08 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -3014,6 +3014,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( }); patterns.onOp( "Pow", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // ONNX specifies that the result types matches the type of lhs. + // In torch, the result type is integer when both operands are integer, + // and otherwise operand types are promoted to f64. Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || @@ -3022,35 +3025,14 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } auto loc = binder.getLoc(); - auto lhsTy = cast(lhs.getType()); - auto rhsTy = cast(rhs.getType()); Value cstFalse = rewriter.create( loc, rewriter.getBoolAttr(false)); Value none = rewriter.create(loc); - auto torchDtype = Torch::getScalarTypeForType(rewriter.getF32Type()); - Value tyConst = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - static_cast(torchDtype))); - - if (isa(lhsTy.getDtype())) { - lhsTy = rewriter.getType( - lhsTy.getSizes(), rewriter.getF32Type()); - lhs = rewriter.create(loc, lhsTy, lhs, tyConst, - cstFalse, cstFalse, none); - } - - if (isa(rhsTy.getDtype())) { - rhsTy = rewriter.getType( - rhsTy.getSizes(), rewriter.getF32Type()); - rhs = rewriter.create(loc, rhsTy, rhs, tyConst, - cstFalse, cstFalse, none); - } auto powType = resultType; if (isa(resultType.getDtype())) { powType = rewriter.getType( - resultType.getSizes(), rewriter.getF32Type()); + resultType.getSizes(), rewriter.getF64Type()); } Value pow = rewriter.create(loc, powType, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 35e4144f30eb..d6b5aaf869c8 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1019,12 +1019,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = cast(converter->convertType(pow.getType())) .getElementType(); if (!isa(dtype)) { + // The result type is integer when both operands are integer. + // Torch then uses the following implementation: + // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Pow.h pow.emitError("unimplemented: non-floating point dtype"); return nullptr; } - Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); - Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - return b.create(loc, lhs, rhs); + Type powType = dtype; + if (payloadArgs[0].getType().isInteger() || + payloadArgs[1].getType().isInteger()) + powType = mlir::FloatType::getF64(op->getContext()); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], powType); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], powType); + auto powOp = b.create(loc, lhs, rhs); + return convertScalarToDtype(b, loc, powOp, dtype); } if (auto imag = dyn_cast(op)) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 629d72be3580..f3c8a9cd7837 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -33,6 +33,8 @@ # if a dimension is specified in all expand lists, and not in sumdim list. # This is a bug in the implementation of _trilinear in PyTorch. "Aten_TrilinearModuleZerodDimBug_basic", + # missing lowering from aten.pow.Tensor_Tensor for integer result + "PowIntIntModule_basic", } if torch_version_for_comparison() < version.parse("2.5.0.dev"): @@ -220,7 +222,6 @@ "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", "IntFloatModule_basic", - "PowIntFloatModule_basic", # END tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.len "LenStrModule_basic", @@ -448,7 +449,7 @@ "NllLossModuleBackward1D_basic", "NumelModule_basic", "NumelZeroRankModule_basic", - "PowIntFloatModule_basic", + "PowIntIntModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", @@ -796,7 +797,6 @@ "NormalFunctionalModule_basic", "NumelModule_basic", "NumelZeroRankModule_basic", - "PowIntFloatModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", @@ -2301,6 +2301,8 @@ "PadWithNoneValModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", + "PowFloatFloatModule_basic", + "PowFloatIntModule_basic", "PrimListUnpackNumMismatchModule_basic", "PrimsIotaModule_basic", "PrimsSqueezeEmptyDimensionsModule_basic", @@ -3081,7 +3083,7 @@ "PixelShuffleModuleSpatiallyDynamic_basic", "PixelShuffleModuleSpatiallyStatic_basic", "PixelShuffleModuleStaticRank3Int64_basic", - "PowIntFloatModule_basic", + "PowIntIntModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", @@ -3766,7 +3768,6 @@ "PixelShuffleModuleSpatiallyStatic_basic", "PixelShuffleModuleStaticRank3Int64_basic", "PixelShuffleModuleStaticRank4Float32_basic", - "PowIntFloatModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", @@ -4626,7 +4627,6 @@ "PixelShuffleModuleSpatiallyStatic_basic", "PixelShuffleModuleStaticRank3Int64_basic", "PixelShuffleModuleStaticRank4Float32_basic", - "PowIntFloatModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 94f1538dbc21..5e3aa3bc02f6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -4426,25 +4426,100 @@ def IntImplicitModule_basic(module, tu: TestUtils): # ============================================================================== -class PowIntFloat(torch.nn.Module): +class PowModule(torch.nn.Module): def __init__(self): super().__init__() - self.value = 2 - self.power_value = 3.0 @export @annotate_args( [ None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), ] ) - def forward(self): - return torch.ops.aten.pow(self.value, self.power_value) + def forward(self, x, y): + return torch.ops.aten.pow(x, y) -@register_test_case(module_factory=lambda: IntFloatModule()) +@register_test_case(module_factory=lambda: PowModule()) +def PowFloatFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class PowIntFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.pow(x, y) + + +@register_test_case(module_factory=lambda: PowIntFloatModule()) def PowIntFloatModule_basic(module, tu: TestUtils): - module.forward() + module.forward(tu.randint(3, 4, 5, dtype=torch.int32), tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class PowFloatIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.pow(x, y) + + +@register_test_case(module_factory=lambda: PowFloatIntModule()) +def PowFloatIntModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.randint(3, 4, 5, dtype=torch.int32)) + + +# ============================================================================== + + +class PowIntIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int32, True), + ([-1, -1, -1], torch.int32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.pow(x, y) + + +@register_test_case(module_factory=lambda: PowIntIntModule()) +def PowIntIntModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, 5, high=10, dtype=torch.int32), + tu.randint(3, 4, 5, high=20, dtype=torch.int32), + ) # ============================================================================== diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 20f4a85b9f54..5a5fb83d5fc0 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1182,12 +1182,9 @@ func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3 func.func @test_pow_i32(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[DTY:.+]] = torch.constant.int 6 - // CHECK: %[[CAST_LHS:.+]] = torch.aten.to.dtype %arg0, %[[DTY]], %[[FALSE]], %[[FALSE]], %[[NONE]] - // CHECK: %[[CAST_RHS:.+]] = torch.aten.to.dtype %arg1, %[[DTY]], %[[FALSE]], %[[FALSE]], %[[NONE]] - // CHECK: %[[POW:.+]] = torch.aten.pow.Tensor_Tensor %[[CAST_LHS]], %[[CAST_RHS]] + // CHECK: %[[POW:.+]] = torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32> -> !torch.vtensor<[3,4,5],f64> // CHECK: %[[DTY:.+]] = torch.constant.int 3 - // CHECK: %[[RES:.+]] = torch.aten.to.dtype %2, %[[DTY]], %[[FALSE]], %[[FALSE]], %[[NONE]] + // CHECK: %[[RES:.+]] = torch.aten.to.dtype %[[POW]], %[[DTY]], %[[FALSE]], %[[FALSE]], %[[NONE]] // CHECK: return %[[RES]] %0 = torch.operator "onnx.Pow"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32> return %0 : !torch.vtensor<[3,4,5],si32> From 456232afe4a15f1c4689109376d1e4527d064c1e Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Mon, 2 Dec 2024 08:48:50 -0800 Subject: [PATCH 2/2] Support bf16 on aten.uniform lowering (#3895) --- lib/Conversion/TorchToLinalg/Random.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index aa4ec91d7da5..854e3f86d367 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -186,7 +186,7 @@ class ConvertAtenUniformOp : public OpConversionPattern { Value res = randomUniformF64(b, loc, linearIndex, key, min, max); Value truncRes = res; - if (isa(elemTy)) + if (isa(elemTy)) truncRes = b.create(loc, elemTy, res); b.create(loc, truncRes); })