From df81dfad05f68da522886460931d3ee2e9c6b708 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 15 Jan 2025 17:46:55 +0100 Subject: [PATCH 1/2] Update test_benchmark.py to use different data for batch processing --- tests/test_benchmark.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 089c78b4e..641e0f243 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -9,13 +9,14 @@ def vectorize_data(model: js.model.JaxSimModel, batch_size: int): key = jax.random.PRNGKey(seed=0) + keys = jax.random.split(key, num=batch_size) return jax.vmap( lambda key: js.data.random_model_data( model=model, key=key, ) - )(jax.numpy.repeat(key[None, :], repeats=batch_size, axis=0)) + )(keys) def benchmark_test_function( From 709ceedf42dc28d2ca3f3082749ea4b635141d86 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 15 Jan 2025 17:48:26 +0100 Subject: [PATCH 2/2] Update `test_benchmark.py` to use jax.block_until_ready for more accurate measurement of computation time For more info see https://jax.readthedocs.io/en/latest/async_dispatch.html#async-dispatch --- tests/test_benchmark.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 641e0f243..cf6898f18 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -27,7 +27,10 @@ def benchmark_test_function( # Warm-up call to avoid including compilation time jax.vmap(func, in_axes=(None, 0))(model, data) - benchmark(jax.vmap(func, in_axes=(None, 0)), model, data) + + # Benchmark the function call + # Note: jax.block_until_ready is used to ensure that the benchmark is not measuring only the asynchronous dispatch + benchmark(jax.block_until_ready(jax.vmap(func, in_axes=(None, 0))), model, data) @pytest.mark.benchmark