From f88f30477b8f9f4f7b2b5e2a5dc9a8a7535cd1e9 Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Sun, 3 Nov 2024 23:38:33 +0900 Subject: [PATCH] [CI] Add old Jax tests (#1279) --- .github/workflows/ci.yml | 6 ++++++ requirements/requirements.txt | 3 ++- setup.py | 1 - tests/test_animal_shogi.py | 4 ++-- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0b3b2b346..3c272fb57 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,6 +27,12 @@ jobs: architecture: 'x64' - name: install run: | + pip install -U pip + if [ "${{ matrix.python-version }}" = "3.9" ]; then + pip install "jax==0.4.6" "jaxlib==0.4.6" "numpy<2.0.0" + else + pip install jax jaxlib + fi make install - name: test run: | diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 5573d3b84..44a30f15f 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,4 +1,5 @@ # jax must be installed previously depending on hardware -jax>=0.4.1 +# jax==0.4.6 is the oldest version available now (2024.11.03) +jax>=0.4.6 typing_extensions>=4.2.0 svgwrite diff --git a/setup.py b/setup.py index 341d0a1bf..9d833908f 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,6 @@ def _read_requirements(fname): include_package_data=True, install_requires=_read_requirements("requirements.txt"), classifiers=[ - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/tests/test_animal_shogi.py b/tests/test_animal_shogi.py index 0cd7d3b5d..eb3ea612b 100644 --- a/tests/test_animal_shogi.py +++ b/tests/test_animal_shogi.py @@ -232,7 +232,7 @@ def test_api(): def test_buggy_samples(): # https://github.com/sotetsuk/pgx/pull/1209 - state = init(jax.random.key(0)) + state = init(jax.random.PRNGKey(0)) state = step(state, 3 * 12 + 6) # White: Up PAWN state = step(state, 0 * 12 + 11) # Black: Right Up Bishop state = step(state, 8 * 12 + 1) # White: Drop PAWN to 1 @@ -248,7 +248,7 @@ def test_buggy_samples(): assert mask[LEFT_GOLD] # https://github.com/sotetsuk/pgx/pull/1218 - state = init(jax.random.key(0)) + state = init(jax.random.PRNGKey(0)) state = step(state, 3 * 12 + 6) # White: Up PAWN state = step(state, 0 * 12 + 11) # Black: Right Up Bishop DROP_PAWN_TO_0 = 8 * 12 + 0