Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Script to split stablehlo module into sub modules [#2317] #2323

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ddilbazTT
Copy link
Contributor

Add a script that splits stablehlo module into sub modules representing each op individually. Need to pip install -r tools/stablehlo_splitter/requirements.txt to install dependencies.

Ticket

#2317

Problem description

Front-ends (tt-xla, tt-torch) want to run stablehlo graphs op-by-op. We need a mechanism to break down stablehlo modules into standalone sub-modules corresponding to each op individually.

What's changed

Added tools/stablehlo_splitter/shlo_split.py and tools/stablehlo_splitter/requirements.txt

# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

## pip install stablehlo -f https://github.com/openxla/stablehlo/releases/expanded_assets/dev-wheels

from mlir.ir import Context, Module
import mlir.dialects.stablehlo as stablehlo


def parse_module_from_str(module_str):
    module = None
    with Context() as ctx:
        stablehlo.register_dialect(ctx)
        module = Module.parse(module_str)
    return module


class StablehloSplitter:
    def __init__(self, module: str):
        self.module = module
        self.parsed_module = parse_module_from_str(module)
        self.sub_ops = []
        self.get_ops_in_module()

    def get_ops_in_module(self):
        for func_op in self.parsed_module.body.operations:
            for block in func_op.regions[0].blocks:
                for op in block.operations:
                    if op.name.startswith(("func.", "return")):
                        continue

                    inputs = {
                        operand.get_name(): str(operand.type) for operand in op.operands
                    }
                    args_str = ", ".join(f"{key}: {typ}" for key, typ in inputs.items())

                    # Handle multiple results in the operation
                    result_names = [str(result.get_name()) for result in op.results]
                    result_types = [str(result.type) for result in op.results]

                    # Construct the function signature based on the number of results
                    if len(result_names) == 1:
                        result_str = f"{result_types[0]}"
                        return_stmt = f"return {result_names[0]} : {result_types[0]}"
                    else:
                        result_str = f"({', '.join(result_types)})"
                        return_stmt = f"return ({', '.join(result_names)}) : ({', '.join(result_types)})"
                    # Build the new module string
                    new_module_str = f"""module {{
        func.func @main({args_str}) -> {result_str} {{
            {str(op)}
            {return_stmt}
        }}
    }}"""
                    dict_item = {
                        "op_id": ", ".join(result_names),
                        "op": str(op),
                        "module": new_module_str,
                    }
                    self.sub_ops.append(dict_item)

How to use the script?

From each front-end, we need to run pip install -r third_party/tt-mlir/src/tt-mlir/tools/stablehlo_splitter/requirements.txt
I have tested in both tt-xla and tt-mlir that the script works.
tt-xla test branch
tt-torch test branch

The test script is as follows:

# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import importlib.util
import sys

# Define the path to the script
script_path = "third_party/tt-mlir/src/tt-mlir/tools/stablehlo_splitter/shlo_split.py"

# Load the module
spec = importlib.util.spec_from_file_location("shlo_split", script_path)
shlo_split = importlib.util.module_from_spec(spec)
sys.modules["shlo_split"] = shlo_split
spec.loader.exec_module(shlo_split)

mlir_path = "./Autoencoder.mlir"
module_str = ""
with open(mlir_path, "r") as file:
    module_str = file.read()
# Now you can use the StablehloSplitter class
splitter = shlo_split.StablehloSplitter(module_str)
print(splitter.sub_ops)

This file reads the stablehlo module coming from autoencoder model in tt-torch, which is saved as Autoencoder.mlir If you run python temp.py you can see the sub-ops printed as attached to this PR.
Autoencoder_sub_ops.txt

I am creating this as a draft PR to start a discussion on:

  • are we happy with the way that each individual op is represented right now? Currently, each op is a dictionary.
    op_id is the id of the op in the original mlir string
    op is how it was referenced in the original mlir string
    module is the new standalone module that can be compiled/ executed standalone
    i.e. {'op_id': '%33', 'op': '%33 = stablehlo.dot_general %32, %arg7, contracting_dims = [1] x [0] : (tensor<1x12xf32>, tensor<12x3xf32>) -> tensor<1x3xf32>', 'module': 'module {\n func.func @main(%32: tensor<1x12xf32>, %arg7: tensor<12x3xf32>) -> tensor<1x3xf32> {\n %33 = stablehlo.dot_general %32, %arg7, contracting_dims = [1] x [0] : (tensor<1x12xf32>, tensor<12x3xf32>) -> tensor<1x3xf32>\n return %33 : tensor<1x3xf32>\n }\n }'}
  • is it ok if this script returns a list of dictionaries, each dictionary representing an op? tt-torch has it's own class representing an op which wouldn't translate into a generic op easily. So I thought each front-end could use this dictionary and populate its internal data structures on their own.
  • any concerns regarding the structure/ usability of the script for front-ends?

@kmitrovicTT
Copy link
Contributor

I suggest we focus on making this as small as possible but sufficient enough to be generic. Once frontends hook onto it, we will see if any modifications need to be done.

For example, instead of using dict, create a simple StableHLOOp dataclass to hold info you want. If tt-torch or tt-xla have their own op representations, that's ok as long as they can use info from StableHLOOp.

Add a script that splits stablehlo module into sub modules representing each op individually. Need to pip install -r tools/stablehlo_splitter/requirements.txt to install dependencies.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants