Skip to content

Commit

Permalink
test: revert getkey removal
Browse files Browse the repository at this point in the history
Signed-off-by: Nathaniel Starkman <[email protected]>
  • Loading branch information
nstarman authored and patrick-kidger committed Feb 1, 2025
1 parent a576b8d commit 7d65e9b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax.random as jr
import equinox.internal as eqxi
import pytest


@pytest.fixture()
def getkey():
return lambda: jr.PRNGKey(0)
return eqxi.GetKey()
10 changes: 7 additions & 3 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,16 @@ def run3(mlp, vector):
run3(mlp, vector)


def test_materialise(getkey):
def test_materialise():
key = jr.key(0)

key, *subkeys = jr.split(key, 3)
x_false = lora.LoraArray(
jr.normal(getkey(), (3, 3)), rank=2, allow_materialise=False, key=getkey()
jr.normal(subkeys[0], (3, 3)), rank=2, allow_materialise=False, key=subkeys[1]
)
key, *subkeys = jr.split(key, 3)
x_true = lora.LoraArray(
jr.normal(getkey(), (3, 3)), rank=2, allow_materialise=True, key=getkey()
jr.normal(subkeys[0], (3, 3)), rank=2, allow_materialise=True, key=subkeys[1]
)

_ = quax.quaxify(jax.nn.relu)(x_true)
Expand Down

0 comments on commit 7d65e9b

Please sign in to comment.