Skip to content

Commit

Permalink
Adapt message arguments passing to process controller
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Jan 11, 2025
1 parent 15b5caf commit a3e12c9
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 20 deletions.
18 changes: 14 additions & 4 deletions src/aiida/cmdline/commands/cmd_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,13 @@ def process_kill(processes, all_entries, timeout, wait):

with capture_logging() as stream:
try:
message = 'Killed through `verdi process kill`'
control.kill_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message)
control.kill_processes(
processes,
msg_text='Killed through `verdi process kill`',
all_entries=all_entries,
timeout=timeout,
wait=wait,
)
except control.ProcessTimeoutException as exception:
echo.echo_critical(f'{exception}\n{REPAIR_INSTRUCTIONS}')

Expand Down Expand Up @@ -371,8 +376,13 @@ def process_pause(processes, all_entries, timeout, wait):

with capture_logging() as stream:
try:
message = 'Paused through `verdi process pause`'
control.pause_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message)
control.pause_processes(
processes,
msg_text='Paused through `verdi process pause`',
all_entries=all_entries,
timeout=timeout,
wait=wait,
)
except control.ProcessTimeoutException as exception:
echo.echo_critical(f'{exception}\n{REPAIR_INSTRUCTIONS}')

Expand Down
12 changes: 8 additions & 4 deletions src/aiida/engine/processes/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import collections
import concurrent
import functools
from pydoc import text
import typing as t

import kiwipy
Expand Down Expand Up @@ -135,7 +137,7 @@ def play_processes(
def pause_processes(
processes: list[ProcessNode] | None = None,
*,
message: str = 'Paused through `aiida.engine.processes.control.pause_processes`',
msg_text: str = 'Paused through `aiida.engine.processes.control.pause_processes`',
all_entries: bool = False,
timeout: float = 5.0,
wait: bool = False,
Expand Down Expand Up @@ -164,13 +166,14 @@ def pause_processes(
return

controller = get_manager().get_process_controller()
_perform_actions(processes, controller.pause_process, 'pause', 'pausing', timeout, wait, msg=message)
action = functools.partial(controller.pause_process, msg_text=msg_text)
_perform_actions(processes, action, 'pause', 'pausing', timeout, wait)


def kill_processes(
processes: list[ProcessNode] | None = None,
*,
message: str = 'Killed through `aiida.engine.processes.control.kill_processes`',
msg_text: str = 'Killed through `aiida.engine.processes.control.kill_processes`',
all_entries: bool = False,
timeout: float = 5.0,
wait: bool = False,
Expand Down Expand Up @@ -199,7 +202,8 @@ def kill_processes(
return

controller = get_manager().get_process_controller()
_perform_actions(processes, controller.kill_process, 'kill', 'killing', timeout, wait, msg=message)
action = functools.partial(controller.kill_process, msg_text=msg_text)
_perform_actions(processes, action, 'kill', 'killing', timeout, wait)


def _perform_actions(
Expand Down
4 changes: 2 additions & 2 deletions src/aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, 'ProcessNode
if kwargs and not process_class.spec().inputs.dynamic:
raise ValueError(f'{function.__name__} does not support these kwargs: {kwargs.keys()}')

process = process_class(inputs=inputs, runner=runner)
process: Process = process_class(inputs=inputs, runner=runner)

# Only add handlers for interrupt signal to kill the process if we are in a local and not a daemon runner.
# Without this check, running process functions in a daemon worker would be killed if the daemon is shutdown
Expand All @@ -235,7 +235,7 @@ def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, 'ProcessNode
def kill_process(_num, _frame):
"""Send the kill signal to the process in the current scope."""
LOGGER.critical('runner received interrupt, killing process %s', process.pid)
result = process.kill(msg='Process was killed because the runner received an interrupt')
result = process.kill(msg_text='Process was killed because the runner received an interrupt')
return result

# Store the current handler on the signal such that it can be restored after process has terminated
Expand Down
6 changes: 3 additions & 3 deletions src/aiida/engine/processes/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def load_instance_state(

self.node.logger.info(f'Loaded process<{self.node.pk}> from saved state')

def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Future]:
def kill(self, msg_text: str | None = None) -> Union[bool, plumpy.futures.Future]:
"""Kill the process and all the children calculations it called
:param msg: message
Expand All @@ -338,7 +338,7 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Futur

had_been_terminated = self.has_terminated()

result = super().kill(msg)
result = super().kill(msg_text)

# Only kill children if we could be killed ourselves
if result is not False and not had_been_terminated:
Expand All @@ -348,7 +348,7 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Futur
self.logger.info('no controller available to kill child<%s>', child.pk)
continue
try:
result = self.runner.controller.kill_process(child.pk, f'Killed by parent<{self.node.pk}>')
result = self.runner.controller.kill_process(child.pk, msg_text=f'Killed by parent<{self.node.pk}>')
result = asyncio.wrap_future(result) # type: ignore[arg-type]
if asyncio.isfuture(result):
killing.append(result)
Expand Down
2 changes: 1 addition & 1 deletion src/aiida/engine/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def kill_process(_num, _frame):
LOGGER.warning('runner received interrupt, process %s already being killed', process_inited.pid)
return
LOGGER.critical('runner received interrupt, killing process %s', process_inited.pid)
process_inited.kill(msg='Process was killed because the runner received an interrupt')
process_inited.kill(msg_text='Process was killed because the runner received an interrupt')

original_handler_int = signal.getsignal(signal.SIGINT)
original_handler_term = signal.getsignal(signal.SIGTERM)
Expand Down
10 changes: 4 additions & 6 deletions tests/engine/test_rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ async def do_pause():
assert result
assert calc_node.paused

kill_message = 'Sorry, you have to go mate'
kill_future = controller.kill_process(calc_node.pk, msg=kill_message)
kill_future = controller.kill_process(calc_node.pk, msg_text='Sorry, you have to go mate')
future = await with_timeout(asyncio.wrap_future(kill_future))
result = await self.wait_future(asyncio.wrap_future(future))
assert result
Expand All @@ -112,7 +111,7 @@ async def do_pause_play():
await asyncio.sleep(0.1)

pause_message = 'Take a seat'
pause_future = controller.pause_process(calc_node.pk, msg=pause_message)
pause_future = controller.pause_process(calc_node.pk, msg_text=pause_message)
future = await with_timeout(asyncio.wrap_future(pause_future))
result = await self.wait_future(asyncio.wrap_future(future))
assert calc_node.paused
Expand All @@ -126,8 +125,7 @@ async def do_pause_play():
assert not calc_node.paused
assert calc_node.process_status is None

kill_message = 'Sorry, you have to go mate'
kill_future = controller.kill_process(calc_node.pk, msg=kill_message)
kill_future = controller.kill_process(calc_node.pk, msg_text='Sorry, you have to go mate')
future = await with_timeout(asyncio.wrap_future(kill_future))
result = await self.wait_future(asyncio.wrap_future(future))
assert result
Expand All @@ -145,7 +143,7 @@ async def do_kill():
await asyncio.sleep(0.1)

kill_message = 'Sorry, you have to go mate'
kill_future = controller.kill_process(calc_node.pk, msg=kill_message)
kill_future = controller.kill_process(calc_node.pk, msg_text=kill_message)
future = await with_timeout(asyncio.wrap_future(kill_future))
result = await self.wait_future(asyncio.wrap_future(future))
assert result
Expand Down

0 comments on commit a3e12c9

Please sign in to comment.