Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StableHLO to linalg conversions to python bindings #2660

Merged
merged 5 commits into from
Dec 9, 2024

Conversation

mamanain
Copy link
Contributor

@mamanain mamanain commented Dec 6, 2024

Testing function:

from mlir.dialects import stablehlo
from mlir.ir import Context, Location, Module
import mlir.dialects.arith
from mlir.passmanager import PassManager

mlir_text = """
func.func @dot_general(%arg0: tensor<?x?x?xf32>,
                  %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
  %0 = "stablehlo.dot_general"(%arg0, %arg1) {
    dot_dimension_numbers = #stablehlo.dot<
      lhs_batching_dimensions = [1],
      lhs_contracting_dimensions = [2],
      rhs_batching_dimensions = [2],
      rhs_contracting_dimensions = [1]
    >,
    precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
    someattr
  } : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
  func.return %0 : tensor<?x?x?xf32>
}
"""

with Context() as ctx:
    stablehlo.register_dialect(ctx)
    stablehlo.register_stablehlo_passes()


    module = Module.parse(mlir_text)

    pm = PassManager.parse(
        "builtin.module(func.func("
        "shape-legalize-to-stablehlo,"
        "stablehlo-aggressive-folder,"
        "stablehlo-aggressive-simplification,"
        "stablehlo-legalize-to-linalg"
        "))"
    )

    pm.run(module.operation)

    print(f"{module}")

Before this change we get an error:

    pm = PassManager.parse(
         ^^^^^^^^^^^^^^^^^^
ValueError: MLIR Textual PassPipeline Parser:1:103: error: 'stablehlo-legalize-to-linalg' does not refer to a registered pass or pass pipeline
func.func(shape-legalize-to-stablehlo,stablehlo-aggressive-folder,stablehlo-aggressive-simplification,stablehlo-legalize-to-linalg)

Now we get an expected result:

#map = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
module {
  func.func @dot_general(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
    %c1 = arith.constant 1 : index
    %dim = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
    %c0 = arith.constant 0 : index
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
    %c0_1 = arith.constant 0 : index
    %dim_2 = tensor.dim %arg1, %c0_1 : tensor<?x?x?xf32>
    %from_elements = tensor.from_elements %dim, %dim_0, %dim_2 : tensor<3xindex>
    %c0_3 = arith.constant 0 : index
    %extracted = tensor.extract %from_elements[%c0_3] : tensor<3xindex>
    %c1_4 = arith.constant 1 : index
    %extracted_5 = tensor.extract %from_elements[%c1_4] : tensor<3xindex>
    %c2 = arith.constant 2 : index
    %extracted_6 = tensor.extract %from_elements[%c2] : tensor<3xindex>
    %0 = tensor.empty(%extracted, %extracted_5, %extracted_6) : tensor<?x?x?xf32>
    %cst = arith.constant 0.000000e+00 : f32
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%1 : tensor<?x?x?xf32>) attrs =  {someattr} {
    ^bb0(%in: f32, %in_7: f32, %out: f32):
      %3 = arith.mulf %in, %in_7 : f32
      %4 = arith.addf %out, %3 : f32
      linalg.yield %4 : f32
    } -> tensor<?x?x?xf32>
    return %2 : tensor<?x?x?xf32>
  }
}

Copy link

google-cla bot commented Dec 6, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@GleasonK GleasonK requested a review from sdasgup3 December 6, 2024 16:12
@GleasonK GleasonK added the Transformations Pertaining to MLIR passes and transformations label Dec 6, 2024
@GleasonK
Copy link
Member

GleasonK commented Dec 9, 2024

Happy to merge this once CI finishes.

Wanted to note: There is a chance we split this into a different file / target in the future. There are some users who want a "minimal StableHLO build" CAPI / python target which doesn't require all the upstream dialect dependencies but still has serialization APIs, but that's a much larger task and we can refactor this alongside it.

@mamanain
Copy link
Contributor Author

mamanain commented Dec 9, 2024

Perfect, thank you very much! If it would be possible to still have an opportunity to call these passes through python bindings in jax or from the standalone ones that would be great.

Will this commit be included into the next jax release automatically? We use bindings that are shipped with it.

@GleasonK GleasonK merged commit ef176a1 into openxla:main Dec 9, 2024
10 checks passed
@GleasonK
Copy link
Member

GleasonK commented Dec 9, 2024

I believe yes -- this file is included in the JAX build:

[]() { mlirRegisterAllStablehloPasses(); });

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Transformations Pertaining to MLIR passes and transformations
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants