From aee33b539c9ab92fd603bd22524d268e8b14c434 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Mon, 23 Dec 2024 11:19:20 -0600 Subject: [PATCH] Add refinement test for if/case (#2675) cc @bartchr808 for pointing out this test gap --- .../transforms/stablehlo_refine_shapes.mlir | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir index 397c08d388..b262bf095e 100644 --- a/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir +++ b/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir @@ -1199,6 +1199,60 @@ func.func @refine_while(%arg0: tensor<4xf32>) -> tensor { // ----- +// Unlike WhileOp which requires separate patterns for propagating information +// to block arguments, If/Case don't have region arguments, so relies on all +// free variables and body variables to be refined before the case op is refined + +// CHECK-LABEL: func @refine_case +// CHECK-SAME: tensor<2x3x224x224xf32> +func.func @refine_case() -> tensor { + %c = stablehlo.constant dense<1> : tensor + %0 = "stablehlo.case"(%c) ({ + %cst = stablehlo.constant dense<1.000000e+00> : tensor + %c_0 = stablehlo.constant dense<[2, 3, 224, 224]> : tensor<4xi32> + %1 = stablehlo.dynamic_broadcast_in_dim %cst, %c_0, dims = [] : (tensor, tensor<4xi32>) -> tensor + // CHECK: return {{.*}} : tensor<2x3x224x224xf32> + stablehlo.return %1 : tensor + }, { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %c_0 = stablehlo.constant dense<[2, 3, 224, 224]> : tensor<4xi32> + %1 = stablehlo.dynamic_broadcast_in_dim %cst, %c_0, dims = [] : (tensor, tensor<4xi32>) -> tensor + // CHECK: return {{.*}} : tensor<2x3x224x224xf32> + stablehlo.return %1 : tensor + }) : (tensor) -> tensor + // CHECK: return {{.*}} : tensor<2x3x224x224xf32> + return %0 : tensor +} + +// ----- + +// Unlike WhileOp which requires separate patterns for propagating information +// to block arguments, If/Case don't have region arguments, so relies on all +// free variables and body variables to be refined before the case op is refined + +// CHECK-LABEL: func @refine_if +// CHECK-SAME: tensor<2x3x224x224xf32> +func.func @refine_if() -> tensor { + %c = stablehlo.constant dense : tensor + %0 = "stablehlo.if"(%c) ({ + %cst = stablehlo.constant dense<1.000000e+00> : tensor + %c_0 = stablehlo.constant dense<[2, 3, 224, 224]> : tensor<4xi32> + %1 = stablehlo.dynamic_broadcast_in_dim %cst, %c_0, dims = [] : (tensor, tensor<4xi32>) -> tensor + // CHECK: return {{.*}} : tensor<2x3x224x224xf32> + stablehlo.return %1 : tensor + }, { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %c_0 = stablehlo.constant dense<[2, 3, 224, 224]> : tensor<4xi32> + %1 = stablehlo.dynamic_broadcast_in_dim %cst, %c_0, dims = [] : (tensor, tensor<4xi32>) -> tensor + // CHECK: return {{.*}} : tensor<2x3x224x224xf32> + stablehlo.return %1 : tensor + }) : (tensor) -> tensor + // CHECK: return {{.*}} : tensor<2x3x224x224xf32> + return %0 : tensor +} + +// ----- + // TODO: Implement support for these ops. // * dynamic_conv (#867). // * dynamic_fft (#1366).