Skip to content

Commit

Permalink
Move pert_file option into Config and apply in Runner setup.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Mar 13, 2024
1 parent e2e7a97 commit c913c7d
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 95 deletions.
87 changes: 2 additions & 85 deletions src/somd2/app/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,13 @@ def cli():
"or the reference system. If a reference system, then this must be "
"combined with a perturbation file via the --pert-file argument.",
)
parser.add_argument(
"--pert-file",
type=str,
required=False,
help="Path to a file containing the perturbation to apply "
"to the reference system.",
)

# Parse the arguments into a dictionary.
args = vars(parser.parse_args())

# Pop the YAML config, system, and pert file from the arguments dictionary.
# Pop the YAML config and system from the arguments dictionary.
config = args.pop("config")
system = args.pop("system")
pert_file = args.pop("pert_file")

# If set, read the YAML config file.
if config is not None:
Expand All @@ -87,12 +79,9 @@ def cli():
# will override those in the config.
args = vars(parser.parse_args(namespace=Namespace(**config)))

# Re-pop the YAML config, system, and pert file from the arguments
# dictionary.
# Re-pop the YAML config and system from the arguments dictionary.
args.pop("config")
args.pop("system")
if pert_file is None:
pert_file = args.pop("pert_file")

# Instantiate a Config object to validate the arguments.
config = Config(**args)
Expand All @@ -101,80 +90,8 @@ def cli():
_logger.info(f"somd2 version: {__version__}")
_logger.info(f"sire version: {sire_version}+{sire_revisionid}")

# Try to apply the perturbation to the reference system.
if pert_file is not None:
_logger.info(f"Applying perturbation to reference system: {pert_file}")
system = apply_pert(system, pert_file)

# Instantiate a Runner object to run the simulation.
runner = Runner(system, config)

# Run the simulation.
runner.run()


def apply_pert(system, pert_file):
"""
Helper function to apply a perturbation to a reference system.
Parameters
----------
system: str
Path to a stream file containing the reference system.
pert_file: str
Path to a stream file containing the perturbation to apply to the
reference system.
Returns
-------
system: sire.system.System
The perturbable system.
"""

if not isinstance(system, str):
raise TypeError("'system' must be of type 'str'.")

if not isinstance(pert_file, str):
raise TypeError("'pert_file' must be of type 'str'.")

import os as _os

if not _os.path.isfile(system):
raise FileNotFoundError(f"'{system}' does not exist.")

if not _os.path.isfile(pert_file):
raise FileNotFoundError(f"'{pert_file}' does not exist.")

from sire import stream as _stream
from sire import morph as _morph

# Load the reference system.
try:
system = _stream.load(system)
except Exception as e:
raise ValueError(f"Failed to load the reference 'system': {e}")

# Get the non-water molecules in the system.
non_waters = system["not water"]

# Try to apply the perturbation to each non-water molecule.
is_pert = False
for mol in non_waters:
try:
pert_mol = _morph.create_from_pertfile(mol, pert_file)
is_pert = True
break
except:
pass

if not is_pert:
raise ValueError(f"Failed to apply the perturbation in '{pert_file}'.")

# Replace the reference molecule with the perturbed molecule.
system.remove(mol)
system.add(pert_mol)

return system
25 changes: 25 additions & 0 deletions src/somd2/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(
write_config=True,
overwrite=False,
somd1_compatibility=False,
pert_file=None,
):
"""
Constructor.
Expand Down Expand Up @@ -243,6 +244,10 @@ def __init__(
somd1_compatibility: bool
Whether to run using a SOMD1 compatible perturbation.
pert_file: str
The path to a SOMD1 perturbation file to apply to the reference system.
When set, this will automatically set 'somd1_compatibility' to True.
"""

