Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable parallelization for analysis.polymer #4871

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Fixes
the function to prevent shared state. (Issue #4655)

Enhancements
* Enables parallelization for analysis.polymer.PersistenceLength (Issue #4671)
* Addition of 'water' token for water selection (Issue #4839)
* Enables parallelization for analysis.density.DensityAnalysis (Issue #4677, PR #4729)
* Enables parallelization for analysis.contacts.Contacts (Issue #4660)
Expand Down
16 changes: 15 additions & 1 deletion package/MDAnalysis/analysis/polymer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .. import NoDataError
from ..core.groups import requires, AtomGroup
from ..lib.distances import calc_bonds
from .base import AnalysisBase
from .base import AnalysisBase, ResultsGroup

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -233,8 +233,19 @@ class PersistenceLength(AnalysisBase):
Former ``results`` are now stored as ``results.bond_autocorrelation``.
:attr:`lb`, :attr:`lp`, :attr:`fit` are now stored in a
:class:`MDAnalysis.analysis.base.Results` instance.
.. versionchanged:: 2.9.0
Enabled **parallel execution** with the ``multiprocessing`` and ``dask``
backends; use the new method :meth:`get_supported_backends` to see all
supported backends.
"""

@classmethod
def get_supported_backends(cls):
return ('serial', 'multiprocessing', 'dask',)

_analysis_algorithm_is_parallelizable = True


def __init__(self, atomgroups, **kwargs):
super(PersistenceLength, self).__init__(
atomgroups[0].universe.trajectory, **kwargs
Expand Down Expand Up @@ -307,6 +318,9 @@ def _conclude(self):

self._perform_fit()

def _get_aggregator(self):
return ResultsGroup(lookup={'bond_autocorrelation': ResultsGroup.ndarray_vstack})

def _calc_bond_length(self):
"""calculate average bond length"""
bs = []
Expand Down
7 changes: 7 additions & 0 deletions testsuite/MDAnalysisTests/analysis/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from MDAnalysis.analysis.contacts import Contacts
from MDAnalysis.analysis.density import DensityAnalysis
from MDAnalysis.lib.util import is_installed
from MDAnalysis.analysis.polymer import PersistenceLength


def params_for_cls(cls, exclude: list[str] = None):
Expand Down Expand Up @@ -176,3 +177,9 @@ def client_Contacts(request):
@pytest.fixture(scope="module", params=params_for_cls(DensityAnalysis))
def client_DensityAnalysis(request):
return request.param

# MDAnalysis.analysis.polymer

@pytest.fixture(scope="module", params=params_for_cls(PersistenceLength))
def client_PersistenceLength(request):
return request.param
34 changes: 30 additions & 4 deletions testsuite/MDAnalysisTests/analysis/test_persistencelength.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from MDAnalysis.analysis import polymer
from MDAnalysis.exceptions import NoDataError
from MDAnalysis.core.topologyattrs import Bonds
import MDAnalysis

import numpy as np
import matplotlib
Expand All @@ -52,8 +53,8 @@ def p(u):

@staticmethod
@pytest.fixture()
def p_run(p):
return p.run()
def p_run(p, client_PersistenceLength):
return p.run(**client_PersistenceLength)

def test_ag_ValueError(self, u):
ags = [u.atoms[:10], u.atoms[10:110]]
Expand Down Expand Up @@ -98,8 +99,8 @@ def test_current_axes(self, p_run):
assert ax2 is not ax

@pytest.mark.parametrize("attr", ("lb", "lp", "fit"))
def test(self, p, attr):
p_run = p.run(step=3)
def test(self, p, attr, client_PersistenceLength):
p_run = p.run(step=3, **client_PersistenceLength)
wmsg = f"The `{attr}` attribute was deprecated in MDAnalysis 2.0.0"
with pytest.warns(DeprecationWarning, match=wmsg):
getattr(p_run, attr) is p_run.results[attr]
Expand Down Expand Up @@ -166,4 +167,29 @@ def test_circular(self):

with pytest.raises(ValueError) as ex:
polymer.sort_backbone(u.atoms)

assert 'cyclical' in str(ex.value)

# tests for parallelization

@pytest.mark.parametrize(
"classname,is_parallelizable",
[
(MDAnalysis.analysis.polymer.PersistenceLength , True),
]
)
def test_class_is_parallelizable(classname, is_parallelizable):
assert classname._analysis_algorithm_is_parallelizable == is_parallelizable


@pytest.mark.parametrize(
"classname,backends",
[
(MDAnalysis.analysis.polymer.PersistenceLength,
('serial', 'multiprocessing', 'dask',)),
]
)
def test_supported_backends(classname, backends):
assert classname.get_supported_backends() == backends
assert "cyclical" in str(ex.value)

Loading