Skip to content

Commit

Permalink
Sync fx_importer from torch-mlir. (nod-ai#475)
Browse files Browse the repository at this point in the history
I was hoping that we could just depend on it via IREE, but it needed
some local patches (sending upstream) for type imports that don't exist
in old versions of PyTorch. So for now, just updating the local fork.
  • Loading branch information
stellaraccident authored Feb 24, 2024
1 parent fabd52c commit 45f2e57
Show file tree
Hide file tree
Showing 2 changed files with 800 additions and 94 deletions.
50 changes: 33 additions & 17 deletions core/shark_turbine/aot/builtins/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,21 @@
)
from torch.fx.passes.shape_prop import TensorMetadata

from ...dynamo.passes import (
DEFAULT_DECOMPOSITIONS,
)
# TODO: Switch to upstream fx_importer vs local fork when ready.
# from iree.compiler.extras.fx_importer import (
# GraphNodeImporter,
# FxImporter,
# FxImporterHooks,
# )

from ...importers.fx_importer import (
GraphNodeImporter,
FxImporter,
FxImporterHooks,
)

from ...dynamo.passes import (
DEFAULT_DECOMPOSITIONS,
)

from ...support.ir_imports import (
Expand Down Expand Up @@ -68,23 +76,33 @@
StringAttrOrStr = Union[StringAttr, str]


def _make_literal_resolver(module_builder: ModuleBuilder):
# When we first encounter a global during import, we have to pull it
# into the local module being populated by the GraphNodeImporter. This
# will exactly match the global in the target module we are merging into
# and exists so that the IR is valid during Fx import. We keep the set of
# symbols we have done this to here.
cloned_global_symbols: Set[str] = set()
class _Hooks(FxImporterHooks):
__slots__ = [
"cloned_global_symbols",
"module_builder",
]

def __init__(self, module_builder: ModuleBuilder):
self.module_builder = module_builder
# When we first encounter a global during import, we have to pull it
# into the local module being populated by the GraphNodeImporter. This
# will exactly match the global in the target module we are merging into
# and exists so that the IR is valid during Fx import. We keep the set of
# symbols we have done this to here.
self.cloned_global_symbols: set[str] = set()

def resolve_literal(self, gni: GraphNodeImporter, literal: Any) -> Optional[Value]:
module_builder = self.module_builder
cloned_global_symbols = self.cloned_global_symbols

def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]:
# We support resolution of tracked reference types. Currently this
# only includes Tensors. All others we let the importer do what it
# is going to do.
if not isinstance(py_value, torch.Tensor):
if not isinstance(literal, torch.Tensor):
return None

# See if we know about it.
mapping = module_builder.global_ref_tracker.track(py_value)
mapping = module_builder.global_ref_tracker.track(literal)
if mapping.is_empty:
# If it is unknown, just let the default importer take it on.
return None
Expand All @@ -101,7 +119,7 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]:
cloned_global_symbols.add(materialized_global.symbol_name)

# Emit a global load and conversion.
vtensor_type = gni._cc.tensor_to_vtensor_type(py_value)
vtensor_type = gni._cc.tensor_to_vtensor_type(literal)
loaded_value = util_d.GlobalLoadOp(
materialized_global.ir_type, materialized_global.symbol_name
).result
Expand All @@ -112,8 +130,6 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]:
).result
return converted_value

return resolver


ALL_PASSES: Set[str] = set(["functorch_functionalize"])
DEFAULT_PASSES: Tuple[str, ...] = ("functorch_functionalize",)
Expand Down Expand Up @@ -234,7 +250,7 @@ def flat_wrapped_f(*args):
fx_importer = FxImporter(
context=proc_trace.context,
config_check=False,
literal_resolver_callback=_make_literal_resolver(proc_trace.module_builder),
hooks=_Hooks(proc_trace.module_builder),
py_attr_tracker=proc_trace.module_builder.fx_py_attr_tracker,
)
fx_importer.import_stateless_graph(gm.graph, func_name=self.function_name)
Expand Down
Loading

0 comments on commit 45f2e57

Please sign in to comment.