Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for writing energy components to file and preserving the LJ sigma parameter for ghost atoms #44

Merged
merged 9 commits into from
May 30, 2024
16 changes: 16 additions & 0 deletions src/somd2/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
overwrite=False,
somd1_compatibility=False,
pert_file=None,
save_energy_components=False,
):
"""
Constructor.
Expand Down Expand Up @@ -281,6 +282,10 @@ def __init__(
pert_file: str
The path to a SOMD1 perturbation file to apply to the reference system.
When set, this will automatically set 'somd1_compatibility' to True.

save_energy_components: bool
Whether to save the energy contribution for each force when checkpointing.
This is useful when debugging crashes.
"""

# Setup logger before doing anything else
Expand Down Expand Up @@ -327,6 +332,7 @@ def __init__(
self.restart = restart
self.somd1_compatibility = somd1_compatibility
self.pert_file = pert_file
self.save_energy_components = save_energy_components

self.write_config = write_config

Expand Down Expand Up @@ -1201,6 +1207,16 @@ def pert_file(self, pert_file):
if pert_file is not None:
self._somd1_compatibility = True

@property
def save_energy_components(self):
return self._save_energy_components

@save_energy_components.setter
def save_energy_components(self, save_energy_components):
if not isinstance(save_energy_components, bool):
raise ValueError("'save_energy_components' must be of type 'bool'")
self._save_energy_components = save_energy_components

@property
def output_directory(self):
return self._output_directory
Expand Down
54 changes: 53 additions & 1 deletion src/somd2/runner/_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def __init__(
self._config.restart,
)

self._nrg_sample = 0
self._nrg_file = "energy_components.txt"

@staticmethod
def create_filenames(lambda_array, lambda_value, output_directory, restart=False):
# Create incremental file name for current restart.
Expand All @@ -153,6 +156,7 @@ def increment_filename(base_filename, suffix):
filenames["energy_traj"] = f"energy_traj_{lam}.parquet"
filenames["trajectory"] = f"traj_{lam}.dcd"
filenames["trajectory_chunk"] = f"traj_{lam}_"
filenames["energy_components"] = f"energy_components_{lam}.txt"
if restart:
filenames["config"] = increment_filename("config", "yaml")
else:
Expand Down Expand Up @@ -371,7 +375,7 @@ def generate_lam_vals(lambda_base, increment):
)

if self._config.checkpoint_frequency.value() > 0.0:
# Calculate the number of blocks and the remaineder time.
# Calculate the number of blocks and the remainder time.
frac = (
self._config.runtime.value() / self._config.checkpoint_frequency.value()
)
Expand Down Expand Up @@ -409,6 +413,10 @@ def generate_lam_vals(lambda_base, increment):

# Checkpoint.
try:
# Save the energy contribution for each force.
if self._config.save_energy_components:
self._save_energy_components()

# Set to the current block number if this is a restart.
if x == 0:
x = self._current_block
Expand Down Expand Up @@ -584,3 +592,47 @@ def get_timing(self):

def _cleanup(self):
del self._dyn

def _save_energy_components(self):

from copy import deepcopy
import openmm

# Get the current context and system.
context = self._dyn._d._omm_mols
system = deepcopy(context.getSystem())

# Add each force to a unique group.
for i, f in enumerate(system.getForces()):
f.setForceGroup(i)

# Create a new context.
new_context = openmm.Context(system, deepcopy(context.getIntegrator()))
new_context.setPositions(context.getState(getPositions=True).getPositions())

header = f"{'# Sample':>10}"
record = f"{self._nrg_sample:>10}"

# Process the records.
for i, f in enumerate(system.getForces()):
state = new_context.getState(getEnergy=True, groups={i})
header += f"{f.getName():>25}"
record += f"{state.getPotentialEnergy().value_in_unit(openmm.unit.kilocalories_per_mole):>25.2f}"

# Write to file.
if self._nrg_sample == 0:
with open(
self._config.output_directory / self._filenames["energy_components"],
"w",
) as f:
f.write(header + "\n")
f.write(record + "\n")
else:
with open(
self._config.output_directory / self._filenames["energy_components"],
"a",
) as f:
f.write(record + "\n")

# Increment the sample number.
self._nrg_sample += 1
9 changes: 7 additions & 2 deletions src/somd2/runner/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(self, system, config):

_logger.info("Applying SOMD1 perturbation compatibility.")
self._system = _make_compatible(self._system)
self._system = _morph.link_to_reference(self._system)

# Next, swap the water topology so that it is in AMBER format.

Expand Down Expand Up @@ -151,12 +152,14 @@ def __init__(self, system, config):

# Only check for light atoms by the maxium end state mass if running
# in SOMD1 compatibility mode. Ghost atoms are considered light when
# adding bond constraints.
# adding bond constraints. Also fix the LJ sigma for ghost atoms so
# it isn't scaled to zero.
self._config._extra_args["ghosts_are_light"] = True
self._config._extra_args["check_for_h_by_max_mass"] = True
self._config._extra_args["check_for_h_by_mass"] = False
self._config._extra_args["check_for_h_by_element"] = False
self._config._extra_args["check_for_h_by_ambertype"] = False
self._config._extra_args["fix_ghost_sigmas"] = True

# Check for a periodic space.
self._check_space()
Expand Down Expand Up @@ -969,5 +972,7 @@ def _run(sim, is_restart=False):
filename=self._fnames[lambda_value]["energy_traj"],
)
del system
_logger.success(f"{_lam_sym} = {lambda_value} complete, speed = {speed:.2f} ns day-1")
_logger.success(
f"{_lam_sym} = {lambda_value} complete, speed = {speed:.2f} ns day-1"
)
return True
91 changes: 75 additions & 16 deletions src/somd2/runner/_somd1.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def _make_compatible(system):
except KeyError:
raise KeyError("No perturbable molecules in the system")

# Store a dummy element.
dummy = _SireMol.Element("Xx")
# Store a dummy element and ambertype.
element_dummy = _SireMol.Element("Xx")
ambertype_dummy = "du"

for mol in pert_mols:
# Store the molecule info.
Expand All @@ -69,9 +70,43 @@ def _make_compatible(system):
# Get an editable version of the molecule.
edit_mol = mol.edit()

##########################
# First process the bonds.
##########################
#####################################
# First fix the ghost atom LJ sigmas.
#####################################

for atom in mol.atoms():
# Lambda = 0 state is a dummy, use sigma from the lambda = 1 state.
if (
atom.property("element0") == element_dummy
or atom.property("ambertype0") == ambertype_dummy
):
lj0 = atom.property("LJ0")
lj1 = atom.property("LJ1")
edit_mol = (
edit_mol.atom(atom.index())
.set_property(
"LJ0", _SireMM.LJParameter(lj1.sigma(), lj0.epsilon())
)
.molecule()
)
# Lambda = 1 state is a dummy, use sigma from the lambda = 0 state.
elif (
atom.property("element1") == element_dummy
or atom.property("ambertype1") == ambertype_dummy
):
lj0 = atom.property("LJ0")
lj1 = atom.property("LJ1")
edit_mol = (
edit_mol.atom(atom.index())
.set_property(
"LJ1", _SireMM.LJParameter(lj0.sigma(), lj1.epsilon())
)
.molecule()
)

########################
# Now process the bonds.
########################

new_bonds0 = _SireMM.TwoAtomFunctions(mol.info())
new_bonds1 = _SireMM.TwoAtomFunctions(mol.info())
Expand Down Expand Up @@ -534,17 +569,26 @@ def _has_dummy(mol, idxs, is_lambda1=False):
Whether a dummy atom is present.
"""

# Set the element property associated with the end state.
# We need to check by ambertype too since this molecule may have been
# created via sire.morph.create_from_pertfile, in which case the element
# property will have been set to the end state with the largest mass, i.e.
# may no longer by a dummy.
if is_lambda1:
prop = "element1"
element_prop = "element1"
ambertype_prop = "ambertype1"
else:
prop = "element0"
element_prop = "element0"
ambertype_prop = "ambertype0"

dummy = _SireMol.Element(0)
element_dummy = _SireMol.Element(0)
ambertype_dummy = "du"

# Check whether an of the atoms is a dummy.
for idx in idxs:
if mol.atom(idx).property(prop) == dummy:
if (
mol.atom(idx).property(element_prop) == element_dummy
or mol.atom(idx).property(ambertype_prop) == ambertype_dummy
):
return True

return False
Expand Down Expand Up @@ -573,21 +617,36 @@ def _is_dummy(mol, idxs, is_lambda1=False):
Whether each atom is a dummy.
"""

# Set the element property associated with the end state.
# We need to check by ambertype too since this molecule may have been
# created via sire.morph.create_from_pertfile, in which case the element
# property will have been set to the end state with the largest mass, i.e.
# may no longer by a dummy.
if is_lambda1:
prop = "element1"
element_prop = "element1"
ambertype_prop = "ambertype1"
else:
prop = "element0"
element_prop = "element0"
ambertype_prop = "ambertype0"

# Store a dummy element.
dummy = _SireMol.Element(0)
if is_lambda1:
element_prop = "element1"
ambertype_prop = "ambertype1"
else:
element_prop = "element0"
ambertype_prop = "ambertype0"

element_dummy = _SireMol.Element(0)
ambertype_dummy = "du"

# Initialise a list to store the state of each atom.
is_dummy = []

# Check whether each of the atoms is a dummy.
for idx in idxs:
is_dummy.append(mol.atom(idx).property(prop) == dummy)
is_dummy.append(
mol.atom(idx).property(element_prop) == element_dummy
or mol.atom(idx).property(ambertype_prop) == ambertype_dummy
)

return is_dummy

Expand Down