From c8fe7b8db0dc6a1cb345bd0f1dbe4f45d3b4ebd5 Mon Sep 17 00:00:00 2001 From: Zhiyi Wu Date: Tue, 21 May 2024 01:36:48 +0100 Subject: [PATCH] Improve the error handling in the backward and forward convergence (#358) * problem: In the backward and forward convergence, for the initial set of points, which uses for example the first 10% of the date, it could be the cases where due to the fact that there are not many data points, so the overlap is pretty bad, which gives terrible statistical error. solution: If the statistical error is too bad, use the bootstrap error instead, see choderalab/pymbar#519 * new kwarg error_tol in convergence.forward_backward_convergence() to allow the user to specify a error tolerance; if error > error_tol then switch to using bootstrap error * Update CHANGES * add test --- CHANGES | 2 + src/alchemlyb/convergence/convergence.py | 107 +++++++++++++++-------- src/alchemlyb/tests/test_convergence.py | 9 ++ 3 files changed, 83 insertions(+), 35 deletions(-) diff --git a/CHANGES b/CHANGES index 6734d315..02940a79 100644 --- a/CHANGES +++ b/CHANGES @@ -24,6 +24,8 @@ Changes `None` (start from all zeros) as this change provides a sizable speedup (PR #357) Enhancements + - `forward_backward_convergence` uses the bootstrap error when the statistical error + is too large. (PR #358) - `BAR` result is used as initial guess for `MBAR` estimator. (PR #357) - `forward_backward_convergence` uses the result from the previous step as the initial guess for the next step. (PR #357) diff --git a/src/alchemlyb/convergence/convergence.py b/src/alchemlyb/convergence/convergence.py index 65f15b86..d109e1ee 100644 --- a/src/alchemlyb/convergence/convergence.py +++ b/src/alchemlyb/convergence/convergence.py @@ -1,10 +1,12 @@ """Functions for assessing convergence of free energy estimates and raw data.""" +from typing import Any, List, Tuple from warnings import warn import numpy as np import pandas as pd from loguru import logger +from sklearn.base import BaseEstimator from .. import concat from ..estimators import BAR, TI, MBAR, FEP_ESTIMATORS, TI_ESTIMATORS @@ -13,7 +15,9 @@ estimators_dispatch = {"BAR": BAR, "TI": TI, "MBAR": MBAR} -def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs): +def forward_backward_convergence( + df_list, estimator="MBAR", num=10, error_tol: float = 3, **kwargs +): """Forward and backward convergence of the free energy estimate. Generate the free energy estimate as a function of time in both directions, @@ -35,6 +39,12 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs): Lower case input is also accepted until release 2.0.0. num : int The number of time points. + error_tol : float + The maximum error tolerated for analytic error. If the analytic error is + bigger than the error tolerance, the bootstrap error will be used. + + .. versionadded:: 2.3.0 + kwargs : dict Keyword arguments to be passed to the estimator. @@ -93,23 +103,11 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs): sample = [] for data in df_list: sample.append(data[: len(data) // num * i]) - sample = concat(sample) - result = my_estimator.fit(sample) - if estimator == "MBAR": - my_estimator.initial_f_k = result.delta_f_.iloc[0, :] - forward_list.append(result.delta_f_.iloc[0, -1]) - if estimator.lower() == "bar": - error = np.sqrt( - sum( - [ - result.d_delta_f_.iloc[i, i + 1] ** 2 - for i in range(len(result.d_delta_f_) - 1) - ] - ) - ) - forward_error_list.append(error) - else: - forward_error_list.append(result.d_delta_f_.iloc[0, -1]) + mean, error = _forward_backward_convergence_estimate( + sample, estimator, my_estimator, error_tol, **kwargs + ) + forward_list.append(mean) + forward_error_list.append(error) logger.info( "{:.2f} +/- {:.2f} kT".format(forward_list[-1], forward_error_list[-1]) ) @@ -122,23 +120,11 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs): sample = [] for data in df_list: sample.append(data[-len(data) // num * i :]) - sample = concat(sample) - result = my_estimator.fit(sample) - if estimator == "MBAR": - my_estimator.initial_f_k = result.delta_f_.iloc[0, :] - backward_list.append(result.delta_f_.iloc[0, -1]) - if estimator.lower() == "bar": - error = np.sqrt( - sum( - [ - result.d_delta_f_.iloc[i, i + 1] ** 2 - for i in range(len(result.d_delta_f_) - 1) - ] - ) - ) - backward_error_list.append(error) - else: - backward_error_list.append(result.d_delta_f_.iloc[0, -1]) + mean, error = _forward_backward_convergence_estimate( + sample, estimator, my_estimator, error_tol, **kwargs + ) + backward_list.append(mean) + backward_error_list.append(error) logger.info( "{:.2f} +/- {:.2f} kT".format(backward_list[-1], backward_error_list[-1]) ) @@ -156,6 +142,57 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs): return convergence +def _forward_backward_convergence_estimate( + sample_list: List[pd.DataFrame], + estimator: str, + my_estimator: BaseEstimator, + error_tol: float, + **kwargs: Any, +) -> Tuple[float, float]: + """Use estimator to run the estimation and return the mean and error. + + Parameters + ---------- + sample_list: A list of samples as pandas Dataframe. + estimator: The string of the estimator + my_estimator: The estimator object. + error_tol: The error tolerance. + kwargs + + Returns + ------- + mean: The delta_f between 0 and 1 + error: The d_delta_f between 0 and 1 + """ + sample = concat(sample_list) + result = my_estimator.fit(sample) + if estimator == "MBAR": + my_estimator.initial_f_k = result.delta_f_.iloc[0, :] + mean = result.delta_f_.iloc[0, -1] + if estimator.lower() == "bar": + error = np.sqrt( + sum( + [ + result.d_delta_f_.iloc[i, i + 1] ** 2 + for i in range(len(result.d_delta_f_) - 1) + ] + ) + ) + else: + error = result.d_delta_f_.iloc[0, -1] + if estimator.lower() == "mbar" and error > error_tol: + logger.warning( + f"Statistical Error ({error}) bigger than error tolerance ({error_tol}), use bootstrap error instead." + ) + bootstraps_estimator = estimators_dispatch[estimator]( + n_bootstraps=50, initial_f_k=result.delta_f_.iloc[0, :], **kwargs + ) + bootstraps_estimator.fit(sample) + error = bootstraps_estimator.d_delta_f_.iloc[0, -1] + + return mean, error + + def _cummean(vals, out_length): """The cumulative mean of an array. diff --git a/src/alchemlyb/tests/test_convergence.py b/src/alchemlyb/tests/test_convergence.py index 32fa1eb9..2bde94b3 100644 --- a/src/alchemlyb/tests/test_convergence.py +++ b/src/alchemlyb/tests/test_convergence.py @@ -36,6 +36,15 @@ def test_convergence_wrong_cases(gmx_benzene_Coulomb_u_nk): forward_backward_convergence(gmx_benzene_Coulomb_u_nk, "mbar") +def test_convergence_bootstrap(gmx_benzene_Coulomb_u_nk, caplog): + normal_c = forward_backward_convergence(gmx_benzene_Coulomb_u_nk, "mbar", num=2) + bootstrap_c = forward_backward_convergence( + gmx_benzene_Coulomb_u_nk, "mbar", error_tol=0.01, num=2 + ) + assert "use bootstrap error instead." in caplog.text + assert (bootstrap_c["Forward_Error"] != normal_c["Forward_Error"]).all() + + def test_convergence_method(gmx_benzene_Coulomb_u_nk): convergence = forward_backward_convergence( gmx_benzene_Coulomb_u_nk, "MBAR", num=2, method="adaptive"