Skip to content

Commit

Permalink
plumpy.ProcessListener made persistent
Browse files Browse the repository at this point in the history
solves aiidateam#273

We implement the persistence of ProcessListener by deriving the class
ProcessListener and EventHelper from persistence.Savable.
The class EventHelper is moved to a new file because of a circular
import with utils and persistence

Fixing the test

There was a circular reference issue in the test listener that was
storing a reference to the process inside it, making its serialization
impossible. To fix the tests an ugly hack was used: storing the
reference to the process outside the class in a global dict using id as
keys. Some more ugly hacks are needed to check correctly the equality of
two processes. We must ignore the fact that the instances if the
listener are different.

We call del on dict items of the ProcessListener's global implemented in the test suite
to clean the golbal variables

addressed issues in aiidateam#274
  • Loading branch information
rikigigi committed Nov 10, 2023
1 parent 31f85c7 commit 63ab5ec
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 63 deletions.
2 changes: 2 additions & 0 deletions src/plumpy/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> None:
wrapped(self, *args, **kwargs)
self._called -= 1

#the following is to show the correct name later in the call_with_super_check error message
wrapper.__name__ = wrapped.__name__
return wrapper


Expand Down
52 changes: 52 additions & 0 deletions src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
import logging
from typing import TYPE_CHECKING, Any, Callable, Set, Type

from . import persistence

if TYPE_CHECKING:
from .process_listener import ProcessListener # pylint: disable=cyclic-import

_LOGGER = logging.getLogger(__name__)


@persistence.auto_persist('_listeners', '_listener_type')
class EventHelper(persistence.Savable):

def __init__(self, listener_type: 'Type[ProcessListener]'):
assert listener_type is not None, 'Must provide valid listener type'

self._listener_type = listener_type
self._listeners: 'Set[ProcessListener]' = set()

def add_listener(self, listener: 'ProcessListener') -> None:
assert isinstance(listener, self._listener_type), 'Listener is not of right type'
self._listeners.add(listener)

def remove_listener(self, listener: 'ProcessListener') -> None:
self._listeners.discard(listener)

def remove_all_listeners(self) -> None:
self._listeners.clear()

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners

def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
"""Call an event method on all listeners.
:param event_function: the method of the ProcessListener
:param args: arguments to pass to the method
:param kwargs: keyword arguments to pass to the method
"""
if event_function is None:
raise ValueError('Must provide valid event method')

# Make a copy of the list for iteration just in case it changes in a callback
for listener in list(self.listeners):
try:
getattr(listener, event_function.__name__)(*args, **kwargs)
except Exception as exception: # pylint: disable=broad-except
_LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception)
26 changes: 24 additions & 2 deletions src/plumpy/process_listener.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
# -*- coding: utf-8 -*-
import abc
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Dict, Optional

from . import persistence
from .utils import SAVED_STATE_TYPE, protected

__all__ = ['ProcessListener']

if TYPE_CHECKING:
from .processes import Process # pylint: disable=cyclic-import


class ProcessListener(metaclass=abc.ABCMeta):
@persistence.auto_persist('_params')
class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta):

# region Persistence methods

def __init__(self) -> None:
super().__init__()
self._params: Dict[str, Any] = {}

def init(self, **kwargs: Any) -> None:
self._params = kwargs

@protected
def load_instance_state(
self, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext]
) -> None:
super().load_instance_state(saved_state, load_context)
self.init(**saved_state['_params'])

# endregion

