Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
vkovinicTT committed Feb 26, 2025
1 parent 0a6dc13 commit 5ea2928
Show file tree
Hide file tree
Showing 9 changed files with 2,125 additions and 1,188 deletions.
5 changes: 4 additions & 1 deletion forge/forge/tvm_calls/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
"""
TVM passes package for forge compilation.
Contains compilation all the functions required to interact with TVM.
(What was previously in /contrib/ folder in TVM)
"""
"""
573 changes: 403 additions & 170 deletions forge/forge/tvm_calls/forge_compile.py

Large diffs are not rendered by default.

67 changes: 30 additions & 37 deletions forge/forge/tvm_calls/forge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@


def extract_framework_model_outputs(
framework: str,
model,
inputs,
verify_tvm_compile: bool = False,
framework: str,
model,
inputs,
verify_tvm_compile: bool = False,
path=None,
input_dict={},
):
Expand All @@ -47,6 +47,7 @@ def extract_framework_model_outputs(
else:
assert False, "Don't know what to do with this"
elif any([isinstance(x, (tuple, list)) for x in framework_outputs]):

def flatten_outputs(outputs):
new_outputs = []
if isinstance(outputs, (tuple, list)):
Expand All @@ -55,21 +56,22 @@ def flatten_outputs(outputs):
else:
new_outputs.append(outputs)
return new_outputs

framework_outputs = flatten_outputs(framework_outputs)

framework_outputs = [x.detach().numpy() for x in framework_outputs]

elif framework == "tensorflow":
kwargs = {}
import inspect
import inspect

arg_names = inspect.getfullargspec(model.call).args
if "return_dict" in arg_names:
kwargs["return_dict"] = False

if "training" in arg_names:
kwargs["training"] = False

framework_outputs = model(*inputs, **kwargs)

# TODO ref sha: 1fe78625c809e6ca887a8da5fdde44836830f990
Expand All @@ -84,13 +86,17 @@ def flatten_outputs(outputs):

framework_outputs = flatten_structured_output(framework_outputs)
supported_outputs = (tf.Tensor, torch.Tensor)
framework_outputs = [
x.numpy() for x in framework_outputs if isinstance(x, supported_outputs)
]
framework_outputs = [x.numpy() for x in framework_outputs if isinstance(x, supported_outputs)]

elif framework == "jax":
import jax.numpy as jnp
args = [jnp.asarray(x.numpy(),) for x in inputs]

args = [
jnp.asarray(
x.numpy(),
)
for x in inputs
]
framework_outputs = model(*args)
if isinstance(framework_outputs, HFModelOutput):
framework_outputs = list(framework_outputs.values())
Expand All @@ -117,9 +123,9 @@ def flatten_outputs(outputs):
input_details = model.get_input_details()
output_details = model.get_output_details()
model.allocate_tensors()
model.set_tensor(input_details[0]['index'], *inputs)
model.set_tensor(input_details[0]["index"], *inputs)
model.invoke()
framework_outputs = model.get_tensor(output_details[0]['index'])
framework_outputs = model.get_tensor(output_details[0]["index"])

else:
raise RuntimeError("Unsupported framework type: {}".format(framework))
Expand Down Expand Up @@ -155,12 +161,10 @@ def get_input_structure(inputs, input_names):
for k, v in inputs.items():
input_structure[k] = (tuple(v.shape), str(input.dtype).replace("torch.", ""))
return input_structure

input_structure = get_input_structure(inputs, input_names)

flattened_inputs, flattened_input_names, flattened_name_map = flatten_inputs(
inputs, input_names
)

flattened_inputs, flattened_input_names, flattened_name_map = flatten_inputs(inputs, input_names)

elif framework == "tensorflow":
# The tensorflow trace automatically flattens inputs
Expand All @@ -187,9 +191,7 @@ def construct_tvm_ir(framework: str, model, tvm_mod, params, compiler_cfg: Compi
param_name_lookup = {}

if not compiler_cfg.enable_tvm_constant_prop:
tvm_mod = tvm.IRModule.from_expr(
tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], {})
)
tvm_mod = tvm.IRModule.from_expr(tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], {}))
else:
if len(compiler_cfg.tvm_constnat_prop_mask):
propped_params = {
Expand All @@ -212,10 +214,7 @@ def construct_tvm_ir(framework: str, model, tvm_mod, params, compiler_cfg: Compi
for (bad_name, value) in params.items():
weight_found = False
for tf_weight in model.weights:
if (
np.array_equal(tf_weight.value().numpy(), value.numpy())
and tf_weight.name not in found_weights
):
if np.array_equal(tf_weight.value().numpy(), value.numpy()) and tf_weight.name not in found_weights:
param_name_lookup[bad_name] = tf_weight.name
weight_found = True
found_weights.append(tf_weight.name)
Expand All @@ -233,12 +232,7 @@ def construct_tvm_ir(framework: str, model, tvm_mod, params, compiler_cfg: Compi
propped_params = {
k: (v, True)
for k, v, in params.items()
if any(
[
mask in param_name_lookup[k]
for mask in compiler_cfg.tvm_constnat_prop_mask
]
)
if any([mask in param_name_lookup[k] for mask in compiler_cfg.tvm_constnat_prop_mask])
}
propped_params.update(non_weight_params)
else:
Expand Down Expand Up @@ -313,23 +307,22 @@ def construct_tvm_ir(framework: str, model, tvm_mod, params, compiler_cfg: Compi

return tvm_mod, param_name_lookup


def has_op(module, opname, attrs={}):

class Visitor(ExprVisitor):

def __init__(self):
super().__init__()
self.has_op = False

def visit_call(self, call):
if call.op.name == opname:
self.has_op = True
for key in attrs.keys():
self.has_op &= key in call.attrs.keys() and call.attrs[key] == attrs[key]
self.has_op &= key in call.attrs.keys() and call.attrs[key] == attrs[key]
if self.has_op:
return
super().visit_call(call)

visitor = Visitor()
visitor.visit(module)
return visitor.has_op
return visitor.has_op
2 changes: 1 addition & 1 deletion forge/forge/tvm_calls/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .forge import *
from .forge_passes import *
from .relay_passes import *
from .utils import *
from .utils import *
Loading

0 comments on commit 5ea2928

Please sign in to comment.