diff --git a/setup.py b/setup.py index 5e6f37812..d0131677c 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ "networkx", "loguru", "pyarrow", + "dill", # Type stubs "pandas-stubs", ] diff --git a/src/vivarium/framework/engine.py b/src/vivarium/framework/engine.py index 96934835b..e3c47ba9a 100644 --- a/src/vivarium/framework/engine.py +++ b/src/vivarium/framework/engine.py @@ -21,8 +21,10 @@ from pathlib import Path from pprint import pformat +from time import time from typing import Any, Dict, List, Optional, Set, Union +import dill import numpy as np import pandas as pd import yaml @@ -42,7 +44,7 @@ from vivarium.framework.randomness import RandomnessInterface from vivarium.framework.resource import ResourceInterface from vivarium.framework.results import ResultsInterface -from vivarium.framework.time import TimeInterface +from vivarium.framework.time import Time, TimeInterface from vivarium.framework.values import ValuesInterface @@ -206,6 +208,11 @@ def __init__( def name(self) -> str: return self._name + @property + def current_time(self) -> Time: + """Returns the current simulation time.""" + return self._clock.time + def get_results(self) -> Dict[str, pd.DataFrame]: """Return the formatted results.""" return self._results.get_results() @@ -246,7 +253,7 @@ def initialize_simulants(self) -> None: self._clock.step_forward(self.get_population().index) def step(self) -> None: - self._logger.debug(self._clock.time) + self._logger.debug(self.current_time) for event in self.time_step_events: self._logger.debug(f"Event: {event}") self._lifecycle.set_state(event) @@ -258,9 +265,22 @@ def step(self) -> None: self.time_step_emitters[event](pop_to_update) self._clock.step_forward(self.get_population().index) - def run(self) -> None: - while self._clock.time < self._clock.stop_time: - self.step() + def run( + self, + backup_path: Optional[Path] = None, + backup_freq: Optional[Union[int, float]] = None, + ) -> None: + if backup_freq: + time_to_save = time() + backup_freq + while self.current_time < self._clock.stop_time: + self.step() + if time() >= time_to_save: + self._logger.debug(f"Writing Simulation Backup to {backup_path}") + self.write_backup(backup_path) + time_to_save = time() + backup_freq + else: + while self.current_time < self._clock.stop_time: + self.step() def finalize(self) -> None: self._lifecycle.set_state("simulation_end") @@ -296,6 +316,10 @@ def _write_results(self, results: dict[str, pd.DataFrame]) -> None: except ConfigurationKeyError: self._logger.info("No results directory set; results are not written to disk.") + def write_backup(self, backup_path: Path) -> None: + with open(backup_path, "wb") as f: + dill.dump(self, f, protocol=dill.HIGHEST_PROTOCOL) + def get_performance_metrics(self) -> pd.DataFrame: timing_dict = self._lifecycle.timings total_time = np.sum([np.sum(v) for v in timing_dict.values()]) diff --git a/tests/framework/test_engine.py b/tests/framework/test_engine.py index 8520e244e..5146f4e99 100644 --- a/tests/framework/test_engine.py +++ b/tests/framework/test_engine.py @@ -1,8 +1,10 @@ import math from itertools import product from pathlib import Path +from time import time from typing import Dict, List +import dill import pandas as pd import pytest @@ -347,6 +349,45 @@ def test_SimulationContext_report_write(SimulationContext, base_config, componen assert results.equals(written_results) +def test_SimulationContext_write_backup(mocker, SimulationContext, tmpdir): + # TODO MIC-5216: Remove mocks when we can use dill in pytest. + mocker.patch("vivarium.framework.engine.dill.dump") + mocker.patch("vivarium.framework.engine.dill.load", return_value=SimulationContext()) + sim = SimulationContext() + backup_path = tmpdir / "backup.pkl" + sim.write_backup(backup_path) + assert backup_path.exists() + with open(backup_path, "rb") as f: + sim_backup = dill.load(f) + assert isinstance(sim_backup, SimulationContext) + + +def test_SimulationContext_run_with_backup(mocker, SimulationContext, base_config, tmpdir): + mocker.patch("vivarium.framework.engine.SimulationContext.write_backup") + original_time = time() + + def time_generator(): + current_time = original_time + while True: + yield current_time + current_time += 5 + + mocker.patch("vivarium.framework.engine.time", side_effect=time_generator()) + components = [ + Hogwarts(), + HousePointsObserver(), + NoStratificationsQuidditchWinsObserver(), + QuidditchWinsObserver(), + HogwartsResultsStratifier(), + ] + sim = SimulationContext(base_config, components, configuration=HARRY_POTTER_CONFIG) + backup_path = tmpdir / "backup.pkl" + sim.setup() + sim.initialize_simulants() + sim.run(backup_path=backup_path, backup_freq=5) + assert sim.write_backup.call_count == _get_num_steps(sim) + + def test_get_results_formatting(SimulationContext, base_config): """Test formatted results are as expected""" components = [