Skip to content

Commit

Permalink
Switch out black for Ruff for formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
fjclark committed Jun 13, 2024
1 parent cae2d38 commit 8b45f71
Show file tree
Hide file tree
Showing 20 changed files with 145 additions and 94 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ on:
jobs:
# Check with ruff and black
lint_and_format:
name: Check for style and formatting violations with Ruff and Black
name: Check for style and formatting violations with Ruff
runs-on: ubuntu-latest
steps:
- name: Check out code
Expand All @@ -32,13 +32,13 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff black
pip install ruff
- name: Run Ruff
- name: Check style with Ruff
run: ruff check .

- name: Check Black formatting
run: black --check .
- name: Check formatting with Ruff
run: ruff format --check .

test:
name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }}
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ a3fe
[![codecov](https://codecov.io/gh/michellab/a3fe/graph/badge.svg?token=5IGO8SCRRQ)](https://codecov.io/gh/michellab/a3fe)
[![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0)
[![Documentation Status](https://readthedocs.org/projects/a3fe/badge/?version=latest)](https://a3fe.readthedocs.io/en/latest/?badge=latest)
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)

<img src="./a3fe_logo.png" alt="Alt text" style="width: 50%; height: 50%;">

Expand Down
4 changes: 2 additions & 2 deletions a3fe/analyse/autocorrelation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Functionality to calculate autocorrelation
Aknowledgement: "get_statistical_inefficiency" is copied almost exactly from
"statisticalInefficiency_multiscale" here:
Aknowledgement: "get_statistical_inefficiency" is copied almost exactly from
"statisticalInefficiency_multiscale" here:
https://github.com/choderalab/automatic-equilibration-detection/blob/master/examples/liquid-argon/equilibration.py
The license and original authorship are preserved below:
Expand Down
35 changes: 23 additions & 12 deletions a3fe/analyse/detect_equil.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
)


def check_equil_block_gradient(lam_win: "LamWindow", run_nos: _Optional[_List[int]]) -> _Tuple[bool, _Optional[float]]: # type: ignore # noqa: F821
def check_equil_block_gradient(
lam_win: "LamWindow", # type: ignore # noqa: F821
run_nos: _Optional[_List[int]],
) -> _Tuple[bool, _Optional[float]]:
"""
Check if the ensemble of simulations at the lambda window is
equilibrated based on the ensemble gradient between averaged blocks.
Expand Down Expand Up @@ -149,7 +152,10 @@ def check_equil_block_gradient(lam_win: "LamWindow", run_nos: _Optional[_List[in
return equilibrated, equil_time


def check_equil_chodera(lam_win: "LamWindow", run_nos: _Optional[_List[int]] = None) -> _Tuple[bool, _Optional[float]]: # type: ignore # noqa: F821
def check_equil_chodera(
lam_win: "LamWindow", # type: ignore # noqa: F821
run_nos: _Optional[_List[int]] = None,
) -> _Tuple[bool, _Optional[float]]:
"""
Check if the ensemble of simulations at the lambda window is
equilibrated based Chodera's method of maximising the number
Expand Down Expand Up @@ -337,7 +343,9 @@ def check_equil_multiwindow(
# Write out data
with open(f"{output_dir}/check_equil_multiwindow.txt", "w") as ofile:
ofile.write(f"Equilibrated: {equilibrated}\n")
ofile.write(f"Overall gradient {overall_grad_dg} +/- {conf_int[1] - overall_grad_dg} kcal mol^-1 ns^-1\n") # type: ignore
ofile.write(
f"Overall gradient {overall_grad_dg} +/- {conf_int[1] - overall_grad_dg} kcal mol^-1 ns^-1\n"
) # type: ignore
ofile.write(f"Fractional equilibration time: {fractional_equil_time} \n")
ofile.write(f"Equilibration time: {equil_time} ns\n")
ofile.write(f"Run numbers: {run_nos}\n")
Expand Down Expand Up @@ -549,13 +557,14 @@ def check_equil_multiwindow_modified_geweke(
last_slice_means,
equal_var=False, # Welches t-test
alternative="two-sided",
)[
1
] # First value is the t statistic - we want p
)[1] # First value is the t statistic - we want p

# Store results - note that time is the per-run time
p_vals_and_times.append(
(p_value, overall_times[0][0])
(
p_value,
overall_times[0][0],
)
) # second value is the equilibration time

# Check if the p-value is greater than the cutoff
Expand Down Expand Up @@ -704,13 +713,14 @@ def check_equil_multiwindow_paired_t(
first_slice_means,
last_slice_means,
alternative="two-sided",
)[
1
] # First value is the t statistic - we want p
)[1] # First value is the t statistic - we want p

# Store results - note that time is the per-run time
p_vals_and_times.append(
(p_value, overall_times[0][0])
(
p_value,
overall_times[0][0],
)
) # second value is the equilibration time

# Check if the p-value is greater than the cutoff
Expand Down Expand Up @@ -765,7 +775,8 @@ def check_equil_multiwindow_paired_t(


def dummy_check_equil_multiwindow(
lam_win: "LamWindow", run_nos: _Optional[_List[int]] = None # type: ignore # noqa: F821
lam_win: "LamWindow", # type: ignore # noqa: F821
run_nos: _Optional[_List[int]] = None,
) -> _Tuple[bool, _Optional[float]]:
"""
Becuse "check_equil_multiwindow" checks multiple windows at once and sets the _equilibrated
Expand Down
64 changes: 39 additions & 25 deletions a3fe/analyse/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def general_plot(
fig, ax = _plt.subplots(figsize=(8, 6))
ax.plot(x_vals, y_avg, label="Mean", linewidth=2)
for i, entry in enumerate(y_vals):
ax.plot(x_vals, entry, alpha=0.5, label=f"run {run_nos[i] if run_nos else i+1}")
ax.plot(
x_vals, entry, alpha=0.5, label=f"run {run_nos[i] if run_nos else i + 1}"
)
if vline_val is not None:
ax.axvline(x=vline_val, color="red", linestyle="dashed")
if hline_val is not None:
Expand Down Expand Up @@ -198,9 +200,11 @@ def plot_gradient_stats(
edgecolor="black",
yerr=gradients_data.sems_overall,
)
ax.set_ylabel(
r"$\langle \frac{\mathrm{d}h}{\mathrm{d}\lambda}\rangle _{\lambda} $ / kcal mol$^{-1}$"
),
(
ax.set_ylabel(
r"$\langle \frac{\mathrm{d}h}{\mathrm{d}\lambda}\rangle _{\lambda} $ / kcal mol$^{-1}$"
),
)

elif plot_type == "stat_ineff":
ax.bar(
Expand All @@ -219,9 +223,11 @@ def plot_gradient_stats(
width=0.02,
edgecolor="black",
)
ax.set_ylabel(
r"$\sqrt{t}$SEM($\frac{\mathrm{d}h}{\mathrm{d}\lambda} $) / kcal mol$^{-1}$ ns$^{1/2}$"
),
(
ax.set_ylabel(
r"$\sqrt{t}$SEM($\frac{\mathrm{d}h}{\mathrm{d}\lambda} $) / kcal mol$^{-1}$ ns$^{1/2}$"
),
)
ax.legend()
# Get second y axis so we can plot on different scales
ax2 = ax.twinx()
Expand Down Expand Up @@ -253,9 +259,11 @@ def plot_gradient_stats(
# Add vertical lines at optimal lambda vals
for lam_val in optimal_lam_vals:
ax2.axvline(x=lam_val, color="black", linestyle="dashed", linewidth=0.5)
ax2.set_ylabel(
r"Integrated $\sqrt{t}$SEM($\frac{\mathrm{d}h}{\mathrm{d}\lambda} $) / kcal mol$^{-1}$ ns$^{1/2}$"
),
(
ax2.set_ylabel(
r"Integrated $\sqrt{t}$SEM($\frac{\mathrm{d}h}{\mathrm{d}\lambda} $) / kcal mol$^{-1}$ ns$^{1/2}$"
),
)

elif plot_type == "pred_best_simtime":
# Calculate the predicted optimum simulation time
Expand All @@ -277,7 +285,7 @@ def plot_gradient_stats(
width=0.02,
edgecolor="black",
)
ax.set_ylabel(r"Predicted most efficient runtimes per run / ns"),
(ax.set_ylabel(r"Predicted most efficient runtimes per run / ns"),)
ax.legend()

elif plot_type == "integrated_var":
Expand All @@ -288,9 +296,11 @@ def plot_gradient_stats(
width=0.02,
edgecolor="black",
)
ax.set_ylabel(
r"(Var($\frac{\mathrm{d}h}{\mathrm{d}\lambda} $))$^{1/2}$ / kcal mol$^{-1}$"
),
(
ax.set_ylabel(
r"(Var($\frac{\mathrm{d}h}{\mathrm{d}\lambda} $))$^{1/2}$ / kcal mol$^{-1}$"
),
)
ax.legend()
# Get second y axis so we can plot on different scales
ax2 = ax.twinx()
Expand Down Expand Up @@ -318,9 +328,11 @@ def plot_gradient_stats(
# Add vertical lines at optimal lambda vals
for lam_val in optimal_lam_vals:
ax2.axvline(x=lam_val, color="black", linestyle="dashed", linewidth=0.5)
ax2.set_ylabel(
r"Integrated (Var($\frac{\mathrm{d}h}{\mathrm{d}\lambda} $))$^{1/2}$ / kcal mol$^{-1}$"
),
(
ax2.set_ylabel(
r"Integrated (Var($\frac{\mathrm{d}h}{\mathrm{d}\lambda} $))$^{1/2}$ / kcal mol$^{-1}$"
),
)
ax2.legend()
ax2.legend()

Expand Down Expand Up @@ -374,7 +386,7 @@ def plot_gradient_hists(
bins=50,
density=True,
alpha=0.5,
label=f"Run {run_nos[j] if run_nos else j+1}",
label=f"Run {run_nos[j] if run_nos else j + 1}",
)
ax.legend()
ax.set_title(f"$\lambda$ = {gradients_data.lam_vals[i]}")
Expand Down Expand Up @@ -455,7 +467,7 @@ def plot_gradient_timeseries(
gradients_data.times[i],
gradients,
alpha=0.5,
label=f"Run {run_nos[j] if run_nos else j+1}",
label=f"Run {run_nos[j] if run_nos else j + 1}",
)
ax.legend()
ax.set_title(f"$\lambda$ = {gradients_data.lam_vals[i]}")
Expand Down Expand Up @@ -709,7 +721,7 @@ def plot_overlap_mats(
for i in range(n_runs):
plot_overlap_mat(
ax=axs[i],
name=f"Run {i+1}" if not predicted else "Predicted",
name=f"Run {i + 1}" if not predicted else "Predicted",
mbar_file=mbar_outfiles[i] if mbar_outfiles else None,
predicted=predicted,
gradient_data=gradient_data,
Expand Down Expand Up @@ -891,9 +903,11 @@ def _plot_mbar_gradient_convergence_single_run(

# Labels
ax.set_xlabel(r"$\lambda$")
ax.set_ylabel(
r"$\langle \frac{\mathrm{d}h}{\mathrm{d}\lambda}\rangle _{\lambda} $ / kcal mol$^{-1}$"
),
(
ax.set_ylabel(
r"$\langle \frac{\mathrm{d}h}{\mathrm{d}\lambda}\rangle _{\lambda} $ / kcal mol$^{-1}$"
),
)
ax.set_title(run_name)

# Return the colour mapper so we can add it to the plot
Expand Down Expand Up @@ -1053,7 +1067,7 @@ def plot_rmsds(
ax.set_xlabel("Time (ns)")
ax.set_ylabel(r"RMSD ($\AA$)")
for j, rmsd in enumerate(rmsds):
ax.plot(times, rmsd, label=f"Run {j+1}")
ax.plot(times, rmsd, label=f"Run {j + 1}")
ax.legend()

# If we have equilibration data, plot this
Expand All @@ -1069,7 +1083,7 @@ def plot_rmsds(
group_selection_name = (
"none" if not group_selection else group_selection.replace(" ", "")
)
name = f"{output_dir}/rmsd_{selection.replace(' ','')}_{group_selection_name}" # Use selection string to make sure save name is unique
name = f"{output_dir}/rmsd_{selection.replace(' ', '')}_{group_selection_name}" # Use selection string to make sure save name is unique
fig.savefig(
name, dpi=300, bbox_inches="tight", facecolor="white", transparent=False
)
Expand Down
14 changes: 9 additions & 5 deletions a3fe/analyse/process_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,10 @@ def get_time_series_multiwindow(
# simulation time for the whole calculation
n_runs = len(run_nos)
overall_dgs = _np.zeros(
[n_runs, 100]
[
n_runs,
100,
]
) # One point for each % of the total simulation time
overall_times = _np.zeros([n_runs, 100])
for lam_win in lambda_windows:
Expand Down Expand Up @@ -696,15 +699,16 @@ def get_time_series_multiwindow_mbar(
n_runs = len(run_nos)
n_points = 100
overall_dgs = _np.zeros(
[n_runs, n_points]
[
n_runs,
n_points,
]
) # One point for each % of the total simulation time
overall_times = _np.zeros([n_runs, n_points])
start_and_end_fracs = [
(i, i + (end_frac - start_frac) / n_points)
for i in _np.linspace(start_frac, end_frac, n_points + 1)
][
:-1
] # Throw away the last point as > 1
][:-1] # Throw away the last point as > 1
# Round the values to avoid floating point errors
start_and_end_fracs = [
(round(x[0], 5), round(x[1], 5)) for x in start_and_end_fracs
Expand Down
2 changes: 1 addition & 1 deletion a3fe/analyse/waters.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def get_av_waters_lambda_window(
avg_close_waters = _np.full(n_runs, _np.nan)
for i, run_no in enumerate(run_nos):
print_fn(
f"Calculating average number of waters for run {i+1} of {n_runs} for lambda window {lam_val}"
f"Calculating average number of waters for run {i + 1} of {n_runs} for lambda window {lam_val}"
)
sim = simulations[run_no - 1]
avg_close_waters[i] = get_av_waters_simulation(
Expand Down
4 changes: 3 additions & 1 deletion a3fe/read/_process_bss_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from sire.legacy import Mol as _SireMol


def rename_lig(bss_system: _BSS._SireWrappers._system.System, new_name: str = "LIG") -> None: # type: ignore
def rename_lig(
bss_system: _BSS._SireWrappers._system.System, new_name: str = "LIG"
) -> None: # type: ignore
"""Rename the ligand in a BSS system.
Parameters
Expand Down
4 changes: 1 addition & 3 deletions a3fe/read/_process_somd_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,7 @@ def write_truncated_sim_datafile(
raise ValueError(f"No data found in simfile: {simfile}.")
start_reading_idx = (
# + 1 and -1 because no data is written at t = 0
round((final_idx - start_data_idx + 1) * fraction_initial)
+ start_data_idx
- 1
round((final_idx - start_data_idx + 1) * fraction_initial) + start_data_idx - 1
)
end_reading_idx = (
round((final_idx - start_data_idx + 1) * fraction_final) + start_data_idx - 1
Expand Down
2 changes: 1 addition & 1 deletion a3fe/read/_read_exp_dgs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Functionality to read experimental free energies from a supplied csv file.
This must have the columns: calc_base_dir, name, exp_dg, exp_er"""
This must have the columns: calc_base_dir, name, exp_dg, exp_er"""

import os as _os

Expand Down
6 changes: 4 additions & 2 deletions a3fe/run/_simulation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def analyse(
)
for i in range(self.ensemble_size):
ofile.write(
f"Free energy from run {i+1}: {dg_overall[i]: .3f} +/- {er_overall[i]:.3f} kcal/mol\n"
f"Free energy from run {i + 1}: {dg_overall[i]: .3f} +/- {er_overall[i]:.3f} kcal/mol\n"
)
ofile.write(
"Errors are 95 % C.I.s based on the assumption of a Gaussian distribution of free energies\n"
Expand Down Expand Up @@ -831,7 +831,9 @@ def update_paths(self, old_sub_path: str, new_sub_path: str) -> None:
if hasattr(self, "virtual_queue"):
# Virtual queue may have already been updated
if new_sub_path not in self.virtual_queue.log_dir: # type: ignore
self.virtual_queue.log_dir = self.virtual_queue.log_dir.replace(old_sub_path, new_sub_path) # type: ignore
self.virtual_queue.log_dir = self.virtual_queue.log_dir.replace(
old_sub_path, new_sub_path
) # type: ignore
self.virtual_queue._set_up_logging() # type: ignore

# Update the paths of any sub-simulation runners
Expand Down
5 changes: 3 additions & 2 deletions a3fe/run/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""""Utilities for SimulationRunners."""
""" "Utilities for SimulationRunners."""

from __future__ import annotations

Expand Down Expand Up @@ -29,7 +29,8 @@ def check_has_wat_and_box(system: _BSS._SireWrappers._system.System) -> None: #


def get_simtime(
sim_runner: "SimulationRunner", run_nos: _Optional[_List[int]] = None # noqa: F821
sim_runner: "SimulationRunner", # noqa: F821
run_nos: _Optional[_List[int]] = None,
) -> float:
"""
Get the simulation time of a sub simulation runner, in ns. This function
Expand Down
Loading

0 comments on commit 8b45f71

Please sign in to comment.