def on_process_created(self, process: 'Process') -> None:
"""
Expand Down
17 changes: 10 additions & 7 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .base import state_machine
from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event
from .base.utils import call_with_super_check, super_check
from .event_helper import EventHelper
from .process_listener import ProcessListener
from .process_spec import ProcessSpec
from .utils import PID_TYPE, SAVED_STATE_TYPE, protected
Expand Down Expand Up @@ -91,7 +92,9 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
return func_wrapper


@persistence.auto_persist('_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status')
@persistence.auto_persist(
'_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper'
)
class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta):
"""
The Process class is the base for any unit of work in plumpy.
Expand Down Expand Up @@ -289,7 +292,7 @@ def __init__(

# Runtime variables
self._future = persistence.SavableFuture(loop=self._loop)
self.__event_helper = utils.EventHelper(ProcessListener)
self._event_helper = EventHelper(ProcessListener)
self._logger = logger
self._communicator = communicator

Expand Down Expand Up @@ -612,7 +615,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi

# Runtime variables, set initial states
self._future = persistence.SavableFuture()
self.__event_helper = utils.EventHelper(ProcessListener)
self._event_helper = EventHelper(ProcessListener)
self._logger = None
self._communicator = None

Expand Down Expand Up @@ -661,11 +664,11 @@ def add_process_listener(self, listener: ProcessListener) -> None:
"""
assert (listener != self), 'Cannot listen to yourself!'
self.__event_helper.add_listener(listener)
self._event_helper.add_listener(listener)

def remove_process_listener(self, listener: ProcessListener) -> None:
"""Remove a process listener from the process."""
self.__event_helper.remove_listener(listener)
self._event_helper.remove_listener(listener)

@protected
def set_logger(self, logger: logging.Logger) -> None:
Expand Down Expand Up @@ -778,7 +781,7 @@ def on_output_emitting(self, output_port: str, value: Any) -> None:
"""Output is about to be emitted."""

def on_output_emitted(self, output_port: str, value: Any, dynamic: bool) -> None:
self.__event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic)
self._event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic)

@super_check
def on_wait(self, awaitables: Sequence[Awaitable]) -> None:
Expand Down Expand Up @@ -891,7 +894,7 @@ def on_close(self) -> None:
self._closed = True

def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
self.__event_helper.fire_event(evt, self, *args, **kwargs)
self._event_helper.fire_event(evt, self, *args, **kwargs)

# endregion

Expand Down
41 changes: 0 additions & 41 deletions src/plumpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,6 @@
PID_TYPE = Hashable # pylint: disable=invalid-name


class EventHelper:

def __init__(self, listener_type: 'Type[ProcessListener]'):
assert listener_type is not None, 'Must provide valid listener type'

self._listener_type = listener_type
self._listeners: 'Set[ProcessListener]' = set()

def add_listener(self, listener: 'ProcessListener') -> None:
assert isinstance(listener, self._listener_type), 'Listener is not of right type'
self._listeners.add(listener)

def remove_listener(self, listener: 'ProcessListener') -> None:
self._listeners.discard(listener)

def remove_all_listeners(self) -> None:
self._listeners.clear()

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners

def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
"""Call an event method on all listeners.
:param event_function: the method of the ProcessListener
:param args: arguments to pass to the method
:param kwargs: keyword arguments to pass to the method
"""
if event_function is None:
raise ValueError('Must provide valid event method')

# Make a copy of the list for iteration just in case it changes in a callback
for listener in list(self.listeners):
try:
getattr(listener, event_function.__name__)(*args, **kwargs)
except Exception as exception: # pylint: disable=broad-except
_LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception)


class Frozendict(Mapping):
"""
An immutable wrapper around dictionaries that implements the complete :py:class:`collections.abc.Mapping`
Expand Down
7 changes: 5 additions & 2 deletions test/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,9 @@ def test_instance_state_with_outputs(self):
# Check that it is a copy
self.assertIsNot(outputs, bundle.get(BundleKeys.OUTPUTS, {}))
# Check the contents are the same
self.assertDictEqual(outputs, bundle.get(BundleKeys.OUTPUTS, {}))
#we remove the ProcessSaver instance that is an object used only for testing
utils.compare_dictionaries(None, None, outputs, bundle.get(BundleKeys.OUTPUTS, {}), exclude={'_listeners'})
#self.assertDictEqual(outputs, bundle.get(BundleKeys.OUTPUTS, {}))

self.assertIsNot(proc.outputs, saver.snapshots[-1].get(BundleKeys.OUTPUTS, {}))

Expand Down Expand Up @@ -875,7 +877,8 @@ def _check_round_trip(self, proc1):
bundle2 = plumpy.Bundle(proc2)

self.assertEqual(proc1.pid, proc2.pid)
self.assertDictEqual(bundle1, bundle2)
#self.assertDictEqual(bundle1, bundle2)
utils.compare_dictionaries(None, None, bundle1, bundle2, exclude={'_listeners'})


class TestProcessNamespace(unittest.TestCase):
Expand Down
43 changes: 43 additions & 0 deletions test/test_workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,49 @@ def test_checkpointing(self):
if step not in ['isA', 's2', 'isB', 's3']:
self.assertTrue(finished, f'Step {step} was not called by workflow')

def test_listener_persistence(self):
persister = plumpy.InMemoryPersister()
process_finished_count = 0

class TestListener(plumpy.ProcessListener):

def on_process_finished(self, process, output):
nonlocal process_finished_count
process_finished_count += 1

class SimpleWorkChain(plumpy.WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(
cls.step1,
cls.step2,
)

def step1(self):
print('step1')
persister.save_checkpoint(self, 'step1')

def step2(self):
print('step2')
persister.save_checkpoint(self, 'step2')

# add SimpleWorkChain and TestListener to this module global namespace, so they can be reloaded from checkpoint
globals()['SimpleWorkChain'] = SimpleWorkChain
globals()['TestListener'] = TestListener

workchain = SimpleWorkChain()
workchain.add_process_listener(TestListener())
output = workchain.execute()

self.assertEqual(process_finished_count, 1)

print('reload persister checkpoint:')
workchain_checkpoint = persister.load_checkpoint(workchain.pid, 'step1').unbundle()
workchain_checkpoint.execute()
self.assertEqual(process_finished_count, 2)

def test_return_in_outline(self):

class WcWithReturn(WorkChain):
Expand Down
Loading

0 comments on commit 63ab5ec

Please sign in to comment.