Skip to content

Commit

Permalink
further updates
Browse files Browse the repository at this point in the history
  • Loading branch information
khsrali committed Jan 15, 2025
1 parent 35e9814 commit c35bcb1
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 100 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies:
- importlib-metadata~=6.0
- numpy~=1.21
- paramiko~=3.0
- plumpy~=0.23.0
- plumpy~=0.24.0
- pgsu~=0.3.0
- psutil~=5.6
- psycopg[binary]~=3.0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
'importlib-metadata~=6.0',
'numpy~=1.21',
'paramiko~=3.0',
'plumpy~=0.23.0',
'plumpy~=0.24.0',
'pgsu~=0.3.0',
'psutil~=5.6',
'psycopg[binary]~=3.0',
Expand Down
15 changes: 7 additions & 8 deletions src/aiida/engine/daemon/execmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,15 @@ async def upload_calculation(

for file_copy_operation in file_copy_operation_order:
if file_copy_operation is FileCopyOperation.LOCAL:
_copy_local_files(logger, node, transport, inputs, local_copy_list, workdir=workdir)
await _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir=workdir)
elif file_copy_operation is FileCopyOperation.REMOTE:
if not dry_run:
_copy_remote_files(
await _copy_remote_files(
logger, node, computer, transport, remote_copy_list, remote_symlink_list, workdir=workdir
)
elif file_copy_operation is FileCopyOperation.SANDBOX:
if not dry_run:
_copy_sandbox_files(logger, node, transport, folder, workdir=workdir)
await _copy_sandbox_files(logger, node, transport, folder, workdir=workdir)
else:
raise RuntimeError(f'file copy operation {file_copy_operation} is not yet implemented.')

Expand Down Expand Up @@ -279,7 +279,7 @@ async def upload_calculation(
return None


def _copy_remote_files(logger, node, computer, transport, remote_copy_list, remote_symlink_list, workdir: Path):
async def _copy_remote_files(logger, node, computer, transport, remote_copy_list, remote_symlink_list, workdir: Path):
"""Perform the copy instructions of the ``remote_copy_list`` and ``remote_symlink_list``."""
for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_copy_list:
if remote_computer_uuid == computer.uuid:
Expand Down Expand Up @@ -328,7 +328,7 @@ def _copy_remote_files(logger, node, computer, transport, remote_copy_list, remo
)


def _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir: Path):
async def _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir: Path):
"""Perform the copy instructions of the ``local_copy_list``."""
for uuid, filename, target in local_copy_list:
logger.debug(f'[submission of calculation {node.uuid}] copying local file/folder to {target}')
Expand Down Expand Up @@ -386,7 +386,7 @@ def _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir:
transport.put(str(filepath_target), str(workdir.joinpath(target)))


def _copy_sandbox_files(logger, node, transport, folder, workdir: Path):
async def _copy_sandbox_files(logger, node, transport, folder, workdir: Path):
"""Copy the contents of the sandbox folder to the working directory."""
for filename in folder.get_content_list():
logger.debug(f'[submission of calculation {node.pk}] copying file/folder {filename}...')
Expand Down Expand Up @@ -423,7 +423,7 @@ def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str |
return result


def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None:
async def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None:
"""Stash files from the working directory of a completed calculation to a permanent remote folder.
After a calculation has been completed, optionally stash files from the work directory to a storage location on the
Expand Down Expand Up @@ -587,7 +587,6 @@ def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None:
)



async def retrieve_files_from_list(
calculation: CalcJobNode,
transport: Transport,
Expand Down
94 changes: 45 additions & 49 deletions src/aiida/engine/processes/calcjobs/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
Expand All @@ -8,6 +7,7 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Transport tasks for calculation jobs."""

from __future__ import annotations

import asyncio
Expand Down Expand Up @@ -48,10 +48,10 @@
RETRY_INTERVAL_OPTION = 'transport.task_retry_initial_interval'
MAX_ATTEMPTS_OPTION = 'transport.task_maximum_attempts'

logger = logging.getLogger(__name__) # pylint: disable=invalid-name
logger = logging.getLogger(__name__)


class PreSubmitException(Exception):
class PreSubmitException(Exception): # noqa: N818
"""Raise in the `do_upload` coroutine when an exception is raised in `CalcJob.presubmit`."""


Expand Down Expand Up @@ -89,10 +89,12 @@ async def do_upload():
# Any exception thrown in `presubmit` call is not transient so we circumvent the exponential backoff
try:
calc_info = process.presubmit(folder)
except Exception as exception: # pylint: disable=broad-except
except Exception as exception:
raise PreSubmitException('exception occurred in presubmit call') from exception
else:
await execmanager.upload_calculation(node, transport, calc_info, folder)
remote_folder = await execmanager.upload_calculation(node, transport, calc_info, folder)
if remote_folder is not None:
process.out('remote_folder', remote_folder)
skip_submit = calc_info.skip_submit or False

return skip_submit
Expand All @@ -105,7 +107,7 @@ async def do_upload():
)
except PreSubmitException:
raise
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
raise
except Exception as exception:
logger.warning(f'uploading CalcJob<{node.pk}> failed')
Expand Down Expand Up @@ -151,7 +153,7 @@ async def do_submit():
result = await exponential_backoff_retry(
do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
raise
except Exception as exception:
logger.warning(f'submitting CalcJob<{node.pk}> failed')
Expand Down Expand Up @@ -209,7 +211,7 @@ async def do_update():
job_done = await exponential_backoff_retry(
do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
raise
except Exception as exception:
logger.warning(f'updating CalcJob<{node.pk}> failed')
Expand Down Expand Up @@ -249,10 +251,8 @@ async def task_monitor_job(
authinfo = node.get_authinfo()

async def do_monitor():

with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
transport.chdir(node.get_remote_workdir())
return monitors.process(node, transport)

try:
Expand All @@ -261,7 +261,7 @@ async def do_monitor():
monitor_result = await exponential_backoff_retry(
do_monitor, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
raise
except Exception as exception:
logger.warning(f'monitoring CalcJob<{node.pk}> failed')
Expand All @@ -272,36 +272,32 @@ async def do_monitor():


async def task_retrieve_job(
node: CalcJobNode, transport_queue: TransportQueue, retrieved_temporary_folder: str,
cancellable: InterruptableFuture
process: 'CalcJob',
transport_queue: TransportQueue,
retrieved_temporary_folder: str,
cancellable: InterruptableFuture,
):
"""Transport task that will attempt to retrieve all files of a completed job calculation.
The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will
retry after an interval that increases exponentially with the number of retries, for a maximum number of retries.
If all retries fail, the task will raise a TransportTaskException
:param node: the node that represents the job calculation
:param process: the job calculation
:param transport_queue: the TransportQueue from which to request a Transport
:param retrieved_temporary_folder: the absolute path to a directory to store files
:param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled
:raises: TransportTaskException if after the maximum number of retries the transport task still excepted
"""
node = process.node
if node.get_state() == CalcJobState.PARSING:
logger.warning(f'CalcJob<{node.pk}> already marked as PARSING, skipping task_retrieve_job')
return

initial_interval = get_config_option(RETRY_INTERVAL_OPTION)
max_attempts = get_config_option(MAX_ATTEMPTS_OPTION)

authinfo = node.get_authinfo()

async def do_retrieve():
with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)

# Perform the job accounting and set it on the node if successful. If the scheduler does not implement this
# still set the attribute but set it to `None`. This way we can distinguish calculation jobs for which the
# accounting was called but could not be set.
Expand All @@ -310,25 +306,28 @@ async def do_retrieve():

if node.get_job_id() is None:
logger.warning(f'there is no job id for CalcJobNoe<{node.pk}>: skipping `get_detailed_job_info`')
return await execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder)

try:
detailed_job_info = scheduler.get_detailed_job_info(node.get_job_id())
except FeatureNotAvailable:
logger.info(f'detailed job info not available for scheduler of CalcJob<{node.pk}>')
node.set_detailed_job_info(None)
retrieved = await execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder)
else:
node.set_detailed_job_info(detailed_job_info)
try:
detailed_job_info = scheduler.get_detailed_job_info(node.get_job_id())
except FeatureNotAvailable:
logger.info(f'detailed job info not available for scheduler of CalcJob<{node.pk}>')
node.set_detailed_job_info(None)
else:
node.set_detailed_job_info(detailed_job_info)

return await execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder)
retrieved = await execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder)

if retrieved is not None:
process.out(node.link_label_retrieved, retrieved)
return retrieved
try:
logger.info(f'scheduled request to retrieve CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
result = await exponential_backoff_retry(
do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
raise
except Exception as exception:
logger.warning(f'retrieving CalcJob<{node.pk}> failed')
Expand Down Expand Up @@ -367,15 +366,15 @@ async def do_stash():
transport = await cancellable.with_interrupt(request)

logger.info(f'stashing calculation<{node.pk}>')
return execmanager.stash_calculation(node, transport)
return await execmanager.stash_calculation(node, transport)

try:
await exponential_backoff_retry(
do_stash,
initial_interval,
max_attempts,
logger=node.logger,
ignore_exceptions=plumpy.process_states.Interruption
ignore_exceptions=plumpy.process_states.Interruption,
)
except plumpy.process_states.Interruption:
raise
Expand Down Expand Up @@ -439,11 +438,9 @@ def __init__(
process: 'CalcJob',
done_callback: Optional[Callable[..., Any]],
msg: Optional[str] = None,
data: Optional[Any] = None
data: Optional[Any] = None,
):
"""
:param process: The process this state belongs to
"""
""":param process: The process this state belongs to"""
super().__init__(process, done_callback, msg, data)
self._task: InterruptableFuture | None = None
self._killing: plumpy.futures.Future | None = None
Expand Down Expand Up @@ -473,19 +470,16 @@ def monitors(self) -> CalcJobMonitors | None:

@property
def process(self) -> 'CalcJob':
"""
:return: The process
"""
""":return: The process"""
return self.state_machine # type: ignore[return-value]

def load_instance_state(self, saved_state, load_context):
super().load_instance_state(saved_state, load_context)
self._task = None
self._killing = None

async def execute(self) -> plumpy.process_states.State: # type: ignore[override] # pylint: disable=invalid-overridden-method
async def execute(self) -> plumpy.process_states.State: # type: ignore[override]
"""Override the execute coroutine of the base `Waiting` state."""
# pylint: disable=too-many-branches,too-many-statements,too-many-nested-blocks
node = self.process.node
transport_queue = self.process.runner.transport
result: plumpy.process_states.State = self
Expand All @@ -494,7 +488,6 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override
node.set_process_status(process_status)

try:

if self._command == UPLOAD_COMMAND:
skip_submit = await self._launch_task(task_upload_job, self.process, transport_queue)
if skip_submit:
Expand All @@ -506,7 +499,7 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override
result = await self._launch_task(task_submit_job, node, transport_queue)

if isinstance(result, ExitCode):
# The scheduler plugin returned an exit code from ``Scheduler.submit_from_script`` indicating the
# The scheduler plugin returned an exit code from ``Scheduler.submit_job`` indicating the
# job submission failed due to a non-transient problem and the job should be terminated.
return self.create_state(ProcessState.RUNNING, self.process.terminate, result)

Expand All @@ -529,9 +522,7 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override

if monitor_result and not monitor_result.retrieve:
exit_code = self.process.exit_codes.STOPPED_BY_MONITOR.format(message=monitor_result.message)
return self.create_state(
ProcessState.RUNNING, self.process.terminate, exit_code
) # type: ignore[return-value]
return self.create_state(ProcessState.RUNNING, self.process.terminate, exit_code) # type: ignore[return-value]

result = self.stash(monitor_result=monitor_result)

Expand All @@ -542,7 +533,7 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override

elif self._command == RETRIEVE_COMMAND:
temp_folder = tempfile.mkdtemp()
await self._launch_task(task_retrieve_job, node, transport_queue, temp_folder)
await self._launch_task(task_retrieve_job, self.process, transport_queue, temp_folder)

if not self._monitor_result:
result = self.parse(temp_folder)
Expand Down Expand Up @@ -592,6 +583,11 @@ async def _monitor_job(self, node, transport_queue, monitors) -> CalcJobMonitorR

monitor_result = await self._launch_task(task_monitor_job, node, transport_queue, monitors=monitors)

if monitor_result and monitor_result.outputs:
for label, output in monitor_result.outputs.items():
self.process.out(label, output)
self.process.update_outputs()

if monitor_result and monitor_result.action == CalcJobMonitorAction.DISABLE_SELF:
monitors.monitors[monitor_result.key].disabled = True

Expand Down Expand Up @@ -674,4 +670,4 @@ def interrupt(self, reason: Any) -> Optional[plumpy.futures.Future]: # type: ig
self._killing = plumpy.futures.Future()
return self._killing

return None
return None
21 changes: 0 additions & 21 deletions src/aiida/transports/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,16 +440,6 @@ def get(self, remotepath, localpath, *args, **kwargs):
:param localpath: (str) local_folder_path
"""

async def get_async(self, remotepath, localpath, *args, **kwargs):
"""
Retrieve a file or folder from remote source to local destination
dst must be an absolute path (src not necessarily)
:param remotepath: (str) remote_folder_path
:param localpath: (str) local_folder_path
"""
return self.get(remotepath, localpath, *args, **kwargs)

@abc.abstractmethod
def getfile(self, remotepath, localpath, *args, **kwargs):
"""Retrieve a file from remote source to local destination
Expand Down Expand Up @@ -628,17 +618,6 @@ def put(self, localpath, remotepath, *args, **kwargs):
:param str remotepath: path to remote destination
"""

async def put_async(self, localpath, remotepath, *args, **kwargs):
"""
Put a file or a directory from local src to remote dst.
src must be an absolute path (dst not necessarily))
Redirects to putfile and puttree.
:param str localpath: absolute path to local source
:param str remotepath: path to remote destination
"""
return self.put(localpath, remotepath, *args, **kwargs)

@abc.abstractmethod
def putfile(self, localpath, remotepath, *args, **kwargs):
"""Put a file from local src to remote dst.
Expand Down
Loading

0 comments on commit c35bcb1

Please sign in to comment.