Skip to content

Commit

Permalink
[CI] Add old Jax tests (#1279)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Nov 3, 2024
1 parent e96c9ba commit f88f304
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 4 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ jobs:
architecture: 'x64'
- name: install
run: |
pip install -U pip
if [ "${{ matrix.python-version }}" = "3.9" ]; then
pip install "jax==0.4.6" "jaxlib==0.4.6" "numpy<2.0.0"
else
pip install jax jaxlib
fi
make install
- name: test
run: |
Expand Down
3 changes: 2 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# jax must be installed previously depending on hardware
jax>=0.4.1
# jax==0.4.6 is the oldest version available now (2024.11.03)
jax>=0.4.6
typing_extensions>=4.2.0
svgwrite
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def _read_requirements(fname):
include_package_data=True,
install_requires=_read_requirements("requirements.txt"),
classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_animal_shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_api():

def test_buggy_samples():
# https://github.com/sotetsuk/pgx/pull/1209
state = init(jax.random.key(0))
state = init(jax.random.PRNGKey(0))
state = step(state, 3 * 12 + 6) # White: Up PAWN
state = step(state, 0 * 12 + 11) # Black: Right Up Bishop
state = step(state, 8 * 12 + 1) # White: Drop PAWN to 1
Expand All @@ -248,7 +248,7 @@ def test_buggy_samples():
assert mask[LEFT_GOLD]

# https://github.com/sotetsuk/pgx/pull/1218
state = init(jax.random.key(0))
state = init(jax.random.PRNGKey(0))
state = step(state, 3 * 12 + 6) # White: Up PAWN
state = step(state, 0 * 12 + 11) # Black: Right Up Bishop
DROP_PAWN_TO_0 = 8 * 12 + 0
Expand Down

0 comments on commit f88f304

Please sign in to comment.