Skip to content

Commit

Permalink
Don't require inlining for shape refinement (#2631)
Browse files Browse the repository at this point in the history
Currently we require MLIR function inlining before refining shapes.
Refining shapes in order and following call operations (enforcing no
recursive calls) should allow for `refine(dynamic_export, static_args)
== static_export` to be true
  • Loading branch information
GleasonK authored Nov 20, 2024
1 parent 2d42d12 commit 7efac85
Show file tree
Hide file tree
Showing 6 changed files with 765 additions and 74 deletions.
53 changes: 36 additions & 17 deletions docs/generated/stablehlo_passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,28 +272,47 @@ type of every argument to the `main` method being refined.

_Refines shapes across a StableHLO program._

Walks through a StableHLO program refining shapes within ops.
Walks through a StableHLO program refining shapes within ops.

The flagship use case for this pass is specializing dynamically-shaped
programs to static shapes. If a dynamically-shaped StableHLO program has the
right structure, then updating its argument types from dynamic shapes to
static shapes and running this pass will propagate static shapes across
the program.
The flagship use case for this pass is specializing dynamically-shaped
programs to static shapes. If a dynamically-shaped StableHLO program has the
right structure, then updating its argument types from dynamic shapes to
static shapes and running this pass will propagate static shapes across
the program.

This pass removes `custom_call @shape_refinement_operand_wrapper` by
replacing uses of the result with the operand directly, and propagates
static shapes throughout the program.
This pass removes `custom_call @shape_refinement_operand_wrapper` by
replacing uses of the result with the operand directly, and propagates
static shapes throughout the program.

```
%c = stablehlo.constant dense<16> : tensor<1xi64>
%0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {...}
: (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
%1 = stablehlo.add %0, %0 : tensor<?xf32>
```
%c = stablehlo.constant dense<16> : tensor<1xi64>
%0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {...}
: (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
%1 = stablehlo.add %0, %0 : tensor<?xf32>
==>
==>
%1 = stablehlo.add %arg0, %arg0 : tensor<16xf32>
```
%1 = stablehlo.add %arg0, %arg0 : tensor<16xf32>
```

Modules valid for shape refinement must have the following properties:

* All the dynamic shapes depend only on the input shapes (no shape
dependency on the input array contents). We refer to the operations that
depend transitively only on the input shapes (e.g., as given by
`stablehlo.get_dimension_size`) or global constants like the resolved
values of symbolic integers (i.e. tensor<Axf32> : A = 5), as `dimension`
operations. All dimension values can be resolved to constants through
inter-procedural constant folding.
* Intermediate functions may take a number of token arguments (of type
!stablehlo.token) at the start of the argument list, followed by some
global constant arguments which are constant integer scalars, such as the
resolved values of symbolic integers (i.e. tensor<Axf32> : A = 5).
* Some intermediate functions may return computations on global constants,
i.e. `floordiv` on symint values. These functions are indicated by only
returning constant values after refinement. These functions are inlined.
* All calls to a single function resolve to the same argument shapes, and no
recursive / co-recursive function calls are made.
### `-vhlo-legalize-to-stablehlo`

_Legalize VHLO to StableHLO._
Expand Down
Loading

0 comments on commit 7efac85

Please sign in to comment.