Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for loss functions with auxiliary data to linesearch #1177

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ro0mquy
Copy link

@ro0mquy ro0mquy commented Jan 16, 2025

Summary

This change adds support for loss functions that return auxiliary data alongside their primary value, like (loss_value, extra_data). This pattern is commonly used with jax.value_and_grad(fn, has_aux=True).

The approach:

  1. Added value_fn_has_aux flag to zoom_linesearch and scale_by_zoom_linesearch
  2. Modified value handling to properly unpack auxiliary data when needed using a new _unpack_value helper that extracts just the loss value
  3. Updated value storage in state to keep the full value+aux tuple when needed
  4. Added has_aux parameter to value_and_grad_from_state to properly handle auxiliary data when reusing cached values

This allows the linesearch algorithms to work with loss functions that return auxiliary data while maintaining the optimization over just the primary loss value.

Input needed: How to initialize opt_state?

The linesearch algorithm stores value and grad in the optimizer state to enable reuse of function evaluations. When using auxiliary data, JAX compilation needs to know the structure of this data upfront.

Currently, I'm initializing it like this:

opt_state = optimizer.init(params)
# Run loss function once to get auxiliary data structure
_, aux = loss(params)
# Set value to infinity (to force recalculation) but keep aux structure
value = (jnp.asarray(jnp.inf), aux)
opt_state = optax.tree_utils.tree_set(opt_state, value=value)

This feels a bit hacky since it requires an extra function evaluation just to get the structure. Is there a better way to handle this initialization?

The challenge is that the auxiliary data structure is determined by the loss function and could be arbitrary (e.g., dictionaries, nested structures, etc.).

ToDos

  • Add support to backtracking linesearch
  • Add documentation and doc strings
  • Add tests
  • Improve handling of initial opt_state

This change adds support for loss functions that return auxiliary data alongside
their primary value, like (loss_value, extra_data). This pattern is commonly
used with jax.value_and_grad(fn, has_aux=True).

The approach:
1. Added value_fn_has_aux flag to zoom_linesearch and scale_by_zoom_linesearch
2. Modified value handling to properly unpack auxiliary data when needed using
   a new _unpack_value helper that extracts just the loss value
3. Updated value storage in state to keep the full value+aux tuple when needed
4. Added has_aux parameter to value_and_grad_from_state to properly handle
   auxiliary data when reusing cached values

This allows the linesearch algorithms to work with loss functions that return
auxiliary data while maintaining the optimization over just the primary loss value.
Copy link

google-cla bot commented Jan 16, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant