diff --git a/CHANGELOG.md b/CHANGELOG.md index a7291da..19700ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,13 +7,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] -## Added +### Added +- add a new `variables` parameter for environment variables +- add a new error-catching python execution script (add new module) +- add checks inside submit script for `covalent` and `cloudpickle` versions +- clean up job script creation (add new module) - export `COVALENT_CONFIG_DIR=/tmp` inside sbatch script to enable filelock -## Changed +### Changed +- update plugin defaults to use `BaseModel` instead of `dict` +- change to actually get errors from these checks +- use `Path` everywhere instead of `os.path` operations +- allow `poll_freq >= 10` seconds, instead of 60 seconds +- misc. cleanups and refactoring - Aesthetics and string formatting +- remove addition of `COVALENT_CONFIG_DIR=/tmp` to sbatch script - Removed the `sshproxy` interface. - Updates __init__ signature kwargs replaced with parent for better documentation. - Updated license to Apache diff --git a/covalent_slurm_plugin/exec.py b/covalent_slurm_plugin/exec.py new file mode 100644 index 0000000..27092a0 --- /dev/null +++ b/covalent_slurm_plugin/exec.py @@ -0,0 +1,115 @@ +# Copyright 2024 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This script executes the electron function on the Slurm cluster.""" + +import os +import sys + +import cloudpickle as pickle + + +def _import_covalent() -> None: + # Wrapped import for convenience in testing. + import covalent + + +def _check_setup() -> None: + """Use these checks to create more informative error messages.""" + + import filelock + + msg = "" + exception = None + + try: + # covalent is needed because the @electron function + # executes inside `wrapper_fn` to apply deps + _import_covalent() + + except ImportError as _exception: + msg = "The covalent SDK is not installed in the Slurm job environment." + exception = _exception + + except filelock._error.Timeout as _exception: + config_location = os.getenv("COVALENT_CONFIG_DIR", "~/.config/covalent") + config_location = os.path.expanduser(config_location) + config_file = os.path.join(config_location, "covalent.conf") + + msg = "\n".join( + [ + f"Failed to acquire file lock '{config_file}.lock' on Slurm cluster filesystem. " + f"Consider overriding the current config location ('{config_location}'), e.g:", + ' SlurmExecutor(..., variables={"COVALENT_CONFIG_DIR": "/tmp"})' "", + ] + ) + exception = _exception + + # Raise the exception if one was caught + if exception: + raise RuntimeError(msg) from exception + + +def _execute() -> dict: + """Load and execute the @electron function""" + + func_filename = sys.argv[1] + + with open(func_filename, "rb") as f: + function, args, kwargs = pickle.load(f) + + result = None + exception = None + + try: + result = function(*args, **kwargs) + except Exception as ex: + exception = ex + + return { + "result": result, + "exception": exception, + } + + +def main(): + """Execute the @electron function on the Slurm cluster.""" + + output_data = { + "result": None, + "exception": None, + "result_filename": sys.argv[2], + } + try: + _check_setup() + output_data.update(**_execute()) + + except Exception as ex: + output_data.update(exception=ex) + + finally: + _record_output(**output_data) + + +def _record_output(result, exception, result_filename) -> None: + """Record the output of the @electron function""" + + with open(result_filename, "wb") as f: + pickle.dump((result, exception), f) + + +if __name__ == "__main__": + main() diff --git a/covalent_slurm_plugin/job_script.py b/covalent_slurm_plugin/job_script.py new file mode 100644 index 0000000..57df8e0 --- /dev/null +++ b/covalent_slurm_plugin/job_script.py @@ -0,0 +1,235 @@ +# Copyright 2024 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for formatting the Slurm job submission script.""" + +import re +from typing import Dict, List, Optional + +SLURM_JOB_SCRIPT_TEMPLATE = """\ +#!/bin/bash +{sbatch_directives} + +{shell_env_setup} + +{conda_env_setup} + +if [ $? -ne 0 ] ; then + >&2 echo "Failed to activate conda env '$__env_name' on compute node." + exit 99 +fi + +remote_py_version=$(python -c "print('.'.join(map(str, __import__('sys').version_info[:2])))") +if [[ $remote_py_version != "{python_version}" ]] ; then + >&2 echo "Python version mismatch." + >&2 echo "Environment '$__env_name' (python=$remote_py_version) does not match task (python={python_version})." + exit 199 +fi + +covalent_version=$(python -c "import covalent; print(covalent.__version__)") +if [ $? -ne 0 ] ; then + >&2 echo "Covalent may not be installed in the compute environment." + >&2 echo "Please install covalent=={covalent_version} in the '$__env_name' conda env." + exit 299 +elif [[ $covalent_version != "{covalent_version}" ]] ; then + >&2 echo "Covalent version mismatch." + >&2 echo "Environment '$__env_name' (covalent==$covalent_version) does not match task (covalent=={covalent_version})." + exit 299 +fi + +cloudpickle_version=$(python -c "import cloudpickle; print(cloudpickle.__version__)") +if [ $? -ne 0 ] ; then + >&2 echo "Cloudpickle may not be installed in the compute environment." + >&2 echo "Please install cloudpickle=={cloudpickle_version} in the '$__env_name' conda env." + exit 399 +elif [[ $cloudpickle_version != "{cloudpickle_version}" ]] ; then + >&2 echo "Cloudpickle version mismatch." + >&2 echo "Environment '$__env_name' (cloudpickle==$cloudpickle_version) does not match task (cloudpickle=={cloudpickle_version})." + exit 399 +fi + +{run_commands} + +wait +""" + + +class JobScript: + """Formats an sbatch submit script for the Slurm cluster.""" + + def __init__( + self, + sbatch_options: Optional[Dict[str, str]] = None, + srun_options: Optional[Dict[str, str]] = None, + variables: Optional[Dict[str, str]] = None, + bashrc_path: Optional[str] = "", + conda_env: Optional[str] = "", + prerun_commands: Optional[List[str]] = None, + srun_append: Optional[str] = "", + postrun_commands: Optional[List[str]] = None, + use_srun: bool = True, + ): + """Create a job script formatter. + + Args: + See `covalent_slurm_plugin.slurm.SlurmExecutor` for details. + """ + + self._sbatch_options = sbatch_options or {} + self._srun_options = srun_options or {} + self._variables = variables or {} + self._bashrc_path = bashrc_path + self._conda_env = conda_env + self._prerun_commands = prerun_commands or [] + self._srun_append = srun_append + self._postrun_commands = postrun_commands or [] + self._use_srun = use_srun + + @property + def sbatch_directives(self) -> str: + """Get the sbatch directives.""" + directives = [] + for key, value in self._sbatch_options.items(): + if len(key) == 1: + directives.append(f"#SBATCH -{key}" + (f" {value}" if value else "")) + else: + directives.append(f"#SBATCH --{key}" + (f"={value}" if value else "")) + + return "\n".join(directives) + + @property + def shell_env_setup(self) -> str: + """Get the shell environment setup.""" + setup_lines = [ + f"source {self._bashrc_path}" if self._bashrc_path else "", + ] + for key, value in self._variables.items(): + setup_lines.append(f'export {key}="{value}"') + + return "\n".join(setup_lines) + + @property + def conda_env_setup(self) -> str: + """Get the conda environment setup.""" + setup_lines = [] + if not self._conda_env or self._conda_env == "base": + conda_env_name = "base" + setup_lines.append("conda activate") + else: + conda_env_name = self._conda_env + setup_lines.append(f"conda activate {self._conda_env}") + + setup_lines.insert(0, f'__env_name="{conda_env_name}"') + + return "\n".join(setup_lines) + + @property + def covalent_version(self) -> str: + """Get the version of Covalent installed in the compute environment.""" + import covalent + + return covalent.__version__ + + @property + def cloudpickle_version(self) -> str: + """Get the version of cloudpickle installed in the compute environment.""" + import cloudpickle + + return cloudpickle.__version__ + + def get_run_commands( + self, + remote_py_filename: str, + func_filename: str, + result_filename: str, + ) -> str: + """Get the run commands.""" + + # Commands executed before the user's @electron function. + prerun_cmds = "\n".join(self._prerun_commands) + + # Command that executes the user's @electron function. + python_cmd = "python {remote_py_filename} {func_filename} {result_filename}" + + if not self._use_srun: + # Invoke python directly. + run_cmd = python_cmd + else: + # Invoke python via srun. + srun_options = [] + for key, value in self._srun_options.items(): + if len(key) == 1: + srun_options.append(f"-{key}" + (f" {value}" if value else "")) + else: + srun_options.append(f"--{key}" + (f"={value}" if value else "")) + + run_cmds = [ + f"srun {' '.join(srun_options)} \\" if srun_options else "srun \\", + f" {self._srun_append} \\", + f" {python_cmd}", + ] + if not self._srun_append: + # Remove (empty) commands appended to `srun` call. + run_cmds.pop(1) + + run_cmd = "\n".join(run_cmds) + + run_cmd = run_cmd.format( + remote_py_filename=remote_py_filename, + func_filename=func_filename, + result_filename=result_filename, + ) + + # Commands executed after the user's @electron function. + postrun_cmds = "\n".join(self._postrun_commands) + + # Combine all commands. + run_commands = [prerun_cmds, run_cmd, postrun_cmds] + run_commands = [cmd for cmd in run_commands if cmd] + + return "\n\n".join(run_commands) + + def format( + self, + python_version: str, + remote_py_filename: str, + func_filename: str, + result_filename: str, + ) -> str: + """Render the job script.""" + template_kwargs = { + "sbatch_directives": self.sbatch_directives, + "shell_env_setup": self.shell_env_setup, + "conda_env_setup": self.conda_env_setup, + "covalent_version": self.covalent_version, + "cloudpickle_version": self.cloudpickle_version, + "python_version": python_version, + "run_commands": self.get_run_commands( + remote_py_filename=remote_py_filename, + func_filename=func_filename, + result_filename=result_filename, + ), + } + existing_keys = set(template_kwargs.keys()) + required_keys = set(re.findall(r"\{(\w+)\}", SLURM_JOB_SCRIPT_TEMPLATE)) + + if missing_keys := required_keys - existing_keys: + raise ValueError(f"Missing required keys: {', '.join(missing_keys)}") + + if extra_keys := existing_keys - required_keys: + raise ValueError(f"Unexpected keys: {', '.join(extra_keys)}") + + return SLURM_JOB_SCRIPT_TEMPLATE.format(**template_kwargs) diff --git a/covalent_slurm_plugin/slurm.py b/covalent_slurm_plugin/slurm.py index d8e6551..a316a8d 100644 --- a/covalent_slurm_plugin/slurm.py +++ b/covalent_slurm_plugin/slurm.py @@ -17,12 +17,11 @@ """Slurm executor plugin for the Covalent dispatcher.""" import asyncio -import os import re import sys from copy import deepcopy from pathlib import Path -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import aiofiles import asyncssh @@ -32,34 +31,56 @@ from covalent._shared_files import logger from covalent._shared_files.config import get_config from covalent.executor.base import AsyncBaseExecutor +from pydantic import BaseModel, Field + +from covalent_slurm_plugin.job_script import JobScript + +__all__ = ["SlurmExecutor"] + +EXECUTOR_PLUGIN_NAME = "SlurmExecutor" + + +class ExecutorPluginDefaults(BaseModel): + """Defaults for the SlurmExecutor plugin.""" + + username: Optional[str] = "" + address: Optional[str] = "" + ssh_key_file: Optional[str] = "" + cert_file: Optional[str] = None + remote_workdir: Optional[str] = "covalent-workdir" + create_unique_workdir: bool = False + variables: Optional[Dict[str, str]] = Field(default_factory=dict) + conda_env: Optional[str] = "" + options: Optional[Dict] = Field(default_factory=lambda: {"parsable": ""}) + prerun_commands: Optional[List[str]] = Field(default_factory=list) + postrun_commands: Optional[List[str]] = Field(default_factory=list) + use_srun: bool = True + srun_options: Optional[Dict] = Field(default_factory=dict) + srun_append: Optional[str] = "" + bashrc_path: Optional[str] = "$HOME/.bashrc" + slurm_path: Optional[str] = "/usr/bin" + poll_freq: int = 5 + cleanup: bool = True + cache_dir: Optional[str] = str( + Path.home() / ".config/covalent/executor_plugins/covalent-slurm-cache" + ) + + +_EXECUTOR_PLUGIN_DEFAULTS = ExecutorPluginDefaults().model_dump() app_log = logger.app_log log_stack_info = logger.log_stack_info -_EXECUTOR_PLUGIN_DEFAULTS = { - "username": "", - "address": "", - "ssh_key_file": "", - "cert_file": None, - "remote_workdir": "covalent-workdir", - "create_unique_workdir": False, - "conda_env": "", - "options": { - "parsable": "", - }, - "prerun_commands": None, - "postrun_commands": None, - "use_srun": True, - "srun_options": {}, - "srun_append": None, - "bashrc_path": "$HOME/.bashrc", - "slurm_path": None, - "cache_dir": str(Path(get_config("dispatcher.cache_dir")).expanduser().resolve()), - "poll_freq": 60, - "cleanup": True, -} - -executor_plugin_name = "SlurmExecutor" +_LOAD_SLURM_PREFIX = """\ +source /etc/profile +module whatis slurm &> /dev/null +if [ $? -eq 0 ] ; then + module load slurm +fi +""" + +# TODO: consider enumerating job statuses +# TODO: capture poll_freq errors and inform user class SlurmExecutor(AsyncBaseExecutor): @@ -72,8 +93,9 @@ class SlurmExecutor(AsyncBaseExecutor): cert_file: Certificate file used to authenticate over SSH, if required (usually has extension .pub). remote_workdir: Working directory on the remote cluster. create_unique_workdir: Whether to create a unique working (sub)directory for each task. - conda_env: Name of conda environment on which to run the function. Use "base" for the base environment or "" for no conda. options: Dictionary of parameters used to build a Slurm submit script. + variables: A dictionary of environment variables to declare before job execution. + conda_env: Name of conda environment on which to run the function. Use "base" for the base environment or "" for no conda. prerun_commands: List of shell commands to run before running the pickled function. postrun_commands: List of shell commands to run after running the pickled function. use_srun: Whether or not to run the pickled Python function with srun. If your function itself makes srun or mpirun calls, set this to False. @@ -81,9 +103,9 @@ class SlurmExecutor(AsyncBaseExecutor): srun_append: Command nested into srun call. bashrc_path: Path to the bashrc file to source before running the function. slurm_path: Path to the slurm commands if they are not found automatically. - cache_dir: Local cache directory used by this executor for temporary files. - poll_freq: Frequency with which to poll a submitted job. Always is >= 60. + poll_freq: Frequency with which to poll a submitted job. Always is >= 5. cleanup: Whether to perform cleanup or not on remote machine. + cache_dir: Local cache directory used by this executor for temporary files. log_stdout: The path to the file to be used for redirecting stdout. log_stderr: The path to the file to be used for redirecting stderr. time_limit: time limit for the task @@ -92,24 +114,26 @@ class SlurmExecutor(AsyncBaseExecutor): def __init__( self, - username: str = None, - address: str = None, - ssh_key_file: str = None, - cert_file: str = None, - remote_workdir: str = None, - create_unique_workdir: bool = None, - conda_env: str = None, - options: Dict = None, - prerun_commands: List[str] = None, - postrun_commands: List[str] = None, - use_srun: bool = None, - srun_options: Dict = None, - srun_append: str = None, - bashrc_path: str = None, - slurm_path: str = None, - cache_dir: str = None, - poll_freq: int = None, - cleanup: bool = None, + username: Optional[str] = None, + address: Optional[str] = None, + ssh_key_file: Optional[str] = None, + cert_file: Optional[str] = None, + remote_workdir: Optional[str] = None, + create_unique_workdir: bool = False, + options: Optional[Dict] = None, + variables: Optional[Dict[str, str]] = None, + conda_env: Optional[str] = None, + prerun_commands: Optional[List[str]] = None, + postrun_commands: Optional[List[str]] = None, + use_srun: Optional[bool] = None, + srun_options: Optional[Dict] = None, + srun_append: Optional[str] = None, + bashrc_path: Optional[str] = None, + slurm_path: Optional[str] = None, + poll_freq: Optional[int] = None, + cleanup: Optional[bool] = None, + cache_dir: Optional[str] = None, + *, log_stdout: str = "", log_stderr: str = "", time_limit: int = -1, @@ -121,76 +145,42 @@ def __init__( self.username = username or get_config("executors.slurm.username") self.address = address or get_config("executors.slurm.address") - self.ssh_key_file = ssh_key_file or get_config("executors.slurm.ssh_key_file") + self.cert_file = cert_file or get_config("executors.slurm").get("cert_file", None) + self.remote_workdir = remote_workdir or get_config("executors.slurm.remote_workdir") + self.variables = variables or get_config("executors.slurm.variables") + self.conda_env = conda_env or get_config("executors.slurm.conda_env") + self.prerun_commands = prerun_commands or get_config("executors.slurm.prerun_commands") + self.postrun_commands = postrun_commands or get_config("executors.slurm.postrun_commands") + self.srun_append = srun_append or get_config("executors.slurm.srun_append") + self.slurm_path = slurm_path or get_config("executors.slurm.slurm_path") + self.poll_freq = poll_freq or get_config("executors.slurm.poll_freq") + self.cache_dir = Path(cache_dir or get_config("executors.slurm.cache_dir")) - try: - self.cert_file = cert_file or get_config("executors.slurm.cert_file") - except KeyError: - self.cert_file = None + # Allow user to override bashrc_path with empty string. + self.bashrc_path = ( + "" if bashrc_path == "" else (bashrc_path or get_config("executors.slurm.bashrc_path")) + ) - self.remote_workdir = remote_workdir or get_config("executors.slurm.remote_workdir") + self.srun_options = deepcopy(srun_options or get_config("executors.slurm.srun_options")) + self.options = deepcopy(options or get_config("executors.slurm.options")) + self.options.update(parsable="") self.create_unique_workdir = ( get_config("executors.slurm.create_unique_workdir") if create_unique_workdir is None else create_unique_workdir ) - - try: - self.slurm_path = slurm_path or get_config("executors.slurm.slurm_path") - except KeyError: - self.slurm_path = None - - try: - self.conda_env = ( - get_config("executors.slurm.conda_env") if conda_env is None else conda_env - ) - except KeyError: - self.conda_env = None - - try: - self.bashrc_path = ( - get_config("executors.slurm.bashrc_path") if bashrc_path is None else bashrc_path - ) - except KeyError: - self.bashrc_path = None - - self.cache_dir = cache_dir or get_config("executors.slurm.cache_dir") - if not os.path.exists(self.cache_dir): - os.makedirs(self.cache_dir) - - # To allow passing empty dictionary - if options is None: - options = get_config("executors.slurm.options") - self.options = deepcopy(options) - self.use_srun = get_config("executors.slurm.use_srun") if use_srun is None else use_srun - - if srun_options is None: - srun_options = get_config("executors.slurm.srun_options") - self.srun_options = deepcopy(srun_options) - - try: - self.srun_append = srun_append or get_config("executors.slurm.srun_append") - except KeyError: - self.srun_append = None - - self.prerun_commands = list(prerun_commands) if prerun_commands else [] - self.postrun_commands = list(postrun_commands) if postrun_commands else [] - - self.poll_freq = poll_freq or get_config("executors.slurm.poll_freq") - if self.poll_freq < 60: - print("Polling frequency will be increased to the minimum for Slurm: 60 seconds.") - self.poll_freq = 60 - self.cleanup = get_config("executors.slurm.cleanup") if cleanup is None else cleanup + # Force minimum value on `poll_freq`. + if self.poll_freq < 5: + app_log.info("Increasing poll_freq to the minimum allowed: 5 seconds.") + self.poll_freq = 5 - # Ensure that the slurm data is parsable - if "parsable" not in self.options: - self.options["parsable"] = "" - - self.LOAD_SLURM_PREFIX = "source /etc/profile\n module whatis slurm &> /dev/null\n if [ $? -eq 0 ] ; then\n module load slurm\n fi\n" + # Create cache dir if it doesn't exist. + if not self.cache_dir.exists(): + self.cache_dir.mkdir(parents=True) async def _client_connect(self) -> asyncssh.SSHClientConnection: """ @@ -212,10 +202,10 @@ async def _client_connect(self) -> asyncssh.SSHClientConnection: if not self.ssh_key_file: raise ValueError("ssh_key_file is a required parameter in the Slurm plugin.") - if self.cert_file and not os.path.exists(self.cert_file): + if self.cert_file and not Path(self.cert_file).exists(): raise FileNotFoundError(f"Certificate file not found: {self.cert_file}") - if not os.path.exists(self.ssh_key_file): + if not Path(self.ssh_key_file).exists(): raise FileNotFoundError(f"SSH key file not found: {self.ssh_key_file}") if self.cert_file: @@ -238,8 +228,10 @@ async def _client_connect(self) -> asyncssh.SSHClientConnection: except Exception as e: raise RuntimeError( - f"Could not connect to host: '{self.address}' as user: '{self.username}'", e - ) + f"Could not connect to host: '{self.address}' " + f"as user: '{self.username}' " + f"with key file: '{self.ssh_key_file}'" + ) from e return conn @@ -277,160 +269,51 @@ async def perform_cleanup( await conn.run(f"rm {remote_stderr_filename}") def _format_submit_script( - self, python_version: str, py_filename: str, current_remote_workdir: str + self, + python_version: str, + py_filename: str, + func_filename: str, + result_filename: str, + current_remote_workdir: str, ) -> str: """Create the SLURM that defines the job, uses srun to run the python script. Args: python_version: Python version required by the pickled function. py_filename: Name of the python script. + func_filename: Name of the pickled function file. + result_filename: Name of the pickled result file. current_remote_workdir: Current working directory on the remote machine. Returns: script: String object containing a script parsable by sbatch. """ - # Add chdir to current working directory self.options["chdir"] = current_remote_workdir - # preamble - slurm_preamble = "#!/bin/bash\n" - for key, value in self.options.items(): - slurm_preamble += "#SBATCH " - if len(key) == 1: - slurm_preamble += f"-{key}" + (f" {value}" if value else "") - else: - slurm_preamble += f"--{key}" + (f"={value}" if value else "") - slurm_preamble += "\n" - slurm_preamble += "\n" - - conda_env_clean = "" if self.conda_env == "base" else self.conda_env - - # Source commands - if self.bashrc_path: - source_text = f"source {self.bashrc_path}\n" - else: - source_text = "" - - slurm_env_vars = { - "COVALENT_CONFIG_DIR": "/tmp", - } - slurm_env_vars = ( - "\n".join([f"export {key}={value}" for key, value in slurm_env_vars.items()]) + "\n\n" + job_script = JobScript( + sbatch_options=self.options, + srun_options=self.srun_options, + variables=self.variables, + bashrc_path=self.bashrc_path, + conda_env=self.conda_env, + prerun_commands=self.prerun_commands, + srun_append=self.srun_append, + postrun_commands=self.postrun_commands, + use_srun=self.use_srun, ) - # sets up conda environment - if self.conda_env: - slurm_conda = f"""\ -conda activate {conda_env_clean} -retval=$? -if [ $retval -ne 0 ] ; then - >&2 echo "Conda environment {self.conda_env} is not present on the compute node. "\ - "Please create the environment and try again." - exit 99 -fi -""" - else: - slurm_conda = "" - - # checks remote python version - slurm_python_version = f""" -remote_py_version=$(python -c "print('.'.join(map(str, __import__('sys').version_info[:2])))") -if [[ "{python_version}" != $remote_py_version ]] ; then - >&2 echo "Python version mismatch. Please install Python {python_version} in the compute environment." - exit 199 -fi -""" - # runs pre-run commands - if self.prerun_commands: - slurm_prerun_commands = "\n".join([""] + self.prerun_commands + [""]) - else: - slurm_prerun_commands = "" - - if self.use_srun: - # uses srun to run script calling pickled function - srun_options_str = "" - for key, value in self.srun_options.items(): - srun_options_str += " " - if len(key) == 1: - srun_options_str += f"-{key}" + (f" {value}" if value else "") - else: - srun_options_str += f"--{key}" + (f"={value}" if value else "") - - slurm_srun = f"srun{srun_options_str} \\" - - if self.srun_append: - # insert any appended commands - slurm_srun += f""" - {self.srun_append} \\ - """ - else: - slurm_srun += """ - """ - - else: - slurm_srun = "" - - remote_py_filename = os.path.join(self.remote_workdir, py_filename) - python_cmd = slurm_srun + f"python {remote_py_filename}" - - # runs post-run commands - if self.postrun_commands: - slurm_postrun_commands = "\n".join([""] + self.postrun_commands + [""]) - else: - slurm_postrun_commands = "" - - # assemble commands into slurm body - slurm_body = "\n".join([slurm_prerun_commands, python_cmd, slurm_postrun_commands, "wait"]) - - # assemble script - return "".join( - [ - slurm_preamble, - source_text, - slurm_env_vars, - slurm_conda, - slurm_python_version, - slurm_body, - ] + return job_script.format( + python_version=python_version, + remote_py_filename=py_filename, + func_filename=func_filename, + result_filename=result_filename, ) - def _format_py_script( - self, - func_filename: str, - result_filename: str, - ) -> str: - """Create the Python script that executes the pickled python function. - - Args: - func_filename: Name of the pickled function. - result_filename: Name of the pickled result. - - Returns: - script: String object containing a script parsable by sbatch. - """ - func_filename = os.path.join(self.remote_workdir, func_filename) - result_filename = os.path.join(self.remote_workdir, result_filename) - return f""" -import cloudpickle as pickle - -with open("{func_filename}", "rb") as f: - function, args, kwargs = pickle.load(f) - -result = None -exception = None - -try: - result = function(*args, **kwargs) -except Exception as e: - exception = e - -with open("{result_filename}", "wb") as f: - pickle.dump((result, exception), f) -""" - async def get_status( - self, info_dict: dict, conn: asyncssh.SSHClientConnection + self, + info_dict: dict, + conn: asyncssh.SSHClientConnection, ) -> Union[Result, str]: """Query the status of a job previously submitted to Slurm. @@ -456,17 +339,21 @@ async def get_status( else: app_log.debug("Verifying slurm installation for scontrol...") - proc_verify_scontrol = await conn.run(self.LOAD_SLURM_PREFIX + "which scontrol") + proc_verify_scontrol = await conn.run(_LOAD_SLURM_PREFIX + "which scontrol") if proc_verify_scontrol.returncode != 0: raise RuntimeError("Please provide `slurm_path` to run scontrol command") - cmd_scontrol = self.LOAD_SLURM_PREFIX + cmd_scontrol + cmd_scontrol = _LOAD_SLURM_PREFIX + cmd_scontrol proc = await conn.run(cmd_scontrol) return proc.stdout.strip() - async def _poll_slurm(self, job_id: int, conn: asyncssh.SSHClientConnection) -> None: + async def _poll_slurm( + self, + job_id: int, + conn: asyncssh.SSHClientConnection, + ) -> str: """Poll a Slurm job until completion. Args: @@ -489,78 +376,99 @@ async def _poll_slurm(self, job_id: int, conn: asyncssh.SSHClientConnection) -> await asyncio.sleep(self.poll_freq) status = await self.get_status({"job_id": str(job_id)}, conn) - if "COMPLETED" not in status: - raise RuntimeError("Job failed with status:\n", status) + return status + + async def _query_logs( + self, + task_results_dir: Path, + current_remote_workdir: Path, + conn: asyncssh.SSHClientConnection, + ) -> Tuple[str, str]: + """Query and retrieve the job logs. + + Args: + task_metadata: Dictionary of metadata associated with the task. + conn: SSH connection object. + current_remote_workdir: Current working directory on the remote filesystem. + + Returns: + (stdout, stderr): Contents of the stdout and stderr log files. + """ + + stdout_file = task_results_dir / self.options["output"] + stderr_file = task_results_dir / self.options["error"] + + await asyncssh.scp((conn, current_remote_workdir / self.options["output"]), stdout_file) + await asyncssh.scp((conn, current_remote_workdir / self.options["error"]), stderr_file) + + async with aiofiles.open(stdout_file, "r") as f: + stdout = await f.read() + await async_os.remove(stdout_file) + + async with aiofiles.open(stderr_file, "r") as f: + stderr = await f.read() + await async_os.remove(stderr_file) + + return stdout, stderr async def _query_result( - self, result_filename: str, task_results_dir: str, conn: asyncssh.SSHClientConnection - ) -> Any: + self, + remote_result_filename: Path, + task_results_dir: Path, + conn: asyncssh.SSHClientConnection, + ) -> Tuple[Any, str, str, Optional[Exception]]: """Query and retrieve the task result including stdout and stderr logs. Args: result_filename: Name of the pickled result file. task_results_dir: Directory on the Covalent server where the result will be copied. conn: SSH connection object. + current_remote_workdir: Current working directory on the remote machine. + Returns: result: Task result. """ # Check the result file exists on the remote backend - remote_result_filename = os.path.join(self.remote_workdir, result_filename) - - proc = await conn.run(f"test -e {remote_result_filename}") + proc = await conn.run(f"test -e {remote_result_filename!s}") if proc.returncode != 0: raise FileNotFoundError(proc.returncode, proc.stderr.strip(), remote_result_filename) # Copy result file from remote machine to Covalent server - local_result_filename = os.path.join(task_results_dir, result_filename) + local_result_filename = task_results_dir / remote_result_filename.name await asyncssh.scp((conn, remote_result_filename), local_result_filename) - # Copy stdout, stderr from remote machine to Covalent server - stdout_file = os.path.join(task_results_dir, os.path.basename(self.options["output"])) - stderr_file = os.path.join(task_results_dir, os.path.basename(self.options["error"])) - - await asyncssh.scp((conn, self.options["output"]), stdout_file) - await asyncssh.scp((conn, self.options["error"]), stderr_file) - async with aiofiles.open(local_result_filename, "rb") as f: contents = await f.read() result, exception = pickle.loads(contents) await async_os.remove(local_result_filename) - async with aiofiles.open(stdout_file, "r") as f: - stdout = await f.read() - await async_os.remove(stdout_file) - - async with aiofiles.open(stderr_file, "r") as f: - stderr = await f.read() - await async_os.remove(stderr_file) + stdout, stderr = await self._query_logs( + task_results_dir=task_results_dir, + current_remote_workdir=remote_result_filename.parent, + conn=conn, + ) return result, stdout, stderr, exception - async def run(self, function: Callable, args: List, kwargs: Dict, task_metadata: Dict): - """Run a function on a remote machine using Slurm. + async def _copy_files( + self, + dispatch_id: str, + node_id: str, + func_tuple: Tuple[Callable, List, Dict], + conn: asyncssh.SSHClientConnection, + ) -> Dict[str, Path]: + """Copy files to remote machine. Args: - function: Function to be executed. - args: List of positional arguments to be passed to the function. - kwargs: Dictionary of keyword arguments to be passed to the function. - task_metadata: Dictionary of metadata associated with the task. + dispatch_id: Workflow dispatch ID. + node_id: Workflow node ID. + func_tuple: The wrapped electron function, its args, and its kwargs. + conn: SSH connection object. Returns: - result: Result object containing the result of the function execution. + remote_paths: Dictionary of paths on the remote filesystem. """ - dispatch_id = task_metadata["dispatch_id"] - node_id = task_metadata["node_id"] - results_dir = task_metadata["results_dir"] - task_results_dir = os.path.join(results_dir, dispatch_id) - - if self.create_unique_workdir: - current_remote_workdir = os.path.join( - self.remote_workdir, dispatch_id, "node_" + str(node_id) - ) - else: - current_remote_workdir = self.remote_workdir result_filename = f"result-{dispatch_id}-{node_id}.pkl" slurm_filename = f"slurm-{dispatch_id}-{node_id}.sh" @@ -568,91 +476,151 @@ async def run(self, function: Callable, args: List, kwargs: Dict, task_metadata: func_filename = f"func-{dispatch_id}-{node_id}.pkl" if "output" not in self.options: - self.options["output"] = os.path.join( - current_remote_workdir, f"stdout-{dispatch_id}-{node_id}.log" - ) + self.options["output"] = f"stdout-{dispatch_id}-{node_id}.log" if "error" not in self.options: - self.options["error"] = os.path.join( - current_remote_workdir, f"stderr-{dispatch_id}-{node_id}.log" - ) + self.options["error"] = f"stderr-{dispatch_id}-{node_id}.log" - result = None - - conn = await self._client_connect() - - py_version_func = ".".join(function.args[0].python_version.split(".")[:2]) - app_log.debug(f"Python version: {py_version_func}") + remote_workdir = Path(self.remote_workdir) + if self.create_unique_workdir: + current_remote_workdir = remote_workdir / f"{dispatch_id}/node_{node_id}" + else: + current_remote_workdir = remote_workdir # Create the remote directory - app_log.debug(f"Creating remote work directory {current_remote_workdir} ...") - cmd_mkdir_remote = f"mkdir -p {current_remote_workdir}" + app_log.debug("Creating remote work directory %s ...", current_remote_workdir) + cmd_mkdir_remote = f"mkdir -p {current_remote_workdir!s}" proc_mkdir_remote = await conn.run(cmd_mkdir_remote) if client_err := proc_mkdir_remote.stderr.strip(): raise RuntimeError(client_err) + function = func_tuple[0] + py_version_func = ".".join(function.args[0].python_version.split(".")[:2]) + app_log.debug("Python version: %s", py_version_func) + + # Pickle the function, write to file, and copy to remote filesystem async with aiofiles.tempfile.NamedTemporaryFile(dir=self.cache_dir) as temp_f: - # Pickle the function, write to file, and copy to remote filesystem app_log.debug("Writing pickled function, args, kwargs to file...") - await temp_f.write(pickle.dumps((function, args, kwargs))) + await temp_f.write(pickle.dumps(func_tuple)) await temp_f.flush() - remote_func_filename = os.path.join(self.remote_workdir, func_filename) - app_log.debug(f"Copying pickled function to remote fs: {remote_func_filename} ...") + remote_func_filename = current_remote_workdir / func_filename + app_log.debug("Copying pickled function to remote fs: %s ...", remote_func_filename) await asyncssh.scp(temp_f.name, (conn, remote_func_filename)) + # Format the function execution script, write to file, and copy to remote filesystem async with aiofiles.tempfile.NamedTemporaryFile(dir=self.cache_dir, mode="w") as temp_g: - # Format the function execution script, write to file, and copy to remote filesystem - python_exec_script = self._format_py_script(func_filename, result_filename) + python_exec_script = (Path(__file__).parent / "exec.py").read_text("utf-8") app_log.debug("Writing python run-function script to tempfile...") await temp_g.write(python_exec_script) await temp_g.flush() - remote_py_script_filename = os.path.join(self.remote_workdir, py_script_filename) - app_log.debug(f"Copying python run-function to remote fs: {remote_py_script_filename}") + remote_py_script_filename = current_remote_workdir / py_script_filename + app_log.debug( + "Copying python run-function to remote fs: %s", remote_py_script_filename + ) await asyncssh.scp(temp_g.name, (conn, remote_py_script_filename)) + # Format the SLURM submit script, write to file, and copy to remote filesystem async with aiofiles.tempfile.NamedTemporaryFile(dir=self.cache_dir, mode="w") as temp_h: - # Format the SLURM submit script, write to file, and copy to remote filesystem + self.options["job-name"] = ( + self.options.get("job-name", None) or f"covalent-{node_id}-{dispatch_id}" + ) slurm_submit_script = self._format_submit_script( - py_version_func, py_script_filename, current_remote_workdir + python_version=py_version_func, + py_filename=py_script_filename, + func_filename=func_filename, + result_filename=result_filename, + current_remote_workdir=current_remote_workdir, ) app_log.debug("Writing slurm submit script to tempfile...") await temp_h.write(slurm_submit_script) await temp_h.flush() - remote_slurm_filename = os.path.join(current_remote_workdir, slurm_filename) - app_log.debug(f"Copying slurm submit script to remote fs: {remote_slurm_filename} ...") + remote_slurm_filename = current_remote_workdir / slurm_filename + app_log.debug("Copying slurm submit script to remote fs: %s", remote_slurm_filename) await asyncssh.scp(temp_h.name, (conn, remote_slurm_filename)) - # Execute the script - app_log.debug(f"Running the script: {remote_slurm_filename} ...") - cmd_sbatch = f"sbatch {remote_slurm_filename}" + return { + "slurm": remote_slurm_filename, + "py": remote_py_script_filename, + "func": remote_func_filename, + "result": current_remote_workdir / result_filename, + } + + async def run( + self, function: Callable, args: List, kwargs: Dict, task_metadata: Dict + ) -> Optional[Any]: + """Run a function on a remote machine using Slurm. + + Args: + function: Function to be executed. + args: List of positional arguments to be passed to the function. + kwargs: Dictionary of keyword arguments to be passed to the function. + task_metadata: Dictionary of metadata associated with the task. + + Returns: + result: Result object containing the result of the function execution. + """ + + dispatch_id = task_metadata["dispatch_id"] + node_id = task_metadata["node_id"] + result = None + exception = None + stdout = "" + stderr = "" + task_results_dir = Path(task_metadata["results_dir"]) / dispatch_id + + conn = await self._client_connect() + + # Copy files to remote + remote_paths = await self._copy_files( + dispatch_id=dispatch_id, + node_id=node_id, + func_tuple=(function, args, kwargs), + conn=conn, + ) + + # Submit the job script with `sbatch`. + app_log.debug("Running the script: %s", remote_paths["slurm"]) + cmd_sbatch = f"sbatch {remote_paths['slurm']}" if self.slurm_path: app_log.debug("Exporting slurm path for sbatch...") cmd_sbatch = f"export PATH=$PATH:{self.slurm_path} && {cmd_sbatch}" else: app_log.debug("Verifying slurm installation for sbatch...") - proc_verify_sbatch = await conn.run(self.LOAD_SLURM_PREFIX + "which sbatch") + proc_verify_sbatch = await conn.run(_LOAD_SLURM_PREFIX + "which sbatch") if proc_verify_sbatch.returncode != 0: raise RuntimeError("Please provide `slurm_path` to run sbatch command") - cmd_sbatch = self.LOAD_SLURM_PREFIX + cmd_sbatch + cmd_sbatch = _LOAD_SLURM_PREFIX + cmd_sbatch proc = await conn.run(cmd_sbatch) - if proc.returncode != 0: raise RuntimeError(proc.stderr.strip()) - app_log.debug(f"Job submitted with stdout: {proc.stdout.strip()}") + # Poll for result. + app_log.debug("Job submitted with stdout: %s", proc.stdout.strip()) slurm_job_id = int(re.findall("[0-9]+", proc.stdout.strip())[0]) - app_log.debug(f"Polling slurm with job_id: {slurm_job_id} ...") - await self._poll_slurm(slurm_job_id, conn) + app_log.debug("Polling slurm with job_id: %s ...", slurm_job_id) + status = await self._poll_slurm(slurm_job_id, conn) - app_log.debug(f"Querying result with job_id: {slurm_job_id} ...") - result, stdout, stderr, exception = await self._query_result( - result_filename, task_results_dir, conn - ) + if "COMPLETED" in status: + app_log.debug("Querying result with job_id: %s ...", slurm_job_id) + result, stdout, stderr, exception = await self._query_result( + remote_result_filename=remote_paths["result"], + task_results_dir=task_results_dir, + conn=conn, + ) + else: + # Result will be None. Recover errors from Slurm job log files. + app_log.debug("Job submission FAILED with status:\n%s\n", status) + _, stderr = await self._query_logs( + task_results_dir=task_results_dir, + current_remote_workdir=remote_paths["result"].parent, + conn=conn, + ) + exception = RuntimeError(stderr) print(stdout) print(stderr, file=sys.stderr) @@ -661,10 +629,10 @@ async def run(self, function: Callable, args: List, kwargs: Dict, task_metadata: raise RuntimeError(exception) app_log.debug("Preparing for teardown...") - self._remote_func_filename = remote_func_filename - self._remote_slurm_filename = remote_slurm_filename - self._result_filename = result_filename - self._remote_py_script_filename = remote_py_script_filename + self._remote_func_filename = remote_paths["func"] + self._remote_slurm_filename = remote_paths["slurm"] + self._remote_py_script_filename = remote_paths["py"] + self._result_filename = remote_paths["result"] app_log.debug("Closing SSH connection...") conn.close() @@ -673,7 +641,7 @@ async def run(self, function: Callable, args: List, kwargs: Dict, task_metadata: return result - async def teardown(self, task_metadata: Dict): + async def teardown(self, task_metadata: Dict) -> None: """Perform cleanup on remote machine. Args: @@ -686,16 +654,15 @@ async def teardown(self, task_metadata: Dict): try: app_log.debug("Performing cleanup on remote...") conn = await self._client_connect() + current_remote_workdir = self._result_filename.parent await self.perform_cleanup( conn=conn, remote_func_filename=self._remote_func_filename, remote_slurm_filename=self._remote_slurm_filename, remote_py_filename=self._remote_py_script_filename, - remote_result_filename=os.path.join( - self.remote_workdir, self._result_filename - ), - remote_stdout_filename=self.options["output"], - remote_stderr_filename=self.options["error"], + remote_result_filename=self._result_filename, + remote_stdout_filename=current_remote_workdir / self.options["output"], + remote_stderr_filename=current_remote_workdir / self.options["error"], ) app_log.debug("Closing SSH connection...") diff --git a/requirements.txt b/requirements.txt index f6bdc6d..d4a9ffa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ aiofiles==0.8.0 asyncssh>=2.10.1 -covalent>=0.202.0,<1 +covalent>=0.232.0,<1 diff --git a/tests/exec_test.py b/tests/exec_test.py new file mode 100644 index 0000000..3bfd7ad --- /dev/null +++ b/tests/exec_test.py @@ -0,0 +1,174 @@ +# Copyright 2024 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the @electron task execution script that runs on the SLURM cluster""" + +import sys +import tempfile +from unittest import mock + +import cloudpickle as pickle +import filelock +import pytest + +from covalent_slurm_plugin import exec as exec_script + + +def test_setup_import_error(): + """Test error message correctly reported in case of ImportError""" + + patched_import_covalent = mock.patch( + "covalent_slurm_plugin.exec._import_covalent", side_effect=ImportError() + ) + + with patched_import_covalent: + with pytest.raises( + RuntimeError, match="The covalent SDK is not installed in the Slurm job environment." + ): + exec_script._check_setup() + + +def test_setup_filelock_timeout_error(): + """Test error message correctly reported in case of filelock Timeout error""" + + patched_import_covalent = mock.patch( + "covalent_slurm_plugin.exec._import_covalent", + side_effect=filelock._error.Timeout("nonexistent_lock_file"), + ) + + with patched_import_covalent: + with pytest.raises(RuntimeError, match="Failed to acquire file lock"): + exec_script._check_setup() + + +def test_execute(): + """Test _execute correctly reports the result""" + + # Disable this - don't need `covalent` with fake electron functions. + patched_import_covalent = mock.patch( + "covalent_slurm_plugin.exec._import_covalent", + ) + + def _fake_electron_function(x, y): + return x + y + + args = (1, 2) + kwargs = {} + with tempfile.NamedTemporaryFile("wb") as f: + pickle.dump((_fake_electron_function, args, kwargs), f) + f.flush() + f.seek(0) + + patched_sys_argv = mock.patch.object(sys, "argv", ["", f.name]) + + with patched_import_covalent, patched_sys_argv: + result_dict = exec_script._execute() + + assert result_dict["result"] == 3 + assert result_dict["exception"] is None + + +def test_execute_with_error(): + """Test _execute correctly reports an error""" + + _fake_exception = ValueError("bad bad not good") + + # Disable this - don't need `covalent` with fake electron functions. + patched_import_covalent = mock.patch( + "covalent_slurm_plugin.exec._import_covalent", + ) + + def _fake_electron_function(x, y): + raise _fake_exception + + args = (1, 2) + kwargs = {} + with tempfile.NamedTemporaryFile("wb") as f: + pickle.dump((_fake_electron_function, args, kwargs), f) + f.flush() + f.seek(0) + + patched_sys_argv = mock.patch.object(sys, "argv", ["", f.name]) + + with patched_import_covalent, patched_sys_argv: + result_dict = exec_script._execute() + + assert result_dict["result"] is None + assert isinstance(result_dict["exception"], type(_fake_exception)) + assert "bad bad not good" in str(result_dict["exception"]) + + +def test_main(): + """Test the main function correctly reports the result""" + + _fake_result = 123456789 + + # Disable this - don't need `covalent` with fake electron functions. + patched_import_covalent = mock.patch( + "covalent_slurm_plugin.exec._import_covalent", + ) + patched_execute = mock.patch( + "covalent_slurm_plugin.exec._execute", + return_value={"result": _fake_result, "exception": None}, + ) + + with tempfile.NamedTemporaryFile("wb") as f: + patched_sys_argv = mock.patch.object(sys, "argv", ["", "fake_func_filename", f.name]) + + with patched_import_covalent, patched_sys_argv, patched_execute: + # Write the result to the temporary file. + exec_script.main() + + f.flush() + f.seek(0) + + with open(f.name, "rb") as f: + result, exception = pickle.load(f) + + assert result == _fake_result + assert exception is None + + +def test_main_with_error(): + """Test the main function correctly reports an error""" + + _fake_exception = SyntaxError("bad bad not good") + + # Disable this - don't need `covalent` with fake electron functions. + patched_import_covalent = mock.patch( + "covalent_slurm_plugin.exec._import_covalent", + ) + patched_execute = mock.patch( + "covalent_slurm_plugin.exec._execute", + return_value={"result": None, "exception": _fake_exception}, + ) + + with tempfile.NamedTemporaryFile("wb") as f: + patched_sys_argv = mock.patch.object(sys, "argv", ["", "fake_func_filename", f.name]) + + with patched_import_covalent, patched_sys_argv, patched_execute: + # Write the result to the temporary file. + exec_script.main() + + f.flush() + f.seek(0) + + with open(f.name, "rb") as f: + result, exception = pickle.load(f) + + assert result is None + assert isinstance(exception, type(_fake_exception)) + assert "bad bad not good" in str(exception) diff --git a/tests/requirements.txt b/tests/requirements.txt index 9ae9f41..122b7a7 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,6 +1,6 @@ aiofiles==0.8.0 asyncssh>=2.10.1 -covalent>=0.202.0 +covalent>=0.232.0,<1 flake8==3.9.2 isort==5.7.0 mock==4.0.3 diff --git a/tests/slurm_test.py b/tests/slurm_test.py index 564b0d6..6f5e7e1 100644 --- a/tests/slurm_test.py +++ b/tests/slurm_test.py @@ -17,7 +17,6 @@ """Tests for the SLURM executor plugin.""" import os -from copy import deepcopy from functools import partial from pathlib import Path from unittest import mock @@ -25,11 +24,11 @@ import aiofiles import pytest from covalent._results_manager.result import Result -from covalent._shared_files.config import get_config, set_config from covalent._workflow.transport import TransportableObject -from covalent.executor.base import wrapper_fn +from covalent.executor.utils.wrappers import wrapper_fn from covalent_slurm_plugin import SlurmExecutor +from covalent_slurm_plugin.slurm import _EXECUTOR_PLUGIN_DEFAULTS aiofiles.threadpool.wrap.register(mock.MagicMock)( lambda *args, **kwargs: aiofiles.threadpool.AsyncBufferedIOBase(*args, **kwargs) @@ -79,17 +78,15 @@ def test_init(): assert executor.cert_file is None assert executor.remote_workdir == "covalent-workdir" assert executor.create_unique_workdir is False - assert executor.slurm_path is None + assert executor.slurm_path == "/usr/bin" assert executor.conda_env == "" - assert executor.poll_freq == 60 + assert executor.poll_freq == 5 assert executor.options == {"parsable": ""} assert executor.srun_options == {} - assert executor.srun_append is None + assert executor.srun_append == "" assert executor.prerun_commands == [] assert executor.postrun_commands == [] - assert executor.cache_dir == str( - Path(get_config("dispatcher.cache_dir")).expanduser().resolve() - ) + assert executor.cache_dir == Path(_EXECUTOR_PLUGIN_DEFAULTS["cache_dir"]) assert executor.cleanup is True # Test with non-defaults @@ -148,7 +145,7 @@ def test_init(): assert executor.srun_append == srun_append assert executor.prerun_commands == prerun_commands assert executor.postrun_commands == postrun_commands - assert executor.cache_dir == cache_dir + assert executor.cache_dir == Path(cache_dir) assert executor.poll_freq == poll_freq assert executor.cleanup == cleanup @@ -157,52 +154,10 @@ def test_init(): username=username, address=host, ssh_key_file=key_file, - poll_freq=30, - ) - - assert executor.poll_freq == 60 - - -def test_failed_init(): - """Test for failed inits""" - - start_config = deepcopy(get_config()) - for key in ["cert_file", "slurm_path", "conda_env", "bashrc_path", "srun_append"]: - config = get_config() - config["executors"]["slurm"].pop(key, None) - set_config(config) - executor = SlurmExecutor(username="username", address="host", ssh_key_file=SSH_KEY_FILE) - assert not executor.__dict__[key] - set_config(start_config) - - -def test_format_py_script(): - """Test that the python script (in string form) which is to be executed (via srun) - on the remote server is created with no errors.""" - - executor_0 = SlurmExecutor( - username="test_user", - address="test_address", - ssh_key_file=SSH_KEY_FILE, - cert_file=CERT_FILE, - remote_workdir="/federation/test_user/.cache/covalent", - options={}, - cache_dir="~/.cache/covalent", - poll_freq=60, + poll_freq=1, ) - dispatch_id = "148dedae-1b58-3870-z08d-db89bceec915" - task_id = 2 - func_filename = f"func-{dispatch_id}-{task_id}.pkl" - result_filename = f"result-{dispatch_id}-{task_id}.pkl" - - try: - py_script_str = executor_0._format_py_script(func_filename, result_filename) - print(py_script_str) - except Exception as exc: - assert False, f"Exception while running _format_py_script: {exc}" - assert func_filename in py_script_str - assert result_filename in py_script_str + assert executor.poll_freq == 5 def test_format_submit_script_default(): @@ -230,10 +185,16 @@ def simple_task(x): dispatch_id = "259efebf-2c69-4981-a19e-ec90cdffd026" task_id = 3 py_filename = f"script-{dispatch_id}-{task_id}.py" + func_filename = f"func-{dispatch_id}-{task_id}.pkl" + result_filename = f"result-{dispatch_id}-{task_id}.pkl" try: submit_script_str = executor_0._format_submit_script( - python_version, py_filename, remote_workdir + python_version=python_version, + py_filename=py_filename, + func_filename=func_filename, + result_filename=result_filename, + current_remote_workdir=remote_workdir, ) print(submit_script_str) except Exception as exc: @@ -244,7 +205,8 @@ def simple_task(x): assert submit_script_str.startswith( shebang ), f"Missing '{shebang[:-1]}' in sbatch shell script" - assert "conda" not in submit_script_str + + assert "conda activate" in submit_script_str assert "source $HOME/.bashrc" in submit_script_str assert "srun" in submit_script_str assert "--chdir=" + remote_workdir in submit_script_str @@ -287,10 +249,17 @@ def simple_task(x): dispatch_id = "259efebf-2c69-4981-a19e-ec90cdffd026" task_id = 3 py_filename = f"script-{dispatch_id}-{task_id}.py" + func_filename = f"func-{dispatch_id}-{task_id}.pkl" + result_filename = f"result-{dispatch_id}-{task_id}.pkl" + current_remote_workdir = os.path.join(remote_workdir, dispatch_id, "node_" + str(task_id)) try: submit_script_str = executor_1._format_submit_script( - python_version, py_filename, current_remote_workdir + python_version=python_version, + py_filename=py_filename, + func_filename=func_filename, + result_filename=result_filename, + current_remote_workdir=current_remote_workdir, ) print(submit_script_str) except Exception as exc: @@ -330,10 +299,16 @@ def simple_task(x): dispatch_id = "259efebf-2c69-4981-a19e-ec90cdffd026" task_id = 3 py_filename = f"script-{dispatch_id}-{task_id}.py" + func_filename = f"func-{dispatch_id}-{task_id}.pkl" + result_filename = f"result-{dispatch_id}-{task_id}.pkl" try: submit_script_str = executor_1._format_submit_script( - python_version, py_filename, remote_workdir + python_version=python_version, + py_filename=py_filename, + func_filename=func_filename, + result_filename=result_filename, + current_remote_workdir=remote_workdir, ) print(submit_script_str) except Exception as exc: @@ -369,16 +344,22 @@ def simple_task(x): dispatch_id = "259efebf-2c69-4981-a19e-ec90cdffd026" task_id = 3 py_filename = f"script-{dispatch_id}-{task_id}.py" + func_filename = f"func-{dispatch_id}-{task_id}.pkl" + result_filename = f"result-{dispatch_id}-{task_id}.pkl" try: submit_script_str = executor_2._format_submit_script( - python_version, py_filename, remote_workdir + python_version=python_version, + py_filename=py_filename, + func_filename=func_filename, + result_filename=result_filename, + current_remote_workdir=remote_workdir, ) print(submit_script_str) except Exception as exc: assert False, f"Exception while running _format_submit_script with default options: {exc}" - assert "conda" not in submit_script_str + assert "conda activate" in submit_script_str assert "source" not in submit_script_str @@ -451,7 +432,7 @@ async def test_get_status(proc_mock, conn_mock): status = await executor.get_status({"job_id": 0}, conn_mock) assert status == "Fake Status" - assert conn_mock.run.call_count == 2 + assert conn_mock.run.call_count == 1 @pytest.mark.asyncio @@ -519,7 +500,7 @@ async def test_query_result(mocker, proc_mock, conn_mock): try: await executor._query_result( - result_filename="mock_result", task_results_dir="", conn=conn_mock + remote_result_filename=Path("mock_result"), task_results_dir=Path("."), conn=conn_mock ) except Exception as raised_exception: expected_exception = FileNotFoundError(1, "stderr") @@ -547,18 +528,18 @@ async def test_query_result(mocker, proc_mock, conn_mock): unpatched_open = open def mock_open(*args, **kwargs): - if args[0] == "mock_result": + if args[0].name == "mock_result": return mock.mock_open(read_data=None)(*args, **kwargs) - elif args[0] == executor.options["output"]: + elif args[0].name == executor.options["output"]: return mock.mock_open(read_data=expected_stdout)(*args, **kwargs) - elif args[0] == executor.options["error"]: + elif args[0].name == executor.options["error"]: return mock.mock_open(read_data=expected_stderr)(*args, **kwargs) else: return unpatched_open(*args, **kwargs) with mock.patch("aiofiles.threadpool.sync_open", mock_open): result, stdout, stderr, exception = await executor._query_result( - result_filename="mock_result", task_results_dir="", conn=conn_mock + remote_result_filename=Path("mock_result"), task_results_dir=Path("."), conn=conn_mock ) assert result == expected_results @@ -627,19 +608,19 @@ def reset_proc_mock(): proc_mock.stderr = "" proc_mock.returncode = 0 - async def __client_connect_fail(*_): + async def __client_connect_fail(*_, **__): return conn_mock - async def __client_connect_succeed(*_): + async def __client_connect_succeed(*_, **__): return conn_mock - async def __poll_slurm_succeed(*_): + async def __poll_slurm_succeed(*_, **__): return - async def __query_result_fail(*_): + async def __query_result_fail(*_, **__): return None, proc_mock.stdout, proc_mock.stderr, dummy_error_msg - async def __query_result_succeed(*_): + async def __query_result_succeed(*_, **__): return "result", "", "", None # patches @@ -765,10 +746,10 @@ def f(x, y): conn_mock.run = mock.AsyncMock(return_value=proc_mock) conn_mock.wait_closed = mock.AsyncMock(return_value=None) - async def __client_connect_succeed(*_): + async def __client_connect_succeed(*_, **__): return conn_mock - async def __query_result_succeed(*_): + async def __query_result_succeed(*_, **__): return "result", "", "", None async def __perform_cleanup(*_, **__):