Skip to content

Commit

Permalink
minor change to make backtracking init and update return the same sta…
Browse files Browse the repository at this point in the history
…te types, this avoids recompilation when calling update
  • Loading branch information
bafflingbits committed Jan 16, 2025
1 parent 98c73c5 commit 2ad8ad6
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions optax/_src/linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,15 @@ def init_fn(params: base.Params) -> ScaleByBacktrackingLinesearchState:
grad = otu.tree_zeros_like(params)
else:
grad = None
# base output type on params type, except only real part if complex
val_dtype = jax.dtypes.canonicalize_dtype(jnp.real(params[0]).dtype)
return ScaleByBacktrackingLinesearchState(
learning_rate=jnp.array(1.0),
value=jnp.array(jnp.inf),
value=jnp.array(jnp.inf, dtype=val_dtype),
grad=grad,
info=BacktrackingLinesearchInfo(
num_linesearch_steps=0,
decrease_error=jnp.array(jnp.inf),
decrease_error=jnp.array(jnp.inf, dtype=val_dtype),
),
)

Expand Down

0 comments on commit 2ad8ad6

Please sign in to comment.