Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for writing energy components to file and preserving the LJ sigma parameter for ghost atoms #44

Merged
merged 9 commits into from
May 30, 2024
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ simulations. Built on top of [Sire](https://github.com/OpenBioSim/sire) and [Ope
First create a conda environment using the provided environment file:

```
mamba create -f environment.yaml
conda env create -f environment.yaml
```

(We recommend using [Miniforge](https://github.com/conda-forge/miniforge).)

Now install `somd2` into the environment:

```
mamba activate somd2
conda activate somd2
pip install --editable .
```

Expand Down
16 changes: 16 additions & 0 deletions src/somd2/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
overwrite=False,
somd1_compatibility=False,
pert_file=None,
save_energy_components=False,
):
"""
Constructor.
Expand Down Expand Up @@ -281,6 +282,10 @@ def __init__(
pert_file: str
The path to a SOMD1 perturbation file to apply to the reference system.
When set, this will automatically set 'somd1_compatibility' to True.

save_energy_components: bool
Whether to save the energy contribution for each force when checkpointing.
This is useful when debugging crashes.
"""

# Setup logger before doing anything else
Expand Down Expand Up @@ -327,6 +332,7 @@ def __init__(
self.restart = restart
self.somd1_compatibility = somd1_compatibility
self.pert_file = pert_file
self.save_energy_components = save_energy_components

self.write_config = write_config

Expand Down Expand Up @@ -1201,6 +1207,16 @@ def pert_file(self, pert_file):
if pert_file is not None:
self._somd1_compatibility = True

@property
def save_energy_components(self):
return self._save_energy_components

@save_energy_components.setter
def save_energy_components(self, save_energy_components):
if not isinstance(save_energy_components, bool):
raise ValueError("'save_energy_components' must be of type 'bool'")
self._save_energy_components = save_energy_components

@property
def output_directory(self):
return self._output_directory
Expand Down
54 changes: 53 additions & 1 deletion src/somd2/runner/_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def __init__(
self._config.restart,
)

self._nrg_sample = 0
self._nrg_file = "energy_components.txt"

@staticmethod
def create_filenames(lambda_array, lambda_value, output_directory, restart=False):
# Create incremental file name for current restart.
Expand All @@ -153,6 +156,7 @@ def increment_filename(base_filename, suffix):
filenames["energy_traj"] = f"energy_traj_{lam}.parquet"
filenames["trajectory"] = f"traj_{lam}.dcd"
filenames["trajectory_chunk"] = f"traj_{lam}_"
filenames["energy_components"] = f"energy_components_{lam}.txt"
if restart:
filenames["config"] = increment_filename("config", "yaml")
else:
Expand Down Expand Up @@ -371,7 +375,7 @@ def generate_lam_vals(lambda_base, increment):
)

if self._config.checkpoint_frequency.value() > 0.0:
# Calculate the number of blocks and the remaineder time.
# Calculate the number of blocks and the remainder time.
frac = (
self._config.runtime.value() / self._config.checkpoint_frequency.value()
)
Expand Down Expand Up @@ -409,6 +413,10 @@ def generate_lam_vals(lambda_base, increment):

# Checkpoint.
try:
# Save the energy contribution for each force.
if self._config.save_energy_components:
self._save_energy_components()

# Set to the current block number if this is a restart.
if x == 0:
x = self._current_block
Expand Down Expand Up @@ -584,3 +592,47 @@ def get_timing(self):

def _cleanup(self):
del self._dyn

def _save_energy_components(self):

from copy import deepcopy
import openmm

# Get the current context and system.
context = self._dyn._d._omm_mols
system = deepcopy(context.getSystem())

# Add each force to a unique group.
for i, f in enumerate(system.getForces()):
f.setForceGroup(i)

# Create a new context.
new_context = openmm.Context(system, deepcopy(context.getIntegrator()))
new_context.setPositions(context.getState(getPositions=True).getPositions())

header = f"{'# Sample':>10}"
record = f"{self._nrg_sample:>10}"

# Process the records.
for i, f in enumerate(system.getForces()):
state = new_context.getState(getEnergy=True, groups={i})
header += f"{f.getName():>25}"
record += f"{state.getPotentialEnergy().value_in_unit(openmm.unit.kilocalories_per_mole):>25.2f}"

# Write to file.
if self._nrg_sample == 0:
with open(
self._config.output_directory / self._filenames["energy_components"],
"w",
) as f:
f.write(header + "\n")
f.write(record + "\n")
else:
with open(
self._config.output_directory / self._filenames["energy_components"],
"a",
) as f:
f.write(record + "\n")

# Increment the sample number.
self._nrg_sample += 1
9 changes: 7 additions & 2 deletions src/somd2/runner/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(self, system, config):

_logger.info("Applying SOMD1 perturbation compatibility.")
self._system = _make_compatible(self._system)
self._system = _morph.link_to_reference(self._system)

# Next, swap the water topology so that it is in AMBER format.

Expand Down Expand Up @@ -151,12 +152,14 @@ def __init__(self, system, config):

# Only check for light atoms by the maxium end state mass if running
# in SOMD1 compatibility mode. Ghost atoms are considered light when
# adding bond constraints.
# adding bond constraints. Also fix the LJ sigma for ghost atoms so
# it isn't scaled to zero.
self._config._extra_args["ghosts_are_light"] = True
self._config._extra_args["check_for_h_by_max_mass"] = True
self._config._extra_args["check_for_h_by_mass"] = False
self._config._extra_args["check_for_h_by_element"] = False
self._config._extra_args["check_for_h_by_ambertype"] = False
self._config._extra_args["fix_ghost_sigmas"] = True

# Check for a periodic space.
self._check_space()
Expand Down Expand Up @@ -969,5 +972,7 @@ def _run(sim, is_restart=False):
filename=self._fnames[lambda_value]["energy_traj"],
)
del system
_logger.success(f"{_lam_sym} = {lambda_value} complete, speed = {speed:.2f} ns day-1")
_logger.success(
f"{_lam_sym} = {lambda_value} complete, speed = {speed:.2f} ns day-1"
)
return True
93 changes: 76 additions & 17 deletions src/somd2/runner/_somd1.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def _make_compatible(system):
except KeyError:
raise KeyError("No perturbable molecules in the system")

# Store a dummy element.
dummy = _SireMol.Element("Xx")
# Store a dummy element and ambertype.
element_dummy = _SireMol.Element("Xx")
ambertype_dummy = "du"

for mol in pert_mols:
# Store the molecule info.
Expand All @@ -69,9 +70,43 @@ def _make_compatible(system):
# Get an editable version of the molecule.
edit_mol = mol.edit()

##########################
# First process the bonds.
##########################
##################################
# First fix zero LJ sigmas values.
##################################

# Create a null LJParameter.
null_lj = _SireMM.LJParameter()

for atom in mol.atoms():
# Get the end state LJ sigma values.
lj0 = atom.property("LJ0")
lj1 = atom.property("LJ1")

# Lambda = 0 state has a zero sigma value.
if lj0.sigma() == null_lj.sigma():
# Use the sigma value from the lambda = 1 state.
edit_mol = (
edit_mol.atom(atom.index())
.set_property(
"LJ0", _SireMM.LJParameter(lj1.sigma(), lj0.epsilon())
)
.molecule()
)

# Lambda = 1 state has a zero sigma value.
if lj1.sigma() == null_lj.sigma():
# Use the sigma value from the lambda = 0 state.
edit_mol = (
edit_mol.atom(atom.index())
.set_property(
"LJ1", _SireMM.LJParameter(lj0.sigma(), lj1.epsilon())
)
.molecule()
)

########################
# Now process the bonds.
########################

new_bonds0 = _SireMM.TwoAtomFunctions(mol.info())
new_bonds1 = _SireMM.TwoAtomFunctions(mol.info())
Expand Down Expand Up @@ -534,17 +569,26 @@ def _has_dummy(mol, idxs, is_lambda1=False):
Whether a dummy atom is present.
"""

# Set the element property associated with the end state.
# We need to check by ambertype too since this molecule may have been
# created via sire.morph.create_from_pertfile, in which case the element
# property will have been set to the end state with the largest mass, i.e.
# may no longer by a dummy.
if is_lambda1:
prop = "element1"
element_prop = "element1"
ambertype_prop = "ambertype1"
else:
prop = "element0"
element_prop = "element0"
ambertype_prop = "ambertype0"

dummy = _SireMol.Element(0)
element_dummy = _SireMol.Element(0)
ambertype_dummy = "du"

# Check whether an of the atoms is a dummy.
for idx in idxs:
if mol.atom(idx).property(prop) == dummy:
if (
mol.atom(idx).property(element_prop) == element_dummy
or mol.atom(idx).property(ambertype_prop) == ambertype_dummy
):
return True

return False
Expand Down Expand Up @@ -573,21 +617,36 @@ def _is_dummy(mol, idxs, is_lambda1=False):
Whether each atom is a dummy.
"""

# Set the element property associated with the end state.
# We need to check by ambertype too since this molecule may have been
# created via sire.morph.create_from_pertfile, in which case the element
# property will have been set to the end state with the largest mass, i.e.
# may no longer by a dummy.
if is_lambda1:
prop = "element1"
element_prop = "element1"
ambertype_prop = "ambertype1"
else:
prop = "element0"
element_prop = "element0"
ambertype_prop = "ambertype0"

# Store a dummy element.
dummy = _SireMol.Element(0)
if is_lambda1:
element_prop = "element1"
ambertype_prop = "ambertype1"
else:
element_prop = "element0"
ambertype_prop = "ambertype0"

element_dummy = _SireMol.Element(0)
ambertype_dummy = "du"

# Initialise a list to store the state of each atom.
is_dummy = []

# Check whether each of the atoms is a dummy.
for idx in idxs:
is_dummy.append(mol.atom(idx).property(prop) == dummy)
is_dummy.append(
mol.atom(idx).property(element_prop) == element_dummy
or mol.atom(idx).property(ambertype_prop) == ambertype_dummy
)

return is_dummy

Expand Down Expand Up @@ -622,7 +681,7 @@ def _apply_pert(system, pert_file):
from sire import morph as _morph

# Get the non-water molecules in the system.
non_waters = system["not water"]
non_waters = system["not water"].molecules()

# Try to apply the perturbation to each non-water molecule.
is_pert = False
Expand Down
Loading