Skip to content

Commit

Permalink
Fix a test that fails on update of jax due to mismatched integer types
Browse files Browse the repository at this point in the history
Looks to be a bug in Jax. A search function is returning a 32-bit
integer in x64 mode. I filed an issue with Jax.

For now, the workaround is to manually cast to int64.
  • Loading branch information
btalamini committed Jul 5, 2022
1 parent 8bc8b1b commit 5b45c16
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions optimism/TensorMath.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,9 +616,13 @@ def cond_f(loopData):

def compute_pade_degree(diff, j, itk):
j += 1
p = np.searchsorted(log_pade_coefficients[2:16], diff, side='right')
# Manually force the return type of searchsorted to be 64-bit int, because it
# returns 32-bit ints, ignoring the global `jax_enable_x64` flag. This looks
# like a bug. I filed an issue (#11375) with Jax to correct this.
# If they fix it, the conversions on p and q can be removed.
p = np.searchsorted(log_pade_coefficients[2:16], diff, side='right').astype(np.int64)
p += 2
q = np.searchsorted(log_pade_coefficients[2:16], diff/2.0, side='right')
q = np.searchsorted(log_pade_coefficients[2:16], diff/2.0, side='right').astype(np.int64)
q += 2
m,j,converged = if_then_else((2 * (p - q) // 3 < itk) | (j == 2),
(p+1,j,True), (0,j,False))
Expand Down

0 comments on commit 5b45c16

Please sign in to comment.