From 7d65e9b7a5fe0118fe4da77ea3a4444fce70da2d Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Fri, 31 Jan 2025 12:56:39 -0500 Subject: [PATCH] test: revert getkey removal Signed-off-by: Nathaniel Starkman --- tests/conftest.py | 4 ++-- tests/test_lora.py | 10 +++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 220b1e4..80bc645 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ -import jax.random as jr +import equinox.internal as eqxi import pytest @pytest.fixture() def getkey(): - return lambda: jr.PRNGKey(0) + return eqxi.GetKey() diff --git a/tests/test_lora.py b/tests/test_lora.py index 260789b..e20f8d5 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -95,12 +95,16 @@ def run3(mlp, vector): run3(mlp, vector) -def test_materialise(getkey): +def test_materialise(): + key = jr.key(0) + + key, *subkeys = jr.split(key, 3) x_false = lora.LoraArray( - jr.normal(getkey(), (3, 3)), rank=2, allow_materialise=False, key=getkey() + jr.normal(subkeys[0], (3, 3)), rank=2, allow_materialise=False, key=subkeys[1] ) + key, *subkeys = jr.split(key, 3) x_true = lora.LoraArray( - jr.normal(getkey(), (3, 3)), rank=2, allow_materialise=True, key=getkey() + jr.normal(subkeys[0], (3, 3)), rank=2, allow_materialise=True, key=subkeys[1] ) _ = quax.quaxify(jax.nn.relu)(x_true)