# Setup logger before doing anything else
Expand Down Expand Up @@ -284,6 +289,7 @@ def __init__(
self.run_parallel = run_parallel
self.restart = restart
self.somd1_compatibility = somd1_compatibility
self.pert_file = pert_file

self.write_config = write_config

Expand Down Expand Up @@ -1026,6 +1032,25 @@ def somd1_compatibility(self, somd1_compatibility):
raise ValueError("'somd1_compatibility' must be of type 'bool'")
self._somd1_compatibility = somd1_compatibility

@property
def pert_file(self):
return self._pert_file

@pert_file.setter
def pert_file(self, pert_file):
import os

if pert_file is not None and not isinstance(pert_file, str):
raise TypeError("'pert_file' must be of type 'str'")

if pert_file is not None and not os.path.exists(pert_file):
raise ValueError(f"Perturbation file does not exist: {pert_file}")

self._pert_file = pert_file

if pert_file is not None:
self._somd1_compatibility = True

@property
def output_directory(self):
return self._output_directory
Expand Down
29 changes: 21 additions & 8 deletions src/somd2/runner/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,25 @@ def __init__(self, system, config):
else:
self._system = system

# Validate the configuration.
if not isinstance(config, _Config):
raise TypeError("'config' must be of type 'somd2.config.Config'")
self._config = config
self._config._extra_args = {}

# Check whether we need to apply a perturbation to the reference system.
if self._config.pert_file is not None:
_logger.info(
f"Applying perturbation to reference system: {self._config.pert_file}"
)
try:
from ._somd1 import _apply_pert

self._system = _apply_pert(self._system, self._config.pert_file)
self._config.somd1_compatibility = True
except Exception as e:
raise IOError(f"Unable to apply perturbation to reference system: {e}")

# Make sure the system contains perturbable molecules.
try:
self._system.molecules("property is_perturbable")
Expand All @@ -90,20 +109,14 @@ def __init__(self, system, config):
# Link properties to the lambda = 0 end state.
self._system = _morph.link_to_reference(self._system)

# Validate the configuration.
if not isinstance(config, _Config):
raise TypeError("'config' must be of type 'somd2.config.Config'")
self._config = config
self._config._extra_args = {}

# We're running in SOMD1 compatibility mode.
if self._config.somd1_compatibility:
from ._somd1 import _apply_somd1_pert
from ._somd1 import _make_compatible

# First, try to make the perturbation SOMD1 compatible.

_logger.info("Applying SOMD1 perturbation compatibility.")
self._system = _apply_somd1_pert(self._system)
self._system = _make_compatible(self._system)

# Next, swap the water topology so that it is in AMBER format.

Expand Down
56 changes: 54 additions & 2 deletions src/somd2/runner/_somd1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
import sire.legacy.Mol as _SireMol


def _apply_somd1_pert(system):
def _make_compatible(system):
"""
Applies the somd1 perturbation to the system.
Makes a perturbation SOMD1 compatible.
Parameters
----------
Expand Down Expand Up @@ -611,3 +611,55 @@ def _is_dummy(mol, idxs, is_lambda1=False):
is_dummy.append(mol.atom(idx).property(prop) == dummy)

return is_dummy


def _apply_pert(system, pert_file):
"""
Helper function to apply a perturbation to a reference system.
Parameters
----------
system: sr.system.System
The reference system.
pert_file: str
Path to a stream file containing the perturbation to apply to the
reference system.
Returns
-------
system: sire.system.System
The perturbable system.
"""

if not isinstance(system, _System):
raise TypeError("'system' must be of type 'sr.system.System'.")

if not isinstance(pert_file, str):
raise TypeError("'pert_file' must be of type 'str'.")

from sire import morph as _morph

# Get the non-water molecules in the system.
non_waters = system["not water"]

# Try to apply the perturbation to each non-water molecule.
is_pert = False
for mol in non_waters:
try:
pert_mol = _morph.create_from_pertfile(mol, pert_file)
is_pert = True
break
except:
pass

if not is_pert:
raise ValueError(f"Failed to apply the perturbation in '{pert_file}'.")

# Replace the reference molecule with the perturbed molecule.
system.remove(mol)
system.add(pert_mol)

return system

0 comments on commit c913c7d

Please sign in to comment.