Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add backup writing to vivarium simulation Context #455

Merged
29 changes: 22 additions & 7 deletions src/vivarium/framework/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from pathlib import Path
from pprint import pformat
from time import time
from typing import Any, Dict, List, Optional, Set, Union

import dill
Expand All @@ -43,7 +44,7 @@
from vivarium.framework.randomness import RandomnessInterface
from vivarium.framework.resource import ResourceInterface
from vivarium.framework.results import ResultsInterface
from vivarium.framework.time import TimeInterface
from vivarium.framework.time import Time, TimeInterface
from vivarium.framework.values import ValuesInterface


Expand Down Expand Up @@ -207,6 +208,11 @@ def __init__(
def name(self) -> str:
return self._name

@property
def current_time(self) -> Time:
patricktnast marked this conversation as resolved.
Show resolved Hide resolved
"""Returns the current simulation time."""
return self._clock.time

def get_results(self) -> Dict[str, pd.DataFrame]:
"""Return the formatted results."""
return self._results.get_results()
Expand Down Expand Up @@ -259,9 +265,21 @@ def step(self) -> None:
self.time_step_emitters[event](pop_to_update)
self._clock.step_forward(self.get_population().index)

def run(self) -> None:
while not self.past_stop_time():
self.step()
def run(
self,
backup_path: Optional[Path] = None,
backup_freq: Optional[Union[int, float]] = None,
) -> None:
if backup_freq:
time_to_save = time() + backup_freq
while self._clock.time < self._clock.stop_time:
self.step()
if time() >= time_to_save:
self.write_backup(backup_path)
time_to_save = time() + backup_freq
else:
while self._clock.time < self._clock.stop_time:
patricktnast marked this conversation as resolved.
Show resolved Hide resolved
self.step()

def finalize(self) -> None:
self._lifecycle.set_state("simulation_end")
Expand Down Expand Up @@ -327,9 +345,6 @@ def add_components(self, component_list: List[Component]) -> None:
def get_population(self, untracked: bool = True) -> pd.DataFrame:
return self._population.get_population(untracked)

def past_stop_time(self) -> bool:
return self._clock.time >= self._clock.stop_time

def __repr__(self):
return f"SimulationContext({self.name})"

Expand Down
33 changes: 31 additions & 2 deletions tests/framework/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from itertools import product
from pathlib import Path
from time import time
from typing import Dict, List

import dill
Expand Down Expand Up @@ -348,8 +349,10 @@ def test_SimulationContext_report_write(SimulationContext, base_config, componen
assert results.equals(written_results)


@pytest.mark.skip(reason="TODO: Figure out how to make Dill serialize in pytest")
def test_SimulationContext_write_backup(SimulationContext, tmpdir):
def test_SimulationContext_write_backup(mocker, SimulationContext, tmpdir):
# TODO: Remove mocks when we can use dill in pytest.
patricktnast marked this conversation as resolved.
Show resolved Hide resolved
mocker.patch("vivarium.framework.engine.dill.dump")
mocker.patch("vivarium.framework.engine.dill.load", return_value=SimulationContext())
sim = SimulationContext()
backup_path = tmpdir / "backup.pkl"
sim.write_backup(backup_path)
Expand All @@ -359,6 +362,32 @@ def test_SimulationContext_write_backup(SimulationContext, tmpdir):
assert isinstance(sim_backup, SimulationContext)

patricktnast marked this conversation as resolved.
Show resolved Hide resolved

def test_SimulationContext_run_with_backup(mocker, SimulationContext, base_config, tmpdir):
mocker.patch("vivarium.framework.engine.SimulationContext.write_backup")
original_time = time()

def time_generator():
current_time = original_time
while True:
yield current_time
current_time += 5

mocker.patch("vivarium.framework.engine.time", side_effect=time_generator())
components = [
Hogwarts(),
HousePointsObserver(),
NoStratificationsQuidditchWinsObserver(),
QuidditchWinsObserver(),
HogwartsResultsStratifier(),
]
sim = SimulationContext(base_config, components, configuration=HARRY_POTTER_CONFIG)
backup_path = tmpdir / "backup.pkl"
sim.setup()
sim.initialize_simulants()
sim.run(backup_path=backup_path, backup_freq=5)
assert sim.write_backup.call_count == _get_num_steps(sim)


def test_get_results_formatting(SimulationContext, base_config):
"""Test formatted results are as expected"""
components = [
Expand Down
Loading