Skip to content

Commit

Permalink
Merge pull request #23 from btalamini/bugfix/fix_failing_test_due_to_…
Browse files Browse the repository at this point in the history
…incompatible_integer_types

Fix a test that fails on update of jax due to mismatched integer types
  • Loading branch information
ralberd authored Jul 7, 2022
2 parents 8bc8b1b + 5b45c16 commit 6172cb5
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions optimism/TensorMath.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,9 +616,13 @@ def cond_f(loopData):

def compute_pade_degree(diff, j, itk):
j += 1
p = np.searchsorted(log_pade_coefficients[2:16], diff, side='right')
# Manually force the return type of searchsorted to be 64-bit int, because it
# returns 32-bit ints, ignoring the global `jax_enable_x64` flag. This looks
# like a bug. I filed an issue (#11375) with Jax to correct this.
# If they fix it, the conversions on p and q can be removed.
p = np.searchsorted(log_pade_coefficients[2:16], diff, side='right').astype(np.int64)
p += 2
q = np.searchsorted(log_pade_coefficients[2:16], diff/2.0, side='right')
q = np.searchsorted(log_pade_coefficients[2:16], diff/2.0, side='right').astype(np.int64)
q += 2
m,j,converged = if_then_else((2 * (p - q) // 3 < itk) | (j == 2),
(p+1,j,True), (0,j,False))
Expand Down

0 comments on commit 6172cb5

Please sign in to comment.