Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThunderFX: handles the callable input of fx.Node #1548

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft

Conversation

kiya00
Copy link
Collaborator

@kiya00 kiya00 commented Dec 12, 2024

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #1539.

Needs #1463 to pass no_grad regions to thunder in thunderFX; #1568
As the analysis in #1539 (comment), this PR try to fix it by adding a dead code elimination and processing the get_attr node

@kiya00 kiya00 marked this pull request as draft December 12, 2024 20:10
Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix, @kiya00 . The fix itself looks good, however the test failures look real.

Also, once everything works fine, I'd expect the following test to fail as it shouldn't find any module which was passed to inductor. (So it will need an update)

def test_splitter_autograd_function(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None):
x = torch.ones(2, device=device, dtype=dtype, requires_grad=True)

This line should raise an assertion error -

assert any(target.startswith("inductor_") for target in targets)

@kiya00
Copy link
Collaborator Author

kiya00 commented Dec 13, 2024

Hi @kshitij12345 I met 2 problems, do you maybe have any suggestions?

  1. When we use the graph.eliminate_dead_code() to remove the unused function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None, other inplace ops can be deleted, like in the case test_thundercompiler_optim_step. We need a function to specify if the node is impure if use graph.eliminate_dead_code(is_impure_node: Optional[Callable[[Node], bool]])
# graph in test_thundercompiler_optim_step
def forward(self, L_self_param_groups_0_params_0_grad : torch.Tensor, L_self_param_groups_0_params_1_grad : torch.Tensor, L_self_param_groups_0_params_2_grad : torch.Tensor, L_self_param_groups_0_params_3_grad : torch.Tensor):
    l_self_param_groups_0_params_0_grad = L_self_param_groups_0_params_0_grad
    l_self_param_groups_0_params_1_grad = L_self_param_groups_0_params_1_grad
    l_self_param_groups_0_params_2_grad = L_self_param_groups_0_params_2_grad
    l_self_param_groups_0_params_3_grad = L_self_param_groups_0_params_3_grad
    p = self.self___param_groups_0__params___0
    p_1 = self.self___param_groups_0__params___1
    p_2 = self.self___param_groups_0__params___2
    p_3 = self.self___param_groups_0__params___3
    _foreach_add_ = torch._foreach_add_([p, p_1, p_2, p_3], [l_self_param_groups_0_params_0_grad, l_self_param_groups_0_params_1_grad, l_self_param_groups_0_params_2_grad, l_self_param_groups_0_params_3_grad], alpha = -0.001);  p = p_1 = p_2 = p_3 = l_self_param_groups_0_params_0_grad = l_self_param_groups_0_params_1_grad = l_self_param_groups_0_params_2_grad = l_self_param_groups_0_params_3_grad = None
    return ()

Or do we set a specific pass to delete the torch.autograd.function.FunctionCtx()? or maybe support the torch.autograd.function.FunctionCtx op?

  1. when we check if thunder supports autograd_function_apply in splitter, I assume that we should take it as supported only if the 2 input fwd/bwd submodules are fully supported. But there's _set_grad_enabled(False) in the 2 submodules , which causes the autograd_function_apply to always be not supported by Thunder.

Here is the FX graph structure of autograd_function_apply

GraphModule(
  (fwd_body_0): GraphModule()
  (bwd_body_0): GraphModule()
)

def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
    l_x_ = L_x_
    cos = torch.cos(l_x_)
    function_ctx = torch.autograd.function.FunctionCtx();  function_ctx = None
    fwd_body_0 = self.fwd_body_0
    bwd_body_0 = self.bwd_body_0
    autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, s0, args_tensor_mask = [True], non_differentiable_idx = []);  fwd_body_0 = bwd_body_0 = s0 = None
    y = cos + autograd_function_apply;  cos = autograd_function_apply = None
    matmul = torch.matmul(l_x_, y);  l_x_ = y = None
    return (matmul,)

the fwd_body_0 module:

def forward(self, ctx : torch.autograd.function.Function, x : torch.Tensor, s0 : torch.SymInt):
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    sin = torch.sin(x)
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
    return (sin, [s0, x])

P.S. when we fix the above problems we also need to let the fwd_body_0 go through the converter, because for the same reason as the checkpoint operator, thunder couldn't trace the torch.sin in the submodule

cc: @IvanYashchuk

@kshitij12345
Copy link
Collaborator

kshitij12345 commented Dec 13, 2024

Regarding1., I am not very familiar with graph.eliminate_dead_code so if using that is tricky then adding a pass that just removes unused function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None sounds good. I am curious as to how torch.compile pipeline deals with this.

For 2., #1463 should take care of this as regions within _set_grad_enabled will be passed to thunder.

@kiya00
Copy link
Collaborator Author

kiya00 commented Dec 16, 2024

By cherry-pick #1463, the autograd_function_apply is compiled by thunder, but the trace has a problem with the _set_grad_enabled(False)
Here is the FX graph structure of autograd_function_apply

GraphModule(
  (fwd_body_0): GraphModule()
  (bwd_body_0): GraphModule()
)

