Skip to content

Commit

Permalink
bump jax everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Dec 24, 2024
1 parent e9f38cc commit b7e6314
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/launch_small_fast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
- name: Install locally
run: |
python -m pip install --upgrade pip
pip install -e .[test] "jax[cpu]==0.4.30"
pip install -e .[test] "jax[cpu]==0.4.38"
- name: Launch Small Fast TPU Train LM job
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_entry_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
matrix:
python-version: ["3.10"]
jax-version: ["0.4.26"]
jax-version: ["0.4.38"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_pre_commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
matrix:
python-version: ["3.10"]
jax-version: ["0.4.14"]
jax-version: ["0.4.38"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_ray_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
matrix:
python-version: ["3.10"]
jax-version: ["0.4.26"]
jax-version: ["0.4.38"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
matrix:
python-version: ["3.10"]
jax-version: ["0.4.26"]
jax-version: ["0.4.38"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion infra/helpers/setup-tpu-vm-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ pip install -U wheel

# jax and jaxlib
# libtpu sometimes has issues installing for clinical (probably firewall?)
retry pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
retru pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# clone levanter
git clone $REPO levanter
Expand Down
2 changes: 1 addition & 1 deletion infra/helpers/setup-tpu-vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ pip install -U wheel

# jax and jaxlib
# libtpu sometimes has issues installing for clinical (probably firewall?)
retry pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
retru pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# clone levanter
git clone $REPO levanter
Expand Down

0 comments on commit b7e6314

Please sign in to comment.