You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using the functional API with optimizer adamw with the mask parameter specified, the expectation is that update is applied with weight decay skipped for the masked parameters. Instead, update fails with 'AttributeError: 'MaskedNode' object has no attribute 'add''.
The comment for 'MaskedNode' states "This node is ignored when mapping functions across the tree e.g. using :func:pytree.tree_map since it is a container without children. It can therefore be used to mask out parts of a tree." However, this does not appear to be the case.
Change the optimizer to adamw and specify a mask as specified in the python snipped above.
python parallel_train_torchopt.py
optimizer.update fails
Traceback
File "torchopt_test.py", line 230, in <module>
functorch_original.test_train_step_fn(weights, opt_state, points, labels)
File "torchopt_test.py", line 160, in test_train_step_fn
loss, (weights, opt_state) =self.train_step_fn((weights, opt_state), points, labels)
File "torchopt_test.py", line 154, in train_step_fn
updates, new_opt_state = optimizer.update(grads, opt_state, params=weights, inplace=False)
File "/usr/local/lib/python3.10/dist-packages/torchopt/combine.py", line 92, in update_fn
flat_updates, state = inner.update(flat_updates, state, params=flat_params, inplace=inplace)
File "/usr/local/lib/python3.10/dist-packages/torchopt/base.py", line 196, in update_fn
updates, new_s = fn(updates, s, params=params, inplace=inplace)
File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 132, in update_fn
new_masked_updates, new_inner_state = inner.update(
File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 243, in update_fn
updates = tree_map(f, params, updates)
File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/utils.py", line 65, in tree_map_flatreturn flat_arg.__class__(map(fn, flat_arg, *flat_args)) # type:ignore[call-arg]
File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/utils.py", line 63, in fnreturn func(x, *xs) if x isnotNoneelseNone
File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 241, in freturn g.add(p, alpha=weight_decay) if g isnotNoneelse g
AttributeError: 'MaskedNode' object has no attribute 'add'
Expected behavior
The expectation is that when a mask is supplied to adamw, update is successful and weight decay is skipped for the masked parameters.
Additional context
No response
The text was updated successfully, but these errors were encountered:
Required prerequisites
What version of TorchOpt are you using?
0.7.3
System information
pip install torchopt
Python 3.10.12
3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] linux
0.7.3 2.5.0a0+872d972e41.nv24.08 2.5.0a0+872d972e41.nv24.08
Problem description
When using the functional API with optimizer adamw with the mask parameter specified, the expectation is that update is applied with weight decay skipped for the masked parameters. Instead, update fails with 'AttributeError: 'MaskedNode' object has no attribute 'add''.
The comment for 'MaskedNode' states "This node is ignored when mapping functions across the tree e.g. using :func:
pytree.tree_map
since it is a container without children. It can therefore be used to mask out parts of a tree." However, this does not appear to be the case.Reproducible example code
The Python snippets:
Command lines:
Extra dependencies:
Steps to reproduce:
Traceback
Expected behavior
The expectation is that when a mask is supplied to adamw, update is successful and weight decay is skipped for the masked parameters.
Additional context
No response
The text was updated successfully, but these errors were encountered: