From edfe0c7f4ddc24c2ab72561c68e7952e67924c04 Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Sat, 28 Dec 2024 02:58:15 +0530 Subject: [PATCH 1/2] Enable parallelization for analysis.polymer.PersistenceLength --- package/CHANGELOG | 1 + package/MDAnalysis/analysis/polymer.py | 14 +++++++++ .../MDAnalysisTests/analysis/conftest.py | 7 +++++ .../analysis/test_persistencelength.py | 31 ++++++++++++++++--- 4 files changed, 49 insertions(+), 4 deletions(-) diff --git a/package/CHANGELOG b/package/CHANGELOG index 7b85922487..248a2b3537 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -27,6 +27,7 @@ Fixes the function to prevent shared state. (Issue #4655) Enhancements + * Enables parallelization for analysis.polymer.PersistenceLength (Issue #4671) * Enables parallelization for analysis.density.DensityAnalysis (Issue #4677, PR #4729) * Enables parallelization for analysis.contacts.Contacts (Issue #4660) * Enable parallelization for analysis.nucleicacids.NucPairDist (Issue #4670) diff --git a/package/MDAnalysis/analysis/polymer.py b/package/MDAnalysis/analysis/polymer.py index a38cf68daa..61042a8755 100644 --- a/package/MDAnalysis/analysis/polymer.py +++ b/package/MDAnalysis/analysis/polymer.py @@ -41,6 +41,7 @@ from ..core.groups import requires, AtomGroup from ..lib.distances import calc_bonds from .base import AnalysisBase +from .results import ResultsGroup logger = logging.getLogger(__name__) @@ -228,7 +229,17 @@ class PersistenceLength(AnalysisBase): Former ``results`` are now stored as ``results.bond_autocorrelation``. :attr:`lb`, :attr:`lp`, :attr:`fit` are now stored in a :class:`MDAnalysis.analysis.base.Results` instance. + .. versionchanged:: 2.9.0 + Enabled **parallel execution** with the ``multiprocessing`` and ``dask`` + backends; use the new method :meth:`get_supported_backends` to see all + supported backends. """ + @classmethod + def get_supported_backends(cls): + return ('serial', 'multiprocessing', 'dask',) + + _analysis_algorithm_is_parallelizable = True + def __init__(self, atomgroups, **kwargs): super(PersistenceLength, self).__init__( atomgroups[0].universe.trajectory, **kwargs) @@ -294,6 +305,9 @@ def _conclude(self): self._perform_fit() + def _get_aggregator(self): + return ResultsGroup(lookup={'bond_autocorrelation': ResultsGroup.ndarray_sum}) + def _calc_bond_length(self): """calculate average bond length""" bs = [] diff --git a/testsuite/MDAnalysisTests/analysis/conftest.py b/testsuite/MDAnalysisTests/analysis/conftest.py index df17c05c06..033a928b60 100644 --- a/testsuite/MDAnalysisTests/analysis/conftest.py +++ b/testsuite/MDAnalysisTests/analysis/conftest.py @@ -18,6 +18,7 @@ from MDAnalysis.analysis.contacts import Contacts from MDAnalysis.analysis.density import DensityAnalysis from MDAnalysis.lib.util import is_installed +from MDAnalysis.analysis.polymer import PersistenceLength def params_for_cls(cls, exclude: list[str] = None): @@ -165,3 +166,9 @@ def client_Contacts(request): @pytest.fixture(scope='module', params=params_for_cls(DensityAnalysis)) def client_DensityAnalysis(request): return request.param + +# MDAnalysis.analysis.polymer + +@pytest.fixture(scope="module", params=params_for_cls(PersistenceLength)) +def client_PersistenceLength(request): + return request.param \ No newline at end of file diff --git a/testsuite/MDAnalysisTests/analysis/test_persistencelength.py b/testsuite/MDAnalysisTests/analysis/test_persistencelength.py index 5d8790ab3d..02d90f290b 100644 --- a/testsuite/MDAnalysisTests/analysis/test_persistencelength.py +++ b/testsuite/MDAnalysisTests/analysis/test_persistencelength.py @@ -26,6 +26,7 @@ from MDAnalysis.analysis import polymer from MDAnalysis.exceptions import NoDataError from MDAnalysis.core.topologyattrs import Bonds +import MDAnalysis import numpy as np import matplotlib @@ -53,8 +54,8 @@ def p(u): @staticmethod @pytest.fixture() - def p_run(p): - return p.run() + def p_run(p, client_PersistenceLength): + return p.run(**client_PersistenceLength) def test_ag_ValueError(self, u): ags = [u.atoms[:10], u.atoms[10:110]] @@ -97,8 +98,8 @@ def test_current_axes(self, p_run): assert ax2 is not ax @pytest.mark.parametrize("attr", ("lb", "lp", "fit")) - def test(self, p, attr): - p_run = p.run(step=3) + def test(self, p, attr, client_PersistenceLength): + p_run = p.run(step=3, **client_PersistenceLength) wmsg = f"The `{attr}` attribute was deprecated in MDAnalysis 2.0.0" with pytest.warns(DeprecationWarning, match=wmsg): getattr(p_run, attr) is p_run.results[attr] @@ -167,3 +168,25 @@ def test_circular(self): with pytest.raises(ValueError) as ex: polymer.sort_backbone(u.atoms) assert 'cyclical' in str(ex.value) + +# tests for parallelization + +@pytest.mark.parametrize( + "classname,is_parallelizable", + [ + (MDAnalysis.analysis.polymer.PersistenceLength , True), + ] +) +def test_class_is_parallelizable(classname, is_parallelizable): + assert classname._analysis_algorithm_is_parallelizable == is_parallelizable + + +@pytest.mark.parametrize( + "classname,backends", + [ + (MDAnalysis.analysis.polymer.PersistenceLength, + ('serial', 'multiprocessing', 'dask',)), + ] +) +def test_supported_backends(classname, backends): + assert classname.get_supported_backends() == backends From dc6e903f142cd00514e6650141d57803242e3dfd Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Sat, 28 Dec 2024 16:12:53 +0530 Subject: [PATCH 2/2] Modified aggregation method --- package/MDAnalysis/analysis/polymer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/package/MDAnalysis/analysis/polymer.py b/package/MDAnalysis/analysis/polymer.py index 61042a8755..55119a5d6b 100644 --- a/package/MDAnalysis/analysis/polymer.py +++ b/package/MDAnalysis/analysis/polymer.py @@ -40,8 +40,7 @@ from .. import NoDataError from ..core.groups import requires, AtomGroup from ..lib.distances import calc_bonds -from .base import AnalysisBase -from .results import ResultsGroup +from .base import AnalysisBase, ResultsGroup logger = logging.getLogger(__name__) @@ -306,7 +305,7 @@ def _conclude(self): self._perform_fit() def _get_aggregator(self): - return ResultsGroup(lookup={'bond_autocorrelation': ResultsGroup.ndarray_sum}) + return ResultsGroup(lookup={'bond_autocorrelation': ResultsGroup.ndarray_vstack}) def _calc_bond_length(self): """calculate average bond length"""