Skip to content

Commit

Permalink
Re-factor of config to allow for the logger to be used within config …
Browse files Browse the repository at this point in the history
…itself at the correct level.

Also adds a test for the writing of the logfile
  • Loading branch information
mb2055 committed Nov 7, 2023
1 parent d4faed6 commit 257a99c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
38 changes: 29 additions & 9 deletions src/somd2/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class Config:

def __init__(
self,
log_level="info",
log_file=None,
runtime="1ns",
timestep="4fs",
temperature="300K",
Expand Down Expand Up @@ -98,8 +100,6 @@ def __init__(
output_directory="output",
restart=False,
write_config=True,
log_level="info",
log_file=None,
supress_overwrite_warning=False,
):
"""
Expand Down Expand Up @@ -208,8 +208,16 @@ def __init__(
log_file: str
Name of log file, will be saved in output directory.
supress_overwrite_warning: bool
Whether to supress the warning when overwriting files in the output directory.
"""

# Setup logger before doing anything else
self.log_level = log_level
self.log_file = log_file
self.output_directory = output_directory

self.runtime = runtime
self.temperature = temperature
self.pressure = pressure
Expand Down Expand Up @@ -238,10 +246,9 @@ def __init__(
self.max_gpus = max_gpus
self.run_parallel = run_parallel
self.restart = restart
self.output_directory = output_directory

self.write_config = write_config
self.log_level = log_level
self.log_file = log_file

self.supress_overwrite_warning = supress_overwrite_warning

def __str__(self):
Expand Down Expand Up @@ -880,6 +887,10 @@ def restart(self):
def restart(self, restart):
if not isinstance(restart, bool):
raise ValueError("'restart' must be of type 'bool'")
if not restart and self.directory_existed:
_logger.warning(
f"Output directory {self.output_directory} already exists files may be overwritten"
)
self._restart = restart

@property
Expand All @@ -888,6 +899,7 @@ def output_directory(self):

@output_directory.setter
def output_directory(self, output_directory):
self.cirectory_existed = False
if not isinstance(output_directory, _Path):
try:
output_directory = _Path(output_directory)
Expand All @@ -900,10 +912,12 @@ def output_directory(self, output_directory):
raise ValueError(
f"Output directory {output_directory} does not exist and cannot be created"
)
elif not self.restart:
_logger.warning(
f"Output directory {output_directory} already exists files may be overwritten"
)
else:
self.directory_existed = True
if self.log_file is not None:
# Can now add the log file
_logger.add(output_directory / self.log_file, level=self.log_level.upper())
_logger.debug(f"Logging to {output_directory / self.log_file}")
self._output_directory = output_directory

@property
Expand All @@ -929,6 +943,11 @@ def log_level(self, log_level):
raise ValueError(
f"Log level not recognised. Valid log levels are: {', '.join(self._choices['log_level'])}"
)
# Do logging setup here for use in the rest of the ocnfig and all other modules.
import sys

_logger.remove()
_logger.add(sys.stderr, level=log_level.upper(), enqueue=True)
self._log_level = log_level

@property
Expand All @@ -939,6 +958,7 @@ def log_file(self):
def log_file(self, log_file):
if log_file is not None and not isinstance(log_file, str):
raise TypeError("'log_file' must be of type 'str'")
# Can't add the logfile to the logger here as we don't know the output directory yet.
self._log_file = log_file

@property
Expand Down
21 changes: 21 additions & 0 deletions tests/runner/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,24 @@ def test_dynamics_options():
assert config_inp.integrator.lower().replace(
"_", ""
) == d.integrator().__class__.__name__.lower().replace("integrator", "")


def test_logfile_creation():
# Test that the logfile is created by either the initialisation of the runner or of a config
with tempfile.TemporaryDirectory() as tmpdir:
# Load the demo stream file.
mols = sr.load(sr.expand(sr.tutorial_url, "merged_molecule.s3"))
from pathlib import Path

# Instantiate a runner using the default config.
# (All default options, other than platform="cpu".)
config = Config(output_directory=tmpdir, log_file="test.log")
assert config.log_file is not None
assert Path.exists(config.output_directory / config.log_file)
Path.unlink(config.output_directory / config.log_file)

# Instantiate a runner using the default config.
# (All default options, other than platform="cpu".)
runner = Runner(mols, Config(output_directory=tmpdir, log_file="test.log"))
assert runner._config.log_file is not None
assert Path.exists(runner._config.output_directory / runner._config.log_file)

0 comments on commit 257a99c

Please sign in to comment.