From 63ab5ec4503bbc987804157cbedd3d818c5d9d0e Mon Sep 17 00:00:00 2001 From: Riccardo Bertossa Date: Tue, 29 Aug 2023 13:07:50 +0200 Subject: [PATCH] plumpy.ProcessListener made persistent solves https://github.com/aiidateam/plumpy/issues/273 We implement the persistence of ProcessListener by deriving the class ProcessListener and EventHelper from persistence.Savable. The class EventHelper is moved to a new file because of a circular import with utils and persistence Fixing the test There was a circular reference issue in the test listener that was storing a reference to the process inside it, making its serialization impossible. To fix the tests an ugly hack was used: storing the reference to the process outside the class in a global dict using id as keys. Some more ugly hacks are needed to check correctly the equality of two processes. We must ignore the fact that the instances if the listener are different. We call del on dict items of the ProcessListener's global implemented in the test suite to clean the golbal variables addressed issues in https://github.com/aiidateam/plumpy/pull/274 --- src/plumpy/base/utils.py | 2 + src/plumpy/event_helper.py | 52 +++++++++++++++++++++ src/plumpy/process_listener.py | 26 ++++++++++- src/plumpy/processes.py | 17 ++++--- src/plumpy/utils.py | 41 ----------------- test/test_processes.py | 7 ++- test/test_workchains.py | 43 +++++++++++++++++ test/utils.py | 84 +++++++++++++++++++++++++++++----- 8 files changed, 209 insertions(+), 63 deletions(-) create mode 100644 src/plumpy/event_helper.py diff --git a/src/plumpy/base/utils.py b/src/plumpy/base/utils.py index c4820f1b..2b9728e9 100644 --- a/src/plumpy/base/utils.py +++ b/src/plumpy/base/utils.py @@ -16,6 +16,8 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> None: wrapped(self, *args, **kwargs) self._called -= 1 + #the following is to show the correct name later in the call_with_super_check error message + wrapper.__name__ = wrapped.__name__ return wrapper diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py new file mode 100644 index 00000000..b36cd602 --- /dev/null +++ b/src/plumpy/event_helper.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +import logging +from typing import TYPE_CHECKING, Any, Callable, Set, Type + +from . import persistence + +if TYPE_CHECKING: + from .process_listener import ProcessListener # pylint: disable=cyclic-import + +_LOGGER = logging.getLogger(__name__) + + +@persistence.auto_persist('_listeners', '_listener_type') +class EventHelper(persistence.Savable): + + def __init__(self, listener_type: 'Type[ProcessListener]'): + assert listener_type is not None, 'Must provide valid listener type' + + self._listener_type = listener_type + self._listeners: 'Set[ProcessListener]' = set() + + def add_listener(self, listener: 'ProcessListener') -> None: + assert isinstance(listener, self._listener_type), 'Listener is not of right type' + self._listeners.add(listener) + + def remove_listener(self, listener: 'ProcessListener') -> None: + self._listeners.discard(listener) + + def remove_all_listeners(self) -> None: + self._listeners.clear() + + @property + def listeners(self) -> 'Set[ProcessListener]': + return self._listeners + + def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + """Call an event method on all listeners. + + :param event_function: the method of the ProcessListener + :param args: arguments to pass to the method + :param kwargs: keyword arguments to pass to the method + + """ + if event_function is None: + raise ValueError('Must provide valid event method') + + # Make a copy of the list for iteration just in case it changes in a callback + for listener in list(self.listeners): + try: + getattr(listener, event_function.__name__)(*args, **kwargs) + except Exception as exception: # pylint: disable=broad-except + _LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception) diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index c0a49d9a..110394a2 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -1,6 +1,9 @@ # -*- coding: utf-8 -*- import abc -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Dict, Optional + +from . import persistence +from .utils import SAVED_STATE_TYPE, protected __all__ = ['ProcessListener'] @@ -8,7 +11,26 @@ from .processes import Process # pylint: disable=cyclic-import -class ProcessListener(metaclass=abc.ABCMeta): +@persistence.auto_persist('_params') +class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta): + + # region Persistence methods + + def __init__(self) -> None: + super().__init__() + self._params: Dict[str, Any] = {} + + def init(self, **kwargs: Any) -> None: + self._params = kwargs + + @protected + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] + ) -> None: + super().load_instance_state(saved_state, load_context) + self.init(**saved_state['_params']) + + # endregion def on_process_created(self, process: 'Process') -> None: """ diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 40b2ccbb..ecffb1f4 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -42,6 +42,7 @@ from .base import state_machine from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check +from .event_helper import EventHelper from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected @@ -91,7 +92,9 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: return func_wrapper -@persistence.auto_persist('_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status') +@persistence.auto_persist( + '_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper' +) class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): """ The Process class is the base for any unit of work in plumpy. @@ -289,7 +292,7 @@ def __init__( # Runtime variables self._future = persistence.SavableFuture(loop=self._loop) - self.__event_helper = utils.EventHelper(ProcessListener) + self._event_helper = EventHelper(ProcessListener) self._logger = logger self._communicator = communicator @@ -612,7 +615,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi # Runtime variables, set initial states self._future = persistence.SavableFuture() - self.__event_helper = utils.EventHelper(ProcessListener) + self._event_helper = EventHelper(ProcessListener) self._logger = None self._communicator = None @@ -661,11 +664,11 @@ def add_process_listener(self, listener: ProcessListener) -> None: """ assert (listener != self), 'Cannot listen to yourself!' - self.__event_helper.add_listener(listener) + self._event_helper.add_listener(listener) def remove_process_listener(self, listener: ProcessListener) -> None: """Remove a process listener from the process.""" - self.__event_helper.remove_listener(listener) + self._event_helper.remove_listener(listener) @protected def set_logger(self, logger: logging.Logger) -> None: @@ -778,7 +781,7 @@ def on_output_emitting(self, output_port: str, value: Any) -> None: """Output is about to be emitted.""" def on_output_emitted(self, output_port: str, value: Any, dynamic: bool) -> None: - self.__event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic) + self._event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic) @super_check def on_wait(self, awaitables: Sequence[Awaitable]) -> None: @@ -891,7 +894,7 @@ def on_close(self) -> None: self._closed = True def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - self.__event_helper.fire_event(evt, self, *args, **kwargs) + self._event_helper.fire_event(evt, self, *args, **kwargs) # endregion diff --git a/src/plumpy/utils.py b/src/plumpy/utils.py index a11ebd01..a2bdc8ae 100644 --- a/src/plumpy/utils.py +++ b/src/plumpy/utils.py @@ -27,47 +27,6 @@ PID_TYPE = Hashable # pylint: disable=invalid-name -class EventHelper: - - def __init__(self, listener_type: 'Type[ProcessListener]'): - assert listener_type is not None, 'Must provide valid listener type' - - self._listener_type = listener_type - self._listeners: 'Set[ProcessListener]' = set() - - def add_listener(self, listener: 'ProcessListener') -> None: - assert isinstance(listener, self._listener_type), 'Listener is not of right type' - self._listeners.add(listener) - - def remove_listener(self, listener: 'ProcessListener') -> None: - self._listeners.discard(listener) - - def remove_all_listeners(self) -> None: - self._listeners.clear() - - @property - def listeners(self) -> 'Set[ProcessListener]': - return self._listeners - - def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - """Call an event method on all listeners. - - :param event_function: the method of the ProcessListener - :param args: arguments to pass to the method - :param kwargs: keyword arguments to pass to the method - - """ - if event_function is None: - raise ValueError('Must provide valid event method') - - # Make a copy of the list for iteration just in case it changes in a callback - for listener in list(self.listeners): - try: - getattr(listener, event_function.__name__)(*args, **kwargs) - except Exception as exception: # pylint: disable=broad-except - _LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception) - - class Frozendict(Mapping): """ An immutable wrapper around dictionaries that implements the complete :py:class:`collections.abc.Mapping` diff --git a/test/test_processes.py b/test/test_processes.py index 737b463d..4b7494f5 100644 --- a/test/test_processes.py +++ b/test/test_processes.py @@ -800,7 +800,9 @@ def test_instance_state_with_outputs(self): # Check that it is a copy self.assertIsNot(outputs, bundle.get(BundleKeys.OUTPUTS, {})) # Check the contents are the same - self.assertDictEqual(outputs, bundle.get(BundleKeys.OUTPUTS, {})) + #we remove the ProcessSaver instance that is an object used only for testing + utils.compare_dictionaries(None, None, outputs, bundle.get(BundleKeys.OUTPUTS, {}), exclude={'_listeners'}) + #self.assertDictEqual(outputs, bundle.get(BundleKeys.OUTPUTS, {})) self.assertIsNot(proc.outputs, saver.snapshots[-1].get(BundleKeys.OUTPUTS, {})) @@ -875,7 +877,8 @@ def _check_round_trip(self, proc1): bundle2 = plumpy.Bundle(proc2) self.assertEqual(proc1.pid, proc2.pid) - self.assertDictEqual(bundle1, bundle2) + #self.assertDictEqual(bundle1, bundle2) + utils.compare_dictionaries(None, None, bundle1, bundle2, exclude={'_listeners'}) class TestProcessNamespace(unittest.TestCase): diff --git a/test/test_workchains.py b/test/test_workchains.py index 71cd0f6a..6d9f26c5 100644 --- a/test/test_workchains.py +++ b/test/test_workchains.py @@ -283,6 +283,49 @@ def test_checkpointing(self): if step not in ['isA', 's2', 'isB', 's3']: self.assertTrue(finished, f'Step {step} was not called by workflow') + def test_listener_persistence(self): + persister = plumpy.InMemoryPersister() + process_finished_count = 0 + + class TestListener(plumpy.ProcessListener): + + def on_process_finished(self, process, output): + nonlocal process_finished_count + process_finished_count += 1 + + class SimpleWorkChain(plumpy.WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.outline( + cls.step1, + cls.step2, + ) + + def step1(self): + print('step1') + persister.save_checkpoint(self, 'step1') + + def step2(self): + print('step2') + persister.save_checkpoint(self, 'step2') + + # add SimpleWorkChain and TestListener to this module global namespace, so they can be reloaded from checkpoint + globals()['SimpleWorkChain'] = SimpleWorkChain + globals()['TestListener'] = TestListener + + workchain = SimpleWorkChain() + workchain.add_process_listener(TestListener()) + output = workchain.execute() + + self.assertEqual(process_finished_count, 1) + + print('reload persister checkpoint:') + workchain_checkpoint = persister.load_checkpoint(workchain.pid, 'step1').unbundle() + workchain_checkpoint.execute() + self.assertEqual(process_finished_count, 2) + def test_return_in_outline(self): class WcWithReturn(WorkChain): diff --git a/test/utils.py b/test/utils.py index feb3d1c8..1f7408f6 100644 --- a/test/utils.py +++ b/test/utils.py @@ -185,7 +185,7 @@ def run(self): self.out('test', 5) return process_states.Continue(self.middle_step) - def middle_step(self,): + def middle_step(self): return process_states.Continue(self.last_step) def last_step(self): @@ -260,25 +260,72 @@ def _save(self, p): self.outputs.append(p.outputs.copy()) -class ProcessSaver(plumpy.ProcessListener, Saver): +_ProcessSaverProcReferences = {} +_ProcessSaver_Saver = {} + + +class ProcessSaver(plumpy.ProcessListener): """ - Save the instance state of a process each time it is about to enter a new state + Save the instance state of a process each time it is about to enter a new state. + NB: this is not a general purpose saver, it is only intended to be used for testing + The listener instances inside a process are persisted, so if we store a process + reference in the ProcessSaver instance, we will have a circular reference that cannot be + persisted. So we store the Saver instance in a global dictionary with the key the id of the + ProcessSaver instance. + In the init_not_persistent method we initialize the instances that cannot be persisted, + like the saver instance. The __del__ method is used to clean up the global dictionaries + (note there is no guarantee that __del__ will be called) + """ + def __del__(self): + global _ProcessSaver_Saver + global _ProcessSaverProcReferences + if _ProcessSaverProcReferences is not None and id(self) in _ProcessSaverProcReferences: + del _ProcessSaverProcReferences[id(self)] + if _ProcessSaver_Saver is not None and id(self) in _ProcessSaver_Saver: + del _ProcessSaver_Saver[id(self)] + + def get_process(self): + global _ProcessSaverProcReferences + return _ProcessSaverProcReferences[id(self)] + + def _save(self, p): + global _ProcessSaver_Saver + _ProcessSaver_Saver[id(self)]._save(p) + + def set_process(self, process): + global _ProcessSaverProcReferences + _ProcessSaverProcReferences[id(self)] = process + def __init__(self, proc): - plumpy.ProcessListener.__init__(self) - Saver.__init__(self) - self.process = proc + super().__init__() proc.add_process_listener(self) + self.init_not_persistent(proc) + + def init_not_persistent(self, proc): + global _ProcessSaver_Saver + _ProcessSaver_Saver[id(self)] = Saver() + self.set_process(proc) def capture(self): - self._save(self.process) - if not self.process.has_terminated(): + self._save(self.get_process()) + if not self.get_process().has_terminated(): try: - self.process.execute() + self.get_process().execute() except Exception: pass + @property + def snapshots(self): + global _ProcessSaver_Saver + return _ProcessSaver_Saver[id(self)].snapshots + + @property + def outputs(self): + global _ProcessSaver_Saver + return _ProcessSaver_Saver[id(self)].outputs + @utils.override def on_process_running(self, process): self._save(process) @@ -335,7 +382,13 @@ def check_process_against_snapshots(loop, proc_class, snapshots): """ for i, bundle in zip(list(range(0, len(snapshots))), snapshots): loaded = bundle.unbundle(plumpy.LoadSaveContext(loop=loop)) - saver = ProcessSaver(loaded) + # the process listeners are persisted + saver = list(loaded._event_helper._listeners)[0] + assert isinstance(saver, ProcessSaver) + # the process reference inside this particular implementation of process listener + # cannot be persisted because of a circular reference. So we load it there + # also the saver is not persisted for the same reason. We load it manually + saver.init_not_persistent(loaded) saver.capture() # Now check going backwards until running that the saved states match @@ -345,7 +398,11 @@ def check_process_against_snapshots(loop, proc_class, snapshots): break compare_dictionaries( - snapshots[-j], saver.snapshots[-j], snapshots[-j], saver.snapshots[-j], exclude={'exception'} + snapshots[-j], + saver.snapshots[-j], + snapshots[-j], + saver.snapshots[-j], + exclude={'exception', '_listeners'} ) j += 1 @@ -376,6 +433,11 @@ def compare_value(bundle1, bundle2, v1, v2, exclude=None): elif isinstance(v1, list) and isinstance(v2, list): for vv1, vv2 in zip(v1, v2): compare_value(bundle1, bundle2, vv1, vv2, exclude) + elif isinstance(v1, set) and isinstance(v2, set) and len(v1) == len(v2) and len(v1) <= 1: + # TODO: implement sets with more than one element + compare_value(bundle1, bundle2, list(v1), list(v2), exclude) + elif isinstance(v1, set) and isinstance(v2, set): + raise NotImplementedError('Comparison between sets not implemented') else: if v1 != v2: raise ValueError(f'Dict values mismatch for :\n{v1} != {v2}')