diff --git a/.github/workflows/CI-jax.yml b/.github/workflows/CI-jax.yml new file mode 100644 index 00000000..8b9d4cc6 --- /dev/null +++ b/.github/workflows/CI-jax.yml @@ -0,0 +1,56 @@ +name: CI with jax and OpenMM + +on: + push: + branches: + - "main" + pull_request: + branches: + - "main" + schedule: + # Nightly tests run on master by default: + # Scheduled workflows run on the latest commit on the default or base branch. + # (from https://help.github.com/en/actions/reference/events-that-trigger-workflows#scheduled-events-schedule) + - cron: "0 0 * * *" + + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: [3.9] + + steps: + + - uses: actions/checkout@v2 + + # More info on options: https://github.com/conda-incubator/setup-miniconda + - uses: conda-incubator/setup-miniconda@v2 + with: + python-version: ${{ matrix.python-version }} + environment-file: devtools/conda-env.yml + channels: conda-forge, pytorch, defaults + activate-environment: test + auto-update-conda: true + auto-activate-base: false + show-channel-urls: true + + - name: Install pip dependencies + shell: bash -l {0} + run: | + pip install einops + pip install nflows + pip install jax jax2torch + + - name: Install package + shell: bash -l {0} + run: | + python setup.py install + + - name: Test with pytest + shell: bash -l {0} + run: | + pytest -vs diff --git a/devtools/conda-env.yml b/devtools/conda-env.yml index 0a682693..c2f88c00 100644 --- a/devtools/conda-env.yml +++ b/devtools/conda-env.yml @@ -15,4 +15,3 @@ dependencies: - ase - openmmtools - pytorch - - jax diff --git a/tests/nn/flow/transformer/test_jax_bridge.py b/tests/nn/flow/transformer/test_jax_bridge.py index 25a4369a..fe27c981 100644 --- a/tests/nn/flow/transformer/test_jax_bridge.py +++ b/tests/nn/flow/transformer/test_jax_bridge.py @@ -116,6 +116,7 @@ def test_approx_inv_gradients(): jax_config.update("jax_enable_x64", True) threshold = 1e-6 + np.random.seed(44) bijectors = [exp_bijector, sin_bijector, monomial_bijector] inverses = [exp_bijector_inv, sin_bijector_inv, monomial_bijector_inv] @@ -157,6 +158,10 @@ def test_bgflow_interface(ctx): num_mixtures = 7 num_params = 4 + rtol = 1e-2 if ctx["dtype"] == torch.float32 else 1e-4 + atol = 1e-4 if ctx["dtype"] == torch.float32 else 1e-6 + np.random.seed(45) + net = torch.nn.Sequential( torch.nn.Linear(dimx, 128), torch.nn.ReLU(), @@ -179,18 +184,14 @@ def compute_params(x, y_shape): bisection_eps=1e-20 ).to(**ctx) - print(ctx) - x = torch.rand(103, dimx).to(**ctx) - y = torch.rand(103, dimy).to(**ctx) - print(x.dtype) + x = torch.tensor(np.random.uniform(0.0, 1.0, (103, dimx)), **ctx) + y = torch.tensor(np.random.uniform(0.0, 1.0, (103, dimy)), **ctx) y1, ldj1 = transformer(x, y, inverse=False) - print(y1.dtype) y2, ldj2 = transformer(x, y1, inverse=True) - print(y2.dtype) - assert torch.allclose(y, y2, atol=1e-5, rtol=1e-3), (y - y2).abs().max() - assert torch.allclose(ldj1, -ldj2, atol=1e-5, rtol=1e-3), (ldj1 + ldj2).abs().max() + assert torch.allclose(y, y2, atol=atol, rtol=rtol), (y - y2).abs().max() + assert torch.allclose(ldj1, -ldj2, atol=atol, rtol=rtol), (ldj1 + ldj2).abs().max() @contextlib.contextmanager