Skip to content

Commit

Permalink
Merge pull request #1173 from jakevdp:ci-nightly-jax
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715903657
  • Loading branch information
OptaxDev committed Jan 15, 2025
2 parents b05d247 + 47760ba commit 98c73c5
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 22 deletions.
31 changes: 21 additions & 10 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,21 @@ on:
schedule:
- cron: '0 3 * * *'

permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

jobs:
linting:
name: "Lint check with flake8 and pylint"
runs-on: "ubuntu-latest"
steps:
- uses: "actions/checkout@v2"
- uses: "actions/setup-python@v4"
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: "3.11"
cache: "pip"
Expand All @@ -38,8 +46,8 @@ jobs:
name: "Lint check with ruff"
runs-on: "ubuntu-latest"
steps:
- uses: "actions/checkout@v2"
- uses: "actions/setup-python@v4"
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: "3.11"
cache: "pip"
Expand All @@ -57,8 +65,8 @@ jobs:
python-version: ["3.11"] # only build docs with a somewhat latest python
os: [ubuntu-latest]
steps:
- uses: "actions/checkout@v2"
- uses: "actions/setup-python@v4"
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: "${{ matrix.python-version }}"
cache: "pip"
Expand All @@ -83,9 +91,12 @@ jobs:
- python-version: "3.9"
os: "ubuntu-latest"
jax-version: "0.4.27" # Keep version in sync with pyproject.toml and copy.bara.sky!
- python-version: "3.12"
os: "ubuntu-latest"
jax-version: "nightly"
steps:
- uses: "actions/checkout@v2"
- uses: "actions/setup-python@v4"
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: "${{ matrix.python-version }}"
cache: "pip"
Expand All @@ -98,9 +109,9 @@ jobs:
runs-on: "ubuntu-latest"
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Check links
uses: gaurav-nelson/github-action-markdown-link-check@v1
uses: gaurav-nelson/github-action-markdown-link-check@d53a906aa6b22b8979d33bc86170567e619495ec # v1.0.15
with:
use-quiet-mode: yes
use-verbose-mode: yes
Expand Down
1 change: 1 addition & 0 deletions optax/_src/linesearch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def test_linesearch_with_jax_variants(self):
value = otu.tree_get(state, 'value')
self.assertFalse(jnp.isinf(value))

@absltest.skip('TODO(rdyro): need to match scipy linesearch algorithm')
@parameterized.product(
problem_name=[
'polynomial',
Expand Down
20 changes: 10 additions & 10 deletions optax/losses/_self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,23 @@ def ntxent(
>>> x = jax.random.normal(key1, shape=(4,2))
>>> labels = jnp.array([0, 0, 1, 1])
>>>
>>> print("input:", x)
input: [[-0.9155995 1.5534698 ]
[ 0.2623586 -1.5908985 ]
[-0.15977189 0.480501 ]
[ 0.58389133 0.10497775]]
>>> print("input:", x) # doctest: +SKIP
input: [[ 0.07592554 -0.48634264]
[ 1.2903206 0.5196119 ]
[ 0.30040437 0.31034866]
[ 0.5761609 -0.8074621 ]]
>>> print("labels:", labels)
labels: [0 0 1 1]
>>>
>>> w = jax.random.normal(key2, shape=(2,1)) # params
>>> b = jax.random.normal(key3, shape=(1,)) # params
>>> out = x @ w + b # model
>>>
>>> print("Embeddings:", out)
Embeddings: [[-1.0076267]
[-1.2960069]
[-1.1829865]
[-1.3485558]]
>>> print("Embeddings:", out) # doctest: +SKIP
Embeddings: [[0.08969027]
[1.6291292 ]
[0.8622629 ]
[0.13612625]]
>>> loss = optax.ntxent(out, labels)
>>> print("loss:", loss)
loss: 1.0986123
Expand Down
4 changes: 2 additions & 2 deletions optax/perturbations/_make_pert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def exact_loss(inputs):
expect_hessian = jax.hessian(pert_loss)(self.array_small_jax, self.rng_jax)
got_hessian = jax.hessian(exact_loss)(self.array_small_jax)
chex.assert_trees_all_equal_shapes(expect_hessian, got_hessian)
chex.assert_trees_all_close(expected_grad, got_grad, atol=2e-2)
chex.assert_trees_all_close(expected_grad, got_grad, atol=6e-2)
expected_dict = pert_argmax_fun(self.tree_a_dict_jax, self.rng_jax)
got_dict = jtu.tree_map(softmax_fun, self.tree_a_dict_jax)
chex.assert_trees_all_close(expected_dict, got_dict, atol=2e-2)
chex.assert_trees_all_close(expected_dict, got_dict, atol=6e-2)

def test_values_on_tree(self):
"""Test that the perturbations are well applied for functions on trees.
Expand Down
2 changes: 2 additions & 0 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ if [ -z "${JAX_VERSION-}" ]; then
: # use version installed in requirements above
elif [ "$JAX_VERSION" = "newest" ]; then
python3 -m pip install --quiet --upgrade jax jaxlib
elif [ "$JAX_VERSION" = "nightly" ]; then
python3 -m pip install --quiet --upgrade --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
else
python3 -m pip install --quiet "jax==${JAX_VERSION}" "jaxlib==${JAX_VERSION}"
fi
Expand Down

0 comments on commit 98c73c5

Please sign in to comment.