Skip to content

Commit

Permalink
ThunderFX: handles the callable input of fx.Node (#1539)
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Dec 18, 2024
1 parent 2d0199e commit 64ffc97
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
1 change: 1 addition & 0 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def callback(node) -> int:
return partition_cnt

# `split_module` iterates over nodes and determines the partition to place them based on the callback.
gm.graph.eliminate_dead_code()
original_split_gm: torch.fx.GraphModule = split_module(
gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True
)
Expand Down
10 changes: 6 additions & 4 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,12 @@ def get_proxy_inputs_from_node(node: torch.fx.Node) -> tuple[tuple, dict]:
# We need to be under trace context to generate proxies.
with thunder.core.trace.tracectx(TraceCtx()):

def make_tensor_proxy(arg_node):
def make_input_proxy(arg_node):
# This is a Node in the graph representing a Tensor or tuple of Tensors or
# a PyTorch object like one representing torch.autocast.
if isinstance(arg_node, torch.fx.Node):
if arg_node.op == "get_attr":
return getattr(arg_node.graph.owning_module, arg_node.target)
if "example_value" not in arg_node.meta:
# This is a non tensor object like `torch.autocast` ctx manager object.
return arg_node
Expand All @@ -185,14 +187,14 @@ def make_tensor_proxy(arg_node):
)
else:
# NOTE - This will be caught will be caught and be part of the SplitReason.
raise TypeError(f"Received `make_tensor_proxy` received example_value which wasn't Tensor or Tuple")
raise TypeError(f"Received `make_input_proxy` received example_value which wasn't Tensor or Tuple")
return proxy(example_value)

# This is int, float, etc.
return arg_node

proxy_args = torch.fx.map_arg(node.args, make_tensor_proxy)
proxy_kwargs = {k: torch.fx.map_arg(v, make_tensor_proxy) for k, v in node.kwargs.items()}
proxy_args = torch.fx.map_arg(node.args, make_input_proxy)
proxy_kwargs = {k: torch.fx.map_arg(v, make_input_proxy) for k, v in node.kwargs.items()}
return proxy_args, proxy_kwargs


Expand Down

0 comments on commit 64ffc97

Please sign in to comment.