Add support for loss functions with auxiliary data to linesearch #1177
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This change adds support for loss functions that return auxiliary data alongside their primary value, like
(loss_value, extra_data)
. This pattern is commonly used withjax.value_and_grad(fn, has_aux=True)
.The approach:
value_fn_has_aux
flag tozoom_linesearch
andscale_by_zoom_linesearch
_unpack_value
helper that extracts just the loss valuehas_aux
parameter tovalue_and_grad_from_state
to properly handle auxiliary data when reusing cached valuesThis allows the linesearch algorithms to work with loss functions that return auxiliary data while maintaining the optimization over just the primary loss value.
Input needed: How to initialize
opt_state
?The linesearch algorithm stores
value
andgrad
in the optimizer state to enable reuse of function evaluations. When using auxiliary data, JAX compilation needs to know the structure of this data upfront.Currently, I'm initializing it like this:
This feels a bit hacky since it requires an extra function evaluation just to get the structure. Is there a better way to handle this initialization?
The challenge is that the auxiliary data structure is determined by the loss function and could be arbitrary (e.g., dictionaries, nested structures, etc.).
ToDos
opt_state