diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 089c78b4e..cf6898f18 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -9,13 +9,14 @@ def vectorize_data(model: js.model.JaxSimModel, batch_size: int): key = jax.random.PRNGKey(seed=0) + keys = jax.random.split(key, num=batch_size) return jax.vmap( lambda key: js.data.random_model_data( model=model, key=key, ) - )(jax.numpy.repeat(key[None, :], repeats=batch_size, axis=0)) + )(keys) def benchmark_test_function( @@ -26,7 +27,10 @@ def benchmark_test_function( # Warm-up call to avoid including compilation time jax.vmap(func, in_axes=(None, 0))(model, data) - benchmark(jax.vmap(func, in_axes=(None, 0)), model, data) + + # Benchmark the function call + # Note: jax.block_until_ready is used to ensure that the benchmark is not measuring only the asynchronous dispatch + benchmark(jax.block_until_ready(jax.vmap(func, in_axes=(None, 0))), model, data) @pytest.mark.benchmark