Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThunderFX: handles the callable input of fx.Node #1548

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,22 @@ def _generate_random_str_id() -> str:
# note that this key is quite new: https://github.com/pytorch/pytorch/pull/134087
# non_differentiable_idx = fwd_kwargs.get("non_differentiable_idx")
length_of_tensor_args = sum(args_tensor_mask)
new_fwd_args = (wrap_const(None),) + fwd_args[:length_of_tensor_args]

# N.B.(crcrpar) When `torch.compile(..., dynamic=True)`,
# GraphModules' forward seem to take `SymInt` and other values
# as its argument with some probability. Though that piece of information unfortunately
# does not seem to be indicated in ``args_tensor_`` nor ``non_differentiable_idx``.
# Thus we optimistically iterate over ``fwd_args`` and gather non-tensor values to ``fwd_args``.
new_fwd_args = []
for i, v in enumerate(fwd_args):
if i < length_of_tensor_args:
new_fwd_args.append(v)
else:
# note(crcrpar): we might want to include `FutureTensorProxy` and
# a proxy of tensor subclass in the near future.
if not isinstance(unwrap(v), TensorProxy):
new_fwd_args.append(v)
new_fwd_args = (wrap_const(None),) + tuple(new_fwd_args)

aug_fwd_trace, aug_fwd_provenance = _convert_pytorchfunc_to_thundertrace(fwd, False, *new_fwd_args)
if aug_fwd_trace is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
Expand Down
8 changes: 8 additions & 0 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ def callback(node) -> int:
supported_partitions.add(partition_cnt)
return partition_cnt

# Removes the unused torch.autograd.function.FunctionCtx
functionctx_nodes_to_del = (
n for n in gm.graph.find_nodes(op="call_function", target=torch.autograd.function.FunctionCtx) if not n.users
)
for n in functionctx_nodes_to_del:
gm.graph.erase_node(n)
gm.recompile()

# `split_module` iterates over nodes and determines the partition to place them based on the callback.
original_split_gm: torch.fx.GraphModule = split_module(
gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True
Expand Down
27 changes: 17 additions & 10 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,15 @@ def get_proxy_inputs_from_node(node: torch.fx.Node) -> tuple[tuple, dict]:
# We need to be under trace context to generate proxies.
with thunder.core.trace.tracectx(TraceCtx()):

def make_tensor_proxy(arg_node):
def make_input_proxy(arg_node):
# This is a Node in the graph representing a Tensor or tuple of Tensors or
# a PyTorch object like one representing torch.autocast.
if isinstance(arg_node, torch.fx.Node):
# Higher-order operator nodes take get_attr nodes as input to get the called module
if arg_node.op == "get_attr":
attr = getattr(arg_node.graph.owning_module, arg_node.target)
if isinstance(attr, torch.nn.Module):
return attr
if "example_value" not in arg_node.meta:
# This is a non tensor object like `torch.autocast` ctx manager object.
return arg_node
Expand All @@ -186,14 +191,14 @@ def make_tensor_proxy(arg_node):
)
else:
# NOTE - This will be caught will be caught and be part of the SplitReason.
raise TypeError(f"Received `make_tensor_proxy` received example_value which wasn't Tensor or Tuple")
raise TypeError(f"Received `make_input_proxy` received example_value which wasn't Tensor or Tuple")
return proxy(example_value)

# This is int, float, etc.
return arg_node

proxy_args = torch.fx.map_arg(node.args, make_tensor_proxy)
proxy_kwargs = {k: torch.fx.map_arg(v, make_tensor_proxy) for k, v in node.kwargs.items()}
proxy_args = torch.fx.map_arg(node.args, make_input_proxy)
proxy_kwargs = {k: torch.fx.map_arg(v, make_input_proxy) for k, v in node.kwargs.items()}
return proxy_args, proxy_kwargs


Expand Down Expand Up @@ -355,13 +360,15 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
return False, split_reason

# The checkpointed function must be fully supported by Thunder
if target is torch.ops.higher_order.tag_activation_checkpoint:
if target in (torch.ops.higher_order.tag_activation_checkpoint, torch.ops.higher_order.autograd_function_apply):
m = node.graph.owning_module
get_attr_node = node.args[0]
assert get_attr_node.op == "get_attr"
checkpointed_fn = getattr(m, get_attr_node.target)
is_module_supported, split_reason = is_graphmodule_supported_by_thunder(checkpointed_fn)
return is_module_supported, split_reason
for arg_node in node.args:
if arg_node.op == "get_attr":
called_module = getattr(m, arg_node.target)
is_module_supported, split_reason = is_graphmodule_supported_by_thunder(called_module)
if not is_module_supported:
return is_module_supported, split_reason
return True, None

# If thunder has a mapping for this operation, try executing the meta function and see.
# We have a symbol for `torch.where`, but we don't support one overload of it.
Expand Down
15 changes: 11 additions & 4 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def forward(ctx, x):
@staticmethod
def backward(ctx, g):
(x,) = ctx.saved_tensors
return g * torch.cos(x)
return g * torch.cos(x) * 100

def func(x):
y = torch.cos(x) + Sin.apply(x)
Expand All @@ -286,9 +286,16 @@ def func(x):
actual = cfunc(x)

backend = cfunc._backend
targets = (node.target for node in backend.subgraph_infos[0].split_graph_module.graph.nodes)
assert any(target.startswith("thunder_") for target in targets)
assert any(target.startswith("inductor_") for target in targets)
assert len(backend.subgraph_infos) == 1 # no graph break in dynamo
subgraph_info = backend.subgraph_infos[0]
assert len(subgraph_info.split_reasons) == 0 # no split
assert len(subgraph_info.thunder_compiled_fns) == 1
jfunc = subgraph_info.thunder_compiled_fns[0]
trc = last_traces(jfunc)[0]
assert any(
isinstance(bsym.sym.id, str) and bsym.sym.id.startswith("higher_order_autograd_function_apply")
for bsym in trc.bound_symbols
)

# Verify forward pass
torch.testing.assert_close(actual, expected)
Expand Down
Loading