def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
    l_x_ = L_x_
    cos = torch.cos(l_x_)
    function_ctx = torch.autograd.function.FunctionCtx();  function_ctx = None
    fwd_body_0 = self.fwd_body_0
    bwd_body_0 = self.bwd_body_0
    autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, s0, args_tensor_mask = [True], non_differentiable_idx = []);  fwd_body_0 = bwd_body_0 = s0 = None
    y = cos + autograd_function_apply;  cos = autograd_function_apply = None
    matmul = torch.matmul(l_x_, y);  l_x_ = y = None
    return (matmul,)

the fwd_body_0 module:

def forward(self, ctx : torch.autograd.function.Function, x : torch.Tensor, s0 : torch.SymInt):
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    sin = torch.sin(x)
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
    return (sin, [s0, x])

the augmented fwd trace is:

@torch.no_grad()
@no_autocast
def computation(l_x_):
  # l_x_: "cpu f32[2]"

  # <eval_with_key>.16:7:           cos = torch.cos(l_x_)
  cos = prims.cos(l_x_)  # cos: "cpu f32[2]"

  # <eval_with_key>.15:8:           _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
  autograd_function_apply = higher_order_autograd_function_apply_vvytl_130871302739856_0(None, l_x_)  # autograd_function_apply: "cpu f32[2]"
    # autograd_function_apply = ltorch.sin(l_x_)  # autograd_function_apply: "cpu f32[2]"
      # autograd_function_apply = prims.sin(l_x_)  # autograd_function_apply: "cpu f32[2]"

  # <eval_with_key>.16:9:           y = cos + autograd_function_apply;  cos = autograd_function_apply = None
  y = ltorch.add(cos, autograd_function_apply, alpha=1)  # y: "cpu f32[2]"
    # y = prims.add(cos, autograd_function_apply)  # y: "cpu f32[2]"

  # <eval_with_key>.16:10:          matmul = torch.matmul(l_x_, y);  l_x_ = y = None
  matmul = prims.matmul(l_x_, y)  # matmul: "cpu f32[]"
  return {'output': (matmul,), 'flat_args': [l_x_], 'flat_output': (matmul,)}, ((l_x_, y), ()), 

Note that the higher_order_autograd_function_apply_vvytl_130871302739856_0 corresponds to the above fwd_body_0 module and the tag is ProxyTag.DETACHED_AUTOGRAD_GRAPH, so the grad of it is skipped in the bwd trace

if is_constant_for_vjp(symbol):
# We can skip the pullback if all the arguments are constant
continue

I think the correct behavior should be to use the grad transformation:
def grad_transform(*args, **kwargs):

Do we maybe remove the torch._C._set_grad_enabled in the fwd_body_0 module? @kshitij12345 @IvanYashchuk , do you have some suggestions?
It seems if I let is_constant_for_vjp(symbol of higher_order_autograd_function_apply) return False even if the output of higher_order_autograd_function_apply has tag ProxyTag.DETACHED_AUTOGRAD_GRAPH, the backward trace is expected

@kshitij12345
Copy link
Collaborator

Thanks for the explanation @kiya00.

I was wondering that since the line below goes through thunder.jit, the relevant code from jit_ext which creates a new symbol with a grad rule should be applied automatically, right? What is the backward trace that is being generated?

autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, s0, args_tensor_mask = [True], non_differentiable_idx = []);

@kiya00
Copy link
Collaborator Author

kiya00 commented Dec 18, 2024

I was wondering that since the line below goes through thunder.jit, the relevant code from jit_ext which creates a new symbol with a grad rule should be applied automatically, right? What is the backward trace that is being generated?

yes, it's automatically, the bwd of symbol torch.ops.higher_order.autograd_function_apply is created in backward_pass L2764, but because the fwd_body_0 has no_grad, it continues in L2744(so no bwd is added for autograd_function_apply)

if is_constant_for_vjp(symbol):
# We can skip the pullback if all the arguments are constant
continue
if all(cotangent is None for cotangent in cotangents):
# We can skip the pullback if the cotangent is None
safe_map(put_grad, symbol.args, (None,) * len(symbol.args))
continue
if symbol.sym.id == "torch.nn.functional.dropout" and not symbol.subsymbols:
# We can skip the pullback if the dropout probability is 0.0
# Assuming that the dropout symbol has the same output and argument
assert symbol.output.name == symbol.args[0].name, "Dropout symbol has a different output and argument"
if symbol.args[1] == 0.0 or symbol.args[2] is False:
continue
backward = backward_impls.get(symbol.sym.id)
aug_forward = augmented_forward_impls.get(symbol.sym.id)
if _get_gradfn_and_executor(symbol)[0] is not None:
aug_forward, backward = make_aug_forward_and_backward(symbol)

so I tried to let is_constant_for_vjp always return False for autograd_function_apply
then the fwd trace is(corresponds to the test case in the PR):

@torch.no_grad()
@no_autocast
def computation(l_x_):
  # l_x_: "cuda:0 f32[2]"
  [y] = nvFusion0(l_x_)
    # cos = prims.cos(l_x_)  # cos: "cuda:0 f32[2]"
    # autograd_function_apply = prims.sin(l_x_)  # autograd_function_apply: "cuda:0 f32[2]"
    # y = prims.add(cos, autograd_function_apply)  # y: "cuda:0 f32[2]"

  # <eval_with_key>.121:10:         matmul = torch.matmul(l_x_, y);  l_x_ = y = None
  matmul = torch.matmul(l_x_, y)  # matmul: "cuda:0 f32[]"
    # matmul = ltorch.matmul(l_x_, y)  # matmul: "cuda:0 f32[]"
      # matmul = prims.matmul(l_x_, y)  # matmul: "cuda:0 f32[]"
  return {'output': (matmul,), 'flat_args': [l_x_], 'flat_output': (matmul,)}, ((l_x_, y), ())

the bwd trace:

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t0, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  l_x_, y, = C0
  clear_mutable_collection(C0)
  del C0
  [t28] = nvFusion0(l_x_, t0, y)
    # t21 = prims.cos(l_x_)  # t21: "cuda:0 f32[2]"
    # t25 = prims.sin(l_x_)  # t25: "cuda:0 f32[2]"
    # t17 = prims.broadcast_in_dim(t0, (2,), ())  # t17: "cuda:0 f32[2]"
    # t18 = prims.mul(t17, y)  # t18: "cuda:0 f32[2]"
    # t20 = prims.mul(t17, l_x_)  # t20: "cuda:0 f32[2]"
    # t22 = prims.mul(t20, t21)  # t22: "cuda:0 f32[2]"
    # t23 = prims.mul(t22, 100.0)  # t23: "cuda:0 f32[2]"
    # t24 = prims.add(t18, t23)  # t24: "cuda:0 f32[2]"
    # t26 = prims.neg(t25)  # t26: "cuda:0 f32[2]"
    # t27 = prims.mul(t20, t26)  # t27: "cuda:0 f32[2]"
    # t28 = prims.add(t24, t27)  # t28: "cuda:0 f32[2]"
  del l_x_, t0, y
  return (t28,)

Comment on lines 2525 to 2526
if isinstance(symbol.sym.id, str) and symbol.sym.id.startswith("higher_order_autograd_function_apply"):
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Can you please provide an example of a failing case without these lines?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the reason is #1548 (comment), the test case test_splitter_autograd_function modified in this PR will fail without it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's most likely a viable workaround for the problems seen in this PR, but I don't think this is a correct fix. There's a problem with thunder.torch._set_grad_enabled_with_warning(False) called inside the forward function passed to thunder.torch.autograd_function_apply that causes the system to ignore provided backward which should be fixed first:

import thunder
import torch

def forward(_, x):
    saved_for_backward = (x,)
    thunder.torch._set_grad_enabled_with_warning(False) # Without this line the specified backward is called as expected
    sin = thunder.torch.sin(x)
    thunder.torch._set_grad_enabled_with_warning(True)
    return sin, saved_for_backward

def backward(_, grad_output, *saved_tensors):
    raise NotImplementedError

def my_sin(x):
    res = thunder.torch.autograd_function_apply(
        forward,
        backward,
        x,
        args_tensor_mask=[True],
        non_differentiable_idx=[],
    )
    return res

jitted = thunder.jit(my_sin)
x = torch.randn((2, 2), requires_grad=True)

out = jitted(x) # Should raise NotImplementedError but it doesn't
out.backward(out)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the repro, I think autograd_function_apply should override any set_grad_enabled inside. I will try updating #1463 to work with this and discuss the fix.

Copy link
Collaborator Author

@kiya00 kiya00 Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thunder.torch.autograd_function_apply corresponds to the torch.ops.higher_order.autograd_function_apply that only appears in dynamo, and dynamo will add a no_grad guard around the forward function, so it's ok if we use it in thunderFX, but I agree we need to think if there are other ways to handle it.
if we write it as follows, it raises error

import thunder
import torch
from thunder.dynamo import thunderfx

class Sin(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        with torch.no_grad():
            return torch.sin(x)

    @staticmethod
    def backward(ctx, g):
        #(x,) = ctx.saved_tensors
        #return g * torch.cos(x) * 100
        raise NotImplementedError("aaaaa")

def my_sin(x):
    return Sin.apply(x)

# jitted = thunder.jit(my_sin)
jitted = thunderfx(my_sin)
x = torch.randn((2, 2), requires_grad=True)

out = jitted(x) # Should raise NotImplementedError but it doesn't
out.backward(out)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kshitij12345, could you please create a pull request with the changes from #1548 (comment) and a test from #1548 (comment). With your fix, the code added here to is_constant_for_vjp shouldn't be necessary.

Copy link
Collaborator

@IvanYashchuk IvanYashchuk Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, the lookaside for torch.ops.higher_order.autograd_function_apply doesn't use thunder.torch.autograd_function_apply, and the fix proposed in #1548 (comment) needs to be duplicated for forward_result here:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have updated #1463

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need to update to consider for comment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @kshitij12345 , I've removed the cherry-picked commits and the modification in is_constant_for_vjp, I think after #1548 (comment) is fixed, the test case test_splitter_autograd_function in this PR should pass

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants