Skip to content

Commit

Permalink
Add refinement test for if/case (#2675)
Browse files Browse the repository at this point in the history
cc @bartchr808 for pointing out this test gap
  • Loading branch information
GleasonK authored Dec 23, 2024
1 parent 960b231 commit aee33b5
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,60 @@ func.func @refine_while(%arg0: tensor<4xf32>) -> tensor<?xf32> {

// -----

// Unlike WhileOp which requires separate patterns for propagating information
// to block arguments, If/Case don't have region arguments, so relies on all
// free variables and body variables to be refined before the case op is refined

// CHECK-LABEL: func @refine_case
// CHECK-SAME: tensor<2x3x224x224xf32>
func.func @refine_case() -> tensor<?x3x224x224xf32> {
%c = stablehlo.constant dense<1> : tensor<i32>
%0 = "stablehlo.case"(%c) ({
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%c_0 = stablehlo.constant dense<[2, 3, 224, 224]> : tensor<4xi32>
%1 = stablehlo.dynamic_broadcast_in_dim %cst, %c_0, dims = [] : (tensor<f32>, tensor<4xi32>) -> tensor<?x3x224x224xf32>
// CHECK: return {{.*}} : tensor<2x3x224x224xf32>
stablehlo.return %1 : tensor<?x3x224x224xf32>
}, {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%c_0 = stablehlo.constant dense<[2, 3, 224, 224]> : tensor<4xi32>
%1 = stablehlo.dynamic_broadcast_in_dim %cst, %c_0, dims = [] : (tensor<f32>, tensor<4xi32>) -> tensor<?x3x224x224xf32>
// CHECK: return {{.*}} : tensor<2x3x224x224xf32>
stablehlo.return %1 : tensor<?x3x224x224xf32>
}) : (tensor<i32>) -> tensor<?x3x224x224xf32>
// CHECK: return {{.*}} : tensor<2x3x224x224xf32>
return %0 : tensor<?x3x224x224xf32>
}

// -----

// Unlike WhileOp which requires separate patterns for propagating information
// to block arguments, If/Case don't have region arguments, so relies on all
// free variables and body variables to be refined before the case op is refined

// CHECK-LABEL: func @refine_if
// CHECK-SAME: tensor<2x3x224x224xf32>
func.func @refine_if() -> tensor<?x3x224x224xf32> {
%c = stablehlo.constant dense<true> : tensor<i1>
%0 = "stablehlo.if"(%c) ({
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%c_0 = stablehlo.constant dense<[2, 3, 224, 224]> : tensor<4xi32>
%1 = stablehlo.dynamic_broadcast_in_dim %cst, %c_0, dims = [] : (tensor<f32>, tensor<4xi32>) -> tensor<?x3x224x224xf32>
// CHECK: return {{.*}} : tensor<2x3x224x224xf32>
stablehlo.return %1 : tensor<?x3x224x224xf32>
}, {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%c_0 = stablehlo.constant dense<[2, 3, 224, 224]> : tensor<4xi32>
%1 = stablehlo.dynamic_broadcast_in_dim %cst, %c_0, dims = [] : (tensor<f32>, tensor<4xi32>) -> tensor<?x3x224x224xf32>
// CHECK: return {{.*}} : tensor<2x3x224x224xf32>
stablehlo.return %1 : tensor<?x3x224x224xf32>
}) : (tensor<i1>) -> tensor<?x3x224x224xf32>
// CHECK: return {{.*}} : tensor<2x3x224x224xf32>
return %0 : tensor<?x3x224x224xf32>
}

// -----

// TODO: Implement support for these ops.
// * dynamic_conv (#867).
// * dynamic_fft (#1366).
Expand Down

0 comments on commit aee33b5

Please sign in to comment.