From cd28291936cb6bb86c2fc00f81bf0ad3c48da4e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20Kr=C3=A4mer?= Date: Fri, 18 Feb 2022 12:28:44 +0100 Subject: [PATCH 1/5] loosen tolerance in jax tests --- tests/nn/flow/transformer/test_jax_bridge.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/nn/flow/transformer/test_jax_bridge.py b/tests/nn/flow/transformer/test_jax_bridge.py index 25a4369a..43cb7e6b 100644 --- a/tests/nn/flow/transformer/test_jax_bridge.py +++ b/tests/nn/flow/transformer/test_jax_bridge.py @@ -179,18 +179,16 @@ def compute_params(x, y_shape): bisection_eps=1e-20 ).to(**ctx) - print(ctx) x = torch.rand(103, dimx).to(**ctx) y = torch.rand(103, dimy).to(**ctx) - print(x.dtype) y1, ldj1 = transformer(x, y, inverse=False) - print(y1.dtype) y2, ldj2 = transformer(x, y1, inverse=True) - print(y2.dtype) - assert torch.allclose(y, y2, atol=1e-5, rtol=1e-3), (y - y2).abs().max() - assert torch.allclose(ldj1, -ldj2, atol=1e-5, rtol=1e-3), (ldj1 + ldj2).abs().max() + rtol = 1e-2 if dtype == torch.float32 else 1e-4 + atol = 1e-4 if dtype == torch.float32 else 1e-6 + assert torch.allclose(y, y2, atol=atol, rtol=rtol), (y - y2).abs().max() + assert torch.allclose(ldj1, -ldj2, atol=atol, rtol=rtol), (ldj1 + ldj2).abs().max() @contextlib.contextmanager From 3e9b0bf27d3dc2bd72880ee3c7515c513be1649f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20Kr=C3=A4mer?= Date: Fri, 18 Feb 2022 12:35:30 +0100 Subject: [PATCH 2/5] bugfix --- tests/nn/flow/transformer/test_jax_bridge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/nn/flow/transformer/test_jax_bridge.py b/tests/nn/flow/transformer/test_jax_bridge.py index 43cb7e6b..b34e0adc 100644 --- a/tests/nn/flow/transformer/test_jax_bridge.py +++ b/tests/nn/flow/transformer/test_jax_bridge.py @@ -185,8 +185,8 @@ def compute_params(x, y_shape): y1, ldj1 = transformer(x, y, inverse=False) y2, ldj2 = transformer(x, y1, inverse=True) - rtol = 1e-2 if dtype == torch.float32 else 1e-4 - atol = 1e-4 if dtype == torch.float32 else 1e-6 + rtol = 1e-2 if ctx["dtype"] == torch.float32 else 1e-4 + atol = 1e-4 if ctx["dtype"] == torch.float32 else 1e-6 assert torch.allclose(y, y2, atol=atol, rtol=rtol), (y - y2).abs().max() assert torch.allclose(ldj1, -ldj2, atol=atol, rtol=rtol), (ldj1 + ldj2).abs().max() From 4bc84a3edaee9c7d9ceb6ae26874f70a582cd149 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20Kr=C3=A4mer?= Date: Fri, 18 Feb 2022 13:56:34 +0100 Subject: [PATCH 3/5] jax from pip --- .github/workflows/CI-jax.yml | 56 +++++++++++++++++++++++++++++++++ .github/workflows/CI-openmm.yml | 2 +- devtools/conda-env.yml | 1 - 3 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/CI-jax.yml diff --git a/.github/workflows/CI-jax.yml b/.github/workflows/CI-jax.yml new file mode 100644 index 00000000..5090db2d --- /dev/null +++ b/.github/workflows/CI-jax.yml @@ -0,0 +1,56 @@ +name: CI with OpenMM on conda + +on: + push: + branches: + - "main" + pull_request: + branches: + - "main" + schedule: + # Nightly tests run on master by default: + # Scheduled workflows run on the latest commit on the default or base branch. + # (from https://help.github.com/en/actions/reference/events-that-trigger-workflows#scheduled-events-schedule) + - cron: "0 0 * * *" + + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: [3.9] + + steps: + + - uses: actions/checkout@v2 + + # More info on options: https://github.com/conda-incubator/setup-miniconda + - uses: conda-incubator/setup-miniconda@v2 + with: + python-version: ${{ matrix.python-version }} + environment-file: devtools/conda-env.yml + channels: conda-forge, pytorch, defaults + activate-environment: test + auto-update-conda: true + auto-activate-base: false + show-channel-urls: true + + - name: Install pip dependencies + shell: bash -l {0} + run: | + pip install einops + pip install nflows + pip install jax + + - name: Install package + shell: bash -l {0} + run: | + python setup.py install + + - name: Test with pytest + shell: bash -l {0} + run: | + pytest -vs diff --git a/.github/workflows/CI-openmm.yml b/.github/workflows/CI-openmm.yml index 72198539..ff26a213 100644 --- a/.github/workflows/CI-openmm.yml +++ b/.github/workflows/CI-openmm.yml @@ -1,4 +1,4 @@ -name: CI with OpenMM on conda +name: CI with OpenMM and jax on: push: diff --git a/devtools/conda-env.yml b/devtools/conda-env.yml index 0a682693..c2f88c00 100644 --- a/devtools/conda-env.yml +++ b/devtools/conda-env.yml @@ -15,4 +15,3 @@ dependencies: - ase - openmmtools - pytorch - - jax From 19a1f016069522f92c5b1ed5635ded1d4975b39f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20Kr=C3=A4mer?= Date: Fri, 18 Feb 2022 14:06:53 +0100 Subject: [PATCH 4/5] seed in jax tests --- .github/workflows/CI-jax.yml | 2 +- .github/workflows/CI-openmm.yml | 2 +- tests/nn/flow/transformer/test_jax_bridge.py | 11 +++++++---- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/.github/workflows/CI-jax.yml b/.github/workflows/CI-jax.yml index 5090db2d..a921098f 100644 --- a/.github/workflows/CI-jax.yml +++ b/.github/workflows/CI-jax.yml @@ -1,4 +1,4 @@ -name: CI with OpenMM on conda +name: CI with jax and OpenMM on: push: diff --git a/.github/workflows/CI-openmm.yml b/.github/workflows/CI-openmm.yml index ff26a213..72198539 100644 --- a/.github/workflows/CI-openmm.yml +++ b/.github/workflows/CI-openmm.yml @@ -1,4 +1,4 @@ -name: CI with OpenMM and jax +name: CI with OpenMM on conda on: push: diff --git a/tests/nn/flow/transformer/test_jax_bridge.py b/tests/nn/flow/transformer/test_jax_bridge.py index b34e0adc..fe27c981 100644 --- a/tests/nn/flow/transformer/test_jax_bridge.py +++ b/tests/nn/flow/transformer/test_jax_bridge.py @@ -116,6 +116,7 @@ def test_approx_inv_gradients(): jax_config.update("jax_enable_x64", True) threshold = 1e-6 + np.random.seed(44) bijectors = [exp_bijector, sin_bijector, monomial_bijector] inverses = [exp_bijector_inv, sin_bijector_inv, monomial_bijector_inv] @@ -157,6 +158,10 @@ def test_bgflow_interface(ctx): num_mixtures = 7 num_params = 4 + rtol = 1e-2 if ctx["dtype"] == torch.float32 else 1e-4 + atol = 1e-4 if ctx["dtype"] == torch.float32 else 1e-6 + np.random.seed(45) + net = torch.nn.Sequential( torch.nn.Linear(dimx, 128), torch.nn.ReLU(), @@ -179,14 +184,12 @@ def compute_params(x, y_shape): bisection_eps=1e-20 ).to(**ctx) - x = torch.rand(103, dimx).to(**ctx) - y = torch.rand(103, dimy).to(**ctx) + x = torch.tensor(np.random.uniform(0.0, 1.0, (103, dimx)), **ctx) + y = torch.tensor(np.random.uniform(0.0, 1.0, (103, dimy)), **ctx) y1, ldj1 = transformer(x, y, inverse=False) y2, ldj2 = transformer(x, y1, inverse=True) - rtol = 1e-2 if ctx["dtype"] == torch.float32 else 1e-4 - atol = 1e-4 if ctx["dtype"] == torch.float32 else 1e-6 assert torch.allclose(y, y2, atol=atol, rtol=rtol), (y - y2).abs().max() assert torch.allclose(ldj1, -ldj2, atol=atol, rtol=rtol), (ldj1 + ldj2).abs().max() From b2dec03ccfbb30dccabb23542479a6df2efd9169 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20Kr=C3=A4mer?= Date: Fri, 18 Feb 2022 15:03:42 +0100 Subject: [PATCH 5/5] CI with jax2torch --- .github/workflows/CI-jax.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI-jax.yml b/.github/workflows/CI-jax.yml index a921098f..8b9d4cc6 100644 --- a/.github/workflows/CI-jax.yml +++ b/.github/workflows/CI-jax.yml @@ -43,7 +43,7 @@ jobs: run: | pip install einops pip install nflows - pip install jax + pip install jax jax2torch - name: Install package shell: bash -l {0}