Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] 'AttributeError: 'MaskedNode' object has no attribute 'add'' error when specifying 'mask' parameter for functional adamw API #232

Open
2 of 3 tasks
gkennickell opened this issue Nov 12, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@gkennickell
Copy link

Required prerequisites

What version of TorchOpt are you using?

0.7.3

System information

pip install torchopt
Python 3.10.12
3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] linux
0.7.3 2.5.0a0+872d972e41.nv24.08 2.5.0a0+872d972e41.nv24.08

Problem description

When using the functional API with optimizer adamw with the mask parameter specified, the expectation is that update is applied with weight decay skipped for the masked parameters. Instead, update fails with 'AttributeError: 'MaskedNode' object has no attribute 'add''.

The comment for 'MaskedNode' states "This node is ignored when mapping functions across the tree e.g. using :func:pytree.tree_map since it is a container without children. It can therefore be used to mask out parts of a tree." However, this does not appear to be the case.

Reproducible example code

The Python snippets:

    mask = lambda p: torchopt.pytree.tree_map(lambda x: x.ndim != 1, p)
    optimizer = torchopt.adamw(lr=0.2, weight_decay=0.1, mask=mask)

Command lines:

python parallel_train_torchopt.py

Extra dependencies:


Steps to reproduce:

  1. Use example https://github.com/metaopt/torchopt/blob/main/examples/FuncTorch/parallel_train_torchopt.py#L188
  2. Change the optimizer to adamw and specify a mask as specified in the python snipped above.
  3. python parallel_train_torchopt.py
  4. optimizer.update fails

Traceback

File "torchopt_test.py", line 230, in <module>
    functorch_original.test_train_step_fn(weights, opt_state, points, labels)
  File "torchopt_test.py", line 160, in test_train_step_fn
    loss, (weights, opt_state) = self.train_step_fn((weights, opt_state), points, labels)
  File "torchopt_test.py", line 154, in train_step_fn
    updates, new_opt_state = optimizer.update(grads, opt_state, params=weights, inplace=False)
  File "/usr/local/lib/python3.10/dist-packages/torchopt/combine.py", line 92, in update_fn
    flat_updates, state = inner.update(flat_updates, state, params=flat_params, inplace=inplace)
  File "/usr/local/lib/python3.10/dist-packages/torchopt/base.py", line 196, in update_fn
    updates, new_s = fn(updates, s, params=params, inplace=inplace)
  File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 132, in update_fn
    new_masked_updates, new_inner_state = inner.update(
  File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 243, in update_fn
    updates = tree_map(f, params, updates)
  File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/utils.py", line 65, in tree_map_flat
    return flat_arg.__class__(map(fn, flat_arg, *flat_args))  # type: ignore[call-arg]
  File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/utils.py", line 63, in fn
    return func(x, *xs) if x is not None else None
  File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 241, in f
    return g.add(p, alpha=weight_decay) if g is not None else g
AttributeError: 'MaskedNode' object has no attribute 'add'

Expected behavior

The expectation is that when a mask is supplied to adamw, update is successful and weight decay is skipped for the masked parameters.

Additional context

No response

@gkennickell gkennickell added the bug Something isn't working label Nov 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants