From 3d42c10eaf8dcb5be38c106d1ab2a7477f3acc9e Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Mon, 4 Nov 2024 19:47:53 +0100 Subject: [PATCH] thunderFX: delegate autocast regions to thunder (#1378) --- thunder/dynamo/utils.py | 27 ++++++++----- thunder/tests/test_dynamo.py | 75 +++++++++++++++++++++++++++++------- thunder/torch/__init__.py | 9 ++++- 3 files changed, 87 insertions(+), 24 deletions(-) diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 2887a162c9..8b4c690c0a 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -127,8 +127,13 @@ def get_proxy_inputs_from_node(node: torch.fx.Node) -> tuple[tuple, dict]: with thunder.core.trace.tracectx(TraceCtx()): def make_tensor_proxy(arg_node): - # This is a Node in the graph representing a Tensor or tuple of Tensors. + # This is a Node in the graph representing a Tensor or tuple of Tensors or + # a PyTorch object like one representing torch.autocast. if isinstance(arg_node, torch.fx.Node): + if "example_value" not in arg_node.meta: + # This is a non tensor object like `torch.autocast` ctx manager object. + return arg_node + example_value = arg_node.meta["example_value"] if isinstance(example_value, torch.Tensor): @@ -176,7 +181,15 @@ def try_execute_thunder_symbol(thunder_symbol: Symbol, node: torch.fx.Node) -> t """ import thunder from thunder.core.trace import TraceCtx + from thunder.core.compile_data import compile_data_and_stats + from thunder.common import CompileData, CompileStats + + # This is required for verifying `_enter_autocast` + # which pushes state onto `CompileData.autocast_stack`. + cd = CompileData(fn=lambda x: x, disable_preprocessing=True) + cs = CompileStats() + @compile_data_and_stats(cd, cs) @thunder._with_cache_info_ctx def _run_with_cache_info(): @@ -226,8 +239,8 @@ def get_nodes_in_unsupported_ctx_regions(gm: torch.fx.GraphModule) -> set[torch. nodes_in_unsupported_ctx_regions: set[torch.fx.Node] = set() ctx_cnt = 0 # Count of `enters_autocast` we have seen till now - # We want to mark nodes with `_enter_autocast` and `_exit_autocast` - # as unsupported as `thunder` doesn't correctly deal with these stateful functions. + # We want to mark nodes disabling `autograd` as unsupported + # because `thunder` doesn't correctly deal with these stateful functions. def is_no_grad_ctx_enter(node): if node.target == torch._C._set_grad_enabled: @@ -244,13 +257,9 @@ def is_no_grad_ctx_exit(node): return False for node in gm.graph.nodes: - if node.op == "call_function" and ( - node.target in (torch.amp.autocast_mode._enter_autocast,) or is_no_grad_ctx_enter(node) - ): + if node.op == "call_function" and is_no_grad_ctx_enter(node): ctx_cnt += 1 - elif node.op == "call_function" and ( - node.target in (torch.amp.autocast_mode._exit_autocast,) or is_no_grad_ctx_exit(node) - ): + elif node.op == "call_function" and is_no_grad_ctx_exit(node): ctx_cnt -= 1 else: if ctx_cnt > 0: diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index d2f4e6a676..2f9bb0d124 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -7,6 +7,7 @@ from thunder import dtypes from thunder.dynamo import ThunderCompiler +from thunder.dynamo.utils import CompilerType from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking from thunder import last_traces from thunder.core.symbol import Symbol @@ -126,7 +127,7 @@ def func(x): ), ), ) -def test_splitter_unsupported_ctx(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): +def test_splitter_autocast_ctx(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): x = torch.rand(2, 2, device=device, dtype=dtype, requires_grad=True) backend = ThunderCompiler() @@ -149,15 +150,10 @@ def func(x): torch.testing.assert_close(actual_grad, expected_grad) assert len(backend.subgraph_infos) == 1 - assert len(backend.subgraph_infos[0].submodule_to_compiled_functions) > 1 # Verify that the subgraph was split. - assert any( - "it is in unsupported context" in split_reason.info for split_reason in backend.subgraph_infos[0].split_reasons - ) - targets = (node.target for node in backend.subgraph_infos[0].split_graph_module.graph.nodes) - assert any(target.startswith("thunder_") for target in targets) # Verify that the submodules have name `thunder_*` - assert any( - target.startswith("inductor_") for target in targets - ) # Verify that the submodules have name `inductor_*` + assert len(backend.subgraph_infos[0].split_reasons) == 0 + compiled_functions = tuple(backend.subgraph_infos[0].submodule_to_compiled_functions.values()) + assert all(compiled_fn.compiler == CompilerType.THUNDER for compiled_fn in compiled_functions) + assert not any(compiled_fn.compiler == CompilerType.TORCH_INDUCTOR for compiled_fn in compiled_functions) @instantiate( @@ -172,7 +168,7 @@ def func(x): ), ), ) -def test_splitter_unsupported_ctx_with_graph_break(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): +def test_splitter_autocast_ctx_with_graph_break(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): x = torch.rand(2, 2, device=device, dtype=dtype, requires_grad=True) backend = ThunderCompiler() @@ -184,7 +180,7 @@ def func(x): torch._dynamo.graph_break() return torch.matmul(x, y) - expected = torch.compile(func, dynamic=False)(x) + expected = torch.compile(func, dynamic=dynamic)(x) cfunc = torch.compile(func, backend=backend, dynamic=dynamic) actual = cfunc(x) @@ -197,8 +193,59 @@ def func(x): # 2 subgraphs due to graph-break assert len(backend.subgraph_infos) == 2 for subgraph_info in backend.subgraph_infos: - # Verify that for each subgraph we had split due to `autocast` being enabled. - assert any("it is in unsupported context" in split_reason.info for split_reason in subgraph_info.split_reasons) + assert len(subgraph_info.split_reasons) == 0 + compiled_functions = tuple(subgraph_info.submodule_to_compiled_functions.values()) + assert all(compiled_fn.compiler == CompilerType.THUNDER for compiled_fn in compiled_functions) + assert not any(compiled_fn.compiler == CompilerType.TORCH_INDUCTOR for compiled_fn in compiled_functions) + + +@instantiate( + dtypes=NOTHING, + executors=[DynamoThunderExecutor], + decorators=( + pytest.mark.parametrize("dynamic", (True, False, None), ids=("dynamic", "static", "auto")), + pytest.mark.xfail( + condition=IS_WINDOWS, + strict=True, + reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", + ), + ), +) +def test_splitter_autocast_ctx_with_split(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): + x = torch.rand(2, 2, device=device, dtype=dtype, requires_grad=True) + + backend = ThunderCompiler() + + def func(x): + x = x + 2 + with torch.autocast(device): + y = torch.sin(x) + + # torch.sinc has automatic fallback registered, + # so that operation will be given to inductor. + y = torch.sinc(y) + return torch.matmul(x, y) + + expected = torch.compile(func, dynamic=dynamic)(x) + cfunc = torch.compile(func, backend=backend, dynamic=dynamic) + actual = cfunc(x) + + g = torch.rand_like(actual) + torch.testing.assert_close(actual, expected) + actual_grad = torch.autograd.grad(actual, x, g) + expected_grad = torch.autograd.grad(expected, x, g) + torch.testing.assert_close(actual_grad, expected_grad) + + assert len(backend.subgraph_infos) == 1 # no graph break in dynamo + + subgraph_info = backend.subgraph_infos[0] + assert len(subgraph_info.split_reasons) > 1 # Split due to `torch.sinc` + compiled_functions = tuple(subgraph_info.submodule_to_compiled_functions.values()) + assert any(compiled_fn.compiler == CompilerType.THUNDER for compiled_fn in compiled_functions) + assert any(compiled_fn.compiler == CompilerType.TORCH_INDUCTOR for compiled_fn in compiled_functions) + assert any( + "automatic torch fallback" in split_reason.info for split_reason in subgraph_info.split_reasons + ) # Verify that we had a split because we detected an `automatic registered operator` @instantiate( diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 3188855323..7047485ff6 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -5616,7 +5616,12 @@ def backward_autograd_function_apply( id="torch.amp.autocast_mode._enter_autocast", tags=(prims.OpTags.DONT_DCE, prims.OpTags.CTX_MANAGER_ENTER_EXIT_OP), ) -def autocast_enter(device_type, dtype=None, enabled=True): +def autocast_enter(device_type, dtype=None, enabled=True, _unused_cache_enabled=True): + # We may receive device_type=cuda:0 + # PyTorch applies autocast irrespective of device index. + # So, here we grab the device_type from the string. + device_type, unused_deviceno = devices._device_from_string_helper(device_type) + device_type = devices.devicetype_string(device_type) if dtype is None: dtype = torch.get_autocast_dtype(device_type) get_compile_data().autocast_stack.push(device_type, dtype, enabled) @@ -5628,6 +5633,8 @@ def autocast_enter(device_type, dtype=None, enabled=True): tags=(prims.OpTags.DONT_DCE, prims.OpTags.CTX_MANAGER_ENTER_EXIT_OP), ) def autocast_exit(*args): + if get_compile_data().autocast_stack.is_empty(): + return get_compile_data().autocast_stack.pop()