diff --git a/.ci/Dockerfile b/.ci/Dockerfile deleted file mode 100644 index e8458d74eb..0000000000 --- a/.ci/Dockerfile +++ /dev/null @@ -1,111 +0,0 @@ -FROM ubuntu:18.04 -MAINTAINER AiiDA Team - -# This is necessary such that the setup of `tzlocal` is non-interactive -ENV DEBIAN_FRONTEND=noninteractive - -ARG uid=1000 -ARG gid=1000 - -# Set correct locale -# For something more complex, as reported by https://hub.docker.com/_/ubuntu/ -# and taken from postgres: -# make the "en_US.UTF-8" locale so postgres will be utf-8 enabled by default -# The `software-properties-common` is necessary to get the command `add-apt-repository` -RUN apt-get update && apt-get install -y locales software-properties-common && rm -rf /var/lib/apt/lists/* \ - && localedef -i en_US -c -f UTF-8 -A /usr/share/locale/locale.alias en_US.UTF-8 -ENV LANG en_US.utf8 - -# Putting the LANG also in the root .bashrc, so that the DB is later -# Created with the UTF8 locale -RUN sed -i '/interactively/iexport LANG=en_US.utf8' /root/.bashrc -# This is probably the right command to issue to make sure all users see it as the default locale -RUN update-locale LANG=en_US.utf8 - -# I don't define it for now (should use the one of ubuntu by default, anyway -# jenkins will replace it with 'cat') -#CMD ["/bin/true"] - -RUN add-apt-repository ppa:deadsnakes/ppa - -# install required software -RUN apt-get update \ - && apt-get -y install \ - git \ - vim \ - openssh-client \ - postgresql-client-10 \ - postgresql-10 \ - postgresql-server-dev-10 \ - && apt-get -y install \ - python3.7 python3.7-dev \ - python3-pip \ - ipython \ - texlive-base \ - texlive-generic-recommended \ - texlive-fonts-recommended \ - texlive-latex-base \ - texlive-latex-recommended \ - texlive-latex-extra \ - dvipng \ - dvidvi \ - graphviz \ - bc \ - time \ - && rm -rf /var/lib/apt/lists/* \ - && apt-get clean all - -# Disable password requests for requests coming from localhost -# Of course insecure, but ok for testing -RUN cp /etc/postgresql/10/main/pg_hba.conf /etc/postgresql/10/main/pg_hba.conf~ && \ - perl -npe 's/^([^#]*)md5$/$1trust/' /etc/postgresql/10/main/pg_hba.conf~ > /etc/postgresql/10/main/pg_hba.conf - -# install sudo otherwise tests for quicksetup fail, -# see #1382. I think this part should be removed in the -# future and AiiDA should work also without sudo. -## Also install openssh-server needed for AiiDA tests, -## and openmpi-bin to have 'mpirun', -## and rabbitmq-server needed by AiiDA as the event queue -## and libkrb5-dev for gssapi.h -RUN apt-get update \ - && apt-get -y install \ - sudo \ - locate \ - openssh-server \ - openmpi-bin \ - rabbitmq-server \ - libkrb5-dev \ - && rm -rf /var/lib/apt/lists/* \ - && apt-get clean all - -# locate will not find anything if the DB is not updated. -# Should take ~3-4 secs, so ok -RUN updatedb - -# update pip and setuptools to get a relatively-recent version -# This can be updated in the future -RUN pip3 install pip==19.2.3 setuptools==42.0.2 - -# Put the doubler script -COPY doubler.sh /usr/local/bin/ - -# Use messed-up filename to test quoting robustness -RUN mv /usr/local/bin/doubler.sh /usr/local/bin/d\"o\'ub\ ler.sh - -# add USER (no password); 1000 is the uid of the user in the jenkins docker -RUN groupadd -g ${gid} jenkins && useradd -m -s /bin/bash -u ${uid} -g ${gid} jenkins - -# add to sudoers and don't ask password -RUN adduser jenkins sudo && adduser jenkins adm -RUN echo "%sudo ALL=(ALL:ALL) NOPASSWD:ALL" > /etc/sudoers.d/nopwd -RUN mkdir -p /scratch/jenkins/ && chown jenkins /scratch/jenkins/ && chmod o+rX /scratch/ - -########################################## -############ Installation Setup ########## -########################################## - -# install rest of the packages as normal user -USER jenkins - -# set $HOME, create git directory -ENV HOME /home/jenkins diff --git a/.ci/Jenkinsfile b/.ci/Jenkinsfile deleted file mode 100644 index 74c6a2078f..0000000000 --- a/.ci/Jenkinsfile +++ /dev/null @@ -1,179 +0,0 @@ -// Note: this part might happen on a different node than -// the one that will run the pipeline below, see -// https://stackoverflow.com/questions/44805076 -// but it should be ok for us as we only have one node -def user_id -def group_id -node { - user_id = sh(returnStdout: true, script: 'id -u').trim() - group_id = sh(returnStdout: true, script: 'id -g').trim() -} - -pipeline { - /* The tutorial was setting here agent none, and setting the - agent in each stage, using therefore different agents in each - stage. I think that for what we are trying to achieve, having - a single agent and running all in the same docker image is better, - but we need to check this for more advanced usages. */ - // agent none - agent { - // Documentation: https://jenkins.io/doc/book/pipeline/syntax/#agent - // Note: we reuse the pip cache for speed - // TMPFS: we make sure that postgres is different for every run, - // but also runs fast - dockerfile { - filename 'Dockerfile' - dir '.ci' - args '-v jenkins-pip-cache:/home/jenkins/.cache/pip/ --tmpfs /var/lib/postgresql-tmp --tmpfs /tmp:exec' - additionalBuildArgs "--build-arg uid=${user_id} --build-arg gid=${group_id}" - } - } - environment { - WORKSPACE_PATH="." - COMPUTER_SETUP_TYPE="jenkins" - // The following two variables allow to run selectively tests only for one backend - RUN_ALSO_DJANGO="true" - RUN_ALSO_SQLALCHEMY="true" - // To avoid that different pipes (stderr, stdout, different processes) get in the wrong order - PYTHONUNBUFFERED="yes" - } - stages { - stage('Pre-build') { - steps { - // Clean work dir (often runs reshare the same folder, and it might - // contain old data from previous runs - this is particularly - // problematic when a folder is deleted from git but .pyc files - // are left in) - sh 'git clean -fdx' - sh 'sudo /etc/init.d/ssh restart' - sh 'sudo chown -R jenkins:jenkins /home/jenkins/.cache/' - // (re)start rabbitmq (both to start it or to reload the configuration) - sh 'sudo /etc/init.d/rabbitmq-server restart' - - // Make sure the tmpfs folder is owned by postgres, and that it - // contains the right data - sh 'sudo chown postgres:postgres /var/lib/postgresql-tmp' - sh 'sudo mv /var/lib/postgresql/* /var/lib/postgresql-tmp/' - sh 'sudo rmdir /var/lib/postgresql/' - sh 'sudo ln -s /var/lib/postgresql-tmp/ /var/lib/postgresql' - - // (re)start postgres (both to start it or to reload the configuration) - sh 'sudo /etc/init.d/postgresql restart' - - // rerun updatedb otherwise 'locate' prints a warning that the DB is old... - sh 'sudo updatedb' - - // Debug: check that I can connect without password - sh 'echo "SELECT datname FROM pg_database" | psql -h localhost -U postgres -w' - - // Add the line to the .bashrc, but before it stops when non-interactive - // So it can find the location of 'verdi' - sh "sed -i '/interactively/iexport PATH=\${PATH}:~/.local/bin' ~/.bashrc" - // Add path needed by the daemon to find the workchains - sh "sed -i '/interactively/iexport PYTHONPATH=\${PYTHONPATH}:'`pwd`'/.ci/' ~/.bashrc" - sh "cat ~/.bashrc" - } - } - stage('Build') { - steps { - sh 'pip install --upgrade --user pip' - sh 'pip install --user .[all]' - // To be able to do ssh localhost - sh 'ssh-keygen -t rsa -N "" -f ~/.ssh/id_rsa' - sh 'cp ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys' - sh 'ssh-keyscan -H localhost >> ~/.ssh/known_hosts' - } - post { - always { - sh 'pip freeze > pip-freeze.txt' - archiveArtifacts artifacts: 'pip-freeze.txt', fingerprint: true - } - } - } - stage('Test') { - failFast false // It is the default, but I still put it for future reference - // failFast would stop as soon as there is a failing test - parallel { - stage('Test-Django') { - environment { - AIIDA_TEST_BACKEND="django" - // I run the two tests in two different folders, otherwise - // they might get at the point of writing the config.json at the - // same time and one of the two would crash - AIIDA_PATH="/tmp/aiida-django-folder" - } - when { - // This allows to selectively run only one backend - environment name: 'RUN_ALSO_DJANGO', value: 'true' - } - steps { - sh '.ci/setup.sh' - sh '.ci/test_rpn.sh' - } - } - stage('Test-SQLAlchemy') { - environment { - AIIDA_TEST_BACKEND="sqlalchemy" - AIIDA_PATH="/tmp/aiida-sqla-folder" - } - when { - // This allows to selectively run only one backend - environment name: 'RUN_ALSO_SQLALCHEMY', value: 'true' - } - steps { - sh '.ci/setup.sh' - sh '.ci/test_rpn.sh' - } - } - } - } - } - post { - always { - // Some debug stuff - sh 'whoami ; pwd; echo $AIIDA_TEST_BACKEND' - cleanWs() - } - success { - echo 'The run finished successfully!' - } - unstable { - echo 'This run is unstable...' - } - failure { - echo "This run failed..." - } - // You can trigger actions when the status change (e.g. it starts failing, - // or it starts working again - e.g. sending emails or similar) - // possible variables: see e.g. https://qa.nuxeo.org/jenkins/pipeline-syntax/globals - // Other valid names: fixed, regression (opposite of fixed), aborted (by user, typically) - // Note that I had problems with email, I don't know if it is a configuration problem - // or a missing plugin. - changed { - script { - if (currentBuild.getPreviousBuild()) { - echo "The state changed from ${currentBuild.getPreviousBuild().result} to ${currentBuild.currentResult}." - } - else { - echo "This is the first build, and its status is: ${currentBuild.currentResult}." - } - } - } - } - options { - // we do not want the whole run to hang forever - - // we set a total timeout of 1 hour - timeout(time: 60, unit: 'MINUTES') - } -} - - -// Other things to add possibly: -// global options (or per-stage options) with timeout: https://jenkins.io/doc/book/pipeline/syntax/#options-example -// retry-on-failure for some specific tasks: https://jenkins.io/doc/book/pipeline/syntax/#available-stage-options -// parameters: https://jenkins.io/doc/book/pipeline/syntax/#parameters -// input: interesting for user input before continuing: https://jenkins.io/doc/book/pipeline/syntax/#input -// when conditions, e.g. to depending on details on the commit (e.g. only when specific -// files are changed, where there is a string in the commit log, for a specific branch, -// for a Pull Request,for a specific environment variable, ...): -// https://jenkins.io/doc/book/pipeline/syntax/#when diff --git a/.ci/setup.sh b/.ci/setup.sh deleted file mode 100755 index c28fdcadc9..0000000000 --- a/.ci/setup.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env bash -set -ev - -# The following is needed on jenkins, for some reason bashrc is not reloaded automatically -if [ -e ~/.bashrc ] ; then source ~/.bashrc ; fi - -# Add the .ci and the polish folder to the python path such that defined workchains can be found by the daemon -export PYTHONPATH="${PYTHONPATH}:${WORKSPACE_PATH}/.ci" -export PYTHONPATH="${PYTHONPATH}:${WORKSPACE_PATH}/.ci/polish" - -PSQL_COMMAND="CREATE DATABASE $AIIDA_TEST_BACKEND ENCODING \"UTF8\" LC_COLLATE=\"en_US.UTF-8\" LC_CTYPE=\"en_US.UTF-8\" TEMPLATE=template0;" -psql -h localhost -c "${PSQL_COMMAND}" -U postgres -w - -verdi setup --profile $AIIDA_TEST_BACKEND \ - --email="aiida@localhost" --first-name=AiiDA --last-name=test --institution="AiiDA Team" \ - --db-engine 'postgresql_psycopg2' --db-backend=$AIIDA_TEST_BACKEND --db-host="localhost" --db-port=5432 \ - --db-name="$AIIDA_TEST_BACKEND" --db-username=postgres --db-password='' \ - --repository="/tmp/repository_${AIIDA_TEST_BACKEND}/" --non-interactive - -verdi profile setdefault $AIIDA_TEST_BACKEND -verdi config runner.poll.interval 0 - -# Start the daemon for the correct profile and add four additional workers to prevent deadlock with integration tests -verdi -p $AIIDA_TEST_BACKEND daemon start -verdi -p $AIIDA_TEST_BACKEND daemon incr 4 - -verdi -p $AIIDA_TEST_BACKEND computer setup --non-interactive --label=localhost --hostname=localhost --transport=local \ - --scheduler=direct --mpiprocs-per-machine=1 --prepend-text="" --append-text="" -verdi -p $AIIDA_TEST_BACKEND computer configure local localhost --non-interactive --safe-interval=0 - -# Configure the 'add' code inside localhost -verdi -p $AIIDA_TEST_BACKEND code setup -n -L add \ - -D "simple script that adds two numbers" --on-computer -P arithmetic.add \ - -Y localhost --remote-abs-path=/bin/bash diff --git a/.ci/test_rpn.sh b/.ci/test_rpn.sh deleted file mode 100755 index 7af950b26a..0000000000 --- a/.ci/test_rpn.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env bash - -# Be verbose, and stop with error as soon there's one -set -ev - -declare -a EXPRESSIONS=("1 -2 -1 4 -5 -5 * * * * +" "2 1 3 3 -1 + ^ ^ +" "3 -5 -1 -4 + * ^" "2 4 2 -4 * * +" "3 1 1 5 ^ ^ ^" "3 1 3 4 -4 2 * + + ^ ^") -NUMBER_WORKCHAINS=5 -TIMEOUT=600 -CODE='add!' # Note the exclamation point is necessary to force the value to be interpreted as LABEL type identifier - -# Needed on Jenkins -if [ -e ~/.bashrc ] ; then source ~/.bashrc ; fi - -# Define the absolute path to the RPN cli script -DATA_DIR="${WORKSPACE_PATH}/.ci" -CLI_SCRIPT="${DATA_DIR}/polish/cli.py" - -# Export the polish module to the python path so generated workchains can be imported -export PYTHONPATH="${PYTHONPATH}:${DATA_DIR}/polish" - -# Get the absolute path for verdi -VERDI=$(which verdi) - -if [ -n "$EXPRESSIONS" ]; then - for expression in "${EXPRESSIONS[@]}"; do - $VERDI -p ${AIIDA_TEST_BACKEND} run "${CLI_SCRIPT}" -X $CODE -C -F -d -t $TIMEOUT "$expression" - done -else - for i in $(seq 1 $NUMBER_WORKCHAINS); do - $VERDI -p ${AIIDA_TEST_BACKEND} run "${CLI_SCRIPT}" -X $CODE -C -F -d -t $TIMEOUT - done -fi diff --git a/.coveragerc b/.coveragerc index b27dfc7b30..6b8baa55af 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,4 +2,4 @@ source = aiida [html] -directory = .ci/coverage/html +directory = .coverage_html diff --git a/.ci/docker-rabbitmq.yml b/.docker/docker-rabbitmq.yml similarity index 77% rename from .ci/docker-rabbitmq.yml rename to .docker/docker-rabbitmq.yml index 894f81f587..da266790ff 100644 --- a/.ci/docker-rabbitmq.yml +++ b/.docker/docker-rabbitmq.yml @@ -2,10 +2,10 @@ # if you wish to control the rabbitmq used. # Simply install docker, then run: -# $ docker-compose -f .ci/docker-rabbitmq.yml up -d +# $ docker-compose -f .docker/docker-rabbitmq.yml up -d # and to power down, after testing: -# $ docker-compose -f .ci/docker-rabbitmq.yml down +# $ docker-compose -f .docker/docker-rabbitmq.yml down # you can monitor rabbitmq use at: http://localhost:15672 @@ -27,3 +27,8 @@ services: interval: 30s timeout: 30s retries: 5 + networks: + - aiida-rmq + +networks: + aiida-rmq: diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000..dfe06bad59 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,13 @@ +.benchmarks +.cache +.coverage +.mypy_cache +.pytest_cache +.tox +.vscode +aiida_core.egg-info +docs/build +pip-wheel-metadata +**/.DS_Store +**/*.pyc +**/__pycache__ diff --git a/.github/config/README.md b/.github/config/README.md new file mode 100644 index 0000000000..6b43da3d5b --- /dev/null +++ b/.github/config/README.md @@ -0,0 +1,5 @@ +# AiiDA configuration files + +This folder contains configuration files for AiiDA computers, codes etc. + + - `slurm_rsa`: private key that provides access to the `slurm-ssh` container diff --git a/.ci/doubler.sh b/.github/config/doubler.sh similarity index 100% rename from .ci/doubler.sh rename to .github/config/doubler.sh diff --git a/.github/config/slurm-ssh-config.yaml b/.github/config/slurm-ssh-config.yaml new file mode 100644 index 0000000000..48332209de --- /dev/null +++ b/.github/config/slurm-ssh-config.yaml @@ -0,0 +1,7 @@ +--- +safe_interval: 0 +username: xenon +look_for_keys: true +key_filename: "PLACEHOLDER_SSH_KEY" +key_policy: AutoAddPolicy +port: 5001 diff --git a/.github/config/slurm-ssh.yaml b/.github/config/slurm-ssh.yaml new file mode 100644 index 0000000000..43e5919e5b --- /dev/null +++ b/.github/config/slurm-ssh.yaml @@ -0,0 +1,12 @@ +--- +label: slurm-ssh +description: slurm container +hostname: localhost +transport: ssh +scheduler: slurm +shebang: "#!/bin/bash" +work_dir: /home/{username}/workdir +mpirun_command: "mpirun -np {tot_num_mpiprocs}" +mpiprocs_per_machine: 1 +prepend_text: "" +append_text: "" diff --git a/.github/config/slurm_rsa b/.github/config/slurm_rsa new file mode 100644 index 0000000000..20123b7d8c --- /dev/null +++ b/.github/config/slurm_rsa @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAnCqpTQFbmi1WPX4uTUFCHAvf61AhvqXUFoJEHQEvtDYibWJZ +bI7LueA2eEKw68oynIfPeinr4+DOnejMG1+HKCWi03DzWoorBOYc0e9i3nxkU93j +hZZsiQZfBgcCenqh2t1ZLbEFdFnCqLDw6gbDH0F3W3NJW0Q30a8HQ01lqdSKyVdf +UghVLCx1HM53BxXEYGU2m2ii+uyoMIsz9TSCJdKXIAb5N4tZYqKPF8q0vf1eP2BB +SUsn4bAHpPqvx3I0HkyR6qV5UT4K91FteULLTJHjK3Y0bBUMOmNQPh0JTmfj/KNB +EtJdlGYE0Tce1XINvhHItSpdFZs8GTnmOzUaVQIDAQABAoIBAEpWsILcm5tX646Y +KzhRUUQCjxP38ChNzhjs57ma3/d8MYU6ZPEdRHN1/Nfgf1Guzcrfh29S11yBnjlj +IQ4CulbtG4ZlZSJ7VSEe3Sc+OiVIt4WIwY7M3VuY8dDvs0lUaQnDhnkOpFcPh28/ +017D20xcoJGi3o+YeK3TELUD+doOeaot4+5TvR0PiLEmyjlnWB1FRkYpGAVDRKKa +F3dSAGf41ygoDOaGmtNmpH/Fn1k9cSDZsRsMKjZQTjgKfX+y/H6eOpORgHYHVmlu +eFIK8+yVVBy5k+m7nTIAUzXm01yJ5fQuT/75EcILUvjloTwmykaTfO1Ez6rNf+BC +VCdD9H0CgYEAyBjEB9vbZ5gDnnkdG0WCr34xPtBztTuVADWz5HorHYFircBUIaJ0 +XOIUioXMmpgSRTzbryAXVznh+g3LeS8QgiGQJoRhIknN8rrRUWd25tgImCMte0eb +bTieJYpvUk8RPan/Arb6f1MLZjWYfJelSw8qQS6R4ydk1L2M78sri/8CgYEAx8vy +KP1e5gGfA42Q0aHvocH7vqbEAOfDK8J+RpT/EoSJ6kSu2oPvblF1CBqHo/nQMhfK +AGbAtWIfy8rs1Md2k+Y+8PXtY8sJJ/HA8laVnEvTHbPSt4X7TtrLx27a8ZWtTNYu +JH/kK8rFBHEGqLnS6VJmqvHKqglp7FIQmHNNaasCgYEApGSMcXR0zqh6mLEic6xp +EOtZZCT4WzZHVTPJxvWEBKqvOtbfh/6jIUhw3dnNXll/8ThtuHRiGLyqZrj8qWQ8 +aN1QRATQlM4UEM7hd8LMUh28+dk03arYDCTO8ULJ8NKa9JF8vGs+ZGsC24c+72Xb +XE5qRcEQBJLx6UKNztiZv1sCgYACHBEuhZ5e5116eCAzVnZlStsRpEkliUzyRVd3 +/1LCK0wZgSgnfoUksQ9/SmhsPtMH9GBZqLwYLjUPvdDKXmDOJvw7Jx2elCJAnbjf +1jI2OEa+ZYuwDGYe6wiDzpPZQS9XRFuwXvlVzQpPhbIAThYACLK002DEctz/dc5f +DbifiQKBgQCdXgr7tdEAmusvIcTRA1KMIOGE5pMGYfbMnDTTIihUfRMJbCnn9sHe +PrDKVVgD3W4hjOABN24KOlCZPtWZfKUKe893ali7mFAIwKNV/AKhQhDgGzJPidqc +6DIL2GhDwqtPIf3b6sI21ZvyAFDROZMKnoL5Q1xbbp5EADi2wPO55Q== +-----END RSA PRIVATE KEY----- diff --git a/.github/system_tests/README.md b/.github/system_tests/README.md new file mode 100644 index 0000000000..de7976fc15 --- /dev/null +++ b/.github/system_tests/README.md @@ -0,0 +1,3 @@ +This folder contains tests that must be run directly in the GitHub Actions container environment. + +This is usually because they require an active daemon or have other specific environment requirements. diff --git a/.github/system_tests/pytest/test_memory_leaks.py b/.github/system_tests/pytest/test_memory_leaks.py new file mode 100644 index 0000000000..b9f57a7e6a --- /dev/null +++ b/.github/system_tests/pytest/test_memory_leaks.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Utilities for testing memory leakage.""" +from tests.utils import processes as test_processes # pylint: disable=no-name-in-module,import-error +from tests.utils.memory import get_instances # pylint: disable=no-name-in-module,import-error +from aiida.engine import processes, run_get_node +from aiida.plugins import CalculationFactory +from aiida import orm + +ArithmeticAddCalculation = CalculationFactory('arithmetic.add') + + +def run_finished_ok(*args, **kwargs): + """Convenience function to check that run worked fine.""" + _, node = run_get_node(*args, **kwargs) + assert node.is_finished_ok, (node.exit_status, node.exit_message) + + +def test_leak_run_process(): + """Test whether running a dummy process leaks memory.""" + inputs = {'a': orm.Int(2), 'b': orm.Str('test')} + run_finished_ok(test_processes.DummyProcess, **inputs) + + # check that no reference to the process is left in memory + # some delay is necessary in order to allow for all callbacks to finish + process_instances = get_instances(processes.Process, delay=0.2) + assert not process_instances, f'Memory leak: process instances remain in memory: {process_instances}' + + +def test_leak_local_calcjob(aiida_local_code_factory): + """Test whether running a local CalcJob leaks memory.""" + inputs = {'x': orm.Int(1), 'y': orm.Int(2), 'code': aiida_local_code_factory('arithmetic.add', '/bin/bash')} + run_finished_ok(ArithmeticAddCalculation, **inputs) + + # check that no reference to the process is left in memory + # some delay is necessary in order to allow for all callbacks to finish + process_instances = get_instances(processes.Process, delay=0.2) + assert not process_instances, f'Memory leak: process instances remain in memory: {process_instances}' + + +def test_leak_ssh_calcjob(): + """Test whether running a CalcJob over SSH leaks memory. + + Note: This relies on the 'slurm-ssh' computer being set up. + """ + code = orm.Code( + input_plugin_name='arithmetic.add', remote_computer_exec=[orm.load_computer('slurm-ssh'), '/bin/bash'] + ) + inputs = {'x': orm.Int(1), 'y': orm.Int(2), 'code': code} + run_finished_ok(ArithmeticAddCalculation, **inputs) + + # check that no reference to the process is left in memory + # some delay is necessary in order to allow for all callbacks to finish + process_instances = get_instances(processes.Process, delay=0.2) + assert not process_instances, f'Memory leak: process instances remain in memory: {process_instances}' diff --git a/.ci/pytest/test_pytest_fixtures.py b/.github/system_tests/pytest/test_pytest_fixtures.py similarity index 100% rename from .ci/pytest/test_pytest_fixtures.py rename to .github/system_tests/pytest/test_pytest_fixtures.py diff --git a/.ci/pytest/test_unittest_example.py b/.github/system_tests/pytest/test_unittest_example.py similarity index 100% rename from .ci/pytest/test_unittest_example.py rename to .github/system_tests/pytest/test_unittest_example.py diff --git a/.ci/test_daemon.py b/.github/system_tests/test_daemon.py similarity index 85% rename from .ci/test_daemon.py rename to .github/system_tests/test_daemon.py index 6dfd352341..e91112095f 100644 --- a/.ci/test_daemon.py +++ b/.github/system_tests/test_daemon.py @@ -9,15 +9,19 @@ ########################################################################### # pylint: disable=no-name-in-module """Tests to run with a running daemon.""" +import os +import shutil import subprocess import sys +import tempfile import time -from aiida.common import exceptions +from aiida.common import exceptions, StashMode from aiida.engine import run, submit from aiida.engine.daemon.client import get_daemon_client from aiida.engine.persistence import ObjectLoader from aiida.manage.caching import enable_caching +from aiida.engine.processes import Process from aiida.orm import CalcJobNode, load_node, Int, Str, List, Dict, load_code from aiida.plugins import CalculationFactory, WorkflowFactory from aiida.workflows.arithmetic.add_multiply import add_multiply, add @@ -26,8 +30,10 @@ WorkFunctionRunnerWorkChain, NestedInputNamespace, SerializeWorkChain, ArithmeticAddBaseWorkChain ) +from tests.utils.memory import get_instances # pylint: disable=import-error + CODENAME_ADD = 'add@localhost' -CODENAME_DOUBLER = 'doubler' +CODENAME_DOUBLER = 'doubler@localhost' TIMEOUTSECS = 4 * 60 # 4 minutes NUMBER_CALCULATIONS = 15 # Number of calculations to submit NUMBER_WORKCHAINS = 8 # Number of workchains to submit @@ -389,9 +395,12 @@ def run_multiply_add_workchain(): assert results['result'].value == 5 -def main(): - """Launch a bunch of calculation jobs and workchains.""" - # pylint: disable=too-many-locals,too-many-statements,too-many-branches +def launch_all(): + """Launch a bunch of calculation jobs and workchains. + + :returns: dictionary with expected results and pks of all launched calculations and workchains + """ + # pylint: disable=too-many-locals,too-many-statements expected_results_process_functions = {} expected_results_calculations = {} expected_results_workchains = {} @@ -409,6 +418,24 @@ def main(): print('Running the `MultiplyAddWorkChain`') run_multiply_add_workchain() + # Testing the stashing functionality + process, inputs, expected_result = create_calculation_process(code=code_doubler, inputval=1) + with tempfile.TemporaryDirectory() as tmpdir: + + # Delete the temporary directory to test that the stashing functionality will create it if necessary + shutil.rmtree(tmpdir, ignore_errors=True) + + source_list = ['output.txt', 'triple_value.tmp'] + inputs['metadata']['options']['stash'] = {'target_base': tmpdir, 'source_list': source_list} + _, node = run.get_node(process, **inputs) + assert node.is_finished_ok + assert 'remote_stash' in node.outputs + remote_stash = node.outputs.remote_stash + assert remote_stash.stash_mode == StashMode.COPY + assert remote_stash.target_basepath.startswith(tmpdir) + assert sorted(remote_stash.source_list) == sorted(source_list) + assert sorted(p for p in os.listdir(remote_stash.target_basepath)) == sorted(source_list) + # Submitting the calcfunction through the launchers print('Submitting calcfunction to the daemon') proc, expected_result = launch_calcfunction(inputval=1) @@ -437,8 +464,8 @@ def main(): builder = NestedWorkChain.get_builder() input_val = 4 builder.inp = Int(input_val) - proc = submit(builder) - expected_results_workchains[proc.pk] = input_val + pk = submit(builder).pk + expected_results_workchains[pk] = input_val print('Submitting a workchain with a nested input namespace.') value = Int(-12) @@ -483,9 +510,46 @@ def main(): calculation_pks = sorted(expected_results_calculations.keys()) workchains_pks = sorted(expected_results_workchains.keys()) process_functions_pks = sorted(expected_results_process_functions.keys()) - pks = calculation_pks + workchains_pks + process_functions_pks - print('Wating for end of execution...') + return { + 'pks': calculation_pks + workchains_pks + process_functions_pks, + 'calculations': expected_results_calculations, + 'process_functions': expected_results_process_functions, + 'workchains': expected_results_workchains, + } + + +def relaunch_cached(results): + """Launch the same calculations but with caching enabled -- these should be FINISHED immediately.""" + code_doubler = load_code(CODENAME_DOUBLER) + cached_calcs = [] + with enable_caching(identifier='aiida.calculations:templatereplacer'): + for counter in range(1, NUMBER_CALCULATIONS + 1): + inputval = counter + calc, expected_result = run_calculation(code=code_doubler, counter=counter, inputval=inputval) + cached_calcs.append(calc) + results['calculations'][calc.pk] = expected_result + + if not ( + validate_calculations(results['calculations']) and validate_workchains(results['workchains']) and + validate_cached(cached_calcs) and validate_process_functions(results['process_functions']) + ): + print_daemon_log() + print('') + print('ERROR! Some return values are different from the expected value') + sys.exit(3) + + print_daemon_log() + print('') + print('OK, all calculations have the expected parsed result') + + +def main(): + """Launch a bunch of calculation jobs and workchains.""" + + results = launch_all() + + print('Waiting for end of execution...') start_time = time.time() exited_with_timeout = True while time.time() - start_time < TIMEOUTSECS: @@ -515,7 +579,7 @@ def main(): except subprocess.CalledProcessError as exception: print(f'Note: the command failed, message: {exception}') - if jobs_have_finished(pks): + if jobs_have_finished(results['pks']): print('Calculation terminated its execution') exited_with_timeout = False break @@ -525,30 +589,18 @@ def main(): print('') print(f'Timeout!! Calculation did not complete after {TIMEOUTSECS} seconds') sys.exit(2) - else: - # Launch the same calculations but with caching enabled -- these should be FINISHED immediately - cached_calcs = [] - with enable_caching(identifier='aiida.calculations:templatereplacer'): - for counter in range(1, NUMBER_CALCULATIONS + 1): - inputval = counter - calc, expected_result = run_calculation(code=code_doubler, counter=counter, inputval=inputval) - cached_calcs.append(calc) - expected_results_calculations[calc.pk] = expected_result - - if ( - validate_calculations(expected_results_calculations) and - validate_workchains(expected_results_workchains) and validate_cached(cached_calcs) and - validate_process_functions(expected_results_process_functions) - ): - print_daemon_log() - print('') - print('OK, all calculations have the expected parsed result') - sys.exit(0) - else: - print_daemon_log() - print('') - print('ERROR! Some return values are different from the expected value') - sys.exit(3) + + relaunch_cached(results) + + # Check that no references to processes remain in memory + # Note: This tests only processes that were `run` in the same interpreter, not those that were `submitted` + del results + processes = get_instances(Process, delay=1.0) + if processes: + print(f'Memory leak! Process instances remained in memory: {processes}') + sys.exit(4) + + sys.exit(0) if __name__ == '__main__': diff --git a/.github/system_tests/test_ipython_magics.py b/.github/system_tests/test_ipython_magics.py new file mode 100644 index 0000000000..6378f430e8 --- /dev/null +++ b/.github/system_tests/test_ipython_magics.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Test the AiiDA iPython magics.""" +from IPython.testing.globalipapp import get_ipython +from aiida.tools.ipython.ipython_magics import register_ipython_extension + + +def test_ipython_magics(): + """Test that the %aiida magic can be loaded and adds the QueryBuilder and Node variables.""" + ipy = get_ipython() + register_ipython_extension(ipy) + + cell = """ +%aiida +qb=QueryBuilder() +qb.append(Node) +qb.all() +Dict().store() +""" + result = ipy.run_cell(cell) + + assert result.success diff --git a/.ci/test_plugin_testcase.py b/.github/system_tests/test_plugin_testcase.py similarity index 100% rename from .ci/test_plugin_testcase.py rename to .github/system_tests/test_plugin_testcase.py diff --git a/.ci/test_profile_manager.py b/.github/system_tests/test_profile_manager.py similarity index 100% rename from .ci/test_profile_manager.py rename to .github/system_tests/test_profile_manager.py diff --git a/.ci/test_test_manager.py b/.github/system_tests/test_test_manager.py similarity index 100% rename from .ci/test_test_manager.py rename to .github/system_tests/test_test_manager.py diff --git a/.ci/test_verdi_load_time.sh b/.github/system_tests/test_verdi_load_time.sh similarity index 100% rename from .ci/test_verdi_load_time.sh rename to .github/system_tests/test_verdi_load_time.sh diff --git a/.ci/workchains.py b/.github/system_tests/workchains.py similarity index 100% rename from .ci/workchains.py rename to .github/system_tests/workchains.py diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 450b33d1a8..5f33585d3a 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -78,79 +78,3 @@ jobs: comment-on-alert: true fail-on-alert: false alert-comment-cc-users: '@chrisjsewell,@giovannipizzi' - - run-on-comment: - - if: ${{ github.event_name == 'pull_request' }} - - strategy: - matrix: - os: [ubuntu-18.04] - postgres: [12.3] - rabbitmq: [3.8.3] - backend: ['django'] - - runs-on: ${{ matrix.os }} - timeout-minutes: 30 - - services: - postgres: - image: "postgres:${{ matrix.postgres }}" - env: - POSTGRES_DB: test_${{ matrix.backend }} - POSTGRES_PASSWORD: '' - POSTGRES_HOST_AUTH_METHOD: trust - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - ports: - - 5432:5432 - rabbitmq: - image: "rabbitmq:${{ matrix.rabbitmq }}" - ports: - - 5672:5672 - - steps: - # v2 was checking out the wrong commit! https://github.com/actions/checkout/issues/299 - - uses: actions/checkout@v1 - - - name: get commit message - run: echo ::set-env name=commitmsg::$(git log --format=%B -n 1 "${{ github.event.after }}") - - - if: contains( env.commitmsg , '[run bench]' ) - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - - if: contains( env.commitmsg , '[run bench]' ) - name: Install python dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements/requirements-py-3.8.txt - pip install --no-deps -e . - reentry scan - pip freeze - - - if: contains( env.commitmsg , '[run bench]' ) - name: Run benchmarks - env: - AIIDA_TEST_BACKEND: ${{ matrix.backend }} - run: pytest --benchmark-only --benchmark-json benchmark.json - - - if: contains( env.commitmsg , '[run bench]' ) - name: Compare benchmark results - uses: aiidateam/github-action-benchmark@v3 - with: - output-file-path: benchmark.json - name: "pytest-benchmarks:${{ matrix.os }},${{ matrix.backend }}" - benchmark-data-dir-path: "dev/bench/${{ matrix.os }}/${{ matrix.backend }}" - metadata: "postgres:${{ matrix.postgres }}, rabbitmq:${{ matrix.rabbitmq }}" - github-token: ${{ secrets.GITHUB_TOKEN }} - auto-push: false - # Show alert with commit comment on detecting possible performance regression - alert-threshold: '200%' - comment-always: true - fail-on-alert: true diff --git a/.github/workflows/check_release_tag.py b/.github/workflows/check_release_tag.py new file mode 100644 index 0000000000..2501a1c957 --- /dev/null +++ b/.github/workflows/check_release_tag.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +"""Check that the GitHub release tag matches the package version.""" +import argparse +import json + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('GITHUB_REF', help='The GITHUB_REF environmental variable') + parser.add_argument('SETUP_PATH', help='Path to the setup.json') + args = parser.parse_args() + assert args.GITHUB_REF.startswith('refs/tags/v'), f'GITHUB_REF should start with "refs/tags/v": {args.GITHUB_REF}' + tag_version = args.GITHUB_REF[11:] + with open(args.SETUP_PATH) as handle: + data = json.load(handle) + pypi_version = data['version'] + assert tag_version == pypi_version, f'The tag version {tag_version} != {pypi_version} specified in `setup.json`' diff --git a/.github/workflows/ci-code.yml b/.github/workflows/ci-code.yml index d0e4ee7900..8ff88267ec 100644 --- a/.github/workflows/ci-code.yml +++ b/.github/workflows/ci-code.yml @@ -49,8 +49,8 @@ jobs: strategy: fail-fast: false matrix: + python-version: [3.7, 3.8] backend: ['django', 'sqlalchemy'] - python-version: [3.6, 3.8] services: postgres: @@ -70,6 +70,10 @@ jobs: image: rabbitmq:latest ports: - 5672:5672 + slurm: + image: xenonmiddleware/slurm:17 + ports: + - 5001:22 steps: - uses: actions/checkout@v2 @@ -81,22 +85,15 @@ jobs: - name: Install system dependencies run: | - sudo rm -f /etc/apt/sources.list.d/dotnetdev.list /etc/apt/sources.list.d/microsoft-prod.list sudo apt update sudo apt install postgresql-10 graphviz - - name: Upgrade pip + - name: Upgrade pip and setuptools + # It is crucial to update `setuptools` or the installation of `pymatgen` can break run: | - pip install --upgrade pip + pip install --upgrade pip setuptools pip --version - # Work-around issue caused by pymatgen's setup process, which will install the latest - # numpy version (including release candidates) regardless of our actual specification - # By installing the version from the requirements file, we should get a compatible version - - name: Install numpy - run: | - pip install `grep 'numpy==' requirements/requirements-py-${{ matrix.python-version }}.txt` - - name: Install aiida-core run: | pip install --use-feature=2020-resolver -r requirements/requirements-py-${{ matrix.python-version }}.txt @@ -117,10 +114,10 @@ jobs: .github/workflows/tests.sh - name: Upload coverage report - if: matrix.python-version == 3.6 && github.repository == 'aiidateam/aiida-core' + if: matrix.python-version == 3.7 && github.repository == 'aiidateam/aiida-core' uses: codecov/codecov-action@v1 with: - name: aiida-pytests-py3.6-${{ matrix.backend }} + name: aiida-pytests-py3.7-${{ matrix.backend }} flags: ${{ matrix.backend }} file: ./coverage.xml fail_ci_if_error: false # don't fail job, if coverage upload fails diff --git a/.github/workflows/ci-style.yml b/.github/workflows/ci-style.yml index b8c2b75720..42175952a1 100644 --- a/.github/workflows/ci-style.yml +++ b/.github/workflows/ci-style.yml @@ -23,7 +23,6 @@ jobs: - name: Install system dependencies run: | - sudo rm -f /etc/apt/sources.list.d/dotnetdev.list /etc/apt/sources.list.d/microsoft-prod.list sudo apt update sudo apt install libkrb5-dev ruby ruby-dev diff --git a/.github/workflows/post-release.yml b/.github/workflows/post-release.yml new file mode 100644 index 0000000000..6f965dba00 --- /dev/null +++ b/.github/workflows/post-release.yml @@ -0,0 +1,49 @@ +name: post-release + +on: + release: + types: [published, edited] + +jobs: + + upload-transifex: + # Every time when a new version is released, + # upload the latest pot files to transifex services for team transilation. + # https://www.transifex.com/aiidateam/aiida-core/dashboard/ + + name: Upload pot files to transifex + runs-on: ubuntu-latest + timeout-minutes: 30 + + # Build doc to pot files and register them to `.tx/config` file + # Installation steps are modeled after the docs job in `ci.yml` + steps: + - uses: actions/checkout@v2 + + - name: Set up Python 3.7 + uses: actions/setup-python@v2 + with: + python-version: 3.7 + + - name: Install python dependencies + run: | + pip install transifex-client sphinx-intl + pip install -e .[docs,tests] + + - name: Build pot files + env: + READTHEDOCS: 'True' + RUN_APIDOC: 'False' + run: + sphinx-build -b gettext docs/source locale + + - name: Setting transifex configuration and upload pot files + env: + PROJECT_NAME: aiida-core + USER: ${{ secrets.TRANSIFEX_USER }} + PASSWD: ${{ secrets.TRANSIFEX_PASSWORD }} + run: | + tx init --no-interactive + sphinx-intl update-txconfig-resources --pot-dir locale --transifex-project-name ${PROJECT_NAME} + echo $'[https://www.transifex.com]\nhostname = https://www.transifex.com\nusername = '"${USER}"$'\npassword = '"${PASSWD}"$'\n' > ~/.transifexrc + tx push -s diff --git a/.github/workflows/rabbitmq.yml b/.github/workflows/rabbitmq.yml new file mode 100644 index 0000000000..674977945f --- /dev/null +++ b/.github/workflows/rabbitmq.yml @@ -0,0 +1,67 @@ +name: rabbitmq + +on: + push: + branches-ignore: [gh-pages] + pull_request: + branches-ignore: [gh-pages] + paths-ignore: ['docs/**'] + +jobs: + + tests: + + runs-on: ubuntu-latest + timeout-minutes: 30 + + strategy: + fail-fast: false + matrix: + rabbitmq: [3.5, 3.6, 3.7, 3.8] + + services: + postgres: + image: postgres:10 + env: + POSTGRES_DB: test_django + POSTGRES_PASSWORD: '' + POSTGRES_HOST_AUTH_METHOD: trust + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + rabbitmq: + image: rabbitmq:${{ matrix.rabbitmq }} + ports: + - 5672:5672 + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Install system dependencies + run: | + sudo apt update + sudo apt install postgresql-10 + + - name: Upgrade pip + run: | + pip install --upgrade pip + pip --version + + - name: Install aiida-core + run: | + pip install -r requirements/requirements-py-3.8.txt + pip install --no-deps -e . + reentry scan + pip freeze + + - name: Run tests + run: pytest -sv -k 'requires_rmq' diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d55fcfa698..b130b687ea 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,49 +1,110 @@ name: release +# Automate deployment to PyPI when creating a release tag vX.Y.Z +# will only be published to PyPI if the git tag matches the release version +# and the pre-commit and tests pass + on: - release: - types: [published, edited] + push: + tags: + - "v[0-9]+.[0-9]+.[0-9]+*" jobs: - upload-transifex: - # Every time when a new version is released, - # upload the latest pot files to transifex services for team transilation. - # https://www.transifex.com/aiidateam/aiida-core/dashboard/ + check-release-tag: - name: Upload pot files to transifex runs-on: ubuntu-latest - timeout-minutes: 30 - # Build doc to pot files and register them to `.tx/config` file - # Installation steps are modeled after the docs job in `ci.yml` steps: - uses: actions/checkout@v2 - - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.8 + - run: python .github/workflows/check_release_tag.py $GITHUB_REF setup.json + + pre-commit: + needs: [check-release-tag] + runs-on: ubuntu-latest + timeout-minutes: 30 + + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 - name: Install python dependencies + run: pip install -e .[all] + - name: Run pre-commit + run: pre-commit run --all-files || ( git status --short ; git diff ; exit 1 ) + + tests: + + needs: [check-release-tag] + runs-on: ubuntu-latest + timeout-minutes: 30 + + services: + postgres: + image: postgres:10 + env: + POSTGRES_DB: test_django + POSTGRES_PASSWORD: '' + POSTGRES_HOST_AUTH_METHOD: trust + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + rabbitmq: + image: rabbitmq:latest + ports: + - 5672:5672 + + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Install system dependencies + run: | + sudo apt update + sudo apt install postgresql-10 graphviz + - name: Install aiida-core run: | - pip install transifex-client sphinx-intl - pip install -e .[docs,tests] - - - name: Build pot files - env: - READTHEDOCS: 'True' - RUN_APIDOC: 'False' - run: - sphinx-build -b gettext docs/source locale - - - name: Setting transifex configuration and upload pot files - env: - PROJECT_NAME: aiida-core - USER: ${{ secrets.TRANSIFEX_USER }} - PASSWD: ${{ secrets.TRANSIFEX_PASSWORD }} + pip install --upgrade pip setuptools + pip install -r requirements/requirements-py-3.8.txt + pip install --no-deps -e . + reentry scan + - name: Run sub-set of test suite + run: pytest -sv -k 'requires_rmq' + + publish: + + name: Publish to PyPI + + needs: [check-release-tag, pre-commit, tests] + + runs-on: ubuntu-latest + + steps: + - name: Checkout source + uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Build package run: | - tx init --no-interactive - sphinx-intl update-txconfig-resources --pot-dir locale --transifex-project-name ${PROJECT_NAME} - echo $'[https://www.transifex.com]\nhostname = https://www.transifex.com\nusername = '"${USER}"$'\npassword = '"${PASSWD}"$'\n' > ~/.transifexrc - tx push -s + pip install wheel + python setup.py sdist bdist_wheel + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@v1.1.0 + with: + user: __token__ + password: ${{ secrets.PYPI_KEY }} diff --git a/.github/workflows/setup.sh b/.github/workflows/setup.sh index 6ff5c4c6e0..ab508b0897 100755 --- a/.github/workflows/setup.sh +++ b/.github/workflows/setup.sh @@ -10,18 +10,28 @@ chmod 755 "${HOME}" # Replace the placeholders in configuration files with actual values CONFIG="${GITHUB_WORKSPACE}/.github/config" +cp "${CONFIG}/slurm_rsa" "${HOME}/.ssh/slurm_rsa" sed -i "s|PLACEHOLDER_BACKEND|${AIIDA_TEST_BACKEND}|" "${CONFIG}/profile.yaml" sed -i "s|PLACEHOLDER_PROFILE|test_${AIIDA_TEST_BACKEND}|" "${CONFIG}/profile.yaml" sed -i "s|PLACEHOLDER_DATABASE_NAME|test_${AIIDA_TEST_BACKEND}|" "${CONFIG}/profile.yaml" sed -i "s|PLACEHOLDER_REPOSITORY|/tmp/test_repository_test_${AIIDA_TEST_BACKEND}/|" "${CONFIG}/profile.yaml" sed -i "s|PLACEHOLDER_WORK_DIR|${GITHUB_WORKSPACE}|" "${CONFIG}/localhost.yaml" -sed -i "s|PLACEHOLDER_REMOTE_ABS_PATH_DOUBLER|${GITHUB_WORKSPACE}/.ci/doubler.sh|" "${CONFIG}/doubler.yaml" +sed -i "s|PLACEHOLDER_REMOTE_ABS_PATH_DOUBLER|${CONFIG}/doubler.sh|" "${CONFIG}/doubler.yaml" +sed -i "s|PLACEHOLDER_SSH_KEY|${HOME}/.ssh/slurm_rsa|" "${CONFIG}/slurm-ssh-config.yaml" verdi setup --config "${CONFIG}/profile.yaml" + +# set up localhost computer verdi computer setup --config "${CONFIG}/localhost.yaml" verdi computer configure local localhost --config "${CONFIG}/localhost-config.yaml" +verdi computer test localhost verdi code setup --config "${CONFIG}/doubler.yaml" verdi code setup --config "${CONFIG}/add.yaml" +# set up slurm-ssh computer +verdi computer setup --config "${CONFIG}/slurm-ssh.yaml" +verdi computer configure ssh slurm-ssh --config "${CONFIG}/slurm-ssh-config.yaml" -n # needs slurm container +verdi computer test slurm-ssh --print-traceback + verdi profile setdefault test_${AIIDA_TEST_BACKEND} verdi config runner.poll.interval 0 diff --git a/.github/workflows/test-install.yml b/.github/workflows/test-install.yml index 281078a7b9..a51c64015b 100644 --- a/.github/workflows/test-install.yml +++ b/.github/workflows/test-install.yml @@ -42,11 +42,9 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 5 - continue-on-error: ${{ contains(matrix.pip-feature-flag, '2020-resolver') }} strategy: fail-fast: false matrix: - pip-feature-flag: [ '', '--use-feature=2020-resolver' ] extras: [ '', '[atomic_tools,docs,notebook,rest,tests]' ] steps: @@ -61,10 +59,9 @@ jobs: - name: Pip install id: pip_install - continue-on-error: ${{ contains(matrix.pip-feature-flag, '2020-resolver') }} run: | python -m pip --version - python -m pip install -e .${{ matrix.extras }} ${{ matrix.pip-feature-flag }} + python -m pip install -e .${{ matrix.extras }} python -m pip freeze - name: Test importing aiida @@ -72,10 +69,13 @@ jobs: run: python -c "import aiida" - - name: Warn about pip 2020 resolver issues. - if: steps.pip_install.outcome == 'failure' && contains(matrix.pip-feature-flag, '2020-resolver') - run: | - echo "::warning ::Encountered issues with the pip 2020-resolver." + - name: Send Slack notification + if: ${{ failure() && github.event_name == 'schedule' }} + uses: kpritam/slack-job-status-action@v1 + with: + job-status: ${{ job.status }} + slack-bot-token: ${{ secrets.SLACK_BOT_TOKEN }} + channel: dev-aiida-core install-with-conda: @@ -109,6 +109,14 @@ jobs: source activate test-environment python -c "import aiida" + - name: Send Slack notification + if: ${{ failure() && github.event_name == 'schedule' }} + uses: kpritam/slack-job-status-action@v1 + with: + job-status: ${{ job.status }} + slack-bot-token: ${{ secrets.SLACK_BOT_TOKEN }} + channel: dev-aiida-core + tests: needs: [install-with-pip, install-with-conda] @@ -118,7 +126,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: [3.7, 3.8, 3.9] backend: ['django', 'sqlalchemy'] services: @@ -139,6 +147,10 @@ jobs: image: rabbitmq:latest ports: - 5672:5672 + slurm: + image: xenonmiddleware/slurm:17 + ports: + - 5001:22 steps: - uses: actions/checkout@v2 @@ -150,18 +162,14 @@ jobs: - name: Install system dependencies run: | - sudo rm -f /etc/apt/sources.list.d/dotnetdev.list /etc/apt/sources.list.d/microsoft-prod.list sudo apt update sudo apt install postgresql-10 graphviz - - run: pip install --upgrade pip - - # Work-around issue caused by pymatgen's setup process, which will install the latest - # numpy version (including release candidates) regardless of our actual specification - # By installing the version from the requirements file, we should get a compatible version - - name: Install numpy + - name: Upgrade pip and setuptools + # It is crucial to update `setuptools` or the installation of `pymatgen` can break run: | - pip install `grep 'numpy==' requirements/requirements-py-${{ matrix.python-version }}.txt` + pip install --upgrade pip setuptools + pip --version - name: Install aiida-core run: | @@ -182,6 +190,14 @@ jobs: run: .github/workflows/tests.sh + - name: Send Slack notification + if: ${{ failure() && github.event_name == 'schedule' }} + uses: kpritam/slack-job-status-action@v1 + with: + job-status: ${{ job.status }} + slack-bot-token: ${{ secrets.SLACK_BOT_TOKEN }} + channel: dev-aiida-core + - name: Freeze test environment run: pip freeze | sed '1d' | tee requirements-py-${{ matrix.python-version }}.txt diff --git a/.github/workflows/tests.sh b/.github/workflows/tests.sh index 66ad9b5e76..db69cfa29c 100755 --- a/.github/workflows/tests.sh +++ b/.github/workflows/tests.sh @@ -2,7 +2,9 @@ set -ev # Make sure the folder containing the workchains is in the python path before the daemon is started -export PYTHONPATH="${PYTHONPATH}:${GITHUB_WORKSPACE}/.ci" +SYSTEM_TESTS="${GITHUB_WORKSPACE}/.github/system_tests" + +export PYTHONPATH="${PYTHONPATH}:${SYSTEM_TESTS}" # pytest options: # - report timings of tests @@ -16,20 +18,22 @@ export PYTEST_ADDOPTS="${PYTEST_ADDOPTS} --cov-config=${GITHUB_WORKSPACE}/.cover export PYTEST_ADDOPTS="${PYTEST_ADDOPTS} --cov-report xml" export PYTEST_ADDOPTS="${PYTEST_ADDOPTS} --cov-append" export PYTEST_ADDOPTS="${PYTEST_ADDOPTS} --cov=aiida" +export PYTEST_ADDOPTS="${PYTEST_ADDOPTS} --verbose" # daemon tests verdi daemon start 4 -verdi -p test_${AIIDA_TEST_BACKEND} run .ci/test_daemon.py +verdi -p test_${AIIDA_TEST_BACKEND} run ${SYSTEM_TESTS}/test_daemon.py verdi daemon stop # tests for the testing infrastructure -pytest --noconftest .ci/test_test_manager.py -pytest --noconftest .ci/test_profile_manager.py -python .ci/test_plugin_testcase.py # uses custom unittest test runner +pytest --noconftest ${SYSTEM_TESTS}/test_test_manager.py +pytest --noconftest ${SYSTEM_TESTS}/test_ipython_magics.py +pytest --noconftest ${SYSTEM_TESTS}/test_profile_manager.py +python ${SYSTEM_TESTS}/test_plugin_testcase.py # uses custom unittest test runner -# Until the `.ci/pytest` tests are moved within `tests` we have to run them separately and pass in the path to the +# Until the `${SYSTEM_TESTS}/pytest` tests are moved within `tests` we have to run them separately and pass in the path to the # `conftest.py` explicitly, because otherwise it won't be able to find the fixtures it provides -AIIDA_TEST_PROFILE=test_$AIIDA_TEST_BACKEND pytest tests/conftest.py .ci/pytest +AIIDA_TEST_PROFILE=test_$AIIDA_TEST_BACKEND pytest tests/conftest.py ${SYSTEM_TESTS}/pytest # main aiida-core tests AIIDA_TEST_PROFILE=test_$AIIDA_TEST_BACKEND pytest tests diff --git a/.gitignore b/.gitignore index 624d61e194..52e6f940fe 100644 --- a/.gitignore +++ b/.gitignore @@ -25,7 +25,7 @@ coverage.xml # Files created by RPN tests -.ci/polish/polish_workchains/polish* +**/polish_workchains/polish* # Build files dist/ diff --git a/.jenkins/Dockerfile b/.jenkins/Dockerfile new file mode 100644 index 0000000000..33cf8b1057 --- /dev/null +++ b/.jenkins/Dockerfile @@ -0,0 +1,25 @@ +FROM aiidateam/aiida-prerequisites:0.3.0 + +# to run the tests +RUN pip install ansible~=2.10.0 molecule~=3.1.0 + +RUN apt-get update && \ + apt-get install -y sudo && \ + apt-get autoclean + +ARG uid=1000 +ARG gid=1000 + +# add USER (no password); 1000 is the uid of the user in the jenkins docker +RUN groupadd -g ${gid} jenkins && useradd -m -s /bin/bash -u ${uid} -g ${gid} jenkins + +# add to sudoers and don't ask password +RUN adduser jenkins sudo && adduser jenkins adm && adduser jenkins root +RUN echo "%sudo ALL=(ALL:ALL) NOPASSWD:ALL" > /etc/sudoers.d/nopwd +RUN mkdir -p /scratch/jenkins/ && chown jenkins /scratch/jenkins/ && chmod o+rX /scratch/ + +# set $HOME to the directory where the repository is mounted +ENV HOME /home/jenkins + +# this is added since otherwise jenkins prints /etc/profile contents for all sh commands +RUN echo 'set +x' | cat - /etc/profile > temp && mv temp /etc/profile diff --git a/.jenkins/Jenkinsfile b/.jenkins/Jenkinsfile new file mode 100644 index 0000000000..61d5fc43cf --- /dev/null +++ b/.jenkins/Jenkinsfile @@ -0,0 +1,135 @@ +// Note: this part might happen on a different node than +// the one that will run the pipeline below, see +// https://stackoverflow.com/questions/44805076 +// but it should be ok for us as we only have one node +def user_id +def group_id +node { + user_id = sh(returnStdout: true, script: 'id -u').trim() + group_id = sh(returnStdout: true, script: 'id -g').trim() +} + +pipeline { + agent { + dockerfile { + filename 'Dockerfile' + dir '.jenkins' + args '-u root:root -v jenkins-pip-cache:/home/jenkins/.cache/pip/' + additionalBuildArgs "--build-arg uid=${user_id} --build-arg gid=${group_id}" + } + } + environment { + MOLECULE_GLOB = ".molecule/*/config_jenkins.yml" + AIIDA_TEST_WORKERS = 2 + RUN_ALSO_DJANGO = "true" + RUN_ALSO_SQLALCHEMY = "true" + } + stages { + stage ('Init services') { + steps { + // we must run /sbin/my_init directly (rather than in a separate process) + // see: https://github.com/phusion/baseimage-docker/blob/18.04-1.0.0/image/bin/my_init + sh '/etc/my_init.d/00_regen_ssh_host_keys.sh' + sh '/etc/my_init.d/10_create-system-user.sh' + // we cannot run this task because it tries to write to the jenkins log file without permission: + // Cannot contact : java.io.FileNotFoundException: /var/jenkins_home/workspace/aiida_core_aiidateam_PR-4565@2@tmp/durable-65ec45aa/jenkins-log.txt (Permission denied) + // sh '/etc/my_init.d/10_syslog-ng.init' + sh '/etc/my_init.d/20_start-rabbitmq.sh' + sh '/etc/my_init.d/30_start-postgres.sh' + sh '/sbin/my_init --skip-startup-files --no-kill-all-on-exit 2> /dev/null &' + } + } + stage ('Prepare environment') { + steps { + // Clean work dir + // often runs reshare the same folder, and it might contain old data from previous runs + // this is particularly problematic when a folder is deleted from git but .pyc files are left in + sh 'git clean -fdx' + // this folder is mounted from a volume so will have wrong permissions + sh 'sudo chown root:root /home/jenkins/.cache' + // prepare environment (install python dependencies etc) + sh 'pip install -r requirements/requirements-py-3.7.txt --cache-dir /home/jenkins/.cache/pip' + sh 'pip install --no-deps .' + // for some reason if we don't change permissions here then python can't import the modules + sh 'sudo chmod -R a+rwX /opt/conda/lib/python3.7/site-packages/' + } + } + stage('Test') { + failFast false // Do not kill one if the other fails + parallel { + stage('Test-Django') { + environment { + AIIDA_TEST_BACKEND="django" + } + when { + environment name: 'RUN_ALSO_DJANGO', value: 'true' + } + steps { + sh 'molecule test --parallel' + } + } + stage('Test-SQLAlchemy') { + environment { + AIIDA_TEST_BACKEND="sqlalchemy" + } + when { + environment name: 'RUN_ALSO_SQLALCHEMY', value: 'true' + } + steps { + sh 'molecule test --parallel' + } + } + } + } + } + post { + always { + // Some debug stuff + sh ''' + whoami + pwd + ''' + cleanWs() + } + success { + echo 'The run finished successfully!' + } + unstable { + echo 'This run is unstable...' + } + failure { + echo "This run failed..." + } + // You can trigger actions when the status change (e.g. it starts failing, + // or it starts working again - e.g. sending emails or similar) + // possible variables: see e.g. https://qa.nuxeo.org/jenkins/pipeline-syntax/globals + // Other valid names: fixed, regression (opposite of fixed), aborted (by user, typically) + // Note that I had problems with email, I don't know if it is a configuration problem + // or a missing plugin. + changed { + script { + if (currentBuild.getPreviousBuild()) { + echo "The state changed from ${currentBuild.getPreviousBuild().result} to ${currentBuild.currentResult}." + } + else { + echo "This is the first build, and its status is: ${currentBuild.currentResult}." + } + } + } + } + options { + // we do not want the whole run to hang forever - + timeout(time: 40, unit: 'MINUTES') + } +} + + +// Other things to add possibly: +// global options (or per-stage options) with timeout: https://jenkins.io/doc/book/pipeline/syntax/#options-example +// retry-on-failure for some specific tasks: https://jenkins.io/doc/book/pipeline/syntax/#available-stage-options +// parameters: https://jenkins.io/doc/book/pipeline/syntax/#parameters +// input: interesting for user input before continuing: https://jenkins.io/doc/book/pipeline/syntax/#input +// when conditions, e.g. to depending on details on the commit (e.g. only when specific +// files are changed, where there is a string in the commit log, for a specific branch, +// for a Pull Request,for a specific environment variable, ...): +// https://jenkins.io/doc/book/pipeline/syntax/#when diff --git a/.ci/check-jenkinsfile.sh b/.jenkins/check-jenkinsfile.sh similarity index 100% rename from .ci/check-jenkinsfile.sh rename to .jenkins/check-jenkinsfile.sh diff --git a/.molecule/README.md b/.molecule/README.md new file mode 100644 index 0000000000..21dbcd94a5 --- /dev/null +++ b/.molecule/README.md @@ -0,0 +1,61 @@ +# Molecule System Integration/Stress Testing + +This folder contains configuration for running automated system integration tests against an isolated AiiDA environment. + +This utilises [molecule](https://molecule.readthedocs.io) to automate the creation/destruction of a docker container environment and the setup and testing within it. + +The tests are currently set up to stress-test the AiiDA engine by launching a number of workchains of varying complexity, defined by [reverse polish notation](https://en.wikipedia.org/wiki/Reverse_Polish_notation). +They are part of the continuous integration pipeline of AiiDA and are run using [Jenkins](https://www.jenkins.io/) on our own test runner. + +## Running the tests locally + +The simplest way to run these tests is to use the `tox` environment provided in this repository's `pyproject.toml` file: + +```console +$ pip install tox +$ tox -e molecule-django +``` + +**NOTE**: if you wan to run molecule directly, ensure that you set `export MOLECULE_GLOB=.molecule/*/config_local.yml`. + +This runs the `test` scenario (defined in `config_local.yml`) which: + +1. Deletes any existing container with the same label +2. Creates a docker container, based on the `Dockerfile` in this folder, which also copies the repository code into the container (see `create_docker.yml`). +3. Installs aiida-core (see `setup_python.yml`) +4. Sets up an AiiDA profile and computer (see `setup_aiida.yml`). +5. Sets up a number of workchains of varying complexity,defined by [reverse polish notation](https://en.wikipedia.org/wiki/Reverse_Polish_notation), and runs them (see `run_tests.yml`). +6. Deletes the container. + +If you wish to setup the container for manual inspection (i.e. only run steps 2 - 4) you can run: + +```console +$ tox -e molecule-django converge +``` + +Then you can jump into this container or run the tests (step 5) separately with: + +```console +$ tox -e molecule-django validate +``` + +and finally run step 6: + +```console +$ tox -e molecule-django destroy +``` + +You can set up the aiida profile with either django or sqla, +and even run both in parallel: + +```console +$ tox -e molecule-django,molecule-sqla -p -- test --parallel +``` + +## Additional variables + +You can specify the number of daemon workers to spawn using the `AIIDA_TEST_WORKERS` environment variable: + +```console +$ AIIDA_TEST_WORKERS=4 tox -e molecule-django +``` diff --git a/.molecule/default/Dockerfile b/.molecule/default/Dockerfile new file mode 100644 index 0000000000..1dff46e6b2 --- /dev/null +++ b/.molecule/default/Dockerfile @@ -0,0 +1,11 @@ +FROM aiidateam/aiida-prerequisites:0.3.0 + +# allow for collection of query statistics +# (must also be intialised on each database) +RUN sed -i '/.*initdb -D.*/a echo "shared_preload_libraries='pg_stat_statements'" >> /home/${SYSTEM_USER}/.postgresql/postgresql.conf' /opt/start-postgres.sh +# other options +# pg_stat_statements.max = 10000 +# pg_stat_statements.track = all + +# Copy AiiDA repository +COPY . aiida-core diff --git a/.molecule/default/config_jenkins.yml b/.molecule/default/config_jenkins.yml new file mode 100644 index 0000000000..dff5dd8ed7 --- /dev/null +++ b/.molecule/default/config_jenkins.yml @@ -0,0 +1,53 @@ +# On Jenkins we are already inside the container, +# so we simply run the playbooks in the local environment + +scenario: + converge_sequence: + - prepare + - converge + test_sequence: + - converge + - verify +# connect to local environment +driver: + name: delegated + options: + managed: False + ansible_connection_options: + ansible_connection: local +platforms: +- name: molecule-aiida-${AIIDA_TEST_BACKEND:-django} +# configuration for how to run the playbooks +provisioner: + name: ansible + # log: true # for debugging + playbooks: + prepare: setup_python.yml + converge: setup_aiida.yml + verify: run_tests.yml + config_options: + defaults: + # nicer stdout printing + stdout_callback: yaml + bin_ansible_callbacks: true + # add timing to tasks + callback_whitelist: timer, profile_tasks + # reduce CPU load + internal_poll_interval: 0.002 + ssh_connection: + # reduce network operations + pipelining: True + inventory: + hosts: + all: + vars: + become_method: sudo + aiida_user: aiida + aiida_core_dir: $WORKSPACE + aiida_pip_cache: /home/jenkins/.cache/pip + aiida_pip_editable: false + venv_bin: /opt/conda/bin + ansible_python_interpreter: "{{ venv_bin }}/python" + aiida_backend: ${AIIDA_TEST_BACKEND:-django} + aiida_workers: ${AIIDA_TEST_WORKERS:-2} + aiida_path: /tmp/.aiida_${AIIDA_TEST_BACKEND:-django} diff --git a/.molecule/default/config_local.yml b/.molecule/default/config_local.yml new file mode 100644 index 0000000000..c9168f35ac --- /dev/null +++ b/.molecule/default/config_local.yml @@ -0,0 +1,69 @@ +# when we run locally, we must first create a docker container +# then we run the playbooks inside that + +scenario: + create_sequence: + - create + - prepare + converge_sequence: + - create + - prepare + - converge + destroy_sequence: + - destroy + test_sequence: + - destroy + - create + - prepare + - converge + - verify + - destroy +# configuration for building the isolated container +driver: + name: docker +platforms: +- name: molecule-aiida-${AIIDA_TEST_BACKEND:-django} + image: molecule_tests + context: "../.." + command: /sbin/my_init + healthcheck: + test: wait-for-services + volumes: + - molecule-pip-cache-${AIIDA_TEST_BACKEND:-django}:/home/.cache/pip + privileged: true + retries: 3 +# configuration for how to run the playbooks +provisioner: + name: ansible + # log: true # for debugging + playbooks: + create: create_docker.yml + prepare: setup_python.yml + converge: setup_aiida.yml + verify: run_tests.yml + config_options: + defaults: + # nicer stdout printing + stdout_callback: yaml + bin_ansible_callbacks: true + # add timing to tasks + callback_whitelist: timer, profile_tasks + # reduce CPU load + internal_poll_interval: 0.002 + ssh_connection: + # reduce network operations + pipelining: True + inventory: + hosts: + all: + vars: + become_method: su + aiida_user: aiida + aiida_core_dir: /aiida-core + aiida_pip_cache: /home/.cache/pip + venv_bin: /opt/conda/bin + ansible_python_interpreter: "{{ venv_bin }}/python" + aiida_backend: ${AIIDA_TEST_BACKEND:-django} + aiida_workers: ${AIIDA_TEST_WORKERS:-2} + aiida_path: /tmp/.aiida_${AIIDA_TEST_BACKEND:-django} + aiida_query_stats: true diff --git a/.molecule/default/create_docker.yml b/.molecule/default/create_docker.yml new file mode 100644 index 0000000000..2bef943879 --- /dev/null +++ b/.molecule/default/create_docker.yml @@ -0,0 +1,120 @@ +# this is mainly a copy of https://github.com/ansible-community/molecule-docker/blob/master/molecule_docker/playbooks/create.yml +# with fix: https://github.com/ansible-community/molecule-docker/pull/30 +- name: Create + hosts: localhost + connection: local + gather_facts: false + no_log: "{{ molecule_no_log }}" + vars: + molecule_labels: + owner: molecule + tasks: + + - name: Discover local Docker images + docker_image_info: + name: "molecule_local/{{ item.name }}" + docker_host: "{{ item.docker_host | default(lookup('env', 'DOCKER_HOST') or 'unix://var/run/docker.sock') }}" + cacert_path: "{{ item.cacert_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/ca.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" + cert_path: "{{ item.cert_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/cert.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" + key_path: "{{ item.key_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/key.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" + tls_verify: "{{ item.tls_verify | default(lookup('env', 'DOCKER_TLS_VERIFY')) or false }}" + with_items: "{{ molecule_yml.platforms }}" + register: docker_images + + - name: Build the container image + when: + - docker_images.results | map(attribute='images') | select('equalto', []) | list | count >= 0 + docker_image: + build: + path: "{{ item.context | default(molecule_ephemeral_directory) }}" + dockerfile: "{{ item.dockerfile | default(molecule_scenario_directory + '/Dockerfile') }}" + pull: "{{ item.pull | default(true) }}" + network: "{{ item.network_mode | default(omit) }}" + args: "{{ item.buildargs | default(omit) }}" + name: "molecule_local/{{ item.image }}" + docker_host: "{{ item.docker_host | default(lookup('env', 'DOCKER_HOST') or 'unix://var/run/docker.sock') }}" + cacert_path: "{{ item.cacert_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/ca.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" + cert_path: "{{ item.cert_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/cert.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" + key_path: "{{ item.key_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/key.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" + tls_verify: "{{ item.tls_verify | default(lookup('env', 'DOCKER_TLS_VERIFY')) or false }}" + force_source: "{{ item.force | default(true) }}" + source: build + with_items: "{{ molecule_yml.platforms }}" + loop_control: + label: "molecule_local/{{ item.image }}" + no_log: false + register: result + until: result is not failed + retries: "{{ item.retries | default(3) }}" + delay: 30 + + - debug: + var: result + + - name: Determine the CMD directives + set_fact: + command_directives_dict: >- + {{ command_directives_dict | default({}) | + combine({ item.name: item.command | default('bash -c "while true; do sleep 10000; done"') }) + }} + with_items: "{{ molecule_yml.platforms }}" + when: item.override_command | default(true) + + - name: Create molecule instance(s) + docker_container: + name: "{{ item.name }}" + docker_host: "{{ item.docker_host | default(lookup('env', 'DOCKER_HOST') or 'unix://var/run/docker.sock') }}" + cacert_path: "{{ item.cacert_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/ca.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" + cert_path: "{{ item.cert_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/cert.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" + key_path: "{{ item.key_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/key.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" + tls_verify: "{{ item.tls_verify | default(lookup('env', 'DOCKER_TLS_VERIFY')) or false }}" + hostname: "{{ item.hostname | default(item.name) }}" + image: "{{ item.pre_build_image | default(false) | ternary('', 'molecule_local/') }}{{ item.image }}" + pull: "{{ item.pull | default(omit) }}" + memory: "{{ item.memory | default(omit) }}" + memory_swap: "{{ item.memory_swap | default(omit) }}" + state: started + recreate: false + log_driver: json-file + command: "{{ (command_directives_dict | default({}))[item.name] | default(omit) }}" + user: "{{ item.user | default(omit) }}" + pid_mode: "{{ item.pid_mode | default(omit) }}" + privileged: "{{ item.privileged | default(omit) }}" + security_opts: "{{ item.security_opts | default(omit) }}" + devices: "{{ item.devices | default(omit) }}" + volumes: "{{ item.volumes | default(omit) }}" + tmpfs: "{{ item.tmpfs | default(omit) }}" + capabilities: "{{ item.capabilities | default(omit) }}" + sysctls: "{{ item.sysctls | default(omit) }}" + exposed_ports: "{{ item.exposed_ports | default(omit) }}" + published_ports: "{{ item.published_ports | default(omit) }}" + ulimits: "{{ item.ulimits | default(omit) }}" + networks: "{{ item.networks | default(omit) }}" + network_mode: "{{ item.network_mode | default(omit) }}" + networks_cli_compatible: "{{ item.networks_cli_compatible | default(true) }}" + purge_networks: "{{ item.purge_networks | default(omit) }}" + dns_servers: "{{ item.dns_servers | default(omit) }}" + etc_hosts: "{{ item.etc_hosts | default(omit) }}" + env: "{{ item.env | default(omit) }}" + restart_policy: "{{ item.restart_policy | default(omit) }}" + restart_retries: "{{ item.restart_retries | default(omit) }}" + tty: "{{ item.tty | default(omit) }}" + labels: "{{ molecule_labels | combine(item.labels | default({})) }}" + container_default_behavior: "{{ item.container_default_behavior | default('compatibility' if ansible_version.full is version_compare('2.10', '>=') else omit) }}" + healthcheck: "{{ item.healthcheck | default(omit) }}" + register: server + with_items: "{{ molecule_yml.platforms }}" + loop_control: + label: "{{ item.name }}" + no_log: false + async: 7200 + poll: 0 + + - name: Wait for instance(s) creation to complete + async_status: + jid: "{{ item.ansible_job_id }}" + register: docker_jobs + until: docker_jobs.finished + retries: 300 + with_items: "{{ server.results }}" + no_log: false diff --git a/.ci/polish/__init__.py b/.molecule/default/files/polish/__init__.py similarity index 100% rename from .ci/polish/__init__.py rename to .molecule/default/files/polish/__init__.py diff --git a/.ci/polish/cli.py b/.molecule/default/files/polish/cli.py similarity index 70% rename from .ci/polish/cli.py rename to .molecule/default/files/polish/cli.py index 362c398a5c..69942a326e 100755 --- a/.ci/polish/cli.py +++ b/.molecule/default/files/polish/cli.py @@ -9,6 +9,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Command line interface to dynamically create and run a WorkChain that can evaluate a reversed polish expression.""" +import importlib +import sys +import time import click @@ -71,8 +74,16 @@ default=False, help='Only evaluate the expression and generate the workchain but do not launch it' ) +@click.option( + '-r', + '--retries', + type=click.INT, + default=1, + show_default=True, + help='Number of retries for running via the daemon' +) @decorators.with_dbenv() -def launch(expression, code, use_calculations, use_calcfunctions, sleep, timeout, modulo, dry_run, daemon): +def launch(expression, code, use_calculations, use_calcfunctions, sleep, timeout, modulo, dry_run, daemon, retries): """ Evaluate the expression in Reverse Polish Notation in both a normal way and by procedurally generating a workchain that encodes the sequence of operators and gets the stack of operands as an input. Multiplications @@ -96,32 +107,28 @@ def launch(expression, code, use_calculations, use_calcfunctions, sleep, timeout If no expression is specified, a random one will be generated that adheres to these rules """ # pylint: disable=too-many-arguments,too-many-locals,too-many-statements,too-many-branches - import importlib - import sys - import time - import uuid from aiida.orm import Code, Int, Str - from aiida.engine import run_get_node, submit - from lib.expression import generate, validate, evaluate # pylint: disable=import-error - from lib.workchain import generate_outlines, format_outlines, write_workchain # pylint: disable=import-error + from aiida.engine import run_get_node + + lib_expression = importlib.import_module('lib.expression') + lib_workchain = importlib.import_module('lib.workchain') if use_calculations and not isinstance(code, Code): raise click.BadParameter('if you specify the -C flag, you have to specify a code as well') if expression is None: - expression = generate() + expression = lib_expression.generate() - valid, error = validate(expression) + valid, error = lib_expression.validate(expression) if not valid: click.echo(f"the expression '{expression}' is invalid: {error}") sys.exit(1) - filename = f'polish_{str(uuid.uuid4().hex)}.py' - evaluated = evaluate(expression, modulo) - outlines, stack = generate_outlines(expression) - outlines_string = format_outlines(outlines, use_calculations, use_calcfunctions) - write_workchain(outlines_string, filename=filename) + evaluated = lib_expression.evaluate(expression, modulo) + outlines, stack = lib_workchain.generate_outlines(expression) + outlines_string = lib_workchain.format_outlines(outlines, use_calculations, use_calcfunctions) + filename = lib_workchain.write_workchain(outlines_string).name click.echo(f'Expression: {expression}') @@ -139,33 +146,20 @@ def launch(expression, code, use_calculations, use_calcfunctions, sleep, timeout inputs['code'] = code if daemon: - workchain = submit(workchains.Polish00WorkChain, **inputs) - start_time = time.time() - timed_out = True - - while time.time() - start_time < timeout: - time.sleep(sleep) - - if workchain.is_terminated: - timed_out = False + # the daemon tests have been known to fail on Jenkins, when the result node cannot be found + # to mitigate this, we can retry multiple times + for _ in range(retries): + output = run_via_daemon(workchains, inputs, sleep, timeout) + if output is not None: break - - if timed_out: - click.secho('Failed: ', fg='red', bold=True, nl=False) - click.secho( - f'the workchain<{workchain.pk}> did not finish in time and the operation timed out', bold=True - ) - sys.exit(1) - - try: - result = workchain.outputs.result - except AttributeError: - click.secho('Failed: ', fg='red', bold=True, nl=False) - click.secho(f'the workchain<{workchain.pk}> did not return a result output node', bold=True) + if output is None: sys.exit(1) + result, workchain, total_time = output else: + start_time = time.time() results, workchain = run_get_node(workchains.Polish00WorkChain, **inputs) + total_time = time.time() - start_time result = results['result'] click.echo(f'Evaluated : {evaluated}') @@ -179,9 +173,41 @@ def launch(expression, code, use_calculations, use_calcfunctions, sleep, timeout sys.exit(1) else: click.secho('Success: ', fg='green', bold=True, nl=False) - click.secho('the workchain accurately reproduced the evaluated value', bold=True) + click.secho(f'the workchain accurately reproduced the evaluated value in {total_time:.2f}s', bold=True) sys.exit(0) +def run_via_daemon(workchains, inputs, sleep, timeout): + """Run via the daemon, polling until it is terminated or timeout.""" + from aiida.engine import submit + + workchain = submit(workchains.Polish00WorkChain, **inputs) + start_time = time.time() + timed_out = True + + while time.time() - start_time < timeout: + time.sleep(sleep) + + if workchain.is_terminated: + timed_out = False + total_time = time.time() - start_time + break + + if timed_out: + click.secho('Failed: ', fg='red', bold=True, nl=False) + click.secho(f'the workchain<{workchain.pk}> did not finish in time and the operation timed out', bold=True) + return None + + try: + result = workchain.outputs.result + except AttributeError: + click.secho('Failed: ', fg='red', bold=True, nl=False) + click.secho(f'the workchain<{workchain.pk}> did not return a result output node', bold=True) + click.echo(str(workchain.attributes)) + return None + + return result, workchain, total_time + + if __name__ == '__main__': launch() # pylint: disable=no-value-for-parameter diff --git a/.ci/polish/lib/__init__.py b/.molecule/default/files/polish/lib/__init__.py similarity index 100% rename from .ci/polish/lib/__init__.py rename to .molecule/default/files/polish/lib/__init__.py diff --git a/.ci/polish/lib/expression.py b/.molecule/default/files/polish/lib/expression.py similarity index 100% rename from .ci/polish/lib/expression.py rename to .molecule/default/files/polish/lib/expression.py diff --git a/.ci/polish/lib/template/base.tpl b/.molecule/default/files/polish/lib/template/base.tpl similarity index 100% rename from .ci/polish/lib/template/base.tpl rename to .molecule/default/files/polish/lib/template/base.tpl diff --git a/.ci/polish/lib/template/workchain.tpl b/.molecule/default/files/polish/lib/template/workchain.tpl similarity index 100% rename from .ci/polish/lib/template/workchain.tpl rename to .molecule/default/files/polish/lib/template/workchain.tpl diff --git a/.ci/polish/lib/workchain.py b/.molecule/default/files/polish/lib/workchain.py similarity index 83% rename from .ci/polish/lib/workchain.py rename to .molecule/default/files/polish/lib/workchain.py index e20a9a3308..7dd4072d1a 100644 --- a/.ci/polish/lib/workchain.py +++ b/.molecule/default/files/polish/lib/workchain.py @@ -10,8 +10,9 @@ """Functions to dynamically generate a WorkChain from a reversed polish notation expression.""" import collections -import errno +import hashlib import os +from pathlib import Path from string import Template from .expression import OPERATORS # pylint: disable=relative-beyond-top-level @@ -185,9 +186,11 @@ def format_indent(level=0, width=INDENTATION_WIDTH): return ' ' * level * width -def write_workchain(outlines, directory=None, filename=None): +def write_workchain(outlines, directory=None) -> Path: """ Given a list of string formatted outlines, write the corresponding workchains to file + + :returns: file path """ dirpath = os.path.dirname(os.path.realpath(__file__)) template_dir = os.path.join(dirpath, 'template') @@ -197,22 +200,10 @@ def write_workchain(outlines, directory=None, filename=None): if directory is None: directory = os.path.join(dirpath, os.path.pardir, 'polish_workchains') - if filename is None: - filename = os.path.join(directory, 'polish.py') - else: - filename = os.path.join(directory, filename) - - try: - os.makedirs(directory) - except OSError as exception: - if exception.errno != errno.EEXIST: - raise + directory = Path(directory) - try: - init_file = os.path.join(directory, '__init__.py') - os.utime(init_file, None) - except OSError: - open(init_file, 'a').close() + directory.mkdir(parents=True, exist_ok=True) + (directory / '__init__.py').touch() with open(template_file_base, 'r') as handle: template_base = handle.readlines() @@ -220,32 +211,39 @@ def write_workchain(outlines, directory=None, filename=None): with open(template_file_workchain, 'r') as handle: template_workchain = Template(handle.read()) - with open(filename, 'w') as handle: + code_strings = [] - for line in template_base: - handle.write(line) - handle.write('\n') + for line in template_base: + code_strings.append(line) + code_strings.append('\n') - counter = len(outlines) - 1 - for outline in outlines: + counter = len(outlines) - 1 + for outline in outlines: - outline_string = '' - for subline in outline.split('\n'): - outline_string += f'\t\t\t{subline}\n' + outline_string = '' + for subline in outline.split('\n'): + outline_string += f'\t\t\t{subline}\n' - if counter == len(outlines) - 1: - child_class = None - else: - child_class = f'Polish{counter + 1:02d}WorkChain' + if counter == len(outlines) - 1: + child_class = None + else: + child_class = f'Polish{counter + 1:02d}WorkChain' + + subs = { + 'class_name': f'Polish{counter:02d}WorkChain', + 'child_class': child_class, + 'outline': outline_string, + } + code_strings.append(template_workchain.substitute(**subs)) + code_strings.append('\n\n') + + counter -= 1 + + code_string = '\n'.join(code_strings) + hashed = hashlib.md5(code_string.encode('utf8')).hexdigest() - subs = { - 'class_name': f'Polish{counter:02d}WorkChain', - 'child_class': child_class, - 'outline': outline_string, - } - handle.write(template_workchain.substitute(**subs)) - handle.write('\n\n') + filepath = directory / f'polish_{hashed}.py' - counter -= 1 + filepath.write_text(code_string) - return filename + return filepath diff --git a/.molecule/default/run_tests.yml b/.molecule/default/run_tests.yml new file mode 100644 index 0000000000..a1a617ca4b --- /dev/null +++ b/.molecule/default/run_tests.yml @@ -0,0 +1 @@ +- import_playbook: test_polish_workchains.yml diff --git a/.molecule/default/setup_aiida.yml b/.molecule/default/setup_aiida.yml new file mode 100644 index 0000000000..5faca0f399 --- /dev/null +++ b/.molecule/default/setup_aiida.yml @@ -0,0 +1,95 @@ +- name: Set up AiiDa Environment + hosts: all + gather_facts: false + + # run as aiida user + become: true + become_method: "{{ become_method }}" + become_user: "{{ aiida_user | default('aiida') }}" + + environment: + AIIDA_PATH: "{{ aiida_path }}" + + tasks: + + - name: reentry scan + command: "{{ venv_bin }}/reentry scan" + changed_when: false + + - name: Create a new database with name "{{ aiida_backend }}" + postgresql_db: + name: "{{ aiida_backend }}" + login_host: localhost + login_user: aiida + login_password: '' + encoding: UTF8 + lc_collate: en_US.UTF-8 + lc_ctype: en_US.UTF-8 + template: template0 + + - name: Add pg_stat_statements extension to the database + when: aiida_query_stats | default(false) | bool + postgresql_ext: + name: pg_stat_statements + login_host: localhost + login_user: aiida + login_password: '' + db: "{{ aiida_backend }}" + + - name: verdi setup for "{{ aiida_backend }}" + command: > + {{ venv_bin }}/verdi setup + --non-interactive + --profile "{{ aiida_backend }}" + --email "aiida@localhost" + --first-name "ringo" + --last-name "starr" + --institution "the beatles" + --db-backend "{{ aiida_backend }}" + --db-host=localhost + --db-name="{{ aiida_backend }}" + --db-username=aiida + --db-password='' + args: + creates: "{{ aiida_path }}/.aiida/config.json" + + - name: "Check if computer is already present" + command: "{{ venv_bin }}/verdi -p {{ aiida_backend }} computer show localhost" + ignore_errors: true + changed_when: false + no_log: true + register: aiida_check_computer + + - name: verdi computer setup localhost + when: aiida_check_computer.rc != 0 + command: > + {{ venv_bin }}/verdi -p {{ aiida_backend }} computer setup + --non-interactive + --label "localhost" + --description "this computer" + --hostname "localhost" + --transport local + --scheduler direct + --work-dir {{ aiida_path }}/local_work_dir/ + --mpirun-command "mpirun -np {tot_num_mpiprocs}" + --mpiprocs-per-machine 1 + + - name: verdi computer configure localhost + when: aiida_check_computer.rc != 0 + command: > + {{ venv_bin }}/verdi -p {{ aiida_backend }} computer configure local "localhost" + --non-interactive + --safe-interval 0.0 + + # we restart the daemon in run_tests.yml, so no need to start here + # - name: verdi start daemon with {{ aiida_workers }} workers + # command: "{{ venv_bin }}/verdi -p {{ aiida_backend }} daemon start {{ aiida_workers }}" + + - name: get verdi status + command: "{{ venv_bin }}/verdi -p {{ aiida_backend }} status" + register: verdi_status + changed_when: false + + - name: print verdi status + debug: + var: verdi_status.stdout diff --git a/.molecule/default/setup_python.yml b/.molecule/default/setup_python.yml new file mode 100644 index 0000000000..eba59ea303 --- /dev/null +++ b/.molecule/default/setup_python.yml @@ -0,0 +1,27 @@ +- name: Set up Python Environment + hosts: all + gather_facts: false + + # run as root user + become: true + become_method: "{{ become_method }}" + become_user: root + + tasks: + + - name: pip install aiida-core requirements + pip: + chdir: "{{ aiida_core_dir }}" + # TODO dynamically change for python version + requirements: requirements/requirements-py-3.7.txt + executable: "{{ venv_bin }}/pip" + extra_args: --cache-dir {{ aiida_pip_cache }} + register: pip_install_deps + + - name: pip install aiida-core + pip: + chdir: "{{ aiida_core_dir }}" + name: . + executable: "{{ venv_bin }}/pip" + editable: "{{ aiida_pip_editable | default(true) }}" + extra_args: --no-deps diff --git a/.molecule/default/tasks/log_query_stats.yml b/.molecule/default/tasks/log_query_stats.yml new file mode 100644 index 0000000000..b62c53e2d5 --- /dev/null +++ b/.molecule/default/tasks/log_query_stats.yml @@ -0,0 +1,59 @@ +- name: Get DB summary statistics + postgresql_query: + login_host: localhost + login_user: "{{ aiida_user | default('aiida') }}" + login_password: '' + db: "{{ aiida_backend }}" + query: | + SELECT + CAST(sum(calls) AS INTEGER) as calls, + CAST(sum(rows) AS INTEGER) as rows, + to_char(sum(total_time), '9.99EEEE') as time_ms + FROM pg_stat_statements + WHERE query !~* 'pg_stat_statements'; + register: db_query_stats_summary + +- debug: + var: db_query_stats_summary.query_result + +- name: Get DB statistics for largest queries by time + postgresql_query: + login_host: localhost + login_user: "{{ aiida_user | default('aiida') }}" + login_password: '' + db: "{{ aiida_backend }}" + query: | + SELECT + to_char(total_time, '9.99EEEE') AS time_ms, + calls, + rows, + query + FROM pg_stat_statements + WHERE query !~* 'pg_stat_statements' + ORDER BY calls DESC + LIMIT {{ query_stats_limit | default(5) }}; + register: db_query_stats_time + +- debug: + var: db_query_stats_time.query_result + +- name: Get DB statistics for largest queries by calls + postgresql_query: + login_host: localhost + login_user: "{{ aiida_user | default('aiida') }}" + login_password: '' + db: "{{ aiida_backend }}" + query: | + SELECT + to_char(total_time, '9.99EEEE') AS time_ms, + calls, + rows, + query + FROM pg_stat_statements + WHERE query !~* 'pg_stat_statements' + ORDER BY calls DESC + LIMIT {{ query_stats_limit | default(5) }}; + register: db_query_stats_calls + +- debug: + var: db_query_stats_calls.query_result diff --git a/.molecule/default/tasks/reset_query_stats.yml b/.molecule/default/tasks/reset_query_stats.yml new file mode 100644 index 0000000000..44fd9e3827 --- /dev/null +++ b/.molecule/default/tasks/reset_query_stats.yml @@ -0,0 +1,7 @@ +- name: Reset database query statistics + postgresql_query: + login_host: localhost + login_user: "{{ aiida_user | default('aiida') }}" + login_password: '' + db: "{{ aiida_backend }}" + query: SELECT pg_stat_statements_reset(); diff --git a/.molecule/default/test_polish_workchains.yml b/.molecule/default/test_polish_workchains.yml new file mode 100644 index 0000000000..95b060a182 --- /dev/null +++ b/.molecule/default/test_polish_workchains.yml @@ -0,0 +1,82 @@ +- name: Test the runnning of complex polish notation workchains + hosts: all + gather_facts: false + + # run as aiida user + become: true + become_method: "{{ become_method }}" + become_user: "{{ aiida_user | default('aiida') }}" + + environment: + AIIDA_PATH: "{{ aiida_path }}" + + tasks: + + - name: "Check if add code is already present" + command: "{{ venv_bin }}/verdi -p {{ aiida_backend }} code show add@localhost" + ignore_errors: true + changed_when: false + no_log: true + register: aiida_check_code + + - name: verdi add code setup + when: aiida_check_code.rc != 0 + command: > + {{ venv_bin }}/verdi -p {{ aiida_backend }} code setup + -D "simple script that adds two numbers" + -n -L add -P arithmetic.add + -Y localhost --remote-abs-path=/bin/bash + + - name: Copy workchain files + copy: + src: polish + dest: "${HOME}/{{ aiida_backend }}" + + - name: get python path including workchains + command: echo "${PYTHONPATH}:${HOME}/{{ aiida_backend }}/polish" + register: echo_pythonpath + + - set_fact: + aiida_pythonpath: "{{ echo_pythonpath.stdout }}" + + - name: Reset pythonpath of daemon ({{ aiida_workers }} workers) + # note `verdi daemon restart` did not seem to update the environmental variables? + shell: | + {{ venv_bin }}/verdi -p {{ aiida_backend }} daemon stop + {{ venv_bin }}/verdi -p {{ aiida_backend }} daemon start {{ aiida_workers }} + environment: + PYTHONPATH: "{{ aiida_pythonpath }}" + + - when: aiida_query_stats | default(false) | bool + include_tasks: tasks/reset_query_stats.yml + + - name: "run polish workchains" + # Note the exclamation point after the code is necessary to force the value to be interpreted as LABEL type identifier + shell: | + set -e + declare -a EXPRESSIONS=({{ polish_expressions | map('quote') | join(' ') }}) + for expression in "${EXPRESSIONS[@]}"; do + {{ venv_bin }}/verdi -p {{ aiida_backend }} run --auto-group -l polish -- "{{ polish_script }}" -X add! -C -F -d -t {{ polish_timeout }} -r 2 "$expression" + done + args: + executable: /bin/bash + vars: + polish_script: "${HOME}/{{ aiida_backend }}/polish/cli.py" + polish_timeout: 600 + polish_expressions: + - "1 -2 -1 4 -5 -5 * * * * +" + - "2 1 3 3 -1 + ^ ^ +" + - "3 -5 -1 -4 + * ^" + - "2 4 2 -4 * * +" + - "3 1 1 5 ^ ^ ^" + # - "3 1 3 4 -4 2 * + + ^ ^" # this takes a longer time to run + environment: + PYTHONPATH: "{{ aiida_pythonpath }}" + register: polish_output + + - name: print polish workchain output + debug: + msg: "{{ polish_output.stdout }}" + + - when: aiida_query_stats | default(false) | bool + include_tasks: tasks/log_query_stats.yml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5353cc9842..a4204ac27f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,7 +43,12 @@ repos: files: >- (?x)^( aiida/common/progress_reporter.py| - aiida/engine/processes/calcjobs/calcjob.py| + aiida/engine/.*py| + aiida/manage/manager.py| + aiida/manage/database/delete/nodes.py| + aiida/orm/nodes/node.py| + aiida/orm/nodes/process/.*py| + aiida/tools/graph/graph_traversers.py| aiida/tools/groups/paths.py| aiida/tools/importexport/archive/.*py| aiida/tools/importexport/dbexport/__init__.py| diff --git a/CHANGELOG.md b/CHANGELOG.md index e79a0e1134..d5ed12d88b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,112 @@ # Changelog -## v1.5.2 +## v1.6.0 - 2021-03-15 + +[full changelog](https://github.com/aiidateam/aiida-core/compare/v1.5.2...v1.6.0) | [GitHub contributors page for this release](https://github.com/aiidateam/aiida-core/graphs/contributors?from=2020-12-07&to=2021-03-15&type=c) + +As well as introducing a number of improvements and new features listed below, this release marks the "under-the-hood" migration from the `tornado` package to the Python built-in module `asyncio`, for handling asynchronous processing within the AiiDA engine. +This removes a number of blocking dependency version clashes with other tools, in particular with the newest Jupyter shell and notebook environments. +The migration does not present any backward incompatible changes to AiiDA's public API. +A substantial effort has been made to test and debug the new implementation, and ensure it performs at least equivalent to the previous code (or improves it!), but please let us know if you uncover any additional issues. + +This release also drops support for Python 3.6 (testing is carried out against `3.7`, `3.8` and `3.9`). + +NOTE: `v1.6` is tentatively intended to be the final minor `v1.x` release before `v2.x`, that will include a new file repository implementation and remove all deprecated code. + +### New calculation features ✨ + +The `additional_retrieve_list` metadata option has been added to `CalcJob` ([#4437](https://github.com/aiidateam/aiida-core/pull/4437)). +This new option allows one to specify additional files to be retrieved on a per-instance basis, in addition to the files that are already defined by the plugin to be retrieved. + +A **new namespace `stash`** has bee added to the `metadata.options` input namespace of the `CalcJob` process ([#4424](https://github.com/aiidateam/aiida-core/pull/4424)). +This option namespace allows a user to specify certain files that are created by the calculation job to be stashed somewhere on the remote. +This can be useful if those files need to be stored for a longer time than the scratch space (where the job was run) is available for, but need to be kept on the remote machine and not retrieved. +Examples are files that are necessary to restart a calculation but are too big to be retrieved and stored permanently in the local file repository. + +See [Stashing files on the remote](https://aiida.readthedocs.io/projects/aiida-core/en/v1.6.0/topics/calculations/usage.html#stashing-files-on-the-remote) for more details. + +The **new `TransferCalcjob` plugin** ([#4194](https://github.com/aiidateam/aiida-core/pull/4194)) allows the user to copy files between a remote machine and the local machine running AiiDA. +More specifically, it can do any of the following: + +- Take any number of files from any number of `RemoteData` folders in a remote machine and copy them in the local repository of a single newly created `FolderData` node. +- Take any number of files from any number of `FolderData` nodes in the local machine and copy them in a single newly created `RemoteData` folder in a given remote machine. + +See the [Transferring data](https://aiida.readthedocs.io/projects/aiida-core/en/v1.6.0/howto/data.html#transferring-data) how-to for more details. + +### Profile configuration improvements 👌 + +The way the global/profile configuration is accessed has undergone a number of distinct changes ([#4712](https://github.com/aiidateam/aiida-core/pull/4712)): + +- When loaded, the `config.json` (found in the `.aiida` folder) is now validated against a [JSON Schema](https://json-schema.org/) that can be found in [`aiida/manage/configuration/schema`](https://github.com/aiidateam/aiida-core/tree/develop/aiida/manage/configuration/schema). +- The schema includes a number of new global/profile options, including: `transport.task_retry_initial_interval`, `transport.task_maximum_attempts`, `rmq.task_timeout` and `logging.aiopika_loglevel` ([#4583](https://github.com/aiidateam/aiida-core/pull/4583)). +- The `cache_config.yml` has now also been **deprecated** and merged into the `config.json`, as part of the profile options. + This merge will be handled automatically, upon first load of the `config.json` using the new AiiDA version. + +In-line with these changes, the `verdi config` command has been refactored into separate commands, including `verdi config list`, `verdi config set`, `verdi config unset` and `verdi config caching`. + +See the [Configuring profile options](https://aiida.readthedocs.io/projects/aiida-core/en/v1.6.0/howto/installation.html#configuring-profile-options) and [Configuring caching](https://aiida.readthedocs.io/projects/aiida-core/en/v1.6.0/howto/run_codes.html#how-to-save-compute-time-with-caching) how-tos for more details. + +### Command-line additions and improvements 👌 + +In addition to `verdi config`, numerous other new commands and options have been added to `verdi`: + +- **Deprecated** `verdi export` and `verdi import` commands (replaced by new `verdi archive`) ([#4710](https://github.com/aiidateam/aiida-core/pull/4710)) +- Added `verdi group delete --delete-nodes`, to also delete the nodes in a group during its removal ([#4578](https://github.com/aiidateam/aiida-core/pull/4578)). +- Improved `verdi group remove-nodes` command to warn when requested nodes are not in the specified group ([#4728](https://github.com/aiidateam/aiida-core/pull/4728)). +- Added `exception` to the projection mapping of `verdi process list`, for example to use in debugging as: `verdi process list -S excepted -P ctime pk exception` ([#4786](https://github.com/aiidateam/aiida-core/pull/4786)). +- Added `verdi database summary` ([#4737](https://github.com/aiidateam/aiida-core/pull/4737)): + This prints a summary of the count of each entity and (optionally) the list of unique identifiers for some entities. +- Improved `verdi process play` performance, by only querying for active processes with the `--all` flag ([#4671](https://github.com/aiidateam/aiida-core/pull/4671)) +- Added the `verdi database version` command ([#4613](https://github.com/aiidateam/aiida-core/pull/4613)): + This shows the schema generation and version of the database of the given profile, useful mostly for developers when debugging. +- Improved `verdi node delete` performance ([#4575](https://github.com/aiidateam/aiida-core/pull/4575)): + The logic has been re-written to greatly reduce the time to delete large amounts of nodes. +- Fixed `verdi quicksetup --non-interactive`, to ensure it does not include any user prompts ([#4573](https://github.com/aiidateam/aiida-core/pull/4573)) +- Fixed `verdi --version` when used in editable mode ([#4576](https://github.com/aiidateam/aiida-core/pull/4576)) + +### API additions and improvements 👌 + +The base `Node` class now evaluates equality based on the node's UUID ([#4753](https://github.com/aiidateam/aiida-core/pull/4753)). +For example, loading the same node twice will always resolve as equivalent: `load_node(1) == load_node(1)`. +Note that existing, class specific, equality relationships will still override the base class behaviour, for example: `Int(99) == Int(99)`, even if the nodes have different UUIDs. +This behaviour for subclasses is still under discussion at: + +When hashing nodes for use with the caching features, `-0.` is now converted to `0.`, to reduce issues with differing hashes before/after node storage ([#4648](https://github.com/aiidateam/aiida-core/pull/4648)). +Known failure modes for hashing are now also raised with the `HashingError` exception ([#4778](https://github.com/aiidateam/aiida-core/pull/4778)). + +Both `aiida.tools.delete_nodes` ([#4578](https://github.com/aiidateam/aiida-core/pull/4578)) and `aiida.orm.to_aiida_type` ([#4672](https://github.com/aiidateam/aiida-core/pull/4672)) have been exposed for use in the public API. + +A `pathlib.Path` instance can now be used for the `file` argument of `SinglefileData` ([#3614](https://github.com/aiidateam/aiida-core/pull/3614)) + +Type annotations have been added to all inputs/outputs of functions and methods in `aiida.engine` ([#4669](https://github.com/aiidateam/aiida-core/pull/4669)) and `aiida/orm/nodes/processes` ([#4772](https://github.com/aiidateam/aiida-core/pull/4772)). +As outlined in [PEP 484](https://www.python.org/dev/peps/pep-0484/), this improves static code analysis and, for example, allows for better auto-completion and type checking in many code editors. + +### New REST API Query endpoint ✨ + +The `/querybuilder` endpoint is the first POST method available for AiiDA's RESTful API ([#4337](https://github.com/aiidateam/aiida-core/pull/4337)) + +The POST endpoint returns what the QueryBuilder would return, when providing it with a proper `queryhelp` dictionary ([see the documentation here](https://aiida.readthedocs.io/projects/aiida-core/en/latest/topics/database.html#the-queryhelp)). +Furthermore, it returns the entities/results in the "standard" REST API format - with the exception of `link_type` and `link_label` keys for links (these particular keys are still present as `type` and `label`, respectively). + +For security, POST methods can be toggled on/off with the `verdi restapi --posting/--no-posting` options (it is on by default). +Although note that this option is not yet strictly public, since its naming may be changed in the future! + +See [AiiDA REST API documentation](https://aiida.readthedocs.io/projects/aiida-core/en/latest/reference/rest_api.html) for more details. + +### Additional Changes + +- Fixed the direct scheduler which, in combination with `SshTransport`, was hanging on submit command ([#4735](https://github.com/aiidateam/aiida-core/pull/4735)). + In the ssh transport, to emulate 'chdir', the current directory is now kept in memory, and every command prepended with `cd FOLDER_NAME && ACTUALCOMMAND`. + +- In `aiida.tools.ipython.ipython_magics`, `load_ipython_extension` has been **deprecated** in favour of `register_ipython_extension` ([#4548](https://github.com/aiidateam/aiida-core/pull/4548)). + +- Refactored `.ci/` folder to make tests more portable and easier to understand ([#4565](https://github.com/aiidateam/aiida-core/pull/4565)) + The `ci/` folder had become cluttered, containing configuration and scripts for both the GitHub Actions and Jenkins CI. + This change moved the GH actions specific scripts to `.github/system_tests`, and refactored the Jenkins setup/tests to use [molecule](molecule.readthedocs.io) in the `.molecule/` folder. + +- For aiida-core development, the pytest `requires_rmq` marker and `config_with_profile` fixture have been added ([#4739](https://github.com/aiidateam/aiida-core/pull/4739) and [#4764](https://github.com/aiidateam/aiida-core/pull/4764)) + +## v1.5.2 - 2020-12-07 Note: release `v1.5.1` was skipped due to a problem with the uploaded files to PyPI. @@ -16,7 +122,7 @@ Note: release `v1.5.1` was skipped due to a problem with the uploaded files to P - CI: manually install `numpy` to prevent incompatible releases [[#4615]](https://github.com/aiidateam/aiida-core/pull/4615) -## v1.5.0 +## v1.5.0 - 2020-11-13 In this minor version release, support for Python 3.9 is added [[#4301]](https://github.com/aiidateam/aiida-core/pull/4301), while support for Python 3.5 is dropped [[#4386]](https://github.com/aiidateam/aiida-core/pull/4386). This version is compatible with all current Python versions that are not end-of-life: @@ -59,7 +165,7 @@ This version is compatible with all current Python versions that are not end-of- ### Dependencies - Update requirement `pytest~=6.0` and use `pyproject.toml` [[#4410]](https://github.com/aiidateam/aiida-core/pull/4410) -### Archive (import/export) refactor: +### Archive (import/export) refactor - The refactoring goal was to pave the way for the implementation of a new archive format in v2.0.0 ([ aiidateamAEP005](https://github.com/aiidateam/AEP/pull/21)) - Three abstract+concrete interface classes are defined; writer, reader, migrator, which are **independent of theinternal structure of the archive**. These classes are used within the export/import code. - The code in `aiida/tools/importexport` has been largely re-written, in particular adding `aiida/toolsimportexport/archive`, which contains this code for interfacing with an archive, and **does not require connectionto an AiiDA profile**. diff --git a/Dockerfile b/Dockerfile index 849a1bd5ff..3c076051b4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM aiidateam/aiida-prerequisites:0.2.1 +FROM aiidateam/aiida-prerequisites:0.3.0 USER root @@ -14,6 +14,7 @@ ENV AIIDADB_BACKEND django # Copy and install AiiDA COPY . aiida-core RUN pip install ./aiida-core[atomic_tools] +RUN pip install --upgrade git+https://github.com/unkcpz/circus.git@fix/quit-wait # Configure aiida for the user COPY .docker/opt/configure-aiida.sh /opt/configure-aiida.sh diff --git a/MANIFEST.in b/MANIFEST.in index 0c13bf8b4d..845905c022 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,6 @@ include aiida/cmdline/templates/*.tpl include aiida/manage/backup/backup_info.json.tmpl +include aiida/manage/configuration/schema/*.json include setup.json include AUTHORS.txt include CHANGELOG.md diff --git a/README.md b/README.md index f04634d30b..82def74c91 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,8 @@ If you are experiencing problems with your AiiDA installation, please refer to t If you use AiiDA in your research, please consider citing the following publications: * **AiiDA >= 1.0**: S. P. Huber *et al.*, *AiiDA 1.0, a scalable computational infrastructure for automated reproducible workflows and data provenance*, Scientific Data **7**, 300 (2020); DOI: [10.1038/s41597-020-00638-4](https://doi.org/10.1038/s41597-020-00638-4) - * **AiiDA < 1.0**: Giovanni Pizzi, Andrea Cepellotti, Riccardo Sabatini, Nicola Marzari,and Boris Kozinsky, *AiiDA: automated interactive infrastructure and database for computational science*, Comp. Mat. Sci **111**, 218-230 (2016); DOI: [10.1016/j.commatsci.2015.09.013](https://doi.org/10.1016/j.commatsci.2015.09.013) + * **AiiDA >= 1.0**: M. Uhrin *et al.*, *Workflows in AiiDA: Engineering a high-throughput, event-based engine for robust and modular computational workflows*, Computational Materials Science **187**, 110086 (2021); DOI: [10.1016/j.commatsci.2020.110086](https://doi.org/10.1016/j.commatsci.2020.110086) + * **AiiDA < 1.0**: Giovanni Pizzi, Andrea Cepellotti, Riccardo Sabatini, Nicola Marzari,and Boris Kozinsky, *AiiDA: automated interactive infrastructure and database for computational science*, Computational Materials Science **111**, 218-230 (2016); DOI: [10.1016/j.commatsci.2015.09.013](https://doi.org/10.1016/j.commatsci.2015.09.013) ## License diff --git a/aiida/__init__.py b/aiida/__init__.py index ac82ab948a..e357c2907b 100644 --- a/aiida/__init__.py +++ b/aiida/__init__.py @@ -31,15 +31,13 @@ 'For further information please visit http://www.aiida.net/. All rights reserved.' ) __license__ = 'MIT license, see LICENSE.txt file.' -__version__ = '1.5.2' +__version__ = '1.6.0' __authors__ = 'The AiiDA team.' __paper__ = ( - 'G. Pizzi, A. Cepellotti, R. Sabatini, N. Marzari, and B. Kozinsky,' - '"AiiDA: automated interactive infrastructure and database for computational science", ' - 'Comp. Mat. Sci 111, 218-230 (2016); https://doi.org/10.1016/j.commatsci.2015.09.013 ' - '- http://www.aiida.net.' + 'S. P. Huber et al., "AiiDA 1.0, a scalable computational infrastructure for automated reproducible workflows and ' + 'data provenance", Scientific Data 7, 300 (2020); https://doi.org/10.1038/s41597-020-00638-4' ) -__paper_short__ = 'G. Pizzi et al., Comp. Mat. Sci 111, 218 (2016).' +__paper_short__ = 'S. P. Huber et al., Scientific Data 7, 300 (2020).' def load_dbenv(profile=None): diff --git a/aiida/backends/djsite/db/testbase.py b/aiida/backends/djsite/db/testbase.py index fe602e328d..a76aab5763 100644 --- a/aiida/backends/djsite/db/testbase.py +++ b/aiida/backends/djsite/db/testbase.py @@ -12,12 +12,6 @@ """ from aiida.backends.testimplbase import AiidaTestImplementation -from aiida.orm.implementation.django.backend import DjangoBackend - -# Add a new entry here if you add a file with tests under aiida.backends.djsite.db.subtests -# The key is the name to use in the 'verdi test' command (e.g., a key 'generic' -# can be run using 'verdi test db.generic') -# The value must be the module name containing the subclasses of unittest.TestCase # This contains the codebase for the setUpClass and tearDown methods used internally by the AiidaTestCase @@ -29,13 +23,6 @@ class DjangoTests(AiidaTestImplementation): Automatically takes care of the setUpClass and TearDownClass, when needed. """ - # pylint: disable=attribute-defined-outside-init - - # Note this is has to be a normal method, not a class method - def setUpClass_method(self): - self.clean_db() - self.backend = DjangoBackend() - def clean_db(self): from aiida.backends.djsite.db import models @@ -49,8 +36,3 @@ def clean_db(self): models.DbUser.objects.all().delete() # pylint: disable=no-member models.DbComputer.objects.all().delete() models.DbGroup.objects.all().delete() - - def tearDownClass_method(self): - """ - Backend-specific tasks for tearing down the test environment. - """ diff --git a/aiida/backends/sqlalchemy/testbase.py b/aiida/backends/sqlalchemy/testbase.py index 9b5bac8f01..3e1168740b 100644 --- a/aiida/backends/sqlalchemy/testbase.py +++ b/aiida/backends/sqlalchemy/testbase.py @@ -20,22 +20,6 @@ class SqlAlchemyTests(AiidaTestImplementation): """Base class to test SQLA-related functionalities.""" connection = None - _backend = None - - def setUpClass_method(self): - self.clean_db() - - def tearDownClass_method(self): - """Backend-specific tasks for tearing down the test environment.""" - - @property - def backend(self): - """Get the backend.""" - if self._backend is None: - from aiida.manage.manager import get_manager - self._backend = get_manager().get_backend() - - return self._backend def clean_db(self): from sqlalchemy.sql import table diff --git a/aiida/backends/testbase.py b/aiida/backends/testbase.py index 68d8aa107c..871ce36931 100644 --- a/aiida/backends/testbase.py +++ b/aiida/backends/testbase.py @@ -12,12 +12,11 @@ import unittest import traceback -from tornado import ioloop - from aiida.common.exceptions import ConfigurationError, TestsNotAllowedError, InternalError -from aiida.common.lang import classproperty from aiida.manage import configuration from aiida.manage.manager import get_manager, reset_manager +from aiida import orm +from aiida.common.lang import classproperty TEST_KEYWORD = 'test_' @@ -33,6 +32,8 @@ class AiidaTestCase(unittest.TestCase): """This is the base class for AiiDA tests, independent of the backend. Internally it loads the AiidaTestImplementation subclass according to the current backend.""" + _computer = None # type: aiida.orm.Computer + _user = None # type: aiida.orm.User _class_was_setup = False __backend_instance = None backend = None # type: aiida.orm.implementation.Backend @@ -65,46 +66,39 @@ def get_backend_class(cls): return cls.__impl_class @classmethod - def setUpClass(cls, *args, **kwargs): # pylint: disable=arguments-differ + def setUpClass(cls): + """Set up test class.""" # Note: this will raise an exception, that will be seen as a test # failure. To be safe, you should do the same check also in the tearDownClass # to avoid that it is run check_if_tests_can_run() # Force the loading of the backend which will load the required database environment - get_manager().get_backend() - + cls.backend = get_manager().get_backend() cls.__backend_instance = cls.get_backend_class()() - cls.__backend_instance.setUpClass_method(*args, **kwargs) - cls.backend = cls.__backend_instance.backend - cls._class_was_setup = True - cls.clean_db() - cls.insert_data() + cls.refurbish_db() + + @classmethod + def tearDownClass(cls): + """Tear down test class. + + Note: Also cleans file repository. + """ + # Double check for double security to avoid to run the tearDown + # if this is not a test profile - def setUp(self): - # Install a new IOLoop so that any messing up of the state of the loop is not propagated - # to subsequent tests. - # This call should come before the backend instance setup call just in case it uses the loop - ioloop.IOLoop().make_current() + check_if_tests_can_run() + if orm.autogroup.CURRENT_AUTOGROUP is not None: + orm.autogroup.CURRENT_AUTOGROUP.clear_group_cache() + cls.clean_db() + cls.clean_repository() def tearDown(self): - # Clean up the loop we created in set up. - # Call this after the instance tear down just in case it uses the loop reset_manager() - loop = ioloop.IOLoop.current() - if not loop._closing: # pylint: disable=protected-access,no-member - loop.close() - - def reset_database(self): - """Reset the database to the default state deleting any content currently stored""" - from aiida.orm import autogroup - self.clean_db() - if autogroup.CURRENT_AUTOGROUP is not None: - autogroup.CURRENT_AUTOGROUP.clear_group_cache() - self.insert_data() + ### Database/repository-related methods @classmethod def insert_data(cls): @@ -113,19 +107,9 @@ def insert_data(cls): inserts default data into the database (which is for the moment a default computer). """ - from aiida.orm import User - - cls.create_user() - User.objects.reset() - cls.create_computer() - - @classmethod - def create_user(cls): - cls.__backend_instance.create_user() - - @classmethod - def create_computer(cls): - cls.__backend_instance.create_computer() + orm.User.objects.reset() # clear Aiida's cache of the default user + # populate user cache of test clases + cls.user # pylint: disable=pointless-statement @classmethod def clean_db(cls): @@ -144,9 +128,23 @@ def clean_db(cls): raise InvalidOperation('You cannot call clean_db before running the setUpClass') cls.__backend_instance.clean_db() + cls._computer = None + cls._user = None + + if orm.autogroup.CURRENT_AUTOGROUP is not None: + orm.autogroup.CURRENT_AUTOGROUP.clear_group_cache() reset_manager() + @classmethod + def refurbish_db(cls): + """Clean up database and repopulate with initial data. + + Combines clean_db and insert_data. + """ + cls.clean_db() + cls.insert_data() + @classmethod def clean_repository(cls): """ @@ -177,24 +175,31 @@ def computer(cls): # pylint: disable=no-self-argument :return: the test computer :rtype: :class:`aiida.orm.Computer`""" - return cls.__backend_instance.get_computer() + if cls._computer is None: + created, computer = orm.Computer.objects.get_or_create( + label='localhost', + hostname='localhost', + transport_type='local', + scheduler_type='direct', + workdir='/tmp/aiida', + ) + if created: + computer.store() + cls._computer = computer + + return cls._computer @classproperty - def user_email(cls): # pylint: disable=no-self-argument - return cls.__backend_instance.get_user_email() + def user(cls): # pylint: disable=no-self-argument + if cls._user is None: + cls._user = get_default_user() + return cls._user - @classmethod - def tearDownClass(cls, *args, **kwargs): # pylint: disable=arguments-differ - # Double check for double security to avoid to run the tearDown - # if this is not a test profile - from aiida.orm import autogroup + @classproperty + def user_email(cls): # pylint: disable=no-self-argument + return cls.user.email # pylint: disable=no-member - check_if_tests_can_run() - if autogroup.CURRENT_AUTOGROUP is not None: - autogroup.CURRENT_AUTOGROUP.clear_group_cache() - cls.clean_db() - cls.clean_repository() - cls.__backend_instance.tearDownClass_method(*args, **kwargs) + ### Usability methods def assertClickSuccess(self, cli_result): # pylint: disable=invalid-name self.assertEqual(cli_result.exit_code, 0, cli_result.output) @@ -219,3 +224,27 @@ def tearDownClass(cls, *args, **kwargs): """Close the PGTest postgres test cluster.""" super().tearDownClass(*args, **kwargs) cls.pg_test.close() + + +def get_default_user(**kwargs): + """Creates and stores the default user in the database. + + Default user email is taken from current profile. + No-op if user already exists. + The same is done in `verdi setup`. + + :param kwargs: Additional information to use for new user, i.e. 'first_name', 'last_name' or 'institution'. + :returns: the :py:class:`~aiida.orm.User` + """ + from aiida.manage.configuration import get_config + email = get_config().current_profile.default_user + + if kwargs.pop('email', None): + raise ValueError('Do not specify the user email (must coincide with default user email of profile).') + + # Create the AiiDA user if it does not yet exist + created, user = orm.User.objects.get_or_create(email=email, **kwargs) + if created: + user.store() + + return user diff --git a/aiida/backends/testimplbase.py b/aiida/backends/testimplbase.py index 83603e9c42..6390b74949 100644 --- a/aiida/backends/testimplbase.py +++ b/aiida/backends/testimplbase.py @@ -10,87 +10,20 @@ """Implementation-dependednt base tests""" from abc import ABC, abstractmethod -from aiida import orm -from aiida.common import exceptions - class AiidaTestImplementation(ABC): - """For each implementation, define what to do at setUp and tearDown. - - Each subclass must reimplement two *standard* methods (i.e., *not* classmethods), called - respectively ``setUpClass_method`` and ``tearDownClass_method``. - It is also required to implement setUp_method and tearDown_method to be run for each single test - They can set local properties (e.g. ``self.xxx = yyy``) but remember that ``xxx`` - is not visible to the upper (calling) Test class. - - Moreover, it is required that they define in the setUpClass_method the two properties: - - - ``self.computer`` that must be a Computer object - - ``self.user_email`` that must be a string + """Backend-specific test implementations.""" + _backend = None - These two are then exposed by the ``self.get_computer()`` and ``self.get_user_email()`` - methods.""" - # This should be set by the implementing class in setUpClass_method() - backend = None # type: aiida.orm.implementation.Backend - computer = None # type: aiida.orm.Computer - user = None # type: aiida.orm.User - user_email = None # type: str + @property + def backend(self): + """Get the backend.""" + if self._backend is None: + from aiida.manage.manager import get_manager + self._backend = get_manager().get_backend() - @abstractmethod - def setUpClass_method(self): # pylint: disable=invalid-name - """This class prepares the database (cleans it up and installs some basic entries). - You have also to set a self.computer and a self.user_email as explained in the docstring of the - AiidaTestImplemention docstring.""" - - @abstractmethod - def tearDownClass_method(self): # pylint: disable=invalid-name - """Backend-specific tasks for tearing down the test environment.""" + return self._backend @abstractmethod def clean_db(self): - """This method implements the logic to fully clean the DB.""" - - def insert_data(self): - pass - - def create_user(self): - """This method creates and stores the default user. It has the same effect - as the verdi setup.""" - from aiida.manage.configuration import get_config - self.user_email = get_config().current_profile.default_user - - # Since the default user is needed for many operations in AiiDA, it is not deleted by clean_db. - # In principle, it should therefore always exist - if not we create it anyhow. - try: - self.user = orm.User.objects.get(email=self.user_email) - except exceptions.NotExistent: - self.user = orm.User(email=self.user_email).store() - - def create_computer(self): - """This method creates and stores a computer.""" - self.computer = orm.Computer( - label='localhost', - hostname='localhost', - transport_type='local', - scheduler_type='direct', - workdir='/tmp/aiida', - backend=self.backend - ).store() - - def get_computer(self): - """An ORM Computer object present in the DB.""" - try: - return self.computer - except AttributeError: - raise exceptions.InternalError( - 'The AiiDA Test implementation should define a self.computer in the setUpClass_method' - ) - - def get_user_email(self): - """A string with the email of the User.""" - try: - return self.user_email - except AttributeError: - raise exceptions.InternalError( - 'The AiiDA Test implementation should define a self.user_email in the setUpClass_method' - ) + """This method fully cleans the DB.""" diff --git a/aiida/calculations/transfer.py b/aiida/calculations/transfer.py new file mode 100644 index 0000000000..def70db1fb --- /dev/null +++ b/aiida/calculations/transfer.py @@ -0,0 +1,253 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Implementation of Transfer CalcJob.""" + +import os +from aiida import orm +from aiida.engine import CalcJob +from aiida.common.datastructures import CalcInfo + + +def validate_instructions(instructions, _): + """Check that the instructions dict contains the necessary keywords""" + + instructions_dict = instructions.get_dict() + retrieve_files = instructions_dict.get('retrieve_files', None) + + if retrieve_files is None: + errmsg = ( + '\n\n' + 'no indication of what to do in the instruction node:\n' + f' > {instructions.uuid}\n' + '(to store the files in the repository set retrieve_files=True,\n' + 'to copy them to the specified folder on the remote computer,\n' + 'set it to False)\n' + ) + return errmsg + + if not isinstance(retrieve_files, bool): + errmsg = ( + 'entry for retrieve files inside of instruction node:\n' + f' > {instructions.uuid}\n' + 'must be either True or False; instead, it is:\n' + f' > {retrieve_files}\n' + ) + return errmsg + + local_files = instructions_dict.get('local_files', None) + remote_files = instructions_dict.get('remote_files', None) + symlink_files = instructions_dict.get('symlink_files', None) + + if not any([local_files, remote_files, symlink_files]): + errmsg = ( + 'no indication of which files to copy were found in the instruction node:\n' + f' > {instructions.uuid}\n' + 'Please include at least one of `local_files`, `remote_files`, or `symlink_files`.\n' + 'These should be lists containing 3-tuples with the following format:\n' + ' (source_node_key, source_relpath, target_relpath)\n' + ) + return errmsg + + +def validate_transfer_inputs(inputs, _): + """Check that the instructions dict and the source nodes are consistent""" + + source_nodes = inputs['source_nodes'] + instructions = inputs['instructions'] + computer = inputs['metadata']['computer'] + + instructions_dict = instructions.get_dict() + local_files = instructions_dict.get('local_files', []) + remote_files = instructions_dict.get('remote_files', []) + symlink_files = instructions_dict.get('symlink_files', []) + + source_nodes_provided = set(source_nodes.keys()) + source_nodes_required = set() + error_message_list = [] + + for node_label, node_object in source_nodes.items(): + if isinstance(node_object, orm.RemoteData): + if computer.label != node_object.computer.label: + error_message = ( + f' > remote node `{node_label}` points to computer `{node_object.computer}`, ' + f'not the one being used (`{computer}`)' + ) + error_message_list.append(error_message) + + for source_label, _, _ in local_files: + source_nodes_required.add(source_label) + source_node = source_nodes.get(source_label, None) + error_message = check_node_type('local_files', source_label, source_node, orm.FolderData) + if error_message: + error_message_list.append(error_message) + + for source_label, _, _ in remote_files: + source_nodes_required.add(source_label) + source_node = source_nodes.get(source_label, None) + error_message = check_node_type('remote_files', source_label, source_node, orm.RemoteData) + if error_message: + error_message_list.append(error_message) + + for source_label, _, _ in symlink_files: + source_nodes_required.add(source_label) + source_node = source_nodes.get(source_label, None) + error_message = check_node_type('symlink_files', source_label, source_node, orm.RemoteData) + if error_message: + error_message_list.append(error_message) + + unrequired_nodes = source_nodes_provided.difference(source_nodes_required) + for node_label in unrequired_nodes: + error_message = f' > node `{node_label}` provided as inputs is not being used' + error_message_list.append(error_message) + + if len(error_message_list) > 0: + error_message = '\n\n' + for error_add in error_message_list: + error_message = error_message + error_add + '\n' + return error_message + + +def check_node_type(list_name, node_label, node_object, node_type): + """Common utility function to check the type of a node""" + + if node_object is None: + return f' > node `{node_label}` requested on list `{list_name}` not found among inputs' + + if not isinstance(node_object, node_type): + target_class = node_type.class_node_type + return f' > node `{node_label}`, requested on list `{list_name}` should be of type `{target_class}`' + + return None + + +class TransferCalculation(CalcJob): + """Utility to copy files from different FolderData and RemoteData nodes into a single place. + + The final destination for these files can be either the local repository (by creating a + new FolderData node to store them) or in the remote computer (by leaving the files in a + new remote folder saved in a RemoteData node). + + Only files from the local computer and from remote folders in the same external computer + can be moved at the same time with a single instance of this CalcJob. + + The user needs to provide three inputs: + + * ``instructions``: a dict node specifying which files to copy from which nodes. + * ``source_nodes``: a dict of nodes, each with a unique identifier label as its key. + * ``metadata.computer``: the computer that contains the remote files and will contain + the final RemoteData node. + + The ``instructions`` dict must have the ``retrieve_files`` flag. The CalcJob will create a + new folder in the remote machine (``RemoteData``) and put all the files there and will either: + + (1) leave them there (``retrieve_files = False``) or ... + (2) retrieve all the files and store them locally in a ``FolderData`` (``retrieve_files = True``) + + The `instructions` dict must also contain at least one list with specifications of which files + to copy and from where. All these lists take tuples of 3 that have the following format: + + .. code-block:: python + + ( source_node_key, path_to_file_in_source, path_to_file_in_target) + + where the ``source_node_key`` has to be the respective one used when providing the node in the + ``source_nodes`` input nodes dictionary. + + + The two main lists to include are ``local_files`` (for files to be taken from FolderData nodes) + and ``remote_files`` (for files to be taken from RemoteData nodes). Alternatively, files inside + of RemoteData nodes can instead be put in the ``symlink_files`` list: the only difference is that + files from the first list will be fully copied in the target RemoteData folder, whereas for the + files in second list only a symlink to the original file will be created there. This will only + affect the content of the final RemoteData target folder, but in both cases the full file will + be copied back in the local target FolderData (if ``retrieve_files = True``). + """ + + @classmethod + def define(cls, spec): + super().define(spec) + + spec.input( + 'instructions', + valid_type=orm.Dict, + help='A dictionary containing the `retrieve_files` flag and at least one of the file lists:' + '`local_files`, `remote_files` and/or `symlink_files`.', + validator=validate_instructions, + ) + spec.input_namespace( + 'source_nodes', + valid_type=(orm.FolderData, orm.RemoteData), + dynamic=True, + help='All the nodes that contain files referenced in the instructions.', + ) + + # The transfer just needs a computer, the code are resources are set here + spec.inputs.pop('code', None) + spec.inputs['metadata']['computer'].required = True + spec.inputs['metadata']['options']['resources'].default = { + 'num_machines': 1, + 'num_mpiprocs_per_machine': 1, + } + + spec.inputs.validator = validate_transfer_inputs + + def prepare_for_submission(self, folder): + source_nodes = self.inputs.source_nodes + instructions = self.inputs.instructions.get_dict() + + local_files = instructions.get('local_files', []) + remote_files = instructions.get('remote_files', []) + symlink_files = instructions.get('symlink_files', []) + retrieve_files = instructions.get('retrieve_files') + + calc_info = CalcInfo() + calc_info.skip_submit = True + calc_info.codes_info = [] + calc_info.local_copy_list = [] + calc_info.remote_copy_list = [] + calc_info.remote_symlink_list = [] + retrieve_paths = [] + + for source_label, source_relpath, target_relpath in local_files: + + source_node = source_nodes[source_label] + retrieve_paths.append(target_relpath) + calc_info.local_copy_list.append(( + source_node.uuid, + source_relpath, + target_relpath, + )) + + for source_label, source_relpath, target_relpath in remote_files: + + source_node = source_nodes[source_label] + retrieve_paths.append(target_relpath) + calc_info.remote_copy_list.append(( + source_node.computer.uuid, + os.path.join(source_node.get_remote_path(), source_relpath), + target_relpath, + )) + + for source_label, source_relpath, target_relpath in symlink_files: + + source_node = source_nodes[source_label] + retrieve_paths.append(target_relpath) + calc_info.remote_symlink_list.append(( + source_node.computer.uuid, + os.path.join(source_node.get_remote_path(), source_relpath), + target_relpath, + )) + + if retrieve_files: + calc_info.retrieve_list = retrieve_paths + else: + calc_info.retrieve_list = [] + + return calc_info diff --git a/aiida/cmdline/commands/__init__.py b/aiida/cmdline/commands/__init__.py index 93ebb48cc7..c80c47b6e8 100644 --- a/aiida/cmdline/commands/__init__.py +++ b/aiida/cmdline/commands/__init__.py @@ -16,7 +16,7 @@ # Import to populate the `verdi` sub commands from aiida.cmdline.commands import ( - cmd_calcjob, cmd_code, cmd_comment, cmd_completioncommand, cmd_computer, cmd_config, cmd_data, cmd_database, - cmd_daemon, cmd_devel, cmd_export, cmd_graph, cmd_group, cmd_help, cmd_import, cmd_node, cmd_plugin, cmd_process, - cmd_profile, cmd_rehash, cmd_restapi, cmd_run, cmd_setup, cmd_shell, cmd_status, cmd_user + cmd_archive, cmd_calcjob, cmd_code, cmd_comment, cmd_completioncommand, cmd_computer, cmd_config, cmd_data, + cmd_database, cmd_daemon, cmd_devel, cmd_export, cmd_graph, cmd_group, cmd_help, cmd_import, cmd_node, cmd_plugin, + cmd_process, cmd_profile, cmd_rehash, cmd_restapi, cmd_run, cmd_setup, cmd_shell, cmd_status, cmd_user ) diff --git a/aiida/cmdline/commands/cmd_archive.py b/aiida/cmdline/commands/cmd_archive.py new file mode 100644 index 0000000000..43878ca126 --- /dev/null +++ b/aiida/cmdline/commands/cmd_archive.py @@ -0,0 +1,490 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=too-many-arguments,import-error,too-many-locals,broad-except +"""`verdi archive` command.""" +from enum import Enum +from typing import List, Tuple +import traceback +import urllib.request + +import click +import tabulate + +from aiida.cmdline.commands.cmd_verdi import verdi +from aiida.cmdline.params import arguments, options +from aiida.cmdline.params.types import GroupParamType, PathOrUrl +from aiida.cmdline.utils import decorators, echo +from aiida.common.links import GraphTraversalRules + +EXTRAS_MODE_EXISTING = ['keep_existing', 'update_existing', 'mirror', 'none', 'ask'] +EXTRAS_MODE_NEW = ['import', 'none'] +COMMENT_MODE = ['newest', 'overwrite'] + + +@verdi.group('archive') +def verdi_archive(): + """Create, inspect and import AiiDA archives.""" + + +@verdi_archive.command('inspect') +@click.argument('archive', nargs=1, type=click.Path(exists=True, readable=True)) +@click.option('-v', '--version', is_flag=True, help='Print the archive format version and exit.') +@click.option('-d', '--data', hidden=True, is_flag=True, help='Print the data contents and exit.') +@click.option('-m', '--meta-data', is_flag=True, help='Print the meta data contents and exit.') +def inspect(archive, version, data, meta_data): + """Inspect contents of an archive without importing it. + + By default a summary of the archive contents will be printed. The various options can be used to change exactly what + information is displayed. + + .. deprecated:: 1.5.0 + Support for the --data flag + + """ + import dataclasses + from aiida.tools.importexport import CorruptArchive, detect_archive_type, get_reader + + reader_cls = get_reader(detect_archive_type(archive)) + + with reader_cls(archive) as reader: + try: + if version: + echo.echo(reader.export_version) + elif data: + # data is an internal implementation detail + echo.echo_deprecated('--data is deprecated and will be removed in v2.0.0') + echo.echo_dictionary(reader._get_data()) # pylint: disable=protected-access + elif meta_data: + echo.echo_dictionary(dataclasses.asdict(reader.metadata)) + else: + statistics = { + 'Version aiida': reader.metadata.aiida_version, + 'Version format': reader.metadata.export_version, + 'Computers': reader.entity_count('Computer'), + 'Groups': reader.entity_count('Group'), + 'Links': reader.link_count, + 'Nodes': reader.entity_count('Node'), + 'Users': reader.entity_count('User'), + } + if reader.metadata.conversion_info: + statistics['Conversion info'] = '\n'.join(reader.metadata.conversion_info) + + echo.echo(tabulate.tabulate(statistics.items())) + except CorruptArchive as exception: + echo.echo_critical(f'corrupt archive: {exception}') + + +@verdi_archive.command('create') +@arguments.OUTPUT_FILE(type=click.Path(exists=False)) +@options.CODES() +@options.COMPUTERS() +@options.GROUPS() +@options.NODES() +@options.ARCHIVE_FORMAT( + type=click.Choice(['zip', 'zip-uncompressed', 'zip-lowmemory', 'tar.gz', 'null']), +) +@options.FORCE(help='Overwrite output file if it already exists.') +@click.option( + '-v', + '--verbosity', + default='INFO', + type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'CRITICAL']), + help='Control the verbosity of console logging' +) +@options.graph_traversal_rules(GraphTraversalRules.EXPORT.value) +@click.option( + '--include-logs/--exclude-logs', + default=True, + show_default=True, + help='Include or exclude logs for node(s) in export.' +) +@click.option( + '--include-comments/--exclude-comments', + default=True, + show_default=True, + help='Include or exclude comments for node(s) in export. (Will also export extra users who commented).' +) +# will only be useful when moving to a new archive format, that does not store all data in memory +# @click.option( +# '-b', +# '--batch-size', +# default=1000, +# type=int, +# help='Batch database query results in sub-collections to reduce memory usage.' +# ) +@decorators.with_dbenv() +def create( + output_file, codes, computers, groups, nodes, archive_format, force, input_calc_forward, input_work_forward, + create_backward, return_backward, call_calc_backward, call_work_backward, include_comments, include_logs, verbosity +): + """ + Export subsets of the provenance graph to file for sharing. + + Besides Nodes of the provenance graph, you can export Groups, Codes, Computers, Comments and Logs. + + By default, the archive file will include not only the entities explicitly provided via the command line but also + their provenance, according to the rules outlined in the documentation. + You can modify some of those rules using options of this command. + """ + # pylint: disable=too-many-branches + from aiida.common.log import override_log_formatter_context + from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter + from aiida.tools.importexport import export, ExportFileFormat, EXPORT_LOGGER + from aiida.tools.importexport.common.exceptions import ArchiveExportError + + entities = [] + + if codes: + entities.extend(codes) + + if computers: + entities.extend(computers) + + if groups: + entities.extend(groups) + + if nodes: + entities.extend(nodes) + + kwargs = { + 'input_calc_forward': input_calc_forward, + 'input_work_forward': input_work_forward, + 'create_backward': create_backward, + 'return_backward': return_backward, + 'call_calc_backward': call_calc_backward, + 'call_work_backward': call_work_backward, + 'include_comments': include_comments, + 'include_logs': include_logs, + 'overwrite': force, + } + + if archive_format == 'zip': + export_format = ExportFileFormat.ZIP + kwargs.update({'writer_init': {'use_compression': True}}) + elif archive_format == 'zip-uncompressed': + export_format = ExportFileFormat.ZIP + kwargs.update({'writer_init': {'use_compression': False}}) + elif archive_format == 'zip-lowmemory': + export_format = ExportFileFormat.ZIP + kwargs.update({'writer_init': {'cache_zipinfo': True}}) + elif archive_format == 'tar.gz': + export_format = ExportFileFormat.TAR_GZIPPED + elif archive_format == 'null': + export_format = 'null' + + if verbosity in ['DEBUG', 'INFO']: + set_progress_bar_tqdm(leave=(verbosity == 'DEBUG')) + else: + set_progress_reporter(None) + EXPORT_LOGGER.setLevel(verbosity) + + try: + with override_log_formatter_context('%(message)s'): + export(entities, filename=output_file, file_format=export_format, **kwargs) + except ArchiveExportError as exception: + echo.echo_critical(f'failed to write the archive file. Exception: {exception}') + else: + echo.echo_success(f'wrote the export archive file to {output_file}') + + +@verdi_archive.command('migrate') +@arguments.INPUT_FILE() +@arguments.OUTPUT_FILE(required=False) +@options.ARCHIVE_FORMAT() +@options.FORCE(help='overwrite output file if it already exists') +@click.option('-i', '--in-place', is_flag=True, help='Migrate the archive in place, overwriting the original file.') +@options.SILENT(hidden=True) +@click.option( + '-v', + '--version', + type=click.STRING, + required=False, + metavar='VERSION', + # Note: Adding aiida.tools.EXPORT_VERSION as a default value explicitly would result in a slow import of + # aiida.tools and, as a consequence, aiida.orm. As long as this is the case, better determine the latest export + # version inside the function when needed. + help='Archive format version to migrate to (defaults to latest version).', +) +@click.option( + '--verbosity', + default='INFO', + type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'CRITICAL']), + help='Control the verbosity of console logging' +) +def migrate(input_file, output_file, force, silent, in_place, archive_format, version, verbosity): + """Migrate an export archive to a more recent format version. + + .. deprecated:: 1.5.0 + Support for the --silent flag, replaced by --verbosity + + """ + from aiida.common.log import override_log_formatter_context + from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter + from aiida.tools.importexport import detect_archive_type, EXPORT_VERSION + from aiida.tools.importexport.archive.migrators import get_migrator, MIGRATE_LOGGER + + if silent is True: + echo.echo_deprecated('the --silent option is deprecated, use --verbosity') + + if in_place: + if output_file: + echo.echo_critical('output file specified together with --in-place flag') + output_file = input_file + force = True + elif not output_file: + echo.echo_critical( + 'no output file specified. Please add --in-place flag if you would like to migrate in place.' + ) + + if verbosity in ['DEBUG', 'INFO']: + set_progress_bar_tqdm(leave=(verbosity == 'DEBUG')) + else: + set_progress_reporter(None) + MIGRATE_LOGGER.setLevel(verbosity) + + if version is None: + version = EXPORT_VERSION + + migrator_cls = get_migrator(detect_archive_type(input_file)) + migrator = migrator_cls(input_file) + + try: + with override_log_formatter_context('%(message)s'): + migrator.migrate(version, output_file, force=force, out_compression=archive_format) + except Exception as error: # pylint: disable=broad-except + if verbosity == 'DEBUG': + raise + echo.echo_critical( + 'failed to migrate the archive file (use `--verbosity DEBUG` to see traceback): ' + f'{error.__class__.__name__}:{error}' + ) + + if verbosity in ['DEBUG', 'INFO']: + echo.echo_success(f'migrated the archive to version {version}') + + +class ExtrasImportCode(Enum): + """Exit codes for the verdi command line.""" + keep_existing = 'kcl' + update_existing = 'kcu' + mirror = 'ncu' + none = 'knl' + ask = 'kca' + + +@verdi_archive.command('import') +@click.argument('archives', nargs=-1, type=PathOrUrl(exists=True, readable=True)) +@click.option( + '-w', + '--webpages', + type=click.STRING, + cls=options.MultipleValueOption, + help='Discover all URL targets pointing to files with the .aiida extension for these HTTP addresses. ' + 'Automatically discovered archive URLs will be downloaded and added to ARCHIVES for importing.' +) +@options.GROUP( + type=GroupParamType(create_if_not_exist=True), + help='Specify group to which all the import nodes will be added. If such a group does not exist, it will be' + ' created automatically.' +) +@click.option( + '-e', + '--extras-mode-existing', + type=click.Choice(EXTRAS_MODE_EXISTING), + default='keep_existing', + help='Specify which extras from the export archive should be imported for nodes that are already contained in the ' + 'database: ' + 'ask: import all extras and prompt what to do for existing extras. ' + 'keep_existing: import all extras and keep original value of existing extras. ' + 'update_existing: import all extras and overwrite value of existing extras. ' + 'mirror: import all extras and remove any existing extras that are not present in the archive. ' + 'none: do not import any extras.' +) +@click.option( + '-n', + '--extras-mode-new', + type=click.Choice(EXTRAS_MODE_NEW), + default='import', + help='Specify whether to import extras of new nodes: ' + 'import: import extras. ' + 'none: do not import extras.' +) +@click.option( + '--comment-mode', + type=click.Choice(COMMENT_MODE), + default='newest', + help='Specify the way to import Comments with identical UUIDs: ' + 'newest: Only the newest Comments (based on mtime) (default).' + 'overwrite: Replace existing Comments with those from the import file.' +) +@click.option( + '--migration/--no-migration', + default=True, + show_default=True, + help='Force migration of archive file archives, if needed.' +) +@click.option( + '-v', + '--verbosity', + default='INFO', + type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'CRITICAL']), + help='Control the verbosity of console logging' +) +@options.NON_INTERACTIVE() +@decorators.with_dbenv() +@click.pass_context +def import_archive( + ctx, archives, webpages, group, extras_mode_existing, extras_mode_new, comment_mode, migration, non_interactive, + verbosity +): + """Import data from an AiiDA archive file. + + The archive can be specified by its relative or absolute file path, or its HTTP URL. + """ + # pylint: disable=unused-argument + from aiida.common.log import override_log_formatter_context + from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter + from aiida.tools.importexport.dbimport.utils import IMPORT_LOGGER + from aiida.tools.importexport.archive.migrators import MIGRATE_LOGGER + + if verbosity in ['DEBUG', 'INFO']: + set_progress_bar_tqdm(leave=(verbosity == 'DEBUG')) + else: + set_progress_reporter(None) + IMPORT_LOGGER.setLevel(verbosity) + MIGRATE_LOGGER.setLevel(verbosity) + + all_archives = _gather_imports(archives, webpages) + + # Preliminary sanity check + if not all_archives: + echo.echo_critical('no valid exported archives were found') + + # Shared import key-word arguments + import_kwargs = { + 'group': group, + 'extras_mode_existing': ExtrasImportCode[extras_mode_existing].value, + 'extras_mode_new': extras_mode_new, + 'comment_mode': comment_mode, + } + + with override_log_formatter_context('%(message)s'): + for archive, web_based in all_archives: + _import_archive(archive, web_based, import_kwargs, migration) + + +def _echo_exception(msg: str, exception, warn_only: bool = False): + """Correctly report and exception. + + :param msg: The message prefix + :param exception: the exception raised + :param warn_only: If True only print a warning, otherwise calls sys.exit with a non-zero exit status + + """ + from aiida.tools.importexport import IMPORT_LOGGER + message = f'{msg}: {exception.__class__.__name__}: {str(exception)}' + if warn_only: + echo.echo_warning(message) + else: + IMPORT_LOGGER.debug('%s', traceback.format_exc()) + echo.echo_critical(message) + + +def _gather_imports(archives, webpages) -> List[Tuple[str, bool]]: + """Gather archives to import and sort into local files and URLs. + + :returns: list of (archive path, whether it is web based) + + """ + from aiida.tools.importexport.common.utils import get_valid_import_links + + final_archives = [] + + # Build list of archives to be imported + for archive in archives: + if archive.startswith('http://') or archive.startswith('https://'): + final_archives.append((archive, True)) + else: + final_archives.append((archive, False)) + + # Discover and retrieve *.aiida files at URL(s) + if webpages is not None: + for webpage in webpages: + try: + echo.echo_info(f'retrieving archive URLS from {webpage}') + urls = get_valid_import_links(webpage) + except Exception as error: + echo.echo_critical( + f'an exception occurred while trying to discover archives at URL {webpage}:\n{error}' + ) + else: + echo.echo_success(f'{len(urls)} archive URLs discovered and added') + final_archives.extend([(u, True) for u in urls]) + + return final_archives + + +def _import_archive(archive: str, web_based: bool, import_kwargs: dict, try_migration: bool): + """Perform the archive import. + + :param archive: the path or URL to the archive + :param web_based: If the archive needs to be downloaded first + :param import_kwargs: keyword arguments to pass to the import function + :param try_migration: whether to try a migration if the import raises IncompatibleArchiveVersionError + + """ + from aiida.common.folders import SandboxFolder + from aiida.tools.importexport import ( + detect_archive_type, EXPORT_VERSION, import_data, IncompatibleArchiveVersionError + ) + from aiida.tools.importexport.archive.migrators import get_migrator + + with SandboxFolder() as temp_folder: + + archive_path = archive + + if web_based: + echo.echo_info(f'downloading archive: {archive}') + try: + response = urllib.request.urlopen(archive) + except Exception as exception: + _echo_exception(f'downloading archive {archive} failed', exception) + temp_folder.create_file_from_filelike(response, 'downloaded_archive.zip') + archive_path = temp_folder.get_abs_path('downloaded_archive.zip') + echo.echo_success('archive downloaded, proceeding with import') + + echo.echo_info(f'starting import: {archive}') + try: + import_data(archive_path, **import_kwargs) + except IncompatibleArchiveVersionError as exception: + if try_migration: + + echo.echo_info(f'incompatible version detected for {archive}, trying migration') + try: + migrator = get_migrator(detect_archive_type(archive_path))(archive_path) + archive_path = migrator.migrate( + EXPORT_VERSION, None, out_compression='none', work_dir=temp_folder.abspath + ) + except Exception as exception: + _echo_exception(f'an exception occurred while migrating the archive {archive}', exception) + + echo.echo_info('proceeding with import of migrated archive') + try: + import_data(archive_path, **import_kwargs) + except Exception as exception: + _echo_exception( + f'an exception occurred while trying to import the migrated archive {archive}', exception + ) + else: + _echo_exception(f'an exception occurred while trying to import the archive {archive}', exception) + except Exception as exception: + _echo_exception(f'an exception occurred while trying to import the archive {archive}', exception) + + echo.echo_success(f'imported archive {archive}') diff --git a/aiida/cmdline/commands/cmd_code.py b/aiida/cmdline/commands/cmd_code.py index caf2f7e46c..b431271c70 100644 --- a/aiida/cmdline/commands/cmd_code.py +++ b/aiida/cmdline/commands/cmd_code.py @@ -9,6 +9,7 @@ ########################################################################### """`verdi code` command.""" from functools import partial +import logging import click import tabulate @@ -192,16 +193,25 @@ def delete(codes, verbose, dry_run, force): Note that codes are part of the data provenance, and deleting a code will delete all calculations using it. """ - from aiida.manage.database.delete.nodes import delete_nodes + from aiida.common.log import override_log_formatter_context + from aiida.tools import delete_nodes, DELETE_LOGGER - verbosity = 1 - if force: - verbosity = 0 - elif verbose: - verbosity = 2 + verbosity = logging.DEBUG if verbose else logging.INFO + DELETE_LOGGER.setLevel(verbosity) node_pks_to_delete = [code.pk for code in codes] - delete_nodes(node_pks_to_delete, dry_run=dry_run, verbosity=verbosity, force=force) + + def _dry_run_callback(pks): + if not pks or force: + return False + echo.echo_warning(f'YOU ARE ABOUT TO DELETE {len(pks)} NODES! THIS CANNOT BE UNDONE!') + return not click.confirm('Shall I continue?', abort=True) + + with override_log_formatter_context('%(message)s'): + _, was_deleted = delete_nodes(node_pks_to_delete, dry_run=dry_run or _dry_run_callback) + + if was_deleted: + echo.echo_success('Finished deletion.') @verdi_code.command() diff --git a/aiida/cmdline/commands/cmd_computer.py b/aiida/cmdline/commands/cmd_computer.py index 3d83d66bd6..0ca30081c5 100644 --- a/aiida/cmdline/commands/cmd_computer.py +++ b/aiida/cmdline/commands/cmd_computer.py @@ -338,7 +338,7 @@ def computer_disable(computer, user): @verdi_computer.command('list') @options.ALL(help='Show also disabled or unconfigured computers.') -@options.RAW(help='Show only the computer names, one per line.') +@options.RAW(help='Show only the computer labels, one per line.') @with_dbenv() def computer_list(all_entries, raw): """List all available computers.""" @@ -346,7 +346,7 @@ def computer_list(all_entries, raw): if not raw: echo.echo_info('List of configured computers') - echo.echo_info("Use 'verdi computer show COMPUTERNAME' to display more detailed information") + echo.echo_info("Use 'verdi computer show COMPUTERLABEL' to display more detailed information") computers = Computer.objects.all() user = User.objects.get_default() @@ -357,7 +357,7 @@ def computer_list(all_entries, raw): sort = lambda computer: computer.label highlight = lambda comp: comp.is_user_configured(user) and comp.is_user_enabled(user) hide = lambda comp: not (comp.is_user_configured(user) and comp.is_user_enabled(user)) and not all_entries - echo.echo_formatted_list(computers, ['name'], sort=sort, highlight=highlight, hide=hide) + echo.echo_formatted_list(computers, ['label'], sort=sort, highlight=highlight, hide=hide) @verdi_computer.command('show') diff --git a/aiida/cmdline/commands/cmd_config.py b/aiida/cmdline/commands/cmd_config.py index 5eb87cf376..94415a1c22 100644 --- a/aiida/cmdline/commands/cmd_config.py +++ b/aiida/cmdline/commands/cmd_config.py @@ -8,46 +8,231 @@ # For further information please visit http://www.aiida.net # ########################################################################### """`verdi config` command.""" +import textwrap import click from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.params import arguments -from aiida.cmdline.utils import echo +from aiida.cmdline.utils import decorators, echo -@verdi.command('config') +class _DeprecateConfigCommandsGroup(click.Group): + """Overloads the get_command with one that identifies deprecated commands.""" + + def get_command(self, ctx, cmd_name): + """Override the default click.Group get_command with one that identifies deprecated commands.""" + cmd = click.Group.get_command(self, ctx, cmd_name) + + if cmd is not None: + return cmd + + if cmd_name in [ + 'daemon.default_workers', 'logging.plumpy_loglevel', 'daemon.timeout', 'logging.sqlalchemy_loglevel', + 'daemon.worker_process_slots', 'logging.tornado_loglevel', 'db.batch_size', 'runner.poll.interval', + 'logging.aiida_loglevel', 'user.email', 'logging.alembic_loglevel', 'user.first_name', + 'logging.circus_loglevel', 'user.institution', 'logging.db_loglevel', 'user.last_name', + 'logging.kiwipy_loglevel', 'verdi.shell.auto_import', 'logging.paramiko_loglevel', + 'warnings.showdeprecations', 'autofill.user.email', 'autofill.user.first_name', 'autofill.user.last_name', + 'autofill.user.institution' + ]: + ctx.obj.deprecated_name = cmd_name + cmd = click.Group.get_command(self, ctx, '_deprecated') + return cmd + + ctx.fail(f"'{cmd_name}' is not a verdi config command.") + + return None + + +@verdi.group('config', cls=_DeprecateConfigCommandsGroup) +def verdi_config(): + """Manage the AiiDA configuration.""" + + +@verdi_config.command('list') +@click.argument('prefix', metavar='PREFIX', required=False, default='') +@click.option('-d', '--description', is_flag=True, help='Include description of options') +@click.pass_context +def verdi_config_list(ctx, prefix, description: bool): + """List AiiDA options for the current profile. + + Optionally filtered by a prefix. + """ + from tabulate import tabulate + + from aiida.manage.configuration import Config, Profile + + config: Config = ctx.obj.config + profile: Profile = ctx.obj.profile + + if not profile: + echo.echo_warning('no profiles configured: run `verdi setup` to create one') + + option_values = config.get_options(profile.name if profile else None) + + def _join(val): + """split arrays into multiple lines.""" + if isinstance(val, list): + return '\n'.join(str(v) for v in val) + return val + + if description: + table = [[name, source, _join(value), '\n'.join(textwrap.wrap(c.description))] + for name, (c, source, value) in option_values.items() + if name.startswith(prefix)] + headers = ['name', 'source', 'value', 'description'] + else: + table = [[name, source, _join(value)] + for name, (c, source, value) in option_values.items() + if name.startswith(prefix)] + headers = ['name', 'source', 'value'] + + # sort by name + table = sorted(table, key=lambda x: x[0]) + echo.echo(tabulate(table, headers=headers)) + + +@verdi_config.command('show') @arguments.CONFIG_OPTION(metavar='OPTION_NAME') -@click.argument('value', metavar='OPTION_VALUE', required=False) -@click.option('--global', 'globally', is_flag=True, help='Apply the option configuration wide.') -@click.option('--unset', is_flag=True, help='Remove the line matching the option name from the config file.') @click.pass_context -def verdi_config(ctx, option, value, globally, unset): - """Configure profile-specific or global AiiDA options.""" - config = ctx.obj.config - profile = ctx.obj.profile +def verdi_config_show(ctx, option): + """Show details of an AiiDA option for the current profile.""" + from aiida.manage.configuration import Config, Profile + from aiida.manage.configuration.options import NO_DEFAULT + + config: Config = ctx.obj.config + profile: Profile = ctx.obj.profile + + dct = { + 'schema': option.schema, + 'values': { + 'default': '' if option.default is NO_DEFAULT else option.default, + 'global': config.options.get(option.name, ''), + } + } + + if not profile: + echo.echo_warning('no profiles configured: run `verdi setup` to create one') + else: + dct['values']['profile'] = profile.options.get(option.name, '') + + echo.echo_dictionary(dct, fmt='yaml', sort_keys=False) + + +@verdi_config.command('get') +@arguments.CONFIG_OPTION(metavar='OPTION_NAME') +def verdi_config_get(option): + """Get the value of an AiiDA option for the current profile.""" + from aiida import get_config_option + + value = get_config_option(option.name) + echo.echo(str(value)) + + +@verdi_config.command('set') +@arguments.CONFIG_OPTION(metavar='OPTION_NAME') +@click.argument('value', metavar='OPTION_VALUE') +@click.option('-g', '--global', 'globally', is_flag=True, help='Apply the option configuration wide.') +@click.option('-a', '--append', is_flag=True, help='Append the value to an existing array.') +@click.option('-r', '--remove', is_flag=True, help='Remove the value from an existing array.') +@click.pass_context +def verdi_config_set(ctx, option, value, globally, append, remove): + """Set an AiiDA option. + + List values are split by whitespace, e.g. "a b" becomes ["a", "b"]. + """ + from aiida.manage.configuration import Config, Profile, ConfigValidationError + + if append and remove: + echo.echo_critical('Cannot flag both append and remove') + + config: Config = ctx.obj.config + profile: Profile = ctx.obj.profile + + if option.global_only: + globally = True + + # Define the string that determines the scope: for specific profile or globally + scope = profile.name if (not globally and profile) else None + scope_text = f"for '{profile.name}' profile" if (not globally and profile) else 'globally' + + if append or remove: + try: + current = config.get_option(option.name, scope=scope) + except ConfigValidationError as error: + echo.echo_critical(str(error)) + if not isinstance(current, list): + echo.echo_critical(f'cannot append/remove to value: {current}') + if append: + value = current + [value] + else: + value = [item for item in current if item != value] + + # Set the specified option + try: + value = config.set_option(option.name, value, scope=scope) + except ConfigValidationError as error: + echo.echo_critical(str(error)) + + config.store() + echo.echo_success(f"'{option.name}' set to {value} {scope_text}") + + +@verdi_config.command('unset') +@arguments.CONFIG_OPTION(metavar='OPTION_NAME') +@click.option('-g', '--global', 'globally', is_flag=True, help='Unset the option configuration wide.') +@click.pass_context +def verdi_config_unset(ctx, option, globally): + """Unset an AiiDA option.""" + from aiida.manage.configuration import Config, Profile + + config: Config = ctx.obj.config + profile: Profile = ctx.obj.profile if option.global_only: globally = True # Define the string that determines the scope: for specific profile or globally scope = profile.name if (not globally and profile) else None - scope_text = f'for {profile.name}' if (not globally and profile) else 'globally' + scope_text = f"for '{profile.name}' profile" if (not globally and profile) else 'globally' # Unset the specified option - if unset: - config.unset_option(option.name, scope=scope) - config.store() - echo.echo_success(f'{option.name} unset {scope_text}') + config.unset_option(option.name, scope=scope) + config.store() + echo.echo_success(f"'{option.name}' unset {scope_text}") - # Get the specified option - elif value is None: - option_value = config.get_option(option.name, scope=scope, default=False) - if option_value: - echo.echo(f'{option_value}') - # Set the specified option +@verdi_config.command('caching') +@click.option('-d', '--disabled', is_flag=True, help='List disabled types instead.') +def verdi_config_caching(disabled): + """List caching-enabled process types for the current profile.""" + from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR, get_entry_point_names + from aiida.manage.caching import get_use_cache + + for group in ['aiida.calculations', 'aiida.workflows']: + for entry_point in get_entry_point_names(group): + identifier = ENTRY_POINT_STRING_SEPARATOR.join([group, entry_point]) + if get_use_cache(identifier=identifier): + if not disabled: + echo.echo(identifier) + elif disabled: + echo.echo(identifier) + + +@verdi_config.command('_deprecated', hidden=True) +@decorators.deprecated_command("This command has been deprecated. Please use 'verdi config show/set/unset' instead.") +@click.argument('value', metavar='OPTION_VALUE', required=False) +@click.option('--global', 'globally', is_flag=True, help='Apply the option configuration wide.') +@click.option('--unset', is_flag=True, help='Remove the line matching the option name from the config file.') +@click.pass_context +def verdi_config_deprecated(ctx, value, globally, unset): + """"This command has been deprecated. Please use 'verdi config show/set/unset' instead.""" + from aiida.manage.configuration import get_option + option = get_option(ctx.obj.deprecated_name) + if unset: + ctx.invoke(verdi_config_unset, option=option, globally=globally) + elif value is not None: + ctx.invoke(verdi_config_set, option=option, value=value, globally=globally) else: - config.set_option(option.name, value, scope=scope) - config.store() - echo.echo_success(f'{option.name} set to {value} {scope_text}') + ctx.invoke(verdi_config_get, option=option) diff --git a/aiida/cmdline/commands/cmd_daemon.py b/aiida/cmdline/commands/cmd_daemon.py index faf720e436..5fbac9e013 100644 --- a/aiida/cmdline/commands/cmd_daemon.py +++ b/aiida/cmdline/commands/cmd_daemon.py @@ -54,6 +54,8 @@ def start(foreground, number): If the NUMBER of desired workers is not specified, the default is used, which is determined by the configuration option `daemon.default_workers`, which if not explicitly changed defaults to 1. + + Returns exit code 0 if the daemon is OK, non-zero if there was an error. """ from aiida.engine.daemon.client import get_daemon_client @@ -78,7 +80,9 @@ def start(foreground, number): time.sleep(1) response = client.get_status() - print_client_response_status(response) + retcode = print_client_response_status(response) + if retcode: + sys.exit(retcode) @verdi_daemon.command() @@ -115,24 +119,34 @@ def status(all_profiles): @click.argument('number', default=1, type=int) @decorators.only_if_daemon_running() def incr(number): - """Add NUMBER [default=1] workers to the running daemon.""" + """Add NUMBER [default=1] workers to the running daemon. + + Returns exit code 0 if the daemon is OK, non-zero if there was an error. + """ from aiida.engine.daemon.client import get_daemon_client client = get_daemon_client() response = client.increase_workers(number) - print_client_response_status(response) + retcode = print_client_response_status(response) + if retcode: + sys.exit(retcode) @verdi_daemon.command() @click.argument('number', default=1, type=int) @decorators.only_if_daemon_running() def decr(number): - """Remove NUMBER [default=1] workers from the running daemon.""" + """Remove NUMBER [default=1] workers from the running daemon. + + Returns exit code 0 if the daemon is OK, non-zero if there was an error. + """ from aiida.engine.daemon.client import get_daemon_client client = get_daemon_client() response = client.decrease_workers(number) - print_client_response_status(response) + retcode = print_client_response_status(response) + if retcode: + sys.exit(retcode) @verdi_daemon.command() @@ -154,7 +168,10 @@ def logshow(): @click.option('--no-wait', is_flag=True, help='Do not wait for confirmation.') @click.option('--all', 'all_profiles', is_flag=True, help='Stop all daemons.') def stop(no_wait, all_profiles): - """Stop the daemon.""" + """Stop the daemon. + + Returns exit code 0 if the daemon was shut down successfully (or was not running), non-zero if there was an error. + """ from aiida.engine.daemon.client import get_daemon_client config = get_config() @@ -190,7 +207,9 @@ def stop(no_wait, all_profiles): if response['status'] == client.DAEMON_ERROR_NOT_RUNNING: click.echo('The daemon was not running.') else: - print_client_response_status(response) + retcode = print_client_response_status(response) + if retcode: + sys.exit(retcode) @verdi_daemon.command() @@ -205,6 +224,8 @@ def restart(ctx, reset, no_wait): By default will only reset the workers of the running daemon. After the restart the same amount of workers will be running. If the `--reset` flag is passed, however, the full daemon will be stopped and restarted with the default number of workers that is started when calling `verdi daemon start` manually. + + Returns exit code 0 if the result is OK, non-zero if there was an error. """ from aiida.engine.daemon.client import get_daemon_client @@ -230,7 +251,9 @@ def restart(ctx, reset, no_wait): response = client.restart_daemon(wait) if wait: - print_client_response_status(response) + retcode = print_client_response_status(response) + if retcode: + sys.exit(retcode) @verdi_daemon.command(hidden=True) diff --git a/aiida/cmdline/commands/cmd_database.py b/aiida/cmdline/commands/cmd_database.py index c486ea038f..6da318ea34 100644 --- a/aiida/cmdline/commands/cmd_database.py +++ b/aiida/cmdline/commands/cmd_database.py @@ -23,6 +23,24 @@ def verdi_database(): """Inspect and manage the database.""" +@verdi_database.command('version') +def database_version(): + """Show the version of the database. + + The database version is defined by the tuple of the schema generation and schema revision. + """ + from aiida.manage.manager import get_manager + + manager = get_manager() + manager._load_backend(schema_check=False) # pylint: disable=protected-access + backend_manager = manager.get_backend_manager() + + echo.echo('Generation: ', bold=True, nl=False) + echo.echo(backend_manager.get_schema_generation_database()) + echo.echo('Revision: ', bold=True, nl=False) + echo.echo(backend_manager.get_schema_version_database()) + + @verdi_database.command('migrate') @options.FORCE() def database_migrate(force): @@ -176,3 +194,48 @@ def detect_invalid_nodes(): echo.echo_success('no integrity violations detected') else: echo.echo_critical('one or more integrity violations detected') + + +@verdi_database.command('summary') +@options.VERBOSE() +def database_summary(verbose): + """Summarise the entities in the database.""" + from aiida.orm import QueryBuilder, Node, Group, Computer, Comment, Log, User + data = {} + + # User + query_user = QueryBuilder().append(User, project=['email']) + data['Users'] = {'count': query_user.count()} + if verbose: + data['Users']['emails'] = query_user.distinct().all(flat=True) + + # Computer + query_comp = QueryBuilder().append(Computer, project=['name']) + data['Computers'] = {'count': query_comp.count()} + if verbose: + data['Computers']['names'] = query_comp.distinct().all(flat=True) + + # Node + count = QueryBuilder().append(Node).count() + data['Nodes'] = {'count': count} + if verbose: + node_types = QueryBuilder().append(Node, project=['node_type']).distinct().all(flat=True) + data['Nodes']['node_types'] = node_types + process_types = QueryBuilder().append(Node, project=['process_type']).distinct().all(flat=True) + data['Nodes']['process_types'] = [p for p in process_types if p] + + # Group + query_group = QueryBuilder().append(Group, project=['type_string']) + data['Groups'] = {'count': query_group.count()} + if verbose: + data['Groups']['type_strings'] = query_group.distinct().all(flat=True) + + # Comment + count = QueryBuilder().append(Comment).count() + data['Comments'] = {'count': count} + + # Log + count = QueryBuilder().append(Log).count() + data['Logs'] = {'count': count} + + echo.echo_dictionary(data, sort_keys=False, fmt='yaml') diff --git a/aiida/cmdline/commands/cmd_export.py b/aiida/cmdline/commands/cmd_export.py index add1f9641e..0e959de06c 100644 --- a/aiida/cmdline/commands/cmd_export.py +++ b/aiida/cmdline/commands/cmd_export.py @@ -7,30 +7,32 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-arguments,import-error,too-many-locals +# pylint: disable=too-many-arguments,import-error,too-many-locals,unused-argument """`verdi export` command.""" import click -import tabulate from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import arguments -from aiida.cmdline.params import options +from aiida.cmdline.params import arguments, options from aiida.cmdline.utils import decorators -from aiida.cmdline.utils import echo from aiida.common.links import GraphTraversalRules +from aiida.cmdline.commands import cmd_archive -@verdi.group('export') + +@verdi.group('export', hidden=True) +@decorators.deprecated_command("This command has been deprecated. Please use 'verdi archive' instead.") def verdi_export(): - """Create and manage export archives.""" + """Deprecated, use `verdi archive`.""" @verdi_export.command('inspect') +@decorators.deprecated_command("This command has been deprecated. Please use 'verdi archive inspect' instead.") @click.argument('archive', nargs=1, type=click.Path(exists=True, readable=True)) @click.option('-v', '--version', is_flag=True, help='Print the archive format version and exit.') -@click.option('-d', '--data', is_flag=True, help='Print the data contents and exit.') +@click.option('-d', '--data', hidden=True, is_flag=True, help='Print the data contents and exit.') @click.option('-m', '--meta-data', is_flag=True, help='Print the meta data contents and exit.') -def inspect(archive, version, data, meta_data): +@click.pass_context +def inspect(ctx, archive, version, data, meta_data): """Inspect contents of an exported archive without importing it. By default a summary of the archive contents will be printed. The various options can be used to change exactly what @@ -40,40 +42,11 @@ def inspect(archive, version, data, meta_data): Support for the --data flag """ - import dataclasses - from aiida.tools.importexport import CorruptArchive, detect_archive_type, get_reader - - reader_cls = get_reader(detect_archive_type(archive)) - - with reader_cls(archive) as reader: - try: - if version: - echo.echo(reader.export_version) - elif data: - # data is an internal implementation detail - echo.echo_deprecated('--data is deprecated and will be removed in v2.0.0') - echo.echo_dictionary(reader._get_data()) # pylint: disable=protected-access - elif meta_data: - echo.echo_dictionary(dataclasses.asdict(reader.metadata)) - else: - statistics = { - 'Version aiida': reader.metadata.aiida_version, - 'Version format': reader.metadata.export_version, - 'Computers': reader.entity_count('Computer'), - 'Groups': reader.entity_count('Group'), - 'Links': reader.link_count, - 'Nodes': reader.entity_count('Node'), - 'Users': reader.entity_count('User'), - } - if reader.metadata.conversion_info: - statistics['Conversion info'] = '\n'.join(reader.metadata.conversion_info) - - echo.echo(tabulate.tabulate(statistics.items())) - except CorruptArchive as exception: - echo.echo_critical(f'corrupt archive: {exception}') + ctx.forward(cmd_archive.inspect) @verdi_export.command('create') +@decorators.deprecated_command("This command has been deprecated. Please use 'verdi archive create' instead.") @arguments.OUTPUT_FILE(type=click.Path(exists=False)) @options.CODES() @options.COMPUTERS() @@ -103,17 +76,10 @@ def inspect(archive, version, data, meta_data): show_default=True, help='Include or exclude comments for node(s) in export. (Will also export extra users who commented).' ) -# will only be useful when moving to a new archive format, that does not store all data in memory -# @click.option( -# '-b', -# '--batch-size', -# default=1000, -# type=int, -# help='Batch database query results in sub-collections to reduce memory usage.' -# ) +@click.pass_context @decorators.with_dbenv() def create( - output_file, codes, computers, groups, nodes, archive_format, force, input_calc_forward, input_work_forward, + ctx, output_file, codes, computers, groups, nodes, archive_format, force, input_calc_forward, input_work_forward, create_backward, return_backward, call_calc_backward, call_work_backward, include_comments, include_logs, verbosity ): """ @@ -125,74 +91,17 @@ def create( their provenance, according to the rules outlined in the documentation. You can modify some of those rules using options of this command. """ - # pylint: disable=too-many-branches - from aiida.common.log import override_log_formatter_context - from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter - from aiida.tools.importexport import export, ExportFileFormat, EXPORT_LOGGER - from aiida.tools.importexport.common.exceptions import ArchiveExportError - - entities = [] - - if codes: - entities.extend(codes) - - if computers: - entities.extend(computers) - - if groups: - entities.extend(groups) - - if nodes: - entities.extend(nodes) - - kwargs = { - 'input_calc_forward': input_calc_forward, - 'input_work_forward': input_work_forward, - 'create_backward': create_backward, - 'return_backward': return_backward, - 'call_calc_backward': call_calc_backward, - 'call_work_backward': call_work_backward, - 'include_comments': include_comments, - 'include_logs': include_logs, - 'overwrite': force, - } - - if archive_format == 'zip': - export_format = ExportFileFormat.ZIP - kwargs.update({'writer_init': {'use_compression': True}}) - elif archive_format == 'zip-uncompressed': - export_format = ExportFileFormat.ZIP - kwargs.update({'writer_init': {'use_compression': False}}) - elif archive_format == 'zip-lowmemory': - export_format = ExportFileFormat.ZIP - kwargs.update({'writer_init': {'cache_zipinfo': True}}) - elif archive_format == 'tar.gz': - export_format = ExportFileFormat.TAR_GZIPPED - elif archive_format == 'null': - export_format = 'null' - - if verbosity in ['DEBUG', 'INFO']: - set_progress_bar_tqdm(leave=(verbosity == 'DEBUG')) - else: - set_progress_reporter(None) - EXPORT_LOGGER.setLevel(verbosity) - - try: - with override_log_formatter_context('%(message)s'): - export(entities, filename=output_file, file_format=export_format, **kwargs) - except ArchiveExportError as exception: - echo.echo_critical(f'failed to write the archive file. Exception: {exception}') - else: - echo.echo_success(f'wrote the export archive file to {output_file}') + ctx.forward(cmd_archive.create) @verdi_export.command('migrate') +@decorators.deprecated_command("This command has been deprecated. Please use 'verdi archive migrate' instead.") @arguments.INPUT_FILE() @arguments.OUTPUT_FILE(required=False) @options.ARCHIVE_FORMAT() @options.FORCE(help='overwrite output file if it already exists') @click.option('-i', '--in-place', is_flag=True, help='Migrate the archive in place, overwriting the original file.') -@options.SILENT() +@options.SILENT(hidden=True) @click.option( '-v', '--version', @@ -210,53 +119,12 @@ def create( type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'CRITICAL']), help='Control the verbosity of console logging' ) -def migrate(input_file, output_file, force, silent, in_place, archive_format, version, verbosity): +@click.pass_context +def migrate(ctx, input_file, output_file, force, silent, in_place, archive_format, version, verbosity): """Migrate an export archive to a more recent format version. .. deprecated:: 1.5.0 Support for the --silent flag, replaced by --verbosity """ - from aiida.common.log import override_log_formatter_context - from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter - from aiida.tools.importexport import detect_archive_type, EXPORT_VERSION - from aiida.tools.importexport.archive.migrators import get_migrator, MIGRATE_LOGGER - - if silent is True: - echo.echo_deprecated('the --silent option is deprecated, use --verbosity') - - if in_place: - if output_file: - echo.echo_critical('output file specified together with --in-place flag') - output_file = input_file - force = True - elif not output_file: - echo.echo_critical( - 'no output file specified. Please add --in-place flag if you would like to migrate in place.' - ) - - if verbosity in ['DEBUG', 'INFO']: - set_progress_bar_tqdm(leave=(verbosity == 'DEBUG')) - else: - set_progress_reporter(None) - MIGRATE_LOGGER.setLevel(verbosity) - - if version is None: - version = EXPORT_VERSION - - migrator_cls = get_migrator(detect_archive_type(input_file)) - migrator = migrator_cls(input_file) - - try: - with override_log_formatter_context('%(message)s'): - migrator.migrate(version, output_file, force=force, out_compression=archive_format) - except Exception as error: # pylint: disable=broad-except - if verbosity == 'DEBUG': - raise - echo.echo_critical( - 'failed to migrate the archive file (use `--verbosity DEBUG` to see traceback): ' - f'{error.__class__.__name__}:{error}' - ) - - if verbosity in ['DEBUG', 'INFO']: - echo.echo_success(f'migrated the archive to version {version}') + ctx.forward(cmd_archive.migrate) diff --git a/aiida/cmdline/commands/cmd_group.py b/aiida/cmdline/commands/cmd_group.py index ee773b7daf..20b8303f0e 100644 --- a/aiida/cmdline/commands/cmd_group.py +++ b/aiida/cmdline/commands/cmd_group.py @@ -9,6 +9,7 @@ ########################################################################### """`verdi group` commands""" import warnings +import logging import click from aiida.common.exceptions import UniquenessError @@ -17,6 +18,7 @@ from aiida.cmdline.params import options, arguments from aiida.cmdline.utils import echo from aiida.cmdline.utils.decorators import with_dbenv +from aiida.common.links import GraphTraversalRules @verdi.group('group') @@ -30,7 +32,7 @@ def verdi_group(): @arguments.NODES() @with_dbenv() def group_add_nodes(group, force, nodes): - """Add nodes to the a group.""" + """Add nodes to a group.""" if not force: click.confirm(f'Do you really want to add {len(nodes)} nodes to Group<{group.label}>?', abort=True) @@ -45,12 +47,41 @@ def group_add_nodes(group, force, nodes): @with_dbenv() def group_remove_nodes(group, nodes, clear, force): """Remove nodes from a group.""" - if clear: - message = f'Do you really want to remove ALL the nodes from Group<{group.label}>?' - else: - message = f'Do you really want to remove {len(nodes)} nodes from Group<{group.label}>?' + from aiida.orm import QueryBuilder, Group, Node + + label = group.label + klass = group.__class__.__name__ + + if nodes and clear: + echo.echo_critical( + 'Specify either the `--clear` flag to remove all nodes or the identifiers of the nodes you want to remove.' + ) if not force: + + if nodes: + node_pks = [node.pk for node in nodes] + + query = QueryBuilder() + query.append(Group, filters={'id': group.pk}, tag='group') + query.append(Node, with_group='group', filters={'id': {'in': node_pks}}, project='id') + + group_node_pks = query.all(flat=True) + + if not group_node_pks: + echo.echo_critical(f'None of the specified nodes are in {klass}<{label}>.') + + if len(node_pks) > len(group_node_pks): + node_pks = set(node_pks).difference(set(group_node_pks)) + echo.echo_warning(f'{len(node_pks)} nodes with PK {node_pks} are not in {klass}<{label}>.') + + message = f'Are you sure you want to remove {len(group_node_pks)} nodes from {klass}<{label}>?' + + elif clear: + message = f'Are you sure you want to remove ALL the nodes from {klass}<{label}>?' + else: + echo.echo_critical(f'No nodes were provided for removal from {klass}<{label}>.') + click.confirm(message, abort=True) if clear: @@ -61,29 +92,55 @@ def group_remove_nodes(group, nodes, clear, force): @verdi_group.command('delete') @arguments.GROUP() +@options.FORCE() +@click.option( + '--delete-nodes', is_flag=True, default=False, help='Delete all nodes in the group along with the group itself.' +) +@options.graph_traversal_rules(GraphTraversalRules.DELETE.value) +@options.DRY_RUN() +@options.VERBOSE() @options.GROUP_CLEAR( help='Remove all nodes before deleting the group itself.' + ' [deprecated: No longer has any effect. Will be removed in 2.0.0]' ) -@options.FORCE() @with_dbenv() -def group_delete(group, clear, force): - """Delete a group. - - Note that this command only deletes groups - nodes contained in the group will remain untouched. - """ +def group_delete(group, clear, delete_nodes, dry_run, force, verbose, **traversal_rules): + """Delete a group and (optionally) the nodes it contains.""" + from aiida.common.log import override_log_formatter_context + from aiida.tools import delete_group_nodes, DELETE_LOGGER from aiida import orm - label = group.label - if clear: warnings.warn('`--clear` is deprecated and no longer has any effect.', AiidaDeprecationWarning) # pylint: disable=no-member - if not force: - click.confirm(f'Are you sure to delete Group<{label}>?', abort=True) + label = group.label + klass = group.__class__.__name__ + + verbosity = logging.DEBUG if verbose else logging.INFO + DELETE_LOGGER.setLevel(verbosity) + + if not (force or dry_run): + click.confirm(f'Are you sure to delete {klass}<{label}>?', abort=True) + elif dry_run: + echo.echo_info(f'Would have deleted {klass}<{label}>.') + + if delete_nodes: + + def _dry_run_callback(pks): + if not pks or force: + return False + echo.echo_warning(f'YOU ARE ABOUT TO DELETE {len(pks)} NODES! THIS CANNOT BE UNDONE!') + return not click.confirm('Shall I continue?', abort=True) + + with override_log_formatter_context('%(message)s'): + _, nodes_deleted = delete_group_nodes([group.pk], dry_run=dry_run or _dry_run_callback, **traversal_rules) + if not nodes_deleted: + # don't delete the group if the nodes were not deleted + return - orm.Group.objects.delete(group.pk) - echo.echo_success(f'Group<{label}> deleted.') + if not dry_run: + orm.Group.objects.delete(group.pk) + echo.echo_success(f'{klass}<{label}> deleted.') @verdi_group.command('relabel') diff --git a/aiida/cmdline/commands/cmd_import.py b/aiida/cmdline/commands/cmd_import.py index 251006375e..1dad604063 100644 --- a/aiida/cmdline/commands/cmd_import.py +++ b/aiida/cmdline/commands/cmd_import.py @@ -8,34 +8,19 @@ # For further information please visit http://www.aiida.net # ########################################################################### """`verdi import` command.""" -# pylint: disable=broad-except -from enum import Enum -from typing import List, Tuple -import traceback -import urllib.request - +# pylint: disable=broad-except,unused-argument import click from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.params import options from aiida.cmdline.params.types import GroupParamType, PathOrUrl -from aiida.cmdline.utils import decorators, echo - -EXTRAS_MODE_EXISTING = ['keep_existing', 'update_existing', 'mirror', 'none', 'ask'] -EXTRAS_MODE_NEW = ['import', 'none'] -COMMENT_MODE = ['newest', 'overwrite'] +from aiida.cmdline.utils import decorators +from aiida.cmdline.commands.cmd_archive import import_archive, EXTRAS_MODE_EXISTING, EXTRAS_MODE_NEW, COMMENT_MODE -class ExtrasImportCode(Enum): - """Exit codes for the verdi command line.""" - keep_existing = 'kcl' - update_existing = 'kcu' - mirror = 'ncu' - none = 'knl' - ask = 'kca' - -@verdi.command('import') +@verdi.command('import', hidden=True) +@decorators.deprecated_command("This command has been deprecated. Please use 'verdi archive import' instead.") @click.argument('archives', nargs=-1, type=PathOrUrl(exists=True, readable=True)) @click.option( '-w', @@ -100,147 +85,5 @@ def cmd_import( ctx, archives, webpages, group, extras_mode_existing, extras_mode_new, comment_mode, migration, non_interactive, verbosity ): - """Import data from an AiiDA archive file. - - The archive can be specified by its relative or absolute file path, or its HTTP URL. - """ - # pylint: disable=unused-argument - from aiida.common.log import override_log_formatter_context - from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter - from aiida.tools.importexport.dbimport.utils import IMPORT_LOGGER - from aiida.tools.importexport.archive.migrators import MIGRATE_LOGGER - - if verbosity in ['DEBUG', 'INFO']: - set_progress_bar_tqdm(leave=(verbosity == 'DEBUG')) - else: - set_progress_reporter(None) - IMPORT_LOGGER.setLevel(verbosity) - MIGRATE_LOGGER.setLevel(verbosity) - - all_archives = _gather_imports(archives, webpages) - - # Preliminary sanity check - if not all_archives: - echo.echo_critical('no valid exported archives were found') - - # Shared import key-word arguments - import_kwargs = { - 'group': group, - 'extras_mode_existing': ExtrasImportCode[extras_mode_existing].value, - 'extras_mode_new': extras_mode_new, - 'comment_mode': comment_mode, - } - - with override_log_formatter_context('%(message)s'): - for archive, web_based in all_archives: - _import_archive(archive, web_based, import_kwargs, migration) - - -def _echo_exception(msg: str, exception, warn_only: bool = False): - """Correctly report and exception. - - :param msg: The message prefix - :param exception: the exception raised - :param warn_only: If True only print a warning, otherwise calls sys.exit with a non-zero exit status - - """ - from aiida.tools.importexport import IMPORT_LOGGER - message = f'{msg}: {exception.__class__.__name__}: {str(exception)}' - if warn_only: - echo.echo_warning(message) - else: - IMPORT_LOGGER.debug('%s', traceback.format_exc()) - echo.echo_critical(message) - - -def _gather_imports(archives, webpages) -> List[Tuple[str, bool]]: - """Gather archives to import and sort into local files and URLs. - - :returns: list of (archive path, whether it is web based) - - """ - from aiida.tools.importexport.common.utils import get_valid_import_links - - final_archives = [] - - # Build list of archives to be imported - for archive in archives: - if archive.startswith('http://') or archive.startswith('https://'): - final_archives.append((archive, True)) - else: - final_archives.append((archive, False)) - - # Discover and retrieve *.aiida files at URL(s) - if webpages is not None: - for webpage in webpages: - try: - echo.echo_info(f'retrieving archive URLS from {webpage}') - urls = get_valid_import_links(webpage) - except Exception as error: - echo.echo_critical( - f'an exception occurred while trying to discover archives at URL {webpage}:\n{error}' - ) - else: - echo.echo_success(f'{len(urls)} archive URLs discovered and added') - final_archives.extend([(u, True) for u in urls]) - - return final_archives - - -def _import_archive(archive: str, web_based: bool, import_kwargs: dict, try_migration: bool): - """Perform the archive import. - - :param archive: the path or URL to the archive - :param web_based: If the archive needs to be downloaded first - :param import_kwargs: keyword arguments to pass to the import function - :param try_migration: whether to try a migration if the import raises IncompatibleArchiveVersionError - - """ - from aiida.common.folders import SandboxFolder - from aiida.tools.importexport import ( - detect_archive_type, EXPORT_VERSION, import_data, IncompatibleArchiveVersionError - ) - from aiida.tools.importexport.archive.migrators import get_migrator - - with SandboxFolder() as temp_folder: - - archive_path = archive - - if web_based: - echo.echo_info(f'downloading archive: {archive}') - try: - response = urllib.request.urlopen(archive) - except Exception as exception: - _echo_exception(f'downloading archive {archive} failed', exception) - temp_folder.create_file_from_filelike(response, 'downloaded_archive.zip') - archive_path = temp_folder.get_abs_path('downloaded_archive.zip') - echo.echo_success('archive downloaded, proceeding with import') - - echo.echo_info(f'starting import: {archive}') - try: - import_data(archive_path, **import_kwargs) - except IncompatibleArchiveVersionError as exception: - if try_migration: - - echo.echo_info(f'incompatible version detected for {archive}, trying migration') - try: - migrator = get_migrator(detect_archive_type(archive_path))(archive_path) - archive_path = migrator.migrate( - EXPORT_VERSION, None, out_compression='none', work_dir=temp_folder.abspath - ) - except Exception as exception: - _echo_exception(f'an exception occurred while migrating the archive {archive}', exception) - - echo.echo_info('proceeding with import of migrated archive') - try: - import_data(archive_path, **import_kwargs) - except Exception as exception: - _echo_exception( - f'an exception occurred while trying to import the migrated archive {archive}', exception - ) - else: - _echo_exception(f'an exception occurred while trying to import the archive {archive}', exception) - except Exception as exception: - _echo_exception(f'an exception occurred while trying to import the archive {archive}', exception) - - echo.echo_success(f'imported archive {archive}') + """Deprecated, use `verdi archive import`.""" + ctx.forward(import_archive) diff --git a/aiida/cmdline/commands/cmd_node.py b/aiida/cmdline/commands/cmd_node.py index 3f73c106ff..1be70e7ee4 100644 --- a/aiida/cmdline/commands/cmd_node.py +++ b/aiida/cmdline/commands/cmd_node.py @@ -9,6 +9,7 @@ ########################################################################### """`verdi node` command.""" +import logging import shutil import pathlib @@ -296,30 +297,46 @@ def tree(nodes, depth): @verdi_node.command('delete') -@arguments.NODES('nodes', required=True) +@click.argument('identifier', nargs=-1, metavar='NODES') @options.VERBOSE() @options.DRY_RUN() @options.FORCE() @options.graph_traversal_rules(GraphTraversalRules.DELETE.value) @with_dbenv() -def node_delete(nodes, dry_run, verbose, force, **kwargs): +def node_delete(identifier, dry_run, verbose, force, **traversal_rules): """Delete nodes from the provenance graph. This will not only delete the nodes explicitly provided via the command line, but will also include the nodes necessary to keep a consistent graph, according to the rules outlined in the documentation. You can modify some of those rules using options of this command. """ - from aiida.manage.database.delete.nodes import delete_nodes + from aiida.common.log import override_log_formatter_context + from aiida.orm.utils.loaders import NodeEntityLoader + from aiida.tools import delete_nodes, DELETE_LOGGER - verbosity = 1 - if force: - verbosity = 0 - elif verbose: - verbosity = 2 + verbosity = logging.DEBUG if verbose else logging.INFO + DELETE_LOGGER.setLevel(verbosity) - node_pks_to_delete = [node.pk for node in nodes] + pks = [] - delete_nodes(node_pks_to_delete, dry_run=dry_run, verbosity=verbosity, force=force, **kwargs) + for obj in identifier: + # we only load the node if we need to convert from a uuid/label + try: + pks.append(int(obj)) + except ValueError: + pks.append(NodeEntityLoader.load_entity(obj).pk) + + def _dry_run_callback(pks): + if not pks or force: + return False + echo.echo_warning(f'YOU ARE ABOUT TO DELETE {len(pks)} NODES! THIS CANNOT BE UNDONE!') + return not click.confirm('Shall I continue?', abort=True) + + with override_log_formatter_context('%(message)s'): + _, was_deleted = delete_nodes(pks, dry_run=dry_run or _dry_run_callback, **traversal_rules) + + if was_deleted: + echo.echo_success('Finished deletion.') @verdi_node.command('rehash') diff --git a/aiida/cmdline/commands/cmd_process.py b/aiida/cmdline/commands/cmd_process.py index 04513ad94a..ba70c81b37 100644 --- a/aiida/cmdline/commands/cmd_process.py +++ b/aiida/cmdline/commands/cmd_process.py @@ -247,7 +247,8 @@ def process_play(processes, all_entries, timeout, wait): raise click.BadOptionUsage('all', 'cannot specify individual processes and the `--all` flag at the same time.') if not processes and all_entries: - builder = QueryBuilder().append(ProcessNode, filters={'attributes.paused': True}) + filters = CalculationQueryBuilder().get_filters(process_state=('created', 'waiting', 'running'), paused=True) + builder = QueryBuilder().append(ProcessNode, filters=filters) processes = builder.all(flat=True) futures = {} @@ -305,7 +306,7 @@ def _print(communicator, body, sender, subject, correlation_id): # pylint: disa echo.echo('') # add a new line after the interrupt character echo.echo_info('received interrupt, exiting...') try: - communicator.stop() + communicator.close() except RuntimeError: pass @@ -336,6 +337,7 @@ def process_actions(futures_map, infinitive, present, past, wait=False, timeout= """ # pylint: disable=too-many-branches import kiwipy + from plumpy.futures import unwrap_kiwi_future from concurrent import futures from aiida.manage.external.rmq import CommunicationTimeout @@ -347,6 +349,8 @@ def process_actions(futures_map, infinitive, present, past, wait=False, timeout= process = futures_map[future] try: + # unwrap is need here since LoopCommunicator will also wrap a future + future = unwrap_kiwi_future(future) result = future.result() except CommunicationTimeout: echo.echo_error(f'call to {infinitive} Process<{process.pk}> timed out') diff --git a/aiida/cmdline/commands/cmd_restapi.py b/aiida/cmdline/commands/cmd_restapi.py index 0ff7546b30..1f7ad1413e 100644 --- a/aiida/cmdline/commands/cmd_restapi.py +++ b/aiida/cmdline/commands/cmd_restapi.py @@ -38,7 +38,15 @@ help='Whether to enable WSGI profiler middleware for finding bottlenecks' ) @click.option('--hookup/--no-hookup', 'hookup', is_flag=True, default=None, help='Hookup app to flask server') -def restapi(hostname, port, config_dir, debug, wsgi_profile, hookup): +@click.option( + '--posting/--no-posting', + 'posting', + is_flag=True, + default=config.CLI_DEFAULTS['POSTING'], + help='Enable POST endpoints (currently only /querybuilder).', + hidden=True, +) +def restapi(hostname, port, config_dir, debug, wsgi_profile, hookup, posting): """ Run the AiiDA REST API server. @@ -55,4 +63,5 @@ def restapi(hostname, port, config_dir, debug, wsgi_profile, hookup): debug=debug, wsgi_profile=wsgi_profile, hookup=hookup, + posting=posting, ) diff --git a/aiida/cmdline/commands/cmd_setup.py b/aiida/cmdline/commands/cmd_setup.py index 28b34effdd..241e048bb8 100644 --- a/aiida/cmdline/commands/cmd_setup.py +++ b/aiida/cmdline/commands/cmd_setup.py @@ -90,10 +90,10 @@ def setup( echo.echo_success('database migration completed.') # Optionally setting configuration default user settings - config.set_option('user.email', email, override=False) - config.set_option('user.first_name', first_name, override=False) - config.set_option('user.last_name', last_name, override=False) - config.set_option('user.institution', institution, override=False) + config.set_option('autofill.user.email', email, override=False) + config.set_option('autofill.user.first_name', first_name, override=False) + config.set_option('autofill.user.last_name', last_name, override=False) + config.set_option('autofill.user.institution', institution, override=False) # Create the user if it does not yet exist created, user = orm.User.objects.get_or_create( diff --git a/aiida/cmdline/commands/cmd_status.py b/aiida/cmdline/commands/cmd_status.py index a5599ee630..c12344de35 100644 --- a/aiida/cmdline/commands/cmd_status.py +++ b/aiida/cmdline/commands/cmd_status.py @@ -113,7 +113,7 @@ def verdi_status(print_traceback, no_rmq): with Capturing(capture_stderr=True): with override_log_level(): # temporarily suppress noisy logging comm = manager.create_communicator(with_orm=False) - comm.stop() + comm.close() except Exception as exc: message = f'Unable to connect to rabbitmq with URL: {profile.get_rmq_url()}' print_status(ServiceStatus.ERROR, 'rabbitmq', message, exception=exc, print_traceback=print_traceback) diff --git a/aiida/cmdline/commands/cmd_verdi.py b/aiida/cmdline/commands/cmd_verdi.py index e9cebf5229..6a395b1185 100644 --- a/aiida/cmdline/commands/cmd_verdi.py +++ b/aiida/cmdline/commands/cmd_verdi.py @@ -12,6 +12,7 @@ import difflib import click +from aiida import __version__ from aiida.cmdline.params import options, types GIU = ( @@ -84,7 +85,9 @@ def get_command(self, ctx, cmd_name): @click.command(cls=MostSimilarCommandGroup, context_settings={'help_option_names': ['-h', '--help']}) @options.PROFILE(type=types.ProfileParamType(load_profile=True)) -@click.version_option(None, '-v', '--version', message='AiiDA version %(version)s') +# Note, __version__ should always be passed explicitly here, +# because click does not retrieve a dynamic version when installed in editable mode +@click.version_option(__version__, '-v', '--version', message='AiiDA version %(version)s') @click.pass_context def verdi(ctx, profile): """The command line interface of AiiDA.""" diff --git a/aiida/cmdline/params/options/commands/setup.py b/aiida/cmdline/params/options/commands/setup.py index b600c73e83..b5cb9f974d 100644 --- a/aiida/cmdline/params/options/commands/setup.py +++ b/aiida/cmdline/params/options/commands/setup.py @@ -159,32 +159,32 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume SETUP_USER_EMAIL = options.USER_EMAIL.clone( prompt='Email Address (for sharing data)', - default=get_config_option('user.email'), - required_fn=lambda x: get_config_option('user.email') is None, + default=get_config_option('autofill.user.email'), + required_fn=lambda x: get_config_option('autofill.user.email') is None, required=True, cls=options.interactive.InteractiveOption ) SETUP_USER_FIRST_NAME = options.USER_FIRST_NAME.clone( prompt='First name', - default=get_config_option('user.first_name'), - required_fn=lambda x: get_config_option('user.first_name') is None, + default=get_config_option('autofill.user.first_name'), + required_fn=lambda x: get_config_option('autofill.user.first_name') is None, required=True, cls=options.interactive.InteractiveOption ) SETUP_USER_LAST_NAME = options.USER_LAST_NAME.clone( prompt='Last name', - default=get_config_option('user.last_name'), - required_fn=lambda x: get_config_option('user.last_name') is None, + default=get_config_option('autofill.user.last_name'), + required_fn=lambda x: get_config_option('autofill.user.last_name') is None, required=True, cls=options.interactive.InteractiveOption ) SETUP_USER_INSTITUTION = options.USER_INSTITUTION.clone( prompt='Institution', - default=get_config_option('user.institution'), - required_fn=lambda x: get_config_option('user.institution') is None, + default=get_config_option('autofill.user.institution'), + required_fn=lambda x: get_config_option('autofill.user.institution') is None, required=True, cls=options.interactive.InteractiveOption ) diff --git a/aiida/cmdline/utils/common.py b/aiida/cmdline/utils/common.py index 89287e047b..df37ffd49d 100644 --- a/aiida/cmdline/utils/common.py +++ b/aiida/cmdline/utils/common.py @@ -197,7 +197,7 @@ def format_nested_links(links, headers): :param headers: headers to use :return: nested formatted string """ - import collections + from collections.abc import Mapping import tabulate as tb tb.PRESERVE_WHITESPACE = True @@ -208,7 +208,7 @@ def format_recursive(links, depth=0): """Recursively format a dictionary of nodes into indented strings.""" rows = [] for label, value in links.items(): - if isinstance(value, collections.Mapping): + if isinstance(value, Mapping): rows.append([depth, label, '', '']) rows.extend(format_recursive(value, depth=depth + 1)) else: diff --git a/aiida/cmdline/utils/daemon.py b/aiida/cmdline/utils/daemon.py index 552501ee39..afd7bed95a 100644 --- a/aiida/cmdline/utils/daemon.py +++ b/aiida/cmdline/utils/daemon.py @@ -21,23 +21,29 @@ def print_client_response_status(response): Print the response status of a call to the CircusClient through the DaemonClient :param response: the response object + :return: an integer error code; non-zero means there was an error (FAILED, TIMEOUT), zero means OK (OK, RUNNING) """ from aiida.engine.daemon.client import DaemonClient if 'status' not in response: - return + return 1 if response['status'] == 'active': click.secho('RUNNING', fg='green', bold=True) - elif response['status'] == 'ok': + return 0 + if response['status'] == 'ok': click.secho('OK', fg='green', bold=True) - elif response['status'] == DaemonClient.DAEMON_ERROR_NOT_RUNNING: + return 0 + if response['status'] == DaemonClient.DAEMON_ERROR_NOT_RUNNING: click.secho('FAILED', fg='red', bold=True) click.echo('Try to run \'verdi daemon start --foreground\' to potentially see the exception') - elif response['status'] == DaemonClient.DAEMON_ERROR_TIMEOUT: + return 2 + if response['status'] == DaemonClient.DAEMON_ERROR_TIMEOUT: click.secho('TIMEOUT', fg='red', bold=True) - else: - click.echo(response['status']) + return 3 + # Unknown status, I will consider it as failed + click.echo(response['status']) + return -1 def get_daemon_status(client): diff --git a/aiida/cmdline/utils/echo.py b/aiida/cmdline/utils/echo.py index 7a7c3210cf..248a01a2db 100644 --- a/aiida/cmdline/utils/echo.py +++ b/aiida/cmdline/utils/echo.py @@ -187,7 +187,7 @@ def echo_formatted_list(collection, attributes, sort=None, highlight=None, hide= click.secho(template.format(symbol=' ', *values)) -def _format_dictionary_json_date(dictionary): +def _format_dictionary_json_date(dictionary, sort_keys=True): """Return a dictionary formatted as a string using the json format and converting dates to strings.""" from aiida.common import json @@ -201,19 +201,31 @@ def default_jsondump(data): raise TypeError(f'{repr(data)} is not JSON serializable') - return json.dumps(dictionary, indent=4, sort_keys=True, default=default_jsondump) + return json.dumps(dictionary, indent=4, sort_keys=sort_keys, default=default_jsondump) -VALID_DICT_FORMATS_MAPPING = OrderedDict((('json+date', _format_dictionary_json_date), ('yaml', yaml.dump), - ('yaml_expanded', lambda d: yaml.dump(d, default_flow_style=False)))) +def _format_yaml(dictionary, sort_keys=True): + """Return a dictionary formatted as a string using the YAML format.""" + return yaml.dump(dictionary, sort_keys=sort_keys) -def echo_dictionary(dictionary, fmt='json+date'): +def _format_yaml_expanded(dictionary, sort_keys=True): + """Return a dictionary formatted as a string using the expanded YAML format.""" + return yaml.dump(dictionary, sort_keys=sort_keys, default_flow_style=False) + + +VALID_DICT_FORMATS_MAPPING = OrderedDict( + (('json+date', _format_dictionary_json_date), ('yaml', _format_yaml), ('yaml_expanded', _format_yaml_expanded)) +) + + +def echo_dictionary(dictionary, fmt='json+date', sort_keys=True): """ Print the given dictionary to stdout in the given format :param dictionary: the dictionary :param fmt: the format to use for printing + :param sort_keys: Whether to automatically sort keys """ try: format_function = VALID_DICT_FORMATS_MAPPING[fmt] @@ -221,7 +233,7 @@ def echo_dictionary(dictionary, fmt='json+date'): formats = ', '.join(VALID_DICT_FORMATS_MAPPING.keys()) raise ValueError(f'Unrecognised printing format. Valid formats are: {formats}') - echo(format_function(dictionary)) + echo(format_function(dictionary, sort_keys=sort_keys)) def is_stdout_redirected(): diff --git a/aiida/cmdline/utils/query/calculation.py b/aiida/cmdline/utils/query/calculation.py index b2026baebf..d52ace1a34 100644 --- a/aiida/cmdline/utils/query/calculation.py +++ b/aiida/cmdline/utils/query/calculation.py @@ -21,7 +21,8 @@ class CalculationQueryBuilder: _default_projections = ('pk', 'ctime', 'process_label', 'state', 'process_status') _valid_projections = ( 'pk', 'uuid', 'ctime', 'mtime', 'state', 'process_state', 'process_status', 'exit_status', 'sealed', - 'process_label', 'label', 'description', 'node_type', 'paused', 'process_type', 'job_state', 'scheduler_state' + 'process_label', 'label', 'description', 'node_type', 'paused', 'process_type', 'job_state', 'scheduler_state', + 'exception' ) def __init__(self, mapper=None): diff --git a/aiida/cmdline/utils/query/mapping.py b/aiida/cmdline/utils/query/mapping.py index 17d46a8b43..56a62f1e20 100644 --- a/aiida/cmdline/utils/query/mapping.py +++ b/aiida/cmdline/utils/query/mapping.py @@ -91,6 +91,7 @@ def __init__(self, projections, projection_labels=None, projection_attributes=No process_state_key = f'attributes.{ProcessNode.PROCESS_STATE_KEY}' process_status_key = f'attributes.{ProcessNode.PROCESS_STATUS_KEY}' exit_status_key = f'attributes.{ProcessNode.EXIT_STATUS_KEY}' + exception_key = f'attributes.{ProcessNode.EXCEPTION_KEY}' default_labels = {'pk': 'PK', 'uuid': 'UUID', 'ctime': 'Created', 'mtime': 'Modified', 'state': 'Process State'} @@ -104,6 +105,7 @@ def __init__(self, projections, projection_labels=None, projection_attributes=No 'process_state': process_state_key, 'process_status': process_status_key, 'exit_status': exit_status_key, + 'exception': exception_key, } default_formatters = { diff --git a/aiida/common/datastructures.py b/aiida/common/datastructures.py index e10a7cca22..271cdaec48 100644 --- a/aiida/common/datastructures.py +++ b/aiida/common/datastructures.py @@ -12,7 +12,13 @@ from .extendeddicts import DefaultFieldsAttributeDict -__all__ = ('CalcJobState', 'CalcInfo', 'CodeInfo', 'CodeRunMode') +__all__ = ('StashMode', 'CalcJobState', 'CalcInfo', 'CodeInfo', 'CodeRunMode') + + +class StashMode(Enum): + """Mode to use when stashing files from the working directory of a completed calculation job for safekeeping.""" + + COPY = 'copy' class CalcJobState(Enum): @@ -21,6 +27,7 @@ class CalcJobState(Enum): UPLOADING = 'uploading' SUBMITTING = 'submitting' WITHSCHEDULER = 'withscheduler' + STASHING = 'stashing' RETRIEVING = 'retrieving' PARSING = 'parsing' @@ -32,25 +39,32 @@ class CalcInfo(DefaultFieldsAttributeDict): In the following descriptions all paths have to be considered relative - * retrieve_list: a list of strings or tuples that indicate files that are to be retrieved from the remote - after the calculation has finished and stored in the repository in a FolderData. - If the entry in the list is just a string, it is assumed to be the filepath on the remote and it will - be copied to '.' of the repository with name os.path.split(item)[1] - If the entry is a tuple it is expected to have the following format + * retrieve_list: a list of strings or tuples that indicate files that are to be retrieved from the remote after the + calculation has finished and stored in the ``retrieved_folder`` output node of type ``FolderData``. If the entry + in the list is just a string, it is assumed to be the filepath on the remote and it will be copied to the base + directory of the retrieved folder, where the name corresponds to the basename of the remote relative path. This + means that any remote folder hierarchy is ignored entirely. + + Remote folder hierarchy can be (partially) maintained by using a tuple instead, with the following format + + (source, target, depth) + + The ``source`` and ``target`` elements are relative filepaths in the remote and retrieved folder. The contents + of ``source`` (whether it is a file or folder) are copied in its entirety to the ``target`` subdirectory in the + retrieved folder. If no subdirectory should be created, ``'.'`` should be specified for ``target``. - ('remotepath', 'localpath', depth) + The ``source`` filepaths support glob patterns ``*`` in case the exact name of the files that are to be + retrieved are not know a priori. - If the 'remotepath' is a file or folder, it will be copied in the repository to 'localpath'. - However, if the 'remotepath' contains file patterns with wildcards, the 'localpath' should be set to '.' - and the depth parameter should be an integer that decides the localname. The 'remotepath' will be split on - file separators and the local filename will be determined by joining the N last elements, where N is - given by the depth variable. + The ``depth`` element can be used to control what level of nesting of the source folder hierarchy should be + maintained. If ``depth`` equals ``0`` or ``1`` (they are equivalent), only the basename of the ``source`` + filepath is kept. For each additional level, another subdirectory of the remote hierarchy is kept. For example: - Example: ('some/remote/path/files/pattern*[0-9].xml', '.', 2) + ('path/sub/file.txt', '.', 2) - Will result in all files that match the pattern to be copied to the local repository with path + will retrieve the ``file.txt`` and store it under the path: - 'files/pattern*[0-9].xml' + sub/file.txt * retrieve_temporary_list: a list of strings or tuples that indicate files that will be retrieved and stored temporarily in a FolderData, that will be available only during the parsing call. @@ -74,6 +88,8 @@ class CalcInfo(DefaultFieldsAttributeDict): already indirectly present in the repository through one of the data nodes passed as input to the calculation. * codes_info: a list of dictionaries used to pass the info of the execution of a code * codes_run_mode: a string used to specify the order in which multi codes can be executed + * skip_submit: a flag that, when set to True, orders the engine to skip the submit/update steps (so no code will + run, it will only upload the files and then retrieve/parse). """ _default_fields = ( @@ -98,7 +114,8 @@ class CalcInfo(DefaultFieldsAttributeDict): 'remote_symlink_list', 'provenance_exclude_list', 'codes_info', - 'codes_run_mode' + 'codes_run_mode', + 'skip_submit' ) diff --git a/aiida/common/exceptions.py b/aiida/common/exceptions.py index 72ed077628..72909d73e8 100644 --- a/aiida/common/exceptions.py +++ b/aiida/common/exceptions.py @@ -17,7 +17,7 @@ 'PluginInternalError', 'ValidationError', 'ConfigurationError', 'ProfileConfigurationError', 'MissingConfigurationError', 'ConfigurationVersionError', 'IncompatibleDatabaseSchema', 'DbContentError', 'InputValidationError', 'FeatureNotAvailable', 'FeatureDisabled', 'LicensingException', 'TestsNotAllowedError', - 'UnsupportedSpeciesError', 'TransportTaskException', 'OutputParsingError' + 'UnsupportedSpeciesError', 'TransportTaskException', 'OutputParsingError', 'HashingError' ) @@ -250,3 +250,9 @@ class CircusCallError(AiidaException): """ Raised when an attempt to contact Circus returns an error in the response """ + + +class HashingError(AiidaException): + """ + Raised when an attempt to hash an object fails via a known failure mode + """ diff --git a/aiida/common/hashing.py b/aiida/common/hashing.py index 02d7c6e95d..1c688ef740 100644 --- a/aiida/common/hashing.py +++ b/aiida/common/hashing.py @@ -22,6 +22,8 @@ import pytz from aiida.common.constants import AIIDA_FLOAT_PRECISION +from aiida.common.exceptions import HashingError + from .folders import Folder # The prefix of the hashed using pbkdf2_sha256 algorithm in Django @@ -101,7 +103,6 @@ def make_hash(object_to_hash, **kwargs): hashing iteratively. Uses python's sorted function to sort unsorted sets and dictionaries by sorting the hashed keys. """ - hashes = _make_hash(object_to_hash, **kwargs) # pylint: disable=assignment-from-no-return # use the Unlimited fanout hashing protocol outlined in @@ -123,7 +124,7 @@ def _make_hash(object_to_hash, **_): Implementation of the ``make_hash`` function. The hash is created as a 28 byte integer, and only later converted to a string. """ - raise ValueError(f'Value of type {type(object_to_hash)} cannot be hashed') + raise HashingError(f'Value of type {type(object_to_hash)} cannot be hashed') def _single_digest(obj_type, obj_bytes=b''): @@ -288,5 +289,7 @@ def float_to_text(value, sig): :param value: the float value to convert :param sig: choose how many digits after the comma should be output """ + if value == 0: + value = 0. # Identify value of -0. and overwrite with 0. fmt = f'{{:.{sig}g}}' return fmt.format(value) diff --git a/aiida/common/log.py b/aiida/common/log.py index 36310b36df..d916071511 100644 --- a/aiida/common/log.py +++ b/aiida/common/log.py @@ -83,11 +83,6 @@ def filter(self, record): 'level': lambda: get_config_option('logging.aiida_loglevel'), 'propagate': False, }, - 'tornado': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.tornado_loglevel'), - 'propagate': False, - }, 'plumpy': { 'handlers': ['console'], 'level': lambda: get_config_option('logging.plumpy_loglevel'), @@ -108,6 +103,11 @@ def filter(self, record): 'level': lambda: get_config_option('logging.alembic_loglevel'), 'propagate': False, }, + 'aio_pika': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.aiopika_loglevel'), + 'propagate': False, + }, 'sqlalchemy': { 'handlers': ['console'], 'level': lambda: get_config_option('logging.sqlalchemy_loglevel'), diff --git a/aiida/engine/__init__.py b/aiida/engine/__init__.py index 41e147e19e..984ff61866 100644 --- a/aiida/engine/__init__.py +++ b/aiida/engine/__init__.py @@ -14,4 +14,4 @@ from .processes import * from .utils import * -__all__ = (launch.__all__ + processes.__all__ + utils.__all__) +__all__ = (launch.__all__ + processes.__all__ + utils.__all__) # type: ignore[name-defined] diff --git a/aiida/engine/daemon/client.py b/aiida/engine/daemon/client.py index 32d96466bf..428e702d19 100644 --- a/aiida/engine/daemon/client.py +++ b/aiida/engine/daemon/client.py @@ -16,13 +16,21 @@ import shutil import socket import tempfile +from typing import Any, Dict, Optional, TYPE_CHECKING from aiida.manage.configuration import get_config, get_config_option +from aiida.manage.configuration.profile import Profile + +if TYPE_CHECKING: + from circus.client import CircusClient VERDI_BIN = shutil.which('verdi') # Recent versions of virtualenv create the environment variable VIRTUAL_ENV VIRTUALENV = os.environ.get('VIRTUAL_ENV', None) +# see https://github.com/python/typing/issues/182 +JsonDictType = Dict[str, Any] + class ControllerProtocol(enum.Enum): """ @@ -33,13 +41,13 @@ class ControllerProtocol(enum.Enum): TCP = 1 -def get_daemon_client(profile_name=None): +def get_daemon_client(profile_name: Optional[str] = None) -> 'DaemonClient': """ Return the daemon client for the given profile or the current profile if not specified. :param profile_name: the profile name, will use the current profile if None :return: the daemon client - :rtype: :class:`aiida.engine.daemon.client.DaemonClient` + :raises aiida.common.MissingConfigurationError: if the configuration file cannot be found :raises aiida.common.ProfileConfigurationError: if the given profile does not exist """ @@ -65,7 +73,7 @@ class DaemonClient: # pylint: disable=too-many-public-methods _DAEMON_NAME = 'aiida-{name}' _ENDPOINT_PROTOCOL = ControllerProtocol.IPC - def __init__(self, profile): + def __init__(self, profile: Profile): """ Construct a DaemonClient instance for a given profile @@ -73,22 +81,22 @@ def __init__(self, profile): """ config = get_config() self._profile = profile - self._SOCKET_DIRECTORY = None # pylint: disable=invalid-name - self._DAEMON_TIMEOUT = config.get_option('daemon.timeout') # pylint: disable=invalid-name + self._SOCKET_DIRECTORY: Optional[str] = None # pylint: disable=invalid-name + self._DAEMON_TIMEOUT: int = config.get_option('daemon.timeout') # pylint: disable=invalid-name @property - def profile(self): + def profile(self) -> Profile: return self._profile @property - def daemon_name(self): + def daemon_name(self) -> str: """ Get the daemon name which is tied to the profile name """ return self._DAEMON_NAME.format(name=self.profile.name) @property - def cmd_string(self): + def cmd_string(self) -> str: """ Return the command string to start the AiiDA daemon """ @@ -101,42 +109,42 @@ def cmd_string(self): return f'{VERDI_BIN} -p {self.profile.name} devel run_daemon' @property - def loglevel(self): + def loglevel(self) -> str: return get_config_option('logging.circus_loglevel') @property - def virtualenv(self): + def virtualenv(self) -> Optional[str]: return VIRTUALENV @property - def circus_log_file(self): + def circus_log_file(self) -> str: return self.profile.filepaths['circus']['log'] @property - def circus_pid_file(self): + def circus_pid_file(self) -> str: return self.profile.filepaths['circus']['pid'] @property - def circus_port_file(self): + def circus_port_file(self) -> str: return self.profile.filepaths['circus']['port'] @property - def circus_socket_file(self): + def circus_socket_file(self) -> str: return self.profile.filepaths['circus']['socket']['file'] @property - def circus_socket_endpoints(self): + def circus_socket_endpoints(self) -> Dict[str, str]: return self.profile.filepaths['circus']['socket'] @property - def daemon_log_file(self): + def daemon_log_file(self) -> str: return self.profile.filepaths['daemon']['log'] @property - def daemon_pid_file(self): + def daemon_pid_file(self) -> str: return self.profile.filepaths['daemon']['pid'] - def get_circus_port(self): + def get_circus_port(self) -> int: """ Retrieve the port for the circus controller, which should be written to the circus port file. If the daemon is running, the port file should exist and contain the port to which the controller is connected. @@ -158,7 +166,7 @@ def get_circus_port(self): return port - def get_circus_socket_directory(self): + def get_circus_socket_directory(self) -> str: """ Retrieve the absolute path of the directory where the circus sockets are stored if the IPC protocol is used and the daemon is running. If the daemon is running, the sockets file should exist and contain the @@ -176,7 +184,9 @@ def get_circus_socket_directory(self): """ if self.is_daemon_running: try: - return open(self.circus_socket_file, 'r', encoding='utf8').read().strip() + with open(self.circus_socket_file, 'r', encoding='utf8') as fhandle: + content = fhandle.read().strip() + return content except (ValueError, IOError): raise RuntimeError('daemon is running so sockets file should have been there but could not read it') else: @@ -192,7 +202,7 @@ def get_circus_socket_directory(self): self._SOCKET_DIRECTORY = socket_dir_path return socket_dir_path - def get_daemon_pid(self): + def get_daemon_pid(self) -> Optional[int]: """ Get the daemon pid which should be written in the daemon pid file specific to the profile @@ -200,14 +210,16 @@ def get_daemon_pid(self): """ if os.path.isfile(self.circus_pid_file): try: - return int(open(self.circus_pid_file, 'r', encoding='utf8').read().strip()) + with open(self.circus_pid_file, 'r', encoding='utf8') as fhandle: + content = fhandle.read().strip() + return int(content) except (ValueError, IOError): return None else: return None @property - def is_daemon_running(self): + def is_daemon_running(self) -> bool: """ Return whether the daemon is running, which is determined by seeing if the daemon pid file is present @@ -215,7 +227,7 @@ def is_daemon_running(self): """ return self.get_daemon_pid() is not None - def delete_circus_socket_directory(self): + def delete_circus_socket_directory(self) -> None: """ Attempt to delete the directory used to store the circus endpoint sockets. Will not raise if the directory does not exist @@ -321,7 +333,7 @@ def get_tcp_endpoint(self, port=None): return endpoint @property - def client(self): + def client(self) -> 'CircusClient': """ Return an instance of the CircusClient with the endpoint defined by the controller endpoint, which used the port that was written to the port file upon starting of the daemon @@ -334,7 +346,7 @@ def client(self): from circus.client import CircusClient return CircusClient(endpoint=self.get_controller_endpoint(), timeout=self._DAEMON_TIMEOUT) - def call_client(self, command): + def call_client(self, command: JsonDictType) -> JsonDictType: """ Call the client with a specific command. Will check whether the daemon is running first by checking for the pid file. When the pid is found yet the call still fails with a @@ -358,47 +370,51 @@ def call_client(self, command): return result - def get_status(self): + def get_status(self) -> JsonDictType: """ Get the daemon running status :return: the client call response + If successful, will will contain 'status' key """ command = {'command': 'status', 'properties': {'name': self.daemon_name}} return self.call_client(command) - def get_numprocesses(self): + def get_numprocesses(self) -> JsonDictType: """ Get the number of running daemon processes :return: the client call response + If successful, will contain 'numprocesses' key """ command = {'command': 'numprocesses', 'properties': {'name': self.daemon_name}} return self.call_client(command) - def get_worker_info(self): + def get_worker_info(self) -> JsonDictType: """ Get workers statistics for this daemon :return: the client call response + If successful, will contain 'info' key """ command = {'command': 'stats', 'properties': {'name': self.daemon_name}} return self.call_client(command) - def get_daemon_info(self): + def get_daemon_info(self) -> JsonDictType: """ Get statistics about this daemon itself :return: the client call response + If successful, will contain 'info' key """ command = {'command': 'dstats', 'properties': {}} return self.call_client(command) - def increase_workers(self, number): + def increase_workers(self, number: int) -> JsonDictType: """ Increase the number of workers @@ -409,7 +425,7 @@ def increase_workers(self, number): return self.call_client(command) - def decrease_workers(self, number): + def decrease_workers(self, number: int) -> JsonDictType: """ Decrease the number of workers @@ -420,7 +436,7 @@ def decrease_workers(self, number): return self.call_client(command) - def stop_daemon(self, wait): + def stop_daemon(self, wait: bool) -> JsonDictType: """ Stop the daemon @@ -436,7 +452,7 @@ def stop_daemon(self, wait): return result - def restart_daemon(self, wait): + def restart_daemon(self, wait: bool) -> JsonDictType: """ Restart the daemon diff --git a/aiida/engine/daemon/execmanager.py b/aiida/engine/daemon/execmanager.py index 5f8a136589..02dc638a99 100644 --- a/aiida/engine/daemon/execmanager.py +++ b/aiida/engine/daemon/execmanager.py @@ -13,23 +13,56 @@ the routines make reference to the suitable plugins for all plugin-specific operations. """ +from collections.abc import Mapping +from logging import LoggerAdapter import os import shutil +from tempfile import NamedTemporaryFile +from typing import Any, List, Optional, Mapping as MappingType, Tuple, Union from aiida.common import AIIDA_LOGGER, exceptions +from aiida.common.datastructures import CalcInfo from aiida.common.folders import SandboxFolder from aiida.common.links import LinkType -from aiida.orm import FolderData, Node +from aiida.orm import load_node, CalcJobNode, Code, FolderData, Node, RemoteData from aiida.orm.utils.log import get_dblogger_extra from aiida.plugins import DataFactory from aiida.schedulers.datastructures import JobState +from aiida.transports import Transport REMOTE_WORK_DIRECTORY_LOST_FOUND = 'lost+found' -execlogger = AIIDA_LOGGER.getChild('execmanager') +EXEC_LOGGER = AIIDA_LOGGER.getChild('execmanager') -def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run=False): +def _find_data_node(inputs: MappingType[str, Any], uuid: str) -> Optional[Node]: + """Find and return the node with the given UUID from a nested mapping of input nodes. + + :param inputs: (nested) mapping of nodes + :param uuid: UUID of the node to find + :return: instance of `Node` or `None` if not found + """ + data_node = None + + for input_node in inputs.values(): + if isinstance(input_node, Mapping): + data_node = _find_data_node(input_node, uuid) + elif isinstance(input_node, Node) and input_node.uuid == uuid: + data_node = input_node + if data_node is not None: + break + + return data_node + + +def upload_calculation( + node: CalcJobNode, + transport: Transport, + calc_info: CalcInfo, + folder: SandboxFolder, + inputs: Optional[MappingType[str, Any]] = None, + dry_run: bool = False +) -> None: """Upload a `CalcJob` instance :param node: the `CalcJobNode`. @@ -38,16 +71,13 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= :param folder: temporary local file system folder containing the inputs written by `CalcJob.prepare_for_submission` """ # pylint: disable=too-many-locals,too-many-branches,too-many-statements - from logging import LoggerAdapter - from tempfile import NamedTemporaryFile - from aiida.orm import load_node, Code, RemoteData # If the calculation already has a `remote_folder`, simply return. The upload was apparently already completed # before, which can happen if the daemon is restarted and it shuts down after uploading but before getting the # chance to perform the state transition. Upon reloading this calculation, it will re-attempt the upload. link_label = 'remote_folder' if node.get_outgoing(RemoteData, link_label_filter=link_label).first(): - execlogger.warning(f'CalcJobNode<{node.pk}> already has a `{link_label}` output: skipping upload') + EXEC_LOGGER.warning(f'CalcJobNode<{node.pk}> already has a `{link_label}` output: skipping upload') return calc_info computer = node.computer @@ -57,7 +87,7 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= logger_extra = get_dblogger_extra(node) transport.set_logger_extra(logger_extra) - logger = LoggerAdapter(logger=execlogger, extra=logger_extra) + logger = LoggerAdapter(logger=EXEC_LOGGER, extra=logger_extra) if not dry_run and node.has_cached_links(): raise ValueError( @@ -162,30 +192,10 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= for uuid, filename, target in local_copy_list: logger.debug(f'[submission of calculation {node.uuid}] copying local file/folder to {target}') - def find_data_node(inputs, uuid): - """Find and return the node with the given UUID from a nested mapping of input nodes. - - :param inputs: (nested) mapping of nodes - :param uuid: UUID of the node to find - :return: instance of `Node` or `None` if not found - """ - from collections.abc import Mapping - data_node = None - - for input_node in inputs.values(): - if isinstance(input_node, Mapping): - data_node = find_data_node(input_node, uuid) - elif isinstance(input_node, Node) and input_node.uuid == uuid: - data_node = input_node - if data_node is not None: - break - - return data_node - try: data_node = load_node(uuid=uuid) except exceptions.NotExistent: - data_node = find_data_node(inputs, uuid) + data_node = _find_data_node(inputs, uuid) if inputs else None if data_node is None: logger.warning(f'failed to load Node<{uuid}> specified in the `local_copy_list`') @@ -294,7 +304,7 @@ def find_data_node(inputs, uuid): remotedata.store() -def submit_calculation(calculation, transport): +def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str: """Submit a previously uploaded `CalcJob` to the scheduler. :param calculation: the instance of CalcJobNode to submit. @@ -322,7 +332,66 @@ def submit_calculation(calculation, transport): return job_id -def retrieve_calculation(calculation, transport, retrieved_temporary_folder): +def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: + """Stash files from the working directory of a completed calculation to a permanent remote folder. + + After a calculation has been completed, optionally stash files from the work directory to a storage location on the + same remote machine. This is useful if one wants to keep certain files from a completed calculation to be removed + from the scratch directory, because they are necessary for restarts, but that are too heavy to retrieve. + Instructions of which files to copy where are retrieved from the `stash.source_list` option. + + :param calculation: the calculation job node. + :param transport: an already opened transport. + """ + from aiida.common.datastructures import StashMode + from aiida.orm import RemoteStashFolderData + + logger_extra = get_dblogger_extra(calculation) + + stash_options = calculation.get_option('stash') + stash_mode = stash_options.get('mode', StashMode.COPY.value) + source_list = stash_options.get('source_list', []) + + if not source_list: + return + + if stash_mode != StashMode.COPY.value: + EXEC_LOGGER.warning(f'stashing mode {stash_mode} is not implemented yet.') + return + + cls = RemoteStashFolderData + + EXEC_LOGGER.debug(f'stashing files for calculation<{calculation.pk}>: {source_list}', extra=logger_extra) + + uuid = calculation.uuid + target_basepath = os.path.join(stash_options['target_base'], uuid[:2], uuid[2:4], uuid[4:]) + + for source_filename in source_list: + + source_filepath = os.path.join(calculation.get_remote_workdir(), source_filename) + target_filepath = os.path.join(target_basepath, source_filename) + + # If the source file is in a (nested) directory, create those directories first in the target directory + target_dirname = os.path.dirname(target_filepath) + transport.makedirs(target_dirname, ignore_existing=True) + + try: + transport.copy(source_filepath, target_filepath) + except (IOError, ValueError) as exception: + EXEC_LOGGER.warning(f'failed to stash {source_filepath} to {target_filepath}: {exception}') + else: + EXEC_LOGGER.debug(f'stashed {source_filepath} to {target_filepath}') + + remote_stash = cls( + computer=calculation.computer, + target_basepath=target_basepath, + stash_mode=StashMode(stash_mode), + source_list=source_list, + ).store() + remote_stash.add_incoming(calculation, link_type=LinkType.CREATE, link_label='remote_stash') + + +def retrieve_calculation(calculation: CalcJobNode, transport: Transport, retrieved_temporary_folder: str) -> None: """Retrieve all the files of a completed job calculation using the given transport. If the job defined anything in the `retrieve_temporary_list`, those entries will be stored in the @@ -336,15 +405,15 @@ def retrieve_calculation(calculation, transport, retrieved_temporary_folder): logger_extra = get_dblogger_extra(calculation) workdir = calculation.get_remote_workdir() - execlogger.debug(f'Retrieving calc {calculation.pk}', extra=logger_extra) - execlogger.debug(f'[retrieval of calc {calculation.pk}] chdir {workdir}', extra=logger_extra) + EXEC_LOGGER.debug(f'Retrieving calc {calculation.pk}', extra=logger_extra) + EXEC_LOGGER.debug(f'[retrieval of calc {calculation.pk}] chdir {workdir}', extra=logger_extra) # If the calculation already has a `retrieved` folder, simply return. The retrieval was apparently already completed # before, which can happen if the daemon is restarted and it shuts down after retrieving but before getting the # chance to perform the state transition. Upon reloading this calculation, it will re-attempt the retrieval. link_label = calculation.link_label_retrieved if calculation.get_outgoing(FolderData, link_label_filter=link_label).first(): - execlogger.warning( + EXEC_LOGGER.warning( f'CalcJobNode<{calculation.pk}> already has a `{link_label}` output folder: skipping retrieval' ) return @@ -377,13 +446,13 @@ def retrieve_calculation(calculation, transport, retrieved_temporary_folder): # Log the files that were retrieved in the temporary folder for filename in os.listdir(retrieved_temporary_folder): - execlogger.debug( + EXEC_LOGGER.debug( f"[retrieval of calc {calculation.pk}] Retrieved temporary file or folder '{filename}'", extra=logger_extra ) # Store everything - execlogger.debug( + EXEC_LOGGER.debug( f'[retrieval of calc {calculation.pk}] Storing retrieved_files={retrieved_files.pk}', extra=logger_extra ) retrieved_files.store() @@ -394,7 +463,7 @@ def retrieve_calculation(calculation, transport, retrieved_temporary_folder): retrieved_files.add_incoming(calculation, link_type=LinkType.CREATE, link_label=calculation.link_label_retrieved) -def kill_calculation(calculation, transport): +def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None: """ Kill the calculation through the scheduler @@ -403,6 +472,10 @@ def kill_calculation(calculation, transport): """ job_id = calculation.get_job_id() + if job_id is None: + # the calculation has not yet been submitted to the scheduler + return + # Get the scheduler plugin class and initialize it with the correct transport scheduler = calculation.computer.get_scheduler() scheduler.set_transport(transport) @@ -420,16 +493,22 @@ def kill_calculation(calculation, transport): if job is not None and job.job_state != JobState.DONE: raise exceptions.RemoteOperationError(f'scheduler.kill({job_id}) was unsuccessful') else: - execlogger.warning('scheduler.kill() failed but job<{%s}> no longer seems to be running regardless', job_id) - - return True + EXEC_LOGGER.warning( + 'scheduler.kill() failed but job<{%s}> no longer seems to be running regardless', job_id + ) -def _retrieve_singlefiles(job, transport, folder, retrieve_file_list, logger_extra=None): +def _retrieve_singlefiles( + job: CalcJobNode, + transport: Transport, + folder: SandboxFolder, + retrieve_file_list: List[Tuple[str, str, str]], + logger_extra: Optional[dict] = None +): """Retrieve files specified through the singlefile list mechanism.""" singlefile_list = [] for (linkname, subclassname, filename) in retrieve_file_list: - execlogger.debug( + EXEC_LOGGER.debug( '[retrieval of calc {}] Trying ' "to retrieve remote singlefile '{}'".format(job.pk, filename), extra=logger_extra @@ -450,11 +529,14 @@ def _retrieve_singlefiles(job, transport, folder, retrieve_file_list, logger_ext singlefiles.append(singlefile) for fil in singlefiles: - execlogger.debug(f'[retrieval of calc {job.pk}] Storing retrieved_singlefile={fil.pk}', extra=logger_extra) + EXEC_LOGGER.debug(f'[retrieval of calc {job.pk}] Storing retrieved_singlefile={fil.pk}', extra=logger_extra) fil.store() -def retrieve_files_from_list(calculation, transport, folder, retrieve_list): +def retrieve_files_from_list( + calculation: CalcJobNode, transport: Transport, folder: str, retrieve_list: List[Union[str, Tuple[str, str, int], + list]] +) -> None: """ Retrieve all the files in the retrieve_list from the remote into the local folder instance through the transport. The entries in the retrieve_list diff --git a/aiida/engine/daemon/runner.py b/aiida/engine/daemon/runner.py index c428dc9da0..c807c54953 100644 --- a/aiida/engine/daemon/runner.py +++ b/aiida/engine/daemon/runner.py @@ -10,15 +10,34 @@ """Function that starts a daemon runner.""" import logging import signal +import asyncio from aiida.common.log import configure_logging from aiida.engine.daemon.client import get_daemon_client +from aiida.engine.runners import Runner from aiida.manage.manager import get_manager LOGGER = logging.getLogger(__name__) -def start_daemon(): +async def shutdown_runner(runner: Runner) -> None: + """Cleanup tasks tied to the service's shutdown.""" + from asyncio import all_tasks + from asyncio import current_task + + LOGGER.info('Received signal to shut down the daemon runner') + tasks = [task for task in all_tasks() if task is not current_task()] + + for task in tasks: + task.cancel() + + await asyncio.gather(*tasks, return_exceptions=True) + runner.close() + + LOGGER.info('Daemon runner stopped') + + +def start_daemon() -> None: """Start a daemon runner for the currently configured profile.""" daemon_client = get_daemon_client() configure_logging(daemon=True, daemon_log_file=daemon_client.daemon_log_file) @@ -31,19 +50,15 @@ def start_daemon(): LOGGER.exception('daemon runner failed to start') raise - def shutdown_daemon(_num, _frame): - LOGGER.info('Received signal to shut down the daemon runner') - runner.close() - - signal.signal(signal.SIGINT, shutdown_daemon) - signal.signal(signal.SIGTERM, shutdown_daemon) - - LOGGER.info('Starting a daemon runner') + signals = (signal.SIGTERM, signal.SIGINT) + for s in signals: # pylint: disable=invalid-name + runner.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown_runner(runner))) try: + LOGGER.info('Starting a daemon runner') runner.start() except SystemError as exception: LOGGER.info('Received a SystemError: %s', exception) runner.close() - LOGGER.info('Daemon runner stopped') + LOGGER.info('Daemon runner started') diff --git a/aiida/engine/launch.py b/aiida/engine/launch.py index cea0378434..6026ac4731 100644 --- a/aiida/engine/launch.py +++ b/aiida/engine/launch.py @@ -8,27 +8,30 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Top level functions that can be used to launch a Process.""" +from typing import Any, Dict, Tuple, Type, Union from aiida.common import InvalidOperation from aiida.manage import manager +from aiida.orm import ProcessNode from .processes.functions import FunctionProcess -from .processes.process import Process +from .processes.process import Process, ProcessBuilder from .utils import is_process_scoped, instantiate_process __all__ = ('run', 'run_get_pk', 'run_get_node', 'submit') +TYPE_RUN_PROCESS = Union[Process, Type[Process], ProcessBuilder] # pylint: disable=invalid-name +# run can also be process function, but it is not clear what type this should be +TYPE_SUBMIT_PROCESS = Union[Process, Type[Process], ProcessBuilder] # pylint: disable=invalid-name -def run(process, *args, **inputs): + +def run(process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Dict[str, Any]: """Run the process with the supplied inputs in a local runner that will block until the process is completed. :param process: the process class or process function to run - :type process: :class:`aiida.engine.Process` - :param inputs: the inputs to be passed to the process - :type inputs: dict :return: the outputs of the process - :rtype: dict + """ if isinstance(process, Process): runner = process.runner @@ -38,17 +41,13 @@ def run(process, *args, **inputs): return runner.run(process, *args, **inputs) -def run_get_node(process, *args, **inputs): +def run_get_node(process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Tuple[Dict[str, Any], ProcessNode]: """Run the process with the supplied inputs in a local runner that will block until the process is completed. - :param process: the process class or process function to run - :type process: :class:`aiida.engine.Process` - + :param process: the process class, instance, builder or function to run :param inputs: the inputs to be passed to the process - :type inputs: dict :return: tuple of the outputs of the process and the process node - :rtype: (dict, :class:`aiida.orm.ProcessNode`) """ if isinstance(process, Process): @@ -59,17 +58,14 @@ def run_get_node(process, *args, **inputs): return runner.run_get_node(process, *args, **inputs) -def run_get_pk(process, *args, **inputs): +def run_get_pk(process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Tuple[Dict[str, Any], int]: """Run the process with the supplied inputs in a local runner that will block until the process is completed. - :param process: the process class or process function to run - :type process: :class:`aiida.engine.Process` - + :param process: the process class, instance, builder or function to run :param inputs: the inputs to be passed to the process - :type inputs: dict :return: tuple of the outputs of the process and process node pk - :rtype: (dict, int) + """ if isinstance(process, Process): runner = process.runner @@ -79,7 +75,7 @@ def run_get_pk(process, *args, **inputs): return runner.run_get_pk(process, *args, **inputs) -def submit(process, **inputs): +def submit(process: TYPE_SUBMIT_PROCESS, **inputs: Any) -> ProcessNode: """Submit the process with the supplied inputs to the daemon immediately returning control to the interpreter. .. warning: this should not be used within another process. Instead, there one should use the `submit` method of @@ -87,14 +83,11 @@ def submit(process, **inputs): .. warning: submission of processes requires `store_provenance=True` - :param process: the process class to submit - :type process: :class:`aiida.engine.Process` - + :param process: the process class, instance or builder to submit :param inputs: the inputs to be passed to the process - :type inputs: dict :return: the calculation node of the process - :rtype: :class:`aiida.orm.ProcessNode` + """ # Submitting from within another process requires `self.submit` unless it is a work function, in which case the # current process in the scope should be an instance of `FunctionProcess` @@ -102,28 +95,29 @@ def submit(process, **inputs): raise InvalidOperation('Cannot use top-level `submit` from within another process, use `self.submit` instead') runner = manager.get_manager().get_runner() - controller = manager.get_manager().get_process_controller() + assert runner.persister is not None, 'runner does not have a persister' + assert runner.controller is not None, 'runner does not have a persister' - process = instantiate_process(runner, process, **inputs) + process_inited = instantiate_process(runner, process, **inputs) # If a dry run is requested, simply forward to `run`, because it is not compatible with `submit`. We choose for this # instead of raising, because in this way the user does not have to change the launcher when testing. - if process.metadata.get('dry_run', False): - _, node = run_get_node(process) + if process_inited.metadata.get('dry_run', False): + _, node = run_get_node(process_inited) return node - if not process.metadata.store_provenance: + if not process_inited.metadata.store_provenance: raise InvalidOperation('cannot submit a process with `store_provenance=False`') - runner.persister.save_checkpoint(process) - process.close() + runner.persister.save_checkpoint(process_inited) + process_inited.close() # Do not wait for the future's result, because in the case of a single worker this would cock-block itself - controller.continue_process(process.pid, nowait=False, no_reply=True) + runner.controller.continue_process(process_inited.pid, nowait=False, no_reply=True) - return process.node + return process_inited.node # Allow one to also use run.get_node and run.get_pk as a shortcut, without having to import the functions themselves -run.get_node = run_get_node -run.get_pk = run_get_pk +run.get_node = run_get_node # type: ignore[attr-defined] +run.get_pk = run_get_pk # type: ignore[attr-defined] diff --git a/aiida/engine/persistence.py b/aiida/engine/persistence.py index 5aedd9d386..2ccdac03c1 100644 --- a/aiida/engine/persistence.py +++ b/aiida/engine/persistence.py @@ -13,21 +13,27 @@ import importlib import logging import traceback +from typing import Any, Hashable, Optional, TYPE_CHECKING -import plumpy +import plumpy.persistence +import plumpy.loaders +from plumpy.exceptions import PersistenceError from aiida.orm.utils import serialize +if TYPE_CHECKING: + from aiida.engine.processes.process import Process + __all__ = ('AiiDAPersister', 'ObjectLoader', 'get_object_loader') LOGGER = logging.getLogger(__name__) OBJECT_LOADER = None -class ObjectLoader(plumpy.DefaultObjectLoader): +class ObjectLoader(plumpy.loaders.DefaultObjectLoader): """Custom object loader for `aiida-core`.""" - def load_object(self, identifier): + def load_object(self, identifier: str) -> Any: # pylint: disable=no-self-use """Attempt to load the object identified by the given `identifier`. .. note:: We override the `plumpy.DefaultObjectLoader` to be able to throw an `ImportError` instead of a @@ -37,11 +43,11 @@ def load_object(self, identifier): :return: loaded object :raises ImportError: if the object cannot be loaded """ - module, name = identifier.split(':') + module_name, name = identifier.split(':') try: - module = importlib.import_module(module) + module = importlib.import_module(module_name) except ImportError: - raise ImportError(f"module '{module}' from identifier '{identifier}' could not be loaded") + raise ImportError(f"module '{module_name}' from identifier '{identifier}' could not be loaded") try: return getattr(module, name) @@ -49,11 +55,11 @@ def load_object(self, identifier): raise ImportError(f"object '{name}' from identifier '{identifier}' could not be loaded") -def get_object_loader(): +def get_object_loader() -> ObjectLoader: """Return the global AiiDA object loader. :return: The global object loader - :rtype: :class:`plumpy.ObjectLoader` + """ global OBJECT_LOADER if OBJECT_LOADER is None: @@ -61,15 +67,15 @@ def get_object_loader(): return OBJECT_LOADER -class AiiDAPersister(plumpy.Persister): +class AiiDAPersister(plumpy.persistence.Persister): """Persister to take saved process instance states and persisting them to the database.""" - def save_checkpoint(self, process, tag=None): + def save_checkpoint(self, process: 'Process', tag: Optional[str] = None): # type: ignore[override] # pylint: disable=no-self-use """Persist a Process instance. :param process: :class:`aiida.engine.Process` :param tag: optional checkpoint identifier to allow distinguishing multiple checkpoints for the same process - :raises: :class:`plumpy.PersistenceError` Raised if there was a problem saving the checkpoint + :raises: :class:`PersistenceError` Raised if there was a problem saving the checkpoint """ LOGGER.debug('Persisting process<%d>', process.pid) @@ -77,26 +83,26 @@ def save_checkpoint(self, process, tag=None): raise NotImplementedError('Checkpoint tags not supported yet') try: - bundle = plumpy.Bundle(process, plumpy.LoadSaveContext(loader=get_object_loader())) + bundle = plumpy.persistence.Bundle(process, plumpy.persistence.LoadSaveContext(loader=get_object_loader())) except ImportError: # Couldn't create the bundle - raise plumpy.PersistenceError(f"Failed to create a bundle for '{process}': {traceback.format_exc()}") + raise PersistenceError(f"Failed to create a bundle for '{process}': {traceback.format_exc()}") try: process.node.set_checkpoint(serialize.serialize(bundle)) except Exception: - raise plumpy.PersistenceError(f"Failed to store a checkpoint for '{process}': {traceback.format_exc()}") + raise PersistenceError(f"Failed to store a checkpoint for '{process}': {traceback.format_exc()}") return bundle - def load_checkpoint(self, pid, tag=None): + def load_checkpoint(self, pid: Hashable, tag: Optional[str] = None) -> plumpy.persistence.Bundle: # pylint: disable=no-self-use """Load a process from a persisted checkpoint by its process id. :param pid: the process id of the :class:`plumpy.Process` :param tag: optional checkpoint identifier to allow retrieving a specific sub checkpoint :return: a bundle with the process state :rtype: :class:`plumpy.Bundle` - :raises: :class:`plumpy.PersistenceError` Raised if there was a problem loading the checkpoint + :raises: :class:`PersistenceError` Raised if there was a problem loading the checkpoint """ from aiida.common.exceptions import MultipleObjectsError, NotExistent from aiida.orm import load_node @@ -107,17 +113,17 @@ def load_checkpoint(self, pid, tag=None): try: calculation = load_node(pid) except (MultipleObjectsError, NotExistent): - raise plumpy.PersistenceError(f'Failed to load the node for process<{pid}>: {traceback.format_exc()}') + raise PersistenceError(f'Failed to load the node for process<{pid}>: {traceback.format_exc()}') checkpoint = calculation.checkpoint if checkpoint is None: - raise plumpy.PersistenceError(f'Calculation<{calculation.pk}> does not have a saved checkpoint') + raise PersistenceError(f'Calculation<{calculation.pk}> does not have a saved checkpoint') try: bundle = serialize.deserialize(checkpoint) except Exception: - raise plumpy.PersistenceError(f'Failed to load the checkpoint for process<{pid}>: {traceback.format_exc()}') + raise PersistenceError(f'Failed to load the checkpoint for process<{pid}>: {traceback.format_exc()}') return bundle @@ -127,14 +133,14 @@ def get_checkpoints(self): :return: list of PersistedCheckpoint tuples with element containing the process id and optional checkpoint tag. """ - def get_process_checkpoints(self, pid): + def get_process_checkpoints(self, pid: Hashable): """Return a list of all the current persisted process checkpoints for the specified process. :param pid: the process pid :return: list of PersistedCheckpoint tuples with element containing the process id and optional checkpoint tag. """ - def delete_checkpoint(self, pid, tag=None): + def delete_checkpoint(self, pid: Hashable, tag: Optional[str] = None) -> None: # pylint: disable=no-self-use,unused-argument """Delete a persisted process checkpoint, where no error will be raised if the checkpoint does not exist. :param pid: the process id of the :class:`plumpy.Process` @@ -145,7 +151,7 @@ def delete_checkpoint(self, pid, tag=None): calc = load_node(pid) calc.delete_checkpoint() - def delete_process_checkpoints(self, pid): + def delete_process_checkpoints(self, pid: Hashable): """Delete all persisted checkpoints related to the given process id. :param pid: the process id of the :class:`aiida.engine.processes.process.Process` diff --git a/aiida/engine/processes/__init__.py b/aiida/engine/processes/__init__.py index de5a86cf18..b3045dcfd4 100644 --- a/aiida/engine/processes/__init__.py +++ b/aiida/engine/processes/__init__.py @@ -19,6 +19,6 @@ from .workchains import * __all__ = ( - builder.__all__ + calcjobs.__all__ + exit_code.__all__ + functions.__all__ + ports.__all__ + process.__all__ + - process_spec.__all__ + workchains.__all__ + builder.__all__ + calcjobs.__all__ + exit_code.__all__ + functions.__all__ + # type: ignore[name-defined] + ports.__all__ + process.__all__ + process_spec.__all__ + workchains.__all__ # type: ignore[name-defined] ) diff --git a/aiida/engine/processes/builder.py b/aiida/engine/processes/builder.py index 9a620244b4..c7f6939918 100644 --- a/aiida/engine/processes/builder.py +++ b/aiida/engine/processes/builder.py @@ -9,10 +9,14 @@ ########################################################################### """Convenience classes to help building the input dictionaries for Processes.""" import collections +from typing import Any, Type, TYPE_CHECKING from aiida.orm import Node from aiida.engine.processes.ports import PortNamespace +if TYPE_CHECKING: + from aiida.engine.processes.process import Process + __all__ = ('ProcessBuilder', 'ProcessBuilderNamespace') @@ -22,7 +26,7 @@ class ProcessBuilderNamespace(collections.abc.MutableMapping): Dynamically generates the getters and setters for the input ports of a given PortNamespace """ - def __init__(self, port_namespace): + def __init__(self, port_namespace: PortNamespace) -> None: """Dynamically construct the get and set properties for the ports of the given port namespace. For each port in the given port namespace a get and set property will be constructed dynamically @@ -30,7 +34,7 @@ def __init__(self, port_namespace): by calling str() on the Port, which should return the description of the Port. :param port_namespace: the inputs PortNamespace for which to construct the builder - :type port_namespace: str + """ # pylint: disable=super-init-not-called self._port_namespace = port_namespace @@ -52,7 +56,7 @@ def fgetter(self, name=name): return self._data.get(name) elif port.has_default(): - def fgetter(self, name=name, default=port.default): # pylint: disable=cell-var-from-loop + def fgetter(self, name=name, default=port.default): # type: ignore # pylint: disable=cell-var-from-loop return self._data.get(name, default) else: @@ -67,16 +71,12 @@ def fsetter(self, value, name=name): getter.setter(fsetter) # pylint: disable=too-many-function-args setattr(self.__class__, name, getter) - def __setattr__(self, attr, value): + def __setattr__(self, attr: str, value: Any) -> None: """Assign the given value to the port with key `attr`. .. note:: Any attributes without a leading underscore being set correspond to inputs and should hence be validated with respect to the corresponding input port from the process spec - :param attr: attribute - :type attr: str - - :param value: value """ if attr.startswith('_'): object.__setattr__(self, attr, value) @@ -87,7 +87,7 @@ def __setattr__(self, attr, value): if not self._port_namespace.dynamic: raise AttributeError(f'Unknown builder parameter: {attr}') else: - value = port.serialize(value) + value = port.serialize(value) # type: ignore[union-attr] validation_error = port.validate(value) if validation_error: raise ValueError(f'invalid attribute value {validation_error.message}') @@ -126,10 +126,8 @@ def _update(self, *args, **kwds): principle the method functions just as `collections.abc.MutableMapping.update`. :param args: a single mapping that should be mapped on the namespace - :type args: list :param kwds: keyword value pairs that should be mapped onto the ports - :type kwds: dict """ if len(args) > 1: raise TypeError(f'update expected at most 1 arguments, got {int(len(args))}') @@ -147,7 +145,7 @@ def _update(self, *args, **kwds): else: self.__setattr__(key, value) - def _inputs(self, prune=False): + def _inputs(self, prune: bool = False) -> dict: """Return the entire mapping of inputs specified for this builder. :param prune: boolean, when True, will prune nested namespaces that contain no actual values whatsoever @@ -182,7 +180,7 @@ def _prune(self, value): class ProcessBuilder(ProcessBuilderNamespace): # pylint: disable=too-many-ancestors """A process builder that helps setting up the inputs for creating a new process.""" - def __init__(self, process_class): + def __init__(self, process_class: Type['Process']): """Construct a `ProcessBuilder` instance for the given `Process` class. :param process_class: the `Process` subclass @@ -192,6 +190,6 @@ def __init__(self, process_class): super().__init__(self._process_spec.inputs) @property - def process_class(self): + def process_class(self) -> Type['Process']: """Return the process class for which this builder is constructed.""" return self._process_class diff --git a/aiida/engine/processes/calcjobs/__init__.py b/aiida/engine/processes/calcjobs/__init__.py index dc7c275880..57d4777ae7 100644 --- a/aiida/engine/processes/calcjobs/__init__.py +++ b/aiida/engine/processes/calcjobs/__init__.py @@ -12,4 +12,4 @@ from .calcjob import * -__all__ = (calcjob.__all__) +__all__ = (calcjob.__all__) # type: ignore[name-defined] diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py index db4f8a60c3..f13a65a965 100644 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ b/aiida/engine/processes/calcjobs/calcjob.py @@ -9,8 +9,12 @@ ########################################################################### """Implementation of the CalcJob process.""" import io +import os +import shutil +from typing import Any, Dict, Hashable, Optional, Type, Union -import plumpy +import plumpy.ports +import plumpy.process_states from aiida import orm from aiida.common import exceptions, AttributeDict @@ -20,6 +24,7 @@ from aiida.common.links import LinkType from ..exit_code import ExitCode +from ..ports import PortNamespace from ..process import Process, ProcessState from ..process_spec import CalcJobProcessSpec from .tasks import Waiting, UPLOAD_COMMAND @@ -27,7 +32,7 @@ __all__ = ('CalcJob',) -def validate_calc_job(inputs, ctx): # pylint: disable=too-many-return-statements +def validate_calc_job(inputs: Any, ctx: PortNamespace) -> Optional[str]: # pylint: disable=too-many-return-statements """Validate the entire set of inputs passed to the `CalcJob` constructor. Reasons that will cause this validation to raise an `InputValidationError`: @@ -43,7 +48,7 @@ def validate_calc_job(inputs, ctx): # pylint: disable=too-many-return-statement ctx.get_port('metadata.computer') except ValueError: # If the namespace no longer contains the `code` or `metadata.computer` ports we skip validation - return + return None code = inputs.get('code', None) computer_from_code = code.computer @@ -69,11 +74,11 @@ def validate_calc_job(inputs, ctx): # pylint: disable=too-many-return-statement try: resources_port = ctx.get_port('metadata.options.resources') except ValueError: - return + return None # If the resources port exists but is not required, we don't need to validate it against the computer's scheduler if not resources_port.required: - return + return None computer = computer_from_code or computer_from_metadata scheduler = computer.get_scheduler() @@ -89,19 +94,60 @@ def validate_calc_job(inputs, ctx): # pylint: disable=too-many-return-statement except ValueError as exception: return f'input `metadata.options.resources` is not valid for the `{scheduler}` scheduler: {exception}' + return None -def validate_parser(parser_name, _): + +def validate_stash_options(stash_options: Any, _: Any) -> Optional[str]: + """Validate the ``stash`` options.""" + from aiida.common.datastructures import StashMode + + target_base = stash_options.get('target_base', None) + source_list = stash_options.get('source_list', None) + stash_mode = stash_options.get('mode', StashMode.COPY.value) + + if not isinstance(target_base, str) or not os.path.isabs(target_base): + return f'`metadata.options.stash.target_base` should be an absolute filepath, got: {target_base}' + + if ( + not isinstance(source_list, (list, tuple)) or + any(not isinstance(src, str) or os.path.isabs(src) for src in source_list) + ): + port = 'metadata.options.stash.source_list' + return f'`{port}` should be a list or tuple of relative filepaths, got: {source_list}' + + try: + StashMode(stash_mode) + except ValueError: + port = 'metadata.options.stash.mode' + return f'`{port}` should be a member of aiida.common.datastructures.StashMode, got: {stash_mode}' + + return None + + +def validate_parser(parser_name: Any, _: Any) -> Optional[str]: """Validate the parser. :return: string with error message in case the inputs are invalid """ from aiida.plugins import ParserFactory - if parser_name is not plumpy.UNSPECIFIED: - try: - ParserFactory(parser_name) - except exceptions.EntryPointError as exception: - return f'invalid parser specified: {exception}' + try: + ParserFactory(parser_name) + except exceptions.EntryPointError as exception: + return f'invalid parser specified: {exception}' + + return None + + +def validate_additional_retrieve_list(additional_retrieve_list: Any, _: Any) -> Optional[str]: + """Validate the additional retrieve list. + + :return: string with error message in case the input is invalid. + """ + if any(not isinstance(value, str) or os.path.isabs(value) for value in additional_retrieve_list): + return f'`additional_retrieve_list` should only contain relative filepaths but got: {additional_retrieve_list}' + + return None class CalcJob(Process): @@ -109,9 +155,9 @@ class CalcJob(Process): _node_class = orm.CalcJobNode _spec_class = CalcJobProcessSpec - link_label_retrieved = 'retrieved' + link_label_retrieved: str = 'retrieved' - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: """Construct a CalcJob instance. Construct the instance only if it is a sub class of `CalcJob`, otherwise raise `InvalidOperation`. @@ -124,14 +170,18 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @classmethod - def define(cls, spec: CalcJobProcessSpec): - # yapf: disable + def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] """Define the process specification, including its inputs, outputs and known exit codes. + Ports are added to the `metadata` input namespace (inherited from the base Process), + and a `code` input Port, a `remote_folder` output Port and retrieved folder output Port + are added. + :param spec: the calculation job process spec to define. """ + # yapf: disable super().define(spec) - spec.inputs.validator = validate_calc_job + spec.inputs.validator = validate_calc_job # type: ignore[assignment] # takes only PortNamespace not Port spec.input('code', valid_type=orm.Code, help='The `Code` to use for this job.') spec.input('metadata.dry_run', valid_type=bool, default=False, help='When set to `True` will prepare the calculation job for submission but not actually launch it.') @@ -186,9 +236,24 @@ def define(cls, spec: CalcJobProcessSpec): 'script, just after the code execution',) spec.input('metadata.options.parser_name', valid_type=str, required=False, validator=validate_parser, help='Set a string for the output parser. Can be None if no output plugin is available or needed') + spec.input('metadata.options.additional_retrieve_list', required=False, + valid_type=(list, tuple), validator=validate_additional_retrieve_list, + help='List of relative file paths that should be retrieved in addition to what the plugin specifies.') + spec.input_namespace('metadata.options.stash', required=False, populate_defaults=False, + validator=validate_stash_options, + help='Optional directives to stash files after the calculation job has completed.') + spec.input('metadata.options.stash.target_base', valid_type=str, required=False, + help='The base location to where the files should be stashd. For example, for the `copy` stash mode, this ' + 'should be an absolute filepath on the remote computer.') + spec.input('metadata.options.stash.source_list', valid_type=(tuple, list), required=False, + help='Sequence of relative filepaths representing files in the remote directory that should be stashed.') + spec.input('metadata.options.stash.stash_mode', valid_type=str, required=False, + help='Mode with which to perform the stashing, should be value of `aiida.common.datastructures.StashMode.') spec.output('remote_folder', valid_type=orm.RemoteData, help='Input files necessary to run the process will be stored in this folder node.') + spec.output('remote_stash', valid_type=orm.RemoteStashData, required=False, + help='Contents of the `stash.source_list` option are stored in this remote folder after job completion.') spec.output(cls.link_label_retrieved, valid_type=orm.FolderData, pass_to_parser=True, help='Files that are retrieved by the daemon will be stored in this node. By default the stdout and stderr ' 'of the scheduler will be added, but one can add more by specifying them in `CalcInfo.retrieve_list`.') @@ -200,6 +265,7 @@ def define(cls, spec: CalcJobProcessSpec): message='The job ran out of memory.') spec.exit_code(120, 'ERROR_SCHEDULER_OUT_OF_WALLTIME', message='The job ran out of walltime.') + # yapf: enable @classproperty def spec_options(cls): # pylint: disable=no-self-argument @@ -211,11 +277,11 @@ def spec_options(cls): # pylint: disable=no-self-argument return cls.spec_metadata['options'] # pylint: disable=unsubscriptable-object @property - def options(self): + def options(self) -> AttributeDict: """Return the options of the metadata that were specified when this process instance was launched. :return: options dictionary - :rtype: dict + """ try: return self.metadata.options @@ -223,14 +289,18 @@ def options(self): return AttributeDict() @classmethod - def get_state_classes(cls): + def get_state_classes(cls) -> Dict[Hashable, Type[plumpy.process_states.State]]: + """A mapping of the State constants to the corresponding state class. + + Overrides the waiting state with the Calcjob specific version. + """ # Overwrite the waiting state states_map = super().get_state_classes() states_map[ProcessState.WAITING] = Waiting return states_map @override - def on_terminated(self): + def on_terminated(self) -> None: """Cleanup the node by deleting the calulation job state. .. note:: This has to be done before calling the super because that will seal the node after we cannot change it @@ -239,13 +309,17 @@ def on_terminated(self): super().on_terminated() @override - def run(self): + def run(self) -> Union[plumpy.process_states.Stop, int, plumpy.process_states.Wait]: """Run the calculation job. This means invoking the `presubmit` and storing the temporary folder in the node's repository. Then we move the process in the `Wait` state, waiting for the `UPLOAD` transport task to be started. + + :returns: the `Stop` command if a dry run, int if the process has an exit status, + `Wait` command if the calcjob is to be uploaded + """ - if self.inputs.metadata.dry_run: + if self.inputs.metadata.dry_run: # type: ignore[union-attr] from aiida.common.folders import SubmitTestFolder from aiida.engine.daemon.execmanager import upload_calculation from aiida.transports.plugins.local import LocalTransport @@ -259,7 +333,7 @@ def run(self): 'folder': folder.abspath, 'script_filename': self.node.get_option('submit_script_filename') } - return plumpy.Stop(None, True) + return plumpy.process_states.Stop(None, True) # The following conditional is required for the caching to properly work. Even if the source node has a process # state of `Finished` the cached process will still enter the running state. The process state will have then @@ -269,7 +343,7 @@ def run(self): return self.node.exit_status # Launch the upload operation - return plumpy.Wait(msg='Waiting to upload', data=UPLOAD_COMMAND) + return plumpy.process_states.Wait(msg='Waiting to upload', data=UPLOAD_COMMAND) def prepare_for_submission(self, folder: Folder) -> CalcInfo: """Prepare the calculation for submission. @@ -284,13 +358,14 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: """ raise NotImplementedError - def parse(self, retrieved_temporary_folder=None): + def parse(self, retrieved_temporary_folder: Optional[str] = None) -> ExitCode: """Parse a retrieved job calculation. This is called once it's finished waiting for the calculation to be finished and the data has been retrieved. - """ - import shutil + :param retrieved_temporary_folder: The path to the temporary folder + + """ try: retrieved = self.node.outputs.retrieved except exceptions.NotExistent: @@ -320,6 +395,7 @@ def parse(self, retrieved_temporary_folder=None): self.logger.warning(msg) # The final exit code is that of the scheduler, unless the output parser returned one + exit_code: Optional[ExitCode] if exit_code_retrieved is not None: exit_code = exit_code_retrieved else: @@ -331,7 +407,7 @@ def parse(self, retrieved_temporary_folder=None): return exit_code or ExitCode(0) - def parse_scheduler_output(self, retrieved): + def parse_scheduler_output(self, retrieved: orm.Node) -> Optional[ExitCode]: """Parse the output of the scheduler if that functionality has been implemented for the plugin.""" scheduler = self.node.computer.get_scheduler() filename_stderr = self.node.get_option('scheduler_stderr') @@ -359,16 +435,16 @@ def parse_scheduler_output(self, retrieved): # Only attempt to call the scheduler parser if all three resources of information are available if any(entry is None for entry in [detailed_job_info, scheduler_stderr, scheduler_stdout]): - return + return None try: exit_code = scheduler.parse_output(detailed_job_info, scheduler_stdout, scheduler_stderr) except exceptions.FeatureNotAvailable: self.logger.info(f'`{scheduler.__class__.__name__}` does not implement scheduler output parsing') - return + return None except Exception as exception: # pylint: disable=broad-except self.logger.error(f'the `parse_output` method of the scheduler excepted: {exception}') - return + return None if exit_code is not None and not isinstance(exit_code, ExitCode): args = (scheduler.__class__.__name__, type(exit_code)) @@ -376,12 +452,12 @@ def parse_scheduler_output(self, retrieved): return exit_code - def parse_retrieved_output(self, retrieved_temporary_folder=None): + def parse_retrieved_output(self, retrieved_temporary_folder: Optional[str] = None) -> Optional[ExitCode]: """Parse the retrieved data by calling the parser plugin if it was defined in the inputs.""" parser_class = self.node.get_parser_class() if parser_class is None: - return + return None parser = parser_class(self.node) parse_kwargs = parser.get_outputs_for_parsing() @@ -405,18 +481,15 @@ def parse_retrieved_output(self, retrieved_temporary_folder=None): return exit_code - def presubmit(self, folder): + def presubmit(self, folder: Folder) -> CalcInfo: """Prepares the calculation folder with all inputs, ready to be copied to the cluster. :param folder: a SandboxFolder that can be used to write calculation input files and the scheduling script. - :type folder: :class:`aiida.common.folders.Folder` :return calcinfo: the CalcInfo object containing the information needed by the daemon to handle operations. - :rtype calcinfo: :class:`aiida.common.CalcInfo` + """ # pylint: disable=too-many-locals,too-many-statements,too-many-branches - import os - from aiida.common.exceptions import PluginInternalError, ValidationError, InvalidOperation, InputValidationError from aiida.common import json from aiida.common.utils import validate_list_of_string_tuples @@ -428,19 +501,23 @@ def presubmit(self, folder): computer = self.node.computer inputs = self.node.get_incoming(link_type=LinkType.INPUT_CALC) - if not self.inputs.metadata.dry_run and self.node.has_cached_links(): + if not self.inputs.metadata.dry_run and self.node.has_cached_links(): # type: ignore[union-attr] raise InvalidOperation('calculation node has unstored links in cache') codes = [_ for _ in inputs.all_nodes() if isinstance(_, Code)] for code in codes: if not code.can_run_on(computer): - raise InputValidationError('The selected code {} for calculation {} cannot run on computer {}'.format( - code.pk, self.node.pk, computer.label)) + raise InputValidationError( + 'The selected code {} for calculation {} cannot run on computer {}'.format( + code.pk, self.node.pk, computer.label + ) + ) if code.is_local() and code.get_local_executable() in folder.get_content_list(): - raise PluginInternalError('The plugin created a file {} that is also the executable name!'.format( - code.get_local_executable())) + raise PluginInternalError( + f'The plugin created a file {code.get_local_executable()} that is also the executable name!' + ) calc_info = self.prepare_for_submission(folder) calc_info.uuid = str(self.node.uuid) @@ -462,29 +539,29 @@ def presubmit(self, folder): job_tmpl.sched_join_files = False # Set retrieve path, add also scheduler STDOUT and STDERR - retrieve_list = (calc_info.retrieve_list if calc_info.retrieve_list is not None else []) + retrieve_list = calc_info.retrieve_list or [] if (job_tmpl.sched_output_path is not None and job_tmpl.sched_output_path not in retrieve_list): retrieve_list.append(job_tmpl.sched_output_path) if not job_tmpl.sched_join_files: if (job_tmpl.sched_error_path is not None and job_tmpl.sched_error_path not in retrieve_list): retrieve_list.append(job_tmpl.sched_error_path) + retrieve_list.extend(self.node.get_option('additional_retrieve_list') or []) self.node.set_retrieve_list(retrieve_list) - retrieve_singlefile_list = (calc_info.retrieve_singlefile_list - if calc_info.retrieve_singlefile_list is not None else []) + retrieve_singlefile_list = calc_info.retrieve_singlefile_list or [] # a validation on the subclasses of retrieve_singlefile_list for _, subclassname, _ in retrieve_singlefile_list: file_sub_class = DataFactory(subclassname) if not issubclass(file_sub_class, orm.SinglefileData): raise PluginInternalError( '[presubmission of calc {}] retrieve_singlefile_list subclass problem: {} is ' - 'not subclass of SinglefileData'.format(self.node.pk, file_sub_class.__name__)) + 'not subclass of SinglefileData'.format(self.node.pk, file_sub_class.__name__) + ) if retrieve_singlefile_list: self.node.set_retrieve_singlefile_list(retrieve_singlefile_list) # Handle the retrieve_temporary_list - retrieve_temporary_list = (calc_info.retrieve_temporary_list - if calc_info.retrieve_temporary_list is not None else []) + retrieve_temporary_list = calc_info.retrieve_temporary_list or [] self.node.set_retrieve_temporary_list(retrieve_temporary_list) # the if is done so that if the method returns None, this is @@ -526,26 +603,24 @@ def presubmit(self, folder): raise PluginInternalError('Invalid codes_info, must be a list of CodeInfo objects') if code_info.code_uuid is None: - raise PluginInternalError('CalcInfo should have ' - 'the information of the code ' - 'to be launched') + raise PluginInternalError('CalcInfo should have the information of the code to be launched') this_code = load_node(code_info.code_uuid, sub_classes=(Code,)) this_withmpi = code_info.withmpi # to decide better how to set the default if this_withmpi is None: if len(calc_info.codes_info) > 1: - raise PluginInternalError('For more than one code, it is ' - 'necessary to set withmpi in ' - 'codes_info') + raise PluginInternalError('For more than one code, it is necessary to set withmpi in codes_info') else: this_withmpi = self.node.get_option('withmpi') if this_withmpi: - this_argv = (mpi_args + extra_mpirun_params + [this_code.get_execname()] + - (code_info.cmdline_params if code_info.cmdline_params is not None else [])) + this_argv = ( + mpi_args + extra_mpirun_params + [this_code.get_execname()] + + (code_info.cmdline_params if code_info.cmdline_params is not None else []) + ) else: - this_argv = [this_code.get_execname()] + (code_info.cmdline_params - if code_info.cmdline_params is not None else []) + this_argv = [this_code.get_execname() + ] + (code_info.cmdline_params if code_info.cmdline_params is not None else []) # overwrite the old cmdline_params and add codename and mpirun stuff code_info.cmdline_params = this_argv @@ -558,8 +633,8 @@ def presubmit(self, folder): if len(codes) > 1: try: job_tmpl.codes_run_mode = calc_info.codes_run_mode - except KeyError: - raise PluginInternalError('Need to set the order of the code execution (parallel or serial?)') + except KeyError as exc: + raise PluginInternalError('Need to set the order of the code execution (parallel or serial?)') from exc else: job_tmpl.codes_run_mode = CodeRunMode.SERIAL ######################################################################## @@ -613,26 +688,34 @@ def presubmit(self, folder): local_copy_list = calc_info.local_copy_list try: validate_list_of_string_tuples(local_copy_list, tuple_length=3) - except ValidationError as exc: - raise PluginInternalError(f'[presubmission of calc {this_pk}] local_copy_list format problem: {exc}') + except ValidationError as exception: + raise PluginInternalError( + f'[presubmission of calc {this_pk}] local_copy_list format problem: {exception}' + ) from exception remote_copy_list = calc_info.remote_copy_list try: validate_list_of_string_tuples(remote_copy_list, tuple_length=3) - except ValidationError as exc: - raise PluginInternalError(f'[presubmission of calc {this_pk}] remote_copy_list format problem: {exc}') + except ValidationError as exception: + raise PluginInternalError( + f'[presubmission of calc {this_pk}] remote_copy_list format problem: {exception}' + ) from exception for (remote_computer_uuid, _, dest_rel_path) in remote_copy_list: try: Computer.objects.get(uuid=remote_computer_uuid) # pylint: disable=unused-variable - except exceptions.NotExistent: - raise PluginInternalError('[presubmission of calc {}] ' - 'The remote copy requires a computer with UUID={}' - 'but no such computer was found in the ' - 'database'.format(this_pk, remote_computer_uuid)) + except exceptions.NotExistent as exception: + raise PluginInternalError( + '[presubmission of calc {}] ' + 'The remote copy requires a computer with UUID={}' + 'but no such computer was found in the ' + 'database'.format(this_pk, remote_computer_uuid) + ) from exception if os.path.isabs(dest_rel_path): - raise PluginInternalError('[presubmission of calc {}] ' - 'The destination path of the remote copy ' - 'is absolute! ({})'.format(this_pk, dest_rel_path)) + raise PluginInternalError( + '[presubmission of calc {}] ' + 'The destination path of the remote copy ' + 'is absolute! ({})'.format(this_pk, dest_rel_path) + ) return calc_info diff --git a/aiida/engine/processes/calcjobs/manager.py b/aiida/engine/processes/calcjobs/manager.py index c6a2adfc96..3c3cb6229c 100644 --- a/aiida/engine/processes/calcjobs/manager.py +++ b/aiida/engine/processes/calcjobs/manager.py @@ -8,13 +8,19 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module containing utilities and classes relating to job calculations running on systems that require transport.""" +import asyncio import contextlib +import contextvars import logging import time - -from tornado import concurrent, gen +from typing import Any, Dict, Hashable, Iterator, List, Optional, TYPE_CHECKING from aiida.common import lang +from aiida.orm import AuthInfo + +if TYPE_CHECKING: + from aiida.engine.transports import TransportQueue + from aiida.schedulers.datastructures import JobInfo __all__ = ('JobsList', 'JobManager') @@ -36,68 +42,65 @@ class JobsList: See the :py:class:`~aiida.engine.processes.calcjobs.manager.JobManager` for example usage. """ - def __init__(self, authinfo, transport_queue, last_updated=None): + def __init__(self, authinfo: AuthInfo, transport_queue: 'TransportQueue', last_updated: Optional[float] = None): """Construct an instance for the given authinfo and transport queue. :param authinfo: The authinfo used to check the jobs list - :type authinfo: :class:`aiida.orm.AuthInfo` :param transport_queue: A transport queue - :type: :class:`aiida.engine.transports.TransportQueue` :param last_updated: initialize the last updated timestamp - :type: float + """ lang.type_check(last_updated, float, allow_none=True) self._authinfo = authinfo self._transport_queue = transport_queue - self._loop = transport_queue.loop() + self._loop = transport_queue.loop self._logger = logging.getLogger(__name__) - self._jobs_cache = {} - self._job_update_requests = {} # Mapping: {job_id: Future} + self._jobs_cache: Dict[Hashable, 'JobInfo'] = {} + self._job_update_requests: Dict[Hashable, asyncio.Future] = {} # Mapping: {job_id: Future} self._last_updated = last_updated - self._update_handle = None + self._update_handle: Optional[asyncio.TimerHandle] = None @property - def logger(self): + def logger(self) -> logging.Logger: """Return the logger configured for this instance. :return: the logger """ return self._logger - def get_minimum_update_interval(self): + def get_minimum_update_interval(self) -> float: """Get the minimum interval that should be respected between updates of the list. :return: the minimum interval - :rtype: float + """ return self._authinfo.computer.get_minimum_job_poll_interval() @property - def last_updated(self): + def last_updated(self) -> Optional[float]: """Get the timestamp of when the list was last updated as produced by `time.time()` :return: The last update point - :rtype: float + """ return self._last_updated - @gen.coroutine - def _get_jobs_from_scheduler(self): + async def _get_jobs_from_scheduler(self) -> Dict[Hashable, 'JobInfo']: """Get the current jobs list from the scheduler. :return: a mapping of job ids to :py:class:`~aiida.schedulers.datastructures.JobInfo` instances - :rtype: dict + """ with self._transport_queue.request_transport(self._authinfo) as request: self.logger.info('waiting for transport') - transport = yield request + transport = await request scheduler = self._authinfo.computer.get_scheduler() scheduler.set_transport(transport) - kwargs = {'as_dict': True} + kwargs: Dict[str, Any] = {'as_dict': True} if scheduler.get_feature('can_query_by_user'): kwargs['user'] = '$USER' else: @@ -113,10 +116,9 @@ def _get_jobs_from_scheduler(self): for job_id, job_info in scheduler_response.items(): jobs_cache[job_id] = job_info - raise gen.Return(jobs_cache) + return jobs_cache - @gen.coroutine - def _update_job_info(self): + async def _update_job_info(self) -> None: """Update all of the job information objects. This will set the futures for all pending update requests where the corresponding job has a new status compared @@ -127,7 +129,7 @@ def _update_job_info(self): return # Update our cache of the job states - self._jobs_cache = yield self._get_jobs_from_scheduler() + self._jobs_cache = await self._get_jobs_from_scheduler() except Exception as exception: # Set the exception on all the update futures for future in self._job_update_requests.values(): @@ -149,7 +151,7 @@ def _update_job_info(self): self._job_update_requests = {} @contextlib.contextmanager - def request_job_info_update(self, job_id): + def request_job_info_update(self, job_id: Hashable) -> Iterator['asyncio.Future[JobInfo]']: """Request job info about a job when the job next changes state. If the job is not found in the jobs list at the update, the future will resolve to `None`. @@ -158,7 +160,7 @@ def request_job_info_update(self, job_id): :return: future that will resolve to a `JobInfo` object when the job changes state """ # Get or create the future - request = self._job_update_requests.setdefault(job_id, concurrent.Future()) + request = self._job_update_requests.setdefault(job_id, asyncio.Future()) assert not request.done(), 'Expected pending job info future, found in done state.' try: @@ -167,33 +169,40 @@ def request_job_info_update(self, job_id): finally: pass - def _ensure_updating(self): + def _ensure_updating(self) -> None: """Ensure that we are updating the job list from the remote resource. This will automatically stop if there are no outstanding requests. """ - @gen.coroutine - def updating(): + async def updating(): """Do the actual update, stop if not requests left.""" - yield self._update_job_info() + await self._update_job_info() # Any outstanding requests? if self._update_requests_outstanding(): - self._update_handle = self._loop.call_later(self._get_next_update_delay(), updating) + self._update_handle = self._loop.call_later( + self._get_next_update_delay(), + asyncio.ensure_future, + updating(), + context=contextvars.Context(), # type: ignore[call-arg] + ) else: self._update_handle = None # Check if we're already updating if self._update_handle is None: - self._update_handle = self._loop.call_later(self._get_next_update_delay(), updating) + self._update_handle = self._loop.call_later( + self._get_next_update_delay(), + asyncio.ensure_future, + updating(), + context=contextvars.Context(), # type: ignore[call-arg] + ) @staticmethod - def _has_job_state_changed(old, new): + def _has_job_state_changed(old: Optional['JobInfo'], new: Optional['JobInfo']) -> bool: """Return whether the states `old` and `new` are different. - :type old: :class:`aiida.schedulers.JobInfo` or `None` - :type new: :class:`aiida.schedulers.JobInfo` or `None` - :rtype: bool + """ if old is None and new is None: return False @@ -204,14 +213,14 @@ def _has_job_state_changed(old, new): return old.job_state != new.job_state or old.job_substate != new.job_substate - def _get_next_update_delay(self): + def _get_next_update_delay(self) -> float: """Calculate when we are next allowed to poll the scheduler. This delay is calculated as the minimum polling interval defined by the authentication info for this instance, minus time elapsed since the last update. :return: delay (in seconds) after which the scheduler may be polled again - :rtype: float + """ if self.last_updated is None: # Never updated, so do it straight away @@ -225,10 +234,10 @@ def _get_next_update_delay(self): return delay - def _update_requests_outstanding(self): + def _update_requests_outstanding(self) -> bool: return any(not request.done() for request in self._job_update_requests.values()) - def _get_jobs_with_scheduler(self): + def _get_jobs_with_scheduler(self) -> List[str]: """Get all the jobs that are currently with scheduler. :return: the list of jobs with the scheduler @@ -252,11 +261,11 @@ class JobManager: only hold per runner. """ - def __init__(self, transport_queue): + def __init__(self, transport_queue: 'TransportQueue') -> None: self._transport_queue = transport_queue - self._job_lists = {} + self._job_lists: Dict[Hashable, 'JobInfo'] = {} - def get_jobs_list(self, authinfo): + def get_jobs_list(self, authinfo: AuthInfo) -> JobsList: """Get or create a new `JobLists` instance for the given authinfo. :param authinfo: the `AuthInfo` @@ -268,13 +277,11 @@ def get_jobs_list(self, authinfo): return self._job_lists[authinfo.id] @contextlib.contextmanager - def request_job_info_update(self, authinfo, job_id): + def request_job_info_update(self, authinfo: AuthInfo, job_id: Hashable) -> Iterator['asyncio.Future[JobInfo]']: """Get a future that will resolve to information about a given job. This is a context manager so that if the user leaves the context the request is automatically cancelled. - :return: A tuple containing the `JobInfo` object and detailed job info. Both can be None. - :rtype: :class:`tornado.concurrent.Future` """ with self.get_jobs_list(authinfo).request_job_info_update(job_id) as request: try: diff --git a/aiida/engine/processes/calcjobs/tasks.py b/aiida/engine/processes/calcjobs/tasks.py index 0de3d8a8b7..95fb4b0f8e 100644 --- a/aiida/engine/processes/calcjobs/tasks.py +++ b/aiida/engine/processes/calcjobs/tasks.py @@ -8,41 +8,49 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Transport tasks for calculation jobs.""" +import asyncio import functools import logging import tempfile - -from tornado.gen import coroutine, Return +from typing import Any, Callable, Optional, TYPE_CHECKING import plumpy +import plumpy.process_states +import plumpy.futures from aiida.common.datastructures import CalcJobState from aiida.common.exceptions import FeatureNotAvailable, TransportTaskException from aiida.common.folders import SandboxFolder from aiida.engine.daemon import execmanager -from aiida.engine.utils import exponential_backoff_retry, interruptable_task +from aiida.engine.transports import TransportQueue +from aiida.engine.utils import exponential_backoff_retry, interruptable_task, InterruptableFuture +from aiida.orm.nodes.process.calculation.calcjob import CalcJobNode from aiida.schedulers.datastructures import JobState +from aiida.manage.configuration import get_config_option from ..process import ProcessState +if TYPE_CHECKING: + from .calcjob import CalcJob + UPLOAD_COMMAND = 'upload' SUBMIT_COMMAND = 'submit' UPDATE_COMMAND = 'update' RETRIEVE_COMMAND = 'retrieve' +STASH_COMMAND = 'stash' KILL_COMMAND = 'kill' -TRANSPORT_TASK_RETRY_INITIAL_INTERVAL = 20 -TRANSPORT_TASK_MAXIMUM_ATTEMTPS = 5 +RETRY_INTERVAL_OPTION = 'transport.task_retry_initial_interval' +MAX_ATTEMPTS_OPTION = 'transport.task_maximum_attempts' -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) # pylint: disable=invalid-name class PreSubmitException(Exception): """Raise in the `do_upload` coroutine when an exception is raised in `CalcJob.presubmit`.""" -@coroutine -def task_upload_job(process, transport_queue, cancellable): +async def task_upload_job(process: 'CalcJob', transport_queue: TransportQueue, cancellable: InterruptableFuture): """Transport task that will attempt to upload the files of a job calculation to the remote. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager @@ -50,28 +58,26 @@ def task_upload_job(process, transport_queue, cancellable): retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. If all retries fail, the task will raise a TransportTaskException - :param node: the node that represents the job calculation + :param process: the job calculation :param transport_queue: the TransportQueue from which to request a Transport :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled - :type cancellable: :class:`aiida.engine.utils.InterruptableFuture` - :raises: Return if the tasks was successfully completed + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted """ node = process.node if node.get_state() == CalcJobState.SUBMITTING: logger.warning(f'CalcJob<{node.pk}> already marked as SUBMITTING, skipping task_update_job') - raise Return + return - initial_interval = TRANSPORT_TASK_RETRY_INITIAL_INTERVAL - max_attempts = TRANSPORT_TASK_MAXIMUM_ATTEMTPS + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) - authinfo = node.computer.get_authinfo(node.user) + authinfo = node.get_authinfo() - @coroutine - def do_upload(): + async def do_upload(): with transport_queue.request_transport(authinfo) as request: - transport = yield cancellable.with_interrupt(request) + transport = await cancellable.with_interrupt(request) with SandboxFolder() as folder: # Any exception thrown in `presubmit` call is not transient so we circumvent the exponential backoff @@ -81,30 +87,30 @@ def do_upload(): raise PreSubmitException('exception occurred in presubmit call') from exception else: execmanager.upload_calculation(node, transport, calc_info, folder) + skip_submit = calc_info.skip_submit or False - raise Return + return skip_submit try: logger.info(f'scheduled request to upload CalcJob<{node.pk}>') - ignore_exceptions = (plumpy.CancelledError, PreSubmitException) - result = yield exponential_backoff_retry( + ignore_exceptions = (plumpy.futures.CancelledError, PreSubmitException, plumpy.process_states.Interruption) + skip_submit = await exponential_backoff_retry( do_upload, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions ) except PreSubmitException: raise - except plumpy.CancelledError: - pass - except Exception: + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): + raise + except Exception as exception: logger.warning(f'uploading CalcJob<{node.pk}> failed') - raise TransportTaskException(f'upload_calculation failed {max_attempts} times consecutively') + raise TransportTaskException(f'upload_calculation failed {max_attempts} times consecutively') from exception else: logger.info(f'uploading CalcJob<{node.pk}> successful') node.set_state(CalcJobState.SUBMITTING) - raise Return(result) + return skip_submit -@coroutine -def task_submit_job(node, transport_queue, cancellable): +async def task_submit_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture): """Transport task that will attempt to submit a job calculation. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager @@ -115,44 +121,42 @@ def task_submit_job(node, transport_queue, cancellable): :param node: the node that represents the job calculation :param transport_queue: the TransportQueue from which to request a Transport :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled - :type cancellable: :class:`aiida.engine.utils.InterruptableFuture` - :raises: Return if the tasks was successfully completed + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted """ if node.get_state() == CalcJobState.WITHSCHEDULER: assert node.get_job_id() is not None, 'job is WITHSCHEDULER, however, it does not have a job id' logger.warning(f'CalcJob<{node.pk}> already marked as WITHSCHEDULER, skipping task_submit_job') - raise Return(node.get_job_id()) + return node.get_job_id() - initial_interval = TRANSPORT_TASK_RETRY_INITIAL_INTERVAL - max_attempts = TRANSPORT_TASK_MAXIMUM_ATTEMTPS + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) - authinfo = node.computer.get_authinfo(node.user) + authinfo = node.get_authinfo() - @coroutine - def do_submit(): + async def do_submit(): with transport_queue.request_transport(authinfo) as request: - transport = yield cancellable.with_interrupt(request) - raise Return(execmanager.submit_calculation(node, transport)) + transport = await cancellable.with_interrupt(request) + return execmanager.submit_calculation(node, transport) try: logger.info(f'scheduled request to submit CalcJob<{node.pk}>') - result = yield exponential_backoff_retry( - do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption + ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption) + result = await exponential_backoff_retry( + do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions ) - except plumpy.Interruption: - pass - except Exception: + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise + raise + except Exception as exception: logger.warning(f'submitting CalcJob<{node.pk}> failed') - raise TransportTaskException(f'submit_calculation failed {max_attempts} times consecutively') + raise TransportTaskException(f'submit_calculation failed {max_attempts} times consecutively') from exception else: logger.info(f'submitting CalcJob<{node.pk}> successful') node.set_state(CalcJobState.WITHSCHEDULER) - raise Return(result) + return result -@coroutine -def task_update_job(node, job_manager, cancellable): +async def task_update_job(node: CalcJobNode, job_manager, cancellable: InterruptableFuture): """Transport task that will attempt to update the scheduler status of the job calculation. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager @@ -166,23 +170,24 @@ def task_update_job(node, job_manager, cancellable): :type job_manager: :class:`aiida.engine.processes.calcjobs.manager.JobManager` :param cancellable: A cancel flag :type cancellable: :class:`aiida.engine.utils.InterruptableFuture` - :raises: Return containing True if the tasks was successfully completed, False otherwise + :return: True if the tasks was successfully completed, False otherwise """ - if node.get_state() == CalcJobState.RETRIEVING: - logger.warning(f'CalcJob<{node.pk}> already marked as RETRIEVING, skipping task_update_job') - raise Return(True) + state = node.get_state() - initial_interval = TRANSPORT_TASK_RETRY_INITIAL_INTERVAL - max_attempts = TRANSPORT_TASK_MAXIMUM_ATTEMTPS + if state in [CalcJobState.RETRIEVING, CalcJobState.STASHING]: + logger.warning(f'CalcJob<{node.pk}> already marked as `{state}`, skipping task_update_job') + return True - authinfo = node.computer.get_authinfo(node.user) + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + + authinfo = node.get_authinfo() job_id = node.get_job_id() - @coroutine - def do_update(): + async def do_update(): # Get the update request with job_manager.request_job_info_update(authinfo, job_id) as update_request: - job_info = yield cancellable.with_interrupt(update_request) + job_info = await cancellable.with_interrupt(update_request) if job_info is None: # If the job is computed or not found assume it's done @@ -193,28 +198,31 @@ def do_update(): node.set_scheduler_state(job_info.job_state) job_done = job_info.job_state == JobState.DONE - raise Return(job_done) + return job_done try: logger.info(f'scheduled request to update CalcJob<{node.pk}>') - job_done = yield exponential_backoff_retry( - do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption + ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption) + job_done = await exponential_backoff_retry( + do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions ) - except plumpy.Interruption: + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise raise - except Exception: + except Exception as exception: logger.warning(f'updating CalcJob<{node.pk}> failed') - raise TransportTaskException(f'update_calculation failed {max_attempts} times consecutively') + raise TransportTaskException(f'update_calculation failed {max_attempts} times consecutively') from exception else: logger.info(f'updating CalcJob<{node.pk}> successful') if job_done: - node.set_state(CalcJobState.RETRIEVING) + node.set_state(CalcJobState.STASHING) - raise Return(job_done) + return job_done -@coroutine -def task_retrieve_job(node, transport_queue, retrieved_temporary_folder, cancellable): +async def task_retrieve_job( + node: CalcJobNode, transport_queue: TransportQueue, retrieved_temporary_folder: str, + cancellable: InterruptableFuture +): """Transport task that will attempt to retrieve all files of a completed job calculation. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager @@ -224,29 +232,28 @@ def task_retrieve_job(node, transport_queue, retrieved_temporary_folder, cancell :param node: the node that represents the job calculation :param transport_queue: the TransportQueue from which to request a Transport + :param retrieved_temporary_folder: the absolute path to a directory to store files :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled - :type cancellable: :class:`aiida.engine.utils.InterruptableFuture` - :raises: Return if the tasks was successfully completed + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted """ if node.get_state() == CalcJobState.PARSING: logger.warning(f'CalcJob<{node.pk}> already marked as PARSING, skipping task_retrieve_job') - raise Return + return - initial_interval = TRANSPORT_TASK_RETRY_INITIAL_INTERVAL - max_attempts = TRANSPORT_TASK_MAXIMUM_ATTEMTPS + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) - authinfo = node.computer.get_authinfo(node.user) + authinfo = node.get_authinfo() - @coroutine - def do_retrieve(): + async def do_retrieve(): with transport_queue.request_transport(authinfo) as request: - transport = yield cancellable.with_interrupt(request) + transport = await cancellable.with_interrupt(request) # Perform the job accounting and set it on the node if successful. If the scheduler does not implement this # still set the attribute but set it to `None`. This way we can distinguish calculation jobs for which the # accounting was called but could not be set. - scheduler = node.computer.get_scheduler() + scheduler = node.computer.get_scheduler() # type: ignore[union-attr] scheduler.set_transport(transport) try: @@ -257,27 +264,27 @@ def do_retrieve(): else: node.set_detailed_job_info(detailed_job_info) - raise Return(execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder)) + return execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder) try: logger.info(f'scheduled request to retrieve CalcJob<{node.pk}>') - yield exponential_backoff_retry( - do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption + ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption) + result = await exponential_backoff_retry( + do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions ) - except plumpy.Interruption: + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise raise - except Exception: + except Exception as exception: logger.warning(f'retrieving CalcJob<{node.pk}> failed') - raise TransportTaskException(f'retrieve_calculation failed {max_attempts} times consecutively') + raise TransportTaskException(f'retrieve_calculation failed {max_attempts} times consecutively') from exception else: node.set_state(CalcJobState.PARSING) logger.info(f'retrieving CalcJob<{node.pk}> successful') - raise Return + return result -@coroutine -def task_kill_job(node, transport_queue, cancellable): - """Transport task that will attempt to kill a job calculation. +async def task_stash_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture): + """Transport task that will optionally stash files of a completed job calculation on the remote. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will @@ -291,57 +298,118 @@ def task_kill_job(node, transport_queue, cancellable): :raises: Return if the tasks was successfully completed :raises: TransportTaskException if after the maximum number of retries the transport task still excepted """ - initial_interval = TRANSPORT_TASK_RETRY_INITIAL_INTERVAL - max_attempts = TRANSPORT_TASK_MAXIMUM_ATTEMTPS + if node.get_state() == CalcJobState.RETRIEVING: + logger.warning(f'calculation<{node.pk}> already marked as RETRIEVING, skipping task_stash_job') + return + + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + + authinfo = node.get_authinfo() + + async def do_stash(): + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + + logger.info(f'stashing calculation<{node.pk}>') + return execmanager.stash_calculation(node, transport) + + try: + await exponential_backoff_retry( + do_stash, + initial_interval, + max_attempts, + logger=node.logger, + ignore_exceptions=plumpy.process_states.Interruption + ) + except plumpy.process_states.Interruption: + raise + except Exception as exception: + logger.warning(f'stashing calculation<{node.pk}> failed') + raise TransportTaskException(f'stash_calculation failed {max_attempts} times consecutively') from exception + else: + node.set_state(CalcJobState.RETRIEVING) + logger.info(f'stashing calculation<{node.pk}> successful') + return + + +async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture): + """Transport task that will attempt to kill a job calculation. + + The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager + function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. + If all retries fail, the task will raise a TransportTaskException + + :param node: the node that represents the job calculation + :param transport_queue: the TransportQueue from which to request a Transport + :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled + + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted + """ + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) if node.get_state() in [CalcJobState.UPLOADING, CalcJobState.SUBMITTING]: logger.warning(f'CalcJob<{node.pk}> killed, it was in the {node.get_state()} state') - raise Return(True) + return True - authinfo = node.computer.get_authinfo(node.user) + authinfo = node.get_authinfo() - @coroutine - def do_kill(): + async def do_kill(): with transport_queue.request_transport(authinfo) as request: - transport = yield cancellable.with_interrupt(request) - raise Return(execmanager.kill_calculation(node, transport)) + transport = await cancellable.with_interrupt(request) + return execmanager.kill_calculation(node, transport) try: logger.info(f'scheduled request to kill CalcJob<{node.pk}>') - result = yield exponential_backoff_retry(do_kill, initial_interval, max_attempts, logger=node.logger) - except plumpy.Interruption: + result = await exponential_backoff_retry(do_kill, initial_interval, max_attempts, logger=node.logger) + except plumpy.process_states.Interruption: raise - except Exception: + except Exception as exception: logger.warning(f'killing CalcJob<{node.pk}> failed') - raise TransportTaskException(f'kill_calculation failed {max_attempts} times consecutively') + raise TransportTaskException(f'kill_calculation failed {max_attempts} times consecutively') from exception else: logger.info(f'killing CalcJob<{node.pk}> successful') node.set_scheduler_state(JobState.DONE) - raise Return(result) + return result -class Waiting(plumpy.Waiting): +class Waiting(plumpy.process_states.Waiting): """The waiting state for the `CalcJob` process.""" - def __init__(self, process, done_callback, msg=None, data=None): + def __init__( + self, + process: 'CalcJob', + done_callback: Optional[Callable[..., Any]], + msg: Optional[str] = None, + data: Optional[Any] = None + ): """ - :param :class:`~plumpy.base.state_machine.StateMachine` process: The process this state belongs to + :param process: The process this state belongs to """ super().__init__(process, done_callback, msg, data) - self._task = None - self._killing = None + self._task: Optional[InterruptableFuture] = None + self._killing: Optional[plumpy.futures.Future] = None + + @property + def process(self) -> 'CalcJob': + """ + :return: The process + """ + return self.state_machine # type: ignore[return-value] def load_instance_state(self, saved_state, load_context): super().load_instance_state(saved_state, load_context) self._task = None self._killing = None - @coroutine - def execute(self): + async def execute(self) -> plumpy.process_states.State: # type: ignore[override] # pylint: disable=invalid-overridden-method """Override the execute coroutine of the base `Waiting` state.""" - # pylint: disable=too-many-branches + # pylint: disable=too-many-branches,too-many-statements node = self.process.node transport_queue = self.process.runner.transport + result: plumpy.process_states.State = self command = self.data process_status = f'Waiting for transport task: {command}' @@ -350,15 +418,18 @@ def execute(self): if command == UPLOAD_COMMAND: node.set_process_status(process_status) - yield self._launch_task(task_upload_job, self.process, transport_queue) - raise Return(self.submit()) + skip_submit = await self._launch_task(task_upload_job, self.process, transport_queue) + if skip_submit: + result = self.retrieve() + else: + result = self.submit() elif command == SUBMIT_COMMAND: node.set_process_status(process_status) - yield self._launch_task(task_submit_job, node, transport_queue) - raise Return(self.update()) + await self._launch_task(task_submit_job, node, transport_queue) + result = self.update() - elif self.data == UPDATE_COMMAND: + elif command == UPDATE_COMMAND: job_done = False while not job_done: @@ -366,81 +437,104 @@ def execute(self): scheduler_state_string = scheduler_state.name if scheduler_state else 'UNKNOWN' process_status = f'Monitoring scheduler: job state {scheduler_state_string}' node.set_process_status(process_status) - job_done = yield self._launch_task(task_update_job, node, self.process.runner.job_manager) + job_done = await self._launch_task(task_update_job, node, self.process.runner.job_manager) + + if node.get_option('stash') is not None: + result = self.stash() + else: + result = self.retrieve() - raise Return(self.retrieve()) + elif command == STASH_COMMAND: + node.set_process_status(process_status) + await self._launch_task(task_stash_job, node, transport_queue) + result = self.retrieve() - elif self.data == RETRIEVE_COMMAND: + elif command == RETRIEVE_COMMAND: node.set_process_status(process_status) - # Create a temporary folder that has to be deleted by JobProcess.retrieved after successful parsing temp_folder = tempfile.mkdtemp() - yield self._launch_task(task_retrieve_job, node, transport_queue, temp_folder) - raise Return(self.parse(temp_folder)) + await self._launch_task(task_retrieve_job, node, transport_queue, temp_folder) + result = self.parse(temp_folder) else: raise RuntimeError('Unknown waiting command') except TransportTaskException as exception: - raise plumpy.PauseInterruption(f'Pausing after failed transport task: {exception}') - except plumpy.KillInterruption: - yield self._launch_task(task_kill_job, node, transport_queue) - self._killing.set_result(True) + raise plumpy.process_states.PauseInterruption(f'Pausing after failed transport task: {exception}') + except plumpy.process_states.KillInterruption: + await self._launch_task(task_kill_job, node, transport_queue) + if self._killing is not None: + self._killing.set_result(True) + else: + logger.warning(f'killed CalcJob<{node.pk}> but async future was None') raise - except Return: - node.set_process_status(None) + except (plumpy.futures.CancelledError, asyncio.CancelledError): + node.set_process_status(f'Transport task {command} was cancelled') raise - except (plumpy.Interruption, plumpy.CancelledError): + except plumpy.process_states.Interruption: node.set_process_status(f'Transport task {command} was interrupted') raise + else: + node.set_process_status(None) + return result finally: # If we were trying to kill but we didn't deal with it, make sure it's set here if self._killing and not self._killing.done(): self._killing.set_result(False) - @coroutine - def _launch_task(self, coro, *args, **kwargs): + async def _launch_task(self, coro, *args, **kwargs): """Launch a coroutine as a task, making sure to make it interruptable.""" task_fn = functools.partial(coro, *args, **kwargs) try: self._task = interruptable_task(task_fn) - result = yield self._task - raise Return(result) + result = await self._task + return result finally: self._task = None - def upload(self): + def upload(self) -> 'Waiting': """Return the `Waiting` state that will `upload` the `CalcJob`.""" msg = 'Waiting for calculation folder upload' - return self.create_state(ProcessState.WAITING, None, msg=msg, data=UPLOAD_COMMAND) + return self.create_state(ProcessState.WAITING, None, msg=msg, data=UPLOAD_COMMAND) # type: ignore[return-value] - def submit(self): + def submit(self) -> 'Waiting': """Return the `Waiting` state that will `submit` the `CalcJob`.""" msg = 'Waiting for scheduler submission' - return self.create_state(ProcessState.WAITING, None, msg=msg, data=SUBMIT_COMMAND) + return self.create_state(ProcessState.WAITING, None, msg=msg, data=SUBMIT_COMMAND) # type: ignore[return-value] - def update(self): + def update(self) -> 'Waiting': """Return the `Waiting` state that will `update` the `CalcJob`.""" msg = 'Waiting for scheduler update' - return self.create_state(ProcessState.WAITING, None, msg=msg, data=UPDATE_COMMAND) + return self.create_state(ProcessState.WAITING, None, msg=msg, data=UPDATE_COMMAND) # type: ignore[return-value] - def retrieve(self): + def retrieve(self) -> 'Waiting': """Return the `Waiting` state that will `retrieve` the `CalcJob`.""" msg = 'Waiting to retrieve' - return self.create_state(ProcessState.WAITING, None, msg=msg, data=RETRIEVE_COMMAND) + return self.create_state( + ProcessState.WAITING, None, msg=msg, data=RETRIEVE_COMMAND + ) # type: ignore[return-value] + + def stash(self): + """Return the `Waiting` state that will `stash` the `CalcJob`.""" + msg = 'Waiting to stash' + return self.create_state(ProcessState.WAITING, None, msg=msg, data=STASH_COMMAND) - def parse(self, retrieved_temporary_folder): + def parse(self, retrieved_temporary_folder: str) -> plumpy.process_states.Running: """Return the `Running` state that will parse the `CalcJob`. :param retrieved_temporary_folder: temporary folder used in retrieving that can be used during parsing. """ - return self.create_state(ProcessState.RUNNING, self.process.parse, retrieved_temporary_folder) + return self.create_state( + ProcessState.RUNNING, self.process.parse, retrieved_temporary_folder + ) # type: ignore[return-value] - def interrupt(self, reason): + def interrupt(self, reason: Any) -> Optional[plumpy.futures.Future]: # type: ignore[override] """Interrupt the `Waiting` state by calling interrupt on the transport task `InterruptableFuture`.""" if self._task is not None: self._task.interrupt(reason) - if isinstance(reason, plumpy.KillInterruption): + if isinstance(reason, plumpy.process_states.KillInterruption): if self._killing is None: - self._killing = plumpy.Future() + self._killing = plumpy.futures.Future() return self._killing + + return None diff --git a/aiida/engine/processes/exit_code.py b/aiida/engine/processes/exit_code.py index 0c54a5be72..c5baedebb7 100644 --- a/aiida/engine/processes/exit_code.py +++ b/aiida/engine/processes/exit_code.py @@ -8,38 +8,36 @@ # For further information please visit http://www.aiida.net # ########################################################################### """A namedtuple and namespace for ExitCodes that can be used to exit from Processes.""" -from collections import namedtuple +from typing import NamedTuple, Optional from aiida.common.extendeddicts import AttributeDict __all__ = ('ExitCode', 'ExitCodesNamespace') -class ExitCode(namedtuple('ExitCode', ['status', 'message', 'invalidates_cache'])): +class ExitCode(NamedTuple): """A simple data class to define an exit code for a :class:`~aiida.engine.processes.process.Process`. - When an instance of this clas is returned from a `Process._run()` call, it will be interpreted that the `Process` + When an instance of this class is returned from a `Process._run()` call, it will be interpreted that the `Process` should be terminated and that the exit status and message of the namedtuple should be set to the corresponding attributes of the node. - .. note:: this class explicitly sub-classes a namedtuple to not break backwards compatibility and to have it behave - exactly as a tuple. - :param status: positive integer exit status, where a non-zero value indicated the process failed, default is `0` - :type status: int - :param message: optional message with more details about the failure mode - :type message: str - :param invalidates_cache: optional flag, indicating that a process should not be used in caching - :type invalidates_cache: bool """ - def format(self, **kwargs): + status: int = 0 + message: Optional[str] = None + invalidates_cache: bool = False + + def format(self, **kwargs: str) -> 'ExitCode': """Create a clone of this exit code where the template message is replaced by the keyword arguments. :param kwargs: replacement parameters for the template message - :return: `ExitCode` + """ + if self.message is None: + raise ValueError('message is None') try: message = self.message.format(**kwargs) except KeyError: @@ -49,10 +47,6 @@ def format(self, **kwargs): return ExitCode(self.status, message, self.invalidates_cache) -# Set the defaults for the `ExitCode` attributes -ExitCode.__new__.__defaults__ = (0, None, False) - - class ExitCodesNamespace(AttributeDict): """A namespace of `ExitCode` instances that can be accessed through getattr as well as getitem. @@ -60,15 +54,13 @@ class ExitCodesNamespace(AttributeDict): `ExitCode` that needs to be retrieved or the key in the collection. """ - def __call__(self, identifier): + def __call__(self, identifier: str) -> ExitCode: """Return a specific exit code identified by either its exit status or label. :param identifier: the identifier of the exit code. If the type is integer, it will be interpreted as the exit code status, otherwise it be interpreted as the exit code label - :type identifier: str :returns: an `ExitCode` instance - :rtype: :class:`aiida.engine.ExitCode` :raises ValueError: if no exit code with the given label is defined for this process """ diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py index 1b57ea89aa..4f8c9ef999 100644 --- a/aiida/engine/processes/functions.py +++ b/aiida/engine/processes/functions.py @@ -13,18 +13,24 @@ import inspect import logging import signal +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, TYPE_CHECKING from aiida.common.lang import override from aiida.manage.manager import get_manager +from aiida.orm import CalcFunctionNode, Data, ProcessNode, WorkFunctionNode +from aiida.orm.utils.mixins import FunctionCalculationMixin from .process import Process +if TYPE_CHECKING: + from .exit_code import ExitCode + __all__ = ('calcfunction', 'workfunction', 'FunctionProcess') LOGGER = logging.getLogger(__name__) -def calcfunction(function): +def calcfunction(function: Callable[..., Any]) -> Callable[..., Any]: """ A decorator to turn a standard python function into a calcfunction. Example usage: @@ -51,11 +57,10 @@ def calcfunction(function): :return: The decorated function. :rtype: callable """ - from aiida.orm import CalcFunctionNode return process_function(node_class=CalcFunctionNode)(function) -def workfunction(function): +def workfunction(function: Callable[..., Any]) -> Callable[..., Any]: """ A decorator to turn a standard python function into a workfunction. Example usage: @@ -80,13 +85,12 @@ def workfunction(function): :type function: callable :return: The decorated function. - :rtype: callable - """ - from aiida.orm import WorkFunctionNode + + """ return process_function(node_class=WorkFunctionNode)(function) -def process_function(node_class): +def process_function(node_class: Type['ProcessNode']) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """ The base function decorator to create a FunctionProcess out of a normal python function. @@ -94,7 +98,7 @@ def process_function(node_class): :type node_class: :class:`aiida.orm.ProcessNode` """ - def decorator(function): + def decorator(function: Callable[..., Any]) -> Callable[..., Any]: """ Turn the decorated function into a FunctionProcess. @@ -103,21 +107,17 @@ def decorator(function): """ process_class = FunctionProcess.build(function, node_class=node_class) - def run_get_node(*args, **kwargs): + def run_get_node(*args, **kwargs) -> Tuple[Optional[Dict[str, Any]], 'ProcessNode']: """ Run the FunctionProcess with the supplied inputs in a local runner. - The function will have to create a new runner for the FunctionProcess instead of using the global runner, - because otherwise if this process function were to call another one from within its scope, that would use - the same runner and it would be blocking the event loop from continuing. - :param args: input arguments to construct the FunctionProcess :param kwargs: input keyword arguments to construct the FunctionProcess - :return: tuple of the outputs of the process and the process node pk - :rtype: (dict, int) + :return: tuple of the outputs of the process and the process node + """ manager = get_manager() - runner = manager.create_runner(with_persistence=False) + runner = manager.get_runner() inputs = process_class.create_inputs(*args, **kwargs) # Remove all the known inputs from the kwargs @@ -140,10 +140,9 @@ def run_get_node(*args, **kwargs): def kill_process(_num, _frame): """Send the kill signal to the process in the current scope.""" - from tornado import gen LOGGER.critical('runner received interrupt, killing process %s', process.pid) result = process.kill(msg='Process was killed because the runner received an interrupt') - raise gen.Return(result) + return result # Store the current handler on the signal such that it can be restored after process has terminated original_handler = signal.getsignal(kill_signal) @@ -155,7 +154,6 @@ def kill_process(_num, _frame): # If the `original_handler` is set, that means the `kill_process` was bound, which needs to be reset if original_handler: signal.signal(signal.SIGINT, original_handler) - runner.close() store_provenance = inputs.get('metadata', {}).get('store_provenance', True) if not store_provenance: @@ -164,13 +162,13 @@ def kill_process(_num, _frame): return result, process.node - def run_get_pk(*args, **kwargs): + def run_get_pk(*args, **kwargs) -> Tuple[Optional[Dict[str, Any]], int]: """Recreate the `run_get_pk` utility launcher. :param args: input arguments to construct the FunctionProcess :param kwargs: input keyword arguments to construct the FunctionProcess :return: tuple of the outputs of the process and the process node pk - :rtype: (dict, int) + """ result, node = run_get_node(*args, **kwargs) return result, node.pk @@ -181,14 +179,14 @@ def decorated_function(*args, **kwargs): result, _ = run_get_node(*args, **kwargs) return result - decorated_function.run = decorated_function - decorated_function.run_get_pk = run_get_pk - decorated_function.run_get_node = run_get_node - decorated_function.is_process_function = True - decorated_function.node_class = node_class - decorated_function.process_class = process_class - decorated_function.recreate_from = process_class.recreate_from - decorated_function.spec = process_class.spec + decorated_function.run = decorated_function # type: ignore[attr-defined] + decorated_function.run_get_pk = run_get_pk # type: ignore[attr-defined] + decorated_function.run_get_node = run_get_node # type: ignore[attr-defined] + decorated_function.is_process_function = True # type: ignore[attr-defined] + decorated_function.node_class = node_class # type: ignore[attr-defined] + decorated_function.process_class = process_class # type: ignore[attr-defined] + decorated_function.recreate_from = process_class.recreate_from # type: ignore[attr-defined] + decorated_function.spec = process_class.spec # type: ignore[attr-defined] return decorated_function @@ -198,10 +196,10 @@ def decorated_function(*args, **kwargs): class FunctionProcess(Process): """Function process class used for turning functions into a Process""" - _func_args = None + _func_args: Sequence[str] = () @staticmethod - def _func(*_args, **_kwargs): + def _func(*_args, **_kwargs) -> dict: """ This is used internally to store the actual function that is being wrapped and will be replaced by the build method. @@ -209,7 +207,7 @@ def _func(*_args, **_kwargs): return {} @staticmethod - def build(func, node_class): + def build(func: Callable[..., Any], node_class: Type['ProcessNode']) -> Type['FunctionProcess']: """ Build a Process from the given function. @@ -217,19 +215,13 @@ def build(func, node_class): these will also become inputs. :param func: The function to build a process from - :type func: callable - :param node_class: Provide a custom node class to be used, has to be constructable with no arguments. It has to be a sub class of `ProcessNode` and the mixin :class:`~aiida.orm.utils.mixins.FunctionCalculationMixin`. - :type node_class: :class:`aiida.orm.nodes.process.process.ProcessNode` :return: A Process class that represents the function - :rtype: :class:`FunctionProcess` - """ - from aiida import orm - from aiida.orm.utils.mixins import FunctionCalculationMixin - if not issubclass(node_class, orm.ProcessNode) or not issubclass(node_class, FunctionCalculationMixin): + """ + if not issubclass(node_class, ProcessNode) or not issubclass(node_class, FunctionCalculationMixin): raise TypeError('the node_class should be a sub class of `ProcessNode` and `FunctionCalculationMixin`') args, varargs, keywords, defaults, _, _, _ = inspect.getfullargspec(func) @@ -246,7 +238,7 @@ def _define(cls, spec): # pylint: disable=unused-argument for i, arg in enumerate(args): default = () - if i >= first_default_pos: + if defaults and i >= first_default_pos: default = defaults[i - first_default_pos] # If the keyword was already specified, simply override the default @@ -257,9 +249,9 @@ def _define(cls, spec): # pylint: disable=unused-argument # Note that we cannot use `None` because the validation will call `isinstance` which does not work # when passing `None`, but it does work with `NoneType` which is returned by calling `type(None)` if default is None: - valid_type = (orm.Data, type(None)) + valid_type = (Data, type(None)) else: - valid_type = (orm.Data,) + valid_type = (Data,) spec.input(arg, valid_type=valid_type, default=default) @@ -275,7 +267,7 @@ def _define(cls, spec): # pylint: disable=unused-argument # Function processes must have a dynamic output namespace since we do not know beforehand what outputs # will be returned and the valid types for the value should be `Data` nodes as well as a dictionary because # the output namespace can be nested. - spec.outputs.valid_type = (orm.Data, dict) + spec.outputs.valid_type = (Data, dict) return type( func.__name__, (FunctionProcess,), { @@ -289,7 +281,7 @@ def _define(cls, spec): # pylint: disable=unused-argument ) @classmethod - def validate_inputs(cls, *args, **kwargs): # pylint: disable=unused-argument + def validate_inputs(cls, *args: Any, **kwargs: Any) -> None: # pylint: disable=unused-argument """ Validate the positional and keyword arguments passed in the function call. @@ -308,11 +300,8 @@ def validate_inputs(cls, *args, **kwargs): # pylint: disable=unused-argument raise TypeError(f'{name}() takes {nparameters} positional arguments but {nargs} were given') @classmethod - def create_inputs(cls, *args, **kwargs): - """Create the input args for the FunctionProcess. - - :rtype: dict - """ + def create_inputs(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: + """Create the input args for the FunctionProcess.""" cls.validate_inputs(*args, **kwargs) ins = {} @@ -323,29 +312,28 @@ def create_inputs(cls, *args, **kwargs): return ins @classmethod - def args_to_dict(cls, *args): + def args_to_dict(cls, *args: Any) -> Dict[str, Any]: """ Create an input dictionary (of form label -> value) from supplied args. :param args: The values to use for the dictionary - :type args: list :return: A label -> value dictionary - :rtype: dict + """ return dict(list(zip(cls._func_args, args))) @classmethod - def get_or_create_db_record(cls): + def get_or_create_db_record(cls) -> 'ProcessNode': return cls._node_class() - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: if kwargs.get('enable_persistence', False): raise RuntimeError('Cannot persist a function process') - super().__init__(enable_persistence=False, *args, **kwargs) + super().__init__(enable_persistence=False, *args, **kwargs) # type: ignore @property - def process_class(self): + def process_class(self) -> Callable[..., Any]: """ Return the class that represents this Process, for the FunctionProcess this is the function itself. @@ -354,33 +342,29 @@ def process_class(self): class that really represents what was being executed. :return: A Process class that represents the function - :rtype: :class:`FunctionProcess` + """ return self._func - def execute(self): + def execute(self) -> Optional[Dict[str, Any]]: """Execute the process.""" result = super().execute() # FunctionProcesses can return a single value as output, and not a dictionary, so we should also return that - if len(result) == 1 and self.SINGLE_OUTPUT_LINKNAME in result: + if result and len(result) == 1 and self.SINGLE_OUTPUT_LINKNAME in result: return result[self.SINGLE_OUTPUT_LINKNAME] return result @override - def _setup_db_record(self): + def _setup_db_record(self) -> None: """Set up the database record for the process.""" super()._setup_db_record() self.node.store_source_info(self._func) @override - def run(self): - """Run the process. - - :rtype: :class:`aiida.engine.ExitCode` - """ - from aiida.orm import Data + def run(self) -> Optional['ExitCode']: + """Run the process.""" from .exit_code import ExitCode # The following conditional is required for the caching to properly work. Even if the source node has a process @@ -394,9 +378,9 @@ def run(self): args = [None] * len(self._func_args) kwargs = {} - for name, value in self.inputs.items(): + for name, value in (self.inputs or {}).items(): try: - if self.spec().inputs[name].non_db: + if self.spec().inputs[name].non_db: # type: ignore[union-attr] # Don't consider non-database inputs continue except KeyError: diff --git a/aiida/engine/processes/futures.py b/aiida/engine/processes/futures.py index e98f25c64f..dc110fbf5a 100644 --- a/aiida/engine/processes/futures.py +++ b/aiida/engine/processes/futures.py @@ -9,34 +9,44 @@ ########################################################################### # pylint: disable=cyclic-import """Futures that can poll or receive broadcasted messages while waiting for a task to be completed.""" -import tornado.gen +import asyncio +from typing import Optional, Union -import plumpy import kiwipy +from aiida.orm import Node, load_node + __all__ = ('ProcessFuture',) -class ProcessFuture(plumpy.Future): +class ProcessFuture(asyncio.Future): """Future that waits for a process to complete using both polling and listening for broadcast events if possible.""" _filtered = None - def __init__(self, pk, loop=None, poll_interval=None, communicator=None): + def __init__( + self, + pk: int, + loop: Optional[asyncio.AbstractEventLoop] = None, + poll_interval: Union[None, int, float] = None, + communicator: Optional[kiwipy.Communicator] = None + ): """Construct a future for a process node being finished. - If a None poll_interval is supplied polling will not be used. If a communicator is supplied it will be used - to listen for broadcast messages. + If a None poll_interval is supplied polling will not be used. + If a communicator is supplied it will be used to listen for broadcast messages. :param pk: process pk :param loop: An event loop :param poll_interval: optional polling interval, if None, polling is not activated. :param communicator: optional communicator, if None, will not subscribe to broadcasts. """ - from aiida.orm import load_node from .process import ProcessState - super().__init__() + # create future in specified event loop + loop = loop if loop is not None else asyncio.get_event_loop() + super().__init__(loop=loop) + assert not (poll_interval is None and communicator is None), 'Must poll or have a communicator to use' node = load_node(pk=pk) @@ -49,27 +59,31 @@ def __init__(self, pk, loop=None, poll_interval=None, communicator=None): # Try setting up a filtered broadcast subscriber if self._communicator is not None: - broadcast_filter = kiwipy.BroadcastFilter(lambda *args, **kwargs: self.set_result(node), sender=pk) + + def _subscriber(*args, **kwargs): # pylint: disable=unused-argument + if not self.done(): + self.set_result(node) + + broadcast_filter = kiwipy.BroadcastFilter(_subscriber, sender=pk) for state in [ProcessState.FINISHED, ProcessState.KILLED, ProcessState.EXCEPTED]: broadcast_filter.add_subject_filter(f'state_changed.*.{state.value}') self._broadcast_identifier = self._communicator.add_broadcast_subscriber(broadcast_filter) # Start polling if poll_interval is not None: - loop.add_callback(self._poll_process, node, poll_interval) + loop.create_task(self._poll_process(node, poll_interval)) - def cleanup(self): + def cleanup(self) -> None: """Clean up the future by removing broadcast subscribers from the communicator if it still exists.""" if self._communicator is not None: self._communicator.remove_broadcast_subscriber(self._broadcast_identifier) self._communicator = None self._broadcast_identifier = None - @tornado.gen.coroutine - def _poll_process(self, node, poll_interval): + async def _poll_process(self, node: Node, poll_interval: Union[int, float]) -> None: """Poll whether the process node has reached a terminal state.""" while not self.done() and not node.is_terminated: - yield tornado.gen.sleep(poll_interval) + await asyncio.sleep(poll_interval) if not self.done(): self.set_result(node) diff --git a/aiida/engine/processes/ports.py b/aiida/engine/processes/ports.py index 1613d2169d..7a66de915e 100644 --- a/aiida/engine/processes/ports.py +++ b/aiida/engine/processes/ports.py @@ -8,11 +8,16 @@ # For further information please visit http://www.aiida.net # ########################################################################### """AiiDA specific implementation of plumpy Ports and PortNamespaces for the ProcessSpec.""" -import collections +from collections.abc import Mapping import re +from typing import Any, Callable, Dict, Optional, Sequence import warnings from plumpy import ports +from plumpy.ports import breadcrumbs_to_port + +from aiida.common.links import validate_link_label +from aiida.orm import Data, Node __all__ = ( 'PortNamespace', 'InputPort', 'OutputPort', 'CalcJobOutputPort', 'WithNonDb', 'WithSerialize', @@ -26,21 +31,21 @@ class WithNonDb: """ - A mixin that adds support to a port to flag a that should not be stored + A mixin that adds support to a port to flag that it should not be stored in the database using the non_db=True flag. The mixins have to go before the main port class in the superclass order to make sure the mixin has the chance to strip out the non_db keyword. """ - def __init__(self, *args, **kwargs): - self._non_db_explicitly_set = bool('non_db' in kwargs) + def __init__(self, *args, **kwargs) -> None: + self._non_db_explicitly_set: bool = bool('non_db' in kwargs) non_db = kwargs.pop('non_db', False) - super().__init__(*args, **kwargs) - self._non_db = non_db + super().__init__(*args, **kwargs) # type: ignore[call-arg] + self._non_db: bool = non_db @property - def non_db_explicitly_set(self): + def non_db_explicitly_set(self) -> bool: """Return whether the a value for `non_db` was explicitly passed in the construction of the `Port`. :return: boolean, True if `non_db` was explicitly defined during construction, False otherwise @@ -48,7 +53,7 @@ def non_db_explicitly_set(self): return self._non_db_explicitly_set @property - def non_db(self): + def non_db(self) -> bool: """Return whether the value of this `Port` should be stored as a `Node` in the database. :return: boolean, True if it should be storable as a `Node`, False otherwise @@ -56,10 +61,8 @@ def non_db(self): return self._non_db @non_db.setter - def non_db(self, non_db): + def non_db(self, non_db: bool) -> None: """Set whether the value of this `Port` should be stored as a `Node` in the database. - - :param non_db: boolean """ self._non_db_explicitly_set = True self._non_db = non_db @@ -71,19 +74,17 @@ class WithSerialize: that are not AiiDA data types. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: serializer = kwargs.pop('serializer', None) - super().__init__(*args, **kwargs) - self._serializer = serializer + super().__init__(*args, **kwargs) # type: ignore[call-arg] + self._serializer: Callable[[Any], 'Data'] = serializer - def serialize(self, value): + def serialize(self, value: Any) -> 'Data': """Serialize the given value if it is not already a Data type and a serializer function is defined :param value: the value to be serialized :returns: a serialized version of the value or the unchanged value """ - from aiida.orm import Data - if self._serializer is None or isinstance(value, Data): return value @@ -96,11 +97,9 @@ class InputPort(WithSerialize, WithNonDb, ports.InputPort): value serialization to database storable types and support non database storable input types as well. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: """Override the constructor to check the type of the default if set and warn if not immutable.""" # pylint: disable=redefined-builtin,too-many-arguments - from aiida.orm import Node - if 'default' in kwargs: default = kwargs['default'] # If the default is specified and it is a node instance, raise a warning. This is to try and prevent that @@ -112,7 +111,7 @@ def __init__(self, *args, **kwargs): super(InputPort, self).__init__(*args, **kwargs) - def get_description(self): + def get_description(self) -> Dict[str, str]: """ Return a description of the InputPort, which will be a dictionary of its attributes @@ -127,13 +126,13 @@ def get_description(self): class CalcJobOutputPort(ports.OutputPort): """Sub class of plumpy.OutputPort which adds the `_pass_to_parser` attribute.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: pass_to_parser = kwargs.pop('pass_to_parser', False) super().__init__(*args, **kwargs) - self._pass_to_parser = pass_to_parser + self._pass_to_parser: bool = pass_to_parser @property - def pass_to_parser(self): + def pass_to_parser(self) -> bool: return self._pass_to_parser @@ -143,7 +142,7 @@ class PortNamespace(WithNonDb, ports.PortNamespace): serialization of a given mapping onto the ports of the PortNamespace. """ - def __setitem__(self, key, port): + def __setitem__(self, key: str, port: ports.Port) -> None: """Ensure that a `Port` being added inherits the `non_db` attribute if not explicitly defined at construction. The reasoning is that if a `PortNamespace` has `non_db=True`, which is different from the default value, very @@ -157,13 +156,13 @@ def __setitem__(self, key, port): self.validate_port_name(key) - if hasattr(port, 'non_db_explicitly_set') and not port.non_db_explicitly_set: - port.non_db = self.non_db + if hasattr(port, 'non_db_explicitly_set') and not port.non_db_explicitly_set: # type: ignore[attr-defined] + port.non_db = self.non_db # type: ignore[attr-defined] super().__setitem__(key, port) @staticmethod - def validate_port_name(port_name): + def validate_port_name(port_name: str) -> None: """Validate the given port name. Valid port names adhere to the following restrictions: @@ -181,8 +180,6 @@ def validate_port_name(port_name): :raise TypeError: if the port name is not a string type :raise ValueError: if the port name is invalid """ - from aiida.common.links import validate_link_label - try: validate_link_label(port_name) except ValueError as exception: @@ -195,7 +192,7 @@ def validate_port_name(port_name): if any([len(entry) > PORT_NAME_MAX_CONSECUTIVE_UNDERSCORES for entry in consecutive_underscores]): raise ValueError(f'invalid port name `{port_name}`: more than two consecutive underscores') - def serialize(self, mapping, breadcrumbs=()): + def serialize(self, mapping: Optional[Dict[str, Any]], breadcrumbs: Sequence[str] = ()) -> Optional[Dict[str, Any]]: """Serialize the given mapping onto this `Portnamespace`. It will recursively call this function on any nested `PortNamespace` or the serialize function on any `Ports`. @@ -204,26 +201,27 @@ def serialize(self, mapping, breadcrumbs=()): :param breadcrumbs: a tuple with the namespaces of parent namespaces :returns: the serialized mapping """ - from plumpy.ports import breadcrumbs_to_port - if mapping is None: return None - breadcrumbs += (self.name,) + breadcrumbs = (*breadcrumbs, self.name) - if not isinstance(mapping, collections.Mapping): - port = breadcrumbs_to_port(breadcrumbs) - raise TypeError(f'port namespace `{port}` received `{type(mapping)}` instead of a dictionary') + if not isinstance(mapping, Mapping): + port_name = breadcrumbs_to_port(breadcrumbs) + raise TypeError(f'port namespace `{port_name}` received `{type(mapping)}` instead of a dictionary') result = {} for name, value in mapping.items(): if name in self: + port = self[name] if isinstance(port, PortNamespace): result[name] = port.serialize(value, breadcrumbs) - else: + elif isinstance(port, InputPort): result[name] = port.serialize(value) + else: + raise AssertionError(f'port does not have a serialize method: {port}') else: result[name] = value diff --git a/aiida/engine/processes/process.py b/aiida/engine/processes/process.py index ae8cf25c7a..bf5df7b6b7 100644 --- a/aiida/engine/processes/process.py +++ b/aiida/engine/processes/process.py @@ -8,35 +8,48 @@ # For further information please visit http://www.aiida.net # ########################################################################### """The AiiDA process class""" +import asyncio import collections +from collections.abc import Mapping import enum import inspect -import uuid +import logging +from uuid import UUID import traceback - -from pika.exceptions import ConnectionClosed - -import plumpy -from plumpy import ProcessState +from types import TracebackType +from typing import ( + Any, cast, Dict, Iterable, Iterator, List, MutableMapping, Optional, Type, Tuple, Union, TYPE_CHECKING +) + +from aio_pika.exceptions import ConnectionClosed +import plumpy.exceptions +import plumpy.futures +import plumpy.processes +import plumpy.persistence +from plumpy.process_states import ProcessState, Finished from kiwipy.communications import UnroutableError from aiida import orm +from aiida.orm.utils import serialize from aiida.common import exceptions from aiida.common.extendeddicts import AttributeDict from aiida.common.lang import classproperty, override from aiida.common.links import LinkType from aiida.common.log import LOG_LEVEL_REPORT -from .exit_code import ExitCode +from .exit_code import ExitCode, ExitCodesNamespace from .builder import ProcessBuilder from .ports import InputPort, OutputPort, PortNamespace, PORT_NAMESPACE_SEPARATOR from .process_spec import ProcessSpec +if TYPE_CHECKING: + from aiida.engine.runners import Runner + __all__ = ('Process', 'ProcessState') -@plumpy.auto_persist('_parent_pid', '_enable_persistence') -class Process(plumpy.Process): +@plumpy.persistence.auto_persist('_parent_pid', '_enable_persistence') +class Process(plumpy.processes.Process): """ This class represents an AiiDA process which can be executed and will have full provenance saved in the database. @@ -46,88 +59,109 @@ class Process(plumpy.Process): _node_class = orm.ProcessNode _spec_class = ProcessSpec - SINGLE_OUTPUT_LINKNAME = 'result' + SINGLE_OUTPUT_LINKNAME: str = 'result' class SaveKeys(enum.Enum): """ Keys used to identify things in the saved instance state bundle. """ - CALC_ID = 'calc_id' + CALC_ID: str = 'calc_id' + + @classmethod + def spec(cls) -> ProcessSpec: + return super().spec() # type: ignore[return-value] @classmethod - def define(cls, spec): - # yapf: disable + def define(cls, spec: ProcessSpec) -> None: # type: ignore[override] + """Define the specification of the process, including its inputs, outputs and known exit codes. + + A `metadata` input namespace is defined, with optional ports that are not stored in the database. + + """ super().define(spec) spec.input_namespace(spec.metadata_key, required=False, non_db=True) - spec.input(f'{spec.metadata_key}.store_provenance', valid_type=bool, default=True, - help='If set to `False` provenance will not be stored in the database.') - spec.input(f'{spec.metadata_key}.description', valid_type=str, required=False, - help='Description to set on the process node.') - spec.input(f'{spec.metadata_key}.label', valid_type=str, required=False, - help='Label to set on the process node.') - spec.input(f'{spec.metadata_key}.call_link_label', valid_type=str, default='CALL', - help='The label to use for the `CALL` link if the process is called by another process.') + spec.input( + f'{spec.metadata_key}.store_provenance', + valid_type=bool, + default=True, + help='If set to `False` provenance will not be stored in the database.' + ) + spec.input( + f'{spec.metadata_key}.description', + valid_type=str, + required=False, + help='Description to set on the process node.' + ) + spec.input( + f'{spec.metadata_key}.label', valid_type=str, required=False, help='Label to set on the process node.' + ) + spec.input( + f'{spec.metadata_key}.call_link_label', + valid_type=str, + default='CALL', + help='The label to use for the `CALL` link if the process is called by another process.' + ) spec.exit_code(1, 'ERROR_UNSPECIFIED', message='The process has failed with an unspecified error.') spec.exit_code(2, 'ERROR_LEGACY_FAILURE', message='The process failed with legacy failure mode.') spec.exit_code(10, 'ERROR_INVALID_OUTPUT', message='The process returned an invalid output.') spec.exit_code(11, 'ERROR_MISSING_OUTPUT', message='The process did not register a required output.') @classmethod - def get_builder(cls): + def get_builder(cls) -> ProcessBuilder: return ProcessBuilder(cls) @classmethod - def get_or_create_db_record(cls): + def get_or_create_db_record(cls) -> orm.ProcessNode: """ Create a process node that represents what happened in this process. :return: A process node - :rtype: :class:`aiida.orm.ProcessNode` """ return cls._node_class() - def __init__(self, inputs=None, logger=None, runner=None, parent_pid=None, enable_persistence=True): + def __init__( + self, + inputs: Optional[Dict[str, Any]] = None, + logger: Optional[logging.Logger] = None, + runner: Optional['Runner'] = None, + parent_pid: Optional[int] = None, + enable_persistence: bool = True + ) -> None: """ Process constructor. :param inputs: process inputs - :type inputs: dict - :param logger: aiida logger - :type logger: :class:`logging.Logger` - :param runner: process runner - :type: :class:`aiida.engine.runners.Runner` - :param parent_pid: id of parent process - :type parent_pid: int - :param enable_persistence: whether to persist this process - :type enable_persistence: bool + """ from aiida.manage import manager self._runner = runner if runner is not None else manager.get_manager().get_runner() + assert self._runner.communicator is not None, 'communicator not set for runner' super().__init__( inputs=self.spec().inputs.serialize(inputs), logger=logger, loop=self._runner.loop, - communicator=self.runner.communicator) + communicator=self._runner.communicator + ) - self._node = None + self._node: Optional[orm.ProcessNode] = None self._parent_pid = parent_pid self._enable_persistence = enable_persistence if self._enable_persistence and self.runner.persister is None: self.logger.warning('Disabling persistence, runner does not have a persister') self._enable_persistence = False - def init(self): + def init(self) -> None: super().init() if self._logger is None: self.set_logger(self.node.logger) @classmethod - def get_exit_statuses(cls, exit_code_labels): + def get_exit_statuses(cls, exit_code_labels: Iterable[str]) -> List[int]: """Return the exit status (integers) for the given exit code labels. :param exit_code_labels: a list of strings that reference exit code labels of this process class @@ -138,37 +172,34 @@ def get_exit_statuses(cls, exit_code_labels): return [getattr(exit_codes, label).status for label in exit_code_labels] @classproperty - def exit_codes(cls): # pylint: disable=no-self-argument + def exit_codes(cls) -> ExitCodesNamespace: # pylint: disable=no-self-argument """Return the namespace of exit codes defined for this WorkChain through its ProcessSpec. The namespace supports getitem and getattr operations with an ExitCode label to retrieve a specific code. Additionally, the namespace can also be called with either the exit code integer status to retrieve it. :returns: ExitCodesNamespace of ExitCode named tuples - :rtype: :class:`aiida.engine.ExitCodesNamespace` + """ return cls.spec().exit_codes @classproperty - def spec_metadata(cls): # pylint: disable=no-self-argument - """Return the metadata port namespace of the process specification of this process. - - :return: metadata dictionary - :rtype: dict - """ - return cls.spec().inputs['metadata'] + def spec_metadata(cls) -> PortNamespace: # pylint: disable=no-self-argument + """Return the metadata port namespace of the process specification of this process.""" + return cls.spec().inputs['metadata'] # type: ignore[return-value] @property - def node(self): + def node(self) -> orm.ProcessNode: """Return the ProcessNode used by this process to represent itself in the database. :return: instance of sub class of ProcessNode - :rtype: :class:`aiida.orm.ProcessNode` + """ + assert self._node is not None return self._node @property - def uuid(self): + def uuid(self) -> str: # type: ignore[override] """Return the UUID of the process which corresponds to the UUID of its associated `ProcessNode`. :return: the UUID associated to this process instance @@ -176,32 +207,43 @@ def uuid(self): return self.node.uuid @property - def metadata(self): + def metadata(self) -> AttributeDict: """Return the metadata that were specified when this process instance was launched. :return: metadata dictionary - :rtype: dict + """ try: + assert self.inputs is not None return self.inputs.metadata - except AttributeError: + except (AssertionError, AttributeError): return AttributeDict() - def _save_checkpoint(self): + def _save_checkpoint(self) -> None: """ Save the current state in a chechpoint if persistence is enabled and the process state is not terminal If the persistence call excepts with a PersistenceError, it will be caught and a warning will be logged. """ if self._enable_persistence and not self._state.is_terminal(): + if self.runner.persister is None: + self.logger.exception( + 'No persister set to save checkpoint, this means you will ' + 'not be able to restart in case of a crash until the next successful checkpoint.' + ) + return None try: self.runner.persister.save_checkpoint(self) - except plumpy.PersistenceError: - self.logger.exception('Exception trying to save checkpoint, this means you will ' - 'not be able to restart in case of a crash until the next successful checkpoint.') + except plumpy.exceptions.PersistenceError: + self.logger.exception( + 'Exception trying to save checkpoint, this means you will ' + 'not be able to restart in case of a crash until the next successful checkpoint.' + ) @override - def save_instance_state(self, out_state, save_context): + def save_instance_state( + self, out_state: MutableMapping[str, Any], save_context: Optional[plumpy.persistence.LoadSaveContext] + ) -> None: """Save instance state. See documentation of :meth:`!plumpy.processes.Process.save_instance_state`. @@ -213,21 +255,23 @@ def save_instance_state(self, out_state, save_context): out_state[self.SaveKeys.CALC_ID.value] = self.pid - def get_provenance_inputs_iterator(self): + def get_provenance_inputs_iterator(self) -> Iterator[Tuple[str, Union[InputPort, PortNamespace]]]: """Get provenance input iterator. :rtype: filter """ + assert self.inputs is not None return filter(lambda kv: not kv[0].startswith('_'), self.inputs.items()) @override - def load_instance_state(self, saved_state, load_context): + def load_instance_state( + self, saved_state: MutableMapping[str, Any], load_context: plumpy.persistence.LoadSaveContext + ) -> None: """Load instance state. :param saved_state: saved instance state - :param load_context: - :type load_context: :class:`!plumpy.persistence.LoadSaveContext` + """ from aiida.manage import manager @@ -241,20 +285,17 @@ def load_instance_state(self, saved_state, load_context): if self.SaveKeys.CALC_ID.value in saved_state: self._node = orm.load_node(saved_state[self.SaveKeys.CALC_ID.value]) - self._pid = self.node.pk + self._pid = self.node.pk # pylint: disable=attribute-defined-outside-init else: - self._pid = self._create_and_setup_db_record() + self._pid = self._create_and_setup_db_record() # pylint: disable=attribute-defined-outside-init self.node.logger.info(f'Loaded process<{self.node.pk}> from saved state') - def kill(self, msg=None): + def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Future]: """ Kill the process and all the children calculations it called :param msg: message - :type msg: str - - :rtype: bool """ self.node.logger.info(f'Request to kill Process<{self.node.pk}>') @@ -266,35 +307,45 @@ def kill(self, msg=None): if result is not False and not had_been_terminated: killing = [] for child in self.node.called: + if self.runner.controller is None: + self.logger.info('no controller available to kill child<%s>', child.pk) + continue try: result = self.runner.controller.kill_process(child.pk, f'Killed by parent<{self.node.pk}>') - if isinstance(result, plumpy.Future): + result = asyncio.wrap_future(result) # type: ignore[arg-type] + if asyncio.isfuture(result): killing.append(result) except ConnectionClosed: self.logger.info('no connection available to kill child<%s>', child.pk) except UnroutableError: self.logger.info('kill signal was unable to reach child<%s>', child.pk) - if isinstance(result, plumpy.Future): + if asyncio.isfuture(result): # We ourselves are waiting to be killed so add it to the list - killing.append(result) + killing.append(result) # type: ignore[arg-type] if killing: # We are waiting for things to be killed, so return the 'gathered' future - result = plumpy.gather(killing) + kill_future = plumpy.futures.gather(*killing) + result = self.loop.create_future() + + def done(done_future: plumpy.futures.Future): + is_all_killed = all(done_future.result()) + result.set_result(is_all_killed) # type: ignore[union-attr] + + kill_future.add_done_callback(done) return result @override - def out(self, output_port, value=None): + def out(self, output_port: str, value: Any = None) -> None: """Attach output to output port. The name of the port will be used as the link label. :param output_port: name of output port - :type output_port: str - :param value: value to put inside output port + """ if value is None: # In this case assume that output_port is the actual value and there is just one return value @@ -303,7 +354,7 @@ def out(self, output_port, value=None): return super().out(output_port, value) - def out_many(self, out_dict): + def out_many(self, out_dict: Dict[str, Any]) -> None: """Attach outputs to multiple output ports. Keys of the dictionary will be used as output port names, values as outputs. @@ -314,39 +365,40 @@ def out_many(self, out_dict): for key, value in out_dict.items(): self.out(key, value) - def on_create(self): + def on_create(self) -> None: """Called when a Process is created.""" super().on_create() # If parent PID hasn't been supplied try to get it from the stack if self._parent_pid is None and Process.current(): current = Process.current() if isinstance(current, Process): - self._parent_pid = current.pid - self._pid = self._create_and_setup_db_record() + self._parent_pid = current.pid # type: ignore[assignment] + self._pid = self._create_and_setup_db_record() # pylint: disable=attribute-defined-outside-init @override - def on_entering(self, state): + def on_entering(self, state: plumpy.process_states.State) -> None: super().on_entering(state) # Update the node attributes every time we enter a new state - def on_entered(self, from_state): + def on_entered(self, from_state: Optional[plumpy.process_states.State]) -> None: + """After entering a new state, save a checkpoint and update the latest process state change timestamp.""" # pylint: disable=cyclic-import from aiida.engine.utils import set_process_state_change_timestamp self.update_node_state(self._state) self._save_checkpoint() - # Update the latest process state change timestamp set_process_state_change_timestamp(self) super().on_entered(from_state) @override - def on_terminated(self): + def on_terminated(self) -> None: """Called when a Process enters a terminal state.""" super().on_terminated() if self._enable_persistence: try: + assert self.runner.persister is not None self.runner.persister.delete_checkpoint(self.pid) - except Exception: # pylint: disable=broad-except - self.logger.exception('Failed to delete checkpoint') + except Exception as error: # pylint: disable=broad-except + self.logger.exception('Failed to delete checkpoint: %s', error) try: self.node.seal() @@ -354,7 +406,7 @@ def on_terminated(self): pass @override - def on_except(self, exc_info): + def on_except(self, exc_info: Tuple[Any, Exception, TracebackType]) -> None: """ Log the exception by calling the report method with formatted stack trace from exception info object and store the exception string as a node attribute @@ -362,18 +414,16 @@ def on_except(self, exc_info): :param exc_info: the sys.exc_info() object (type, value, traceback) """ super().on_except(exc_info) - self.node.set_exception(''.join(traceback.format_exception(exc_info[0], exc_info[1], None))) + self.node.set_exception(''.join(traceback.format_exception(exc_info[0], exc_info[1], None)).rstrip()) self.report(''.join(traceback.format_exception(*exc_info))) @override - def on_finish(self, result, successful): + def on_finish(self, result: Union[int, ExitCode], successful: bool) -> None: """ Set the finish status on the process node. :param result: result of the process - :type result: int or :class:`aiida.engine.ExitCode` - :param successful: whether execution was successful - :type successful: bool + """ super().on_finish(result, successful) @@ -389,23 +439,24 @@ def on_finish(self, result, successful): self.node.set_exit_status(result.status) self.node.set_exit_message(result.message) else: - raise ValueError('the result should be an integer, ExitCode or None, got {} {} {}'.format( - type(result), result, self.pid)) + raise ValueError( + f'the result should be an integer, ExitCode or None, got {type(result)} {result} {self.pid}' + ) @override - def on_paused(self, msg=None): + def on_paused(self, msg: Optional[str] = None) -> None: """ The Process was paused so set the paused attribute on the process node :param msg: message - :type msg: str + """ super().on_paused(msg) self._save_checkpoint() self.node.pause() @override - def on_playing(self): + def on_playing(self) -> None: """ The Process was unpaused so remove the paused attribute on the process node """ @@ -413,14 +464,13 @@ def on_playing(self): self.node.unpause() @override - def on_output_emitting(self, output_port, value): + def on_output_emitting(self, output_port: str, value: Any) -> None: """ The process has emitted a value on the given output port. :param output_port: The output port name the value was emitted on - :type output_port: str - :param value: The value emitted + """ super().on_output_emitting(output_port, value) @@ -428,39 +478,36 @@ def on_output_emitting(self, output_port, value): if isinstance(output_port, OutputPort) and not isinstance(value, orm.Data): raise TypeError(f'Processes can only return `orm.Data` instances as output, got {value.__class__}') - def set_status(self, status): + def set_status(self, status: Optional[str]) -> None: """ The status of the Process is about to be changed, so we reflect this is in node's attribute proxy. :param status: the status message - :type status: str + """ super().set_status(status) self.node.set_process_status(status) - def submit(self, process, *args, **kwargs): + def submit(self, process: Type['Process'], *args, **kwargs) -> orm.ProcessNode: """Submit process for execution. :param process: process - :type process: :class:`aiida.engine.Process` + :return: the calculation node of the process """ return self.runner.submit(process, *args, **kwargs) @property - def runner(self): - """Get process runner. - - :rtype: :class:`aiida.engine.runners.Runner` - """ + def runner(self) -> 'Runner': + """Get process runner.""" return self._runner - def get_parent_calc(self): + def get_parent_calc(self) -> Optional[orm.ProcessNode]: """ Get the parent process node :return: the parent process node if there is one - :rtype: :class:`aiida.orm.ProcessNode` + """ # Can't get it if we don't know our parent if self._parent_pid is None: @@ -469,12 +516,11 @@ def get_parent_calc(self): return orm.load_node(pk=self._parent_pid) @classmethod - def build_process_type(cls): + def build_process_type(cls) -> str: """ The process type. :return: string of the process type - :rtype: str Note: This could be made into a property 'process_type' but in order to have it be a property of the class it would need to be defined in the metaclass, see https://bugs.python.org/issue20659 @@ -493,29 +539,25 @@ def build_process_type(cls): return process_type - def report(self, msg, *args, **kwargs): + def report(self, msg: str, *args, **kwargs) -> None: """Log a message to the logger, which should get saved to the database through the attached DbLogHandler. The pk, class name and function name of the caller are prepended to the given message :param msg: message to log - :type msg: str - :param args: args to pass to the log call - :type args: list - :param kwargs: kwargs to pass to the log call - :type kwargs: dict + """ message = f'[{self.node.pk}|{self.__class__.__name__}|{inspect.stack()[1][3]}]: {msg}' self.logger.log(LOG_LEVEL_REPORT, message, *args, **kwargs) - def _create_and_setup_db_record(self): + def _create_and_setup_db_record(self) -> Union[int, UUID]: """ Create and setup the database record for this process - :return: the uuid of the process - :rtype: :class:`!uuid.UUID` + :return: the uuid or pk of the process + """ self._node = self.get_or_create_db_record() self._setup_db_record() @@ -523,7 +565,7 @@ def _create_and_setup_db_record(self): try: self.node.store_all() if self.node.is_finished_ok: - self._state = ProcessState.FINISHED + self._state = Finished(self, None, True) # pylint: disable=attribute-defined-outside-init for entry in self.node.get_outgoing(link_type=LinkType.RETURN): if entry.link_label.endswith(f'_{entry.node.pk}'): continue @@ -542,35 +584,33 @@ def _create_and_setup_db_record(self): if self.node.pk is not None: return self.node.pk - return uuid.UUID(self.node.uuid) + return UUID(self.node.uuid) @override - def encode_input_args(self, inputs): + def encode_input_args(self, inputs: Dict[str, Any]) -> str: # pylint: disable=no-self-use """ Encode input arguments such that they may be saved in a Bundle :param inputs: A mapping of the inputs as passed to the process :return: The encoded (serialized) inputs """ - from aiida.orm.utils import serialize return serialize.serialize(inputs) @override - def decode_input_args(self, encoded): + def decode_input_args(self, encoded: str) -> Dict[str, Any]: # pylint: disable=no-self-use """ Decode saved input arguments as they came from the saved instance state Bundle :param encoded: encoded (serialized) inputs :return: The decoded input args """ - from aiida.orm.utils import serialize return serialize.deserialize(encoded) - def update_node_state(self, state): + def update_node_state(self, state: plumpy.process_states.State) -> None: self.update_outputs() self.node.set_process_state(state.LABEL) - def update_outputs(self): + def update_outputs(self) -> None: """Attach new outputs to the node since the last call. Does nothing, if self.metadata.store_provenance is False. @@ -594,7 +634,7 @@ def update_outputs(self): output.store() - def _setup_db_record(self): + def _setup_db_record(self) -> None: """ Create the database record for this process and the links with respect to its inputs @@ -631,9 +671,9 @@ def _setup_db_record(self): self._setup_metadata() self._setup_inputs() - def _setup_metadata(self): + def _setup_metadata(self) -> None: """Store the metadata on the ProcessNode.""" - version_info = self.runner.plugin_version_provider.get_version_info(self) + version_info = self.runner.plugin_version_provider.get_version_info(self.__class__) self.node.set_attribute_many(version_info) for name, metadata in self.metadata.items(): @@ -652,7 +692,7 @@ def _setup_metadata(self): else: raise RuntimeError(f'unsupported metadata key: {name}') - def _setup_inputs(self): + def _setup_inputs(self) -> None: """Create the links between the input nodes and the ProcessNode that represents this process.""" for name, node in self._flat_inputs().items(): @@ -671,7 +711,7 @@ def _setup_inputs(self): elif isinstance(self.node, orm.WorkflowNode): self.node.add_incoming(node, LinkType.INPUT_WORK, name) - def _flat_inputs(self): + def _flat_inputs(self) -> Dict[str, Any]: """ Return a flattened version of the parsed inputs dictionary. @@ -679,12 +719,13 @@ def _flat_inputs(self): is not passed, as those are dealt with separately in `_setup_metadata`. :return: flat dictionary of parsed inputs - :rtype: dict + """ + assert self.inputs is not None inputs = {key: value for key, value in self.inputs.items() if key != self.spec().metadata_key} return dict(self._flatten_inputs(self.spec().inputs, inputs)) - def _flat_outputs(self): + def _flat_outputs(self) -> Dict[str, Any]: """ Return a flattened version of the registered outputs dictionary. @@ -694,104 +735,103 @@ def _flat_outputs(self): """ return dict(self._flatten_outputs(self.spec().outputs, self.outputs)) - def _flatten_inputs(self, port, port_value, parent_name='', separator=PORT_NAMESPACE_SEPARATOR): + def _flatten_inputs( + self, + port: Union[None, InputPort, PortNamespace], + port_value: Any, + parent_name: str = '', + separator: str = PORT_NAMESPACE_SEPARATOR + ) -> List[Tuple[str, Any]]: """ Function that will recursively flatten the inputs dictionary, omitting inputs for ports that are marked as being non database storable :param port: port against which to map the port value, can be InputPort or PortNamespace - :type port: :class:`plumpy.ports.Port` - :param port_value: value for the current port, can be a Mapping - :param parent_name: the parent key with which to prefix the keys - :type parent_name: str - :param separator: character to use for the concatenation of keys - :type separator: str - :return: flat list of inputs - :rtype: list + """ if (port is None and isinstance(port_value, orm.Node)) or (isinstance(port, InputPort) and not port.non_db): return [(parent_name, port_value)] - if port is None and isinstance(port_value, collections.Mapping) or isinstance(port, PortNamespace): + if port is None and isinstance(port_value, Mapping) or isinstance(port, PortNamespace): items = [] for name, value in port_value.items(): prefixed_key = parent_name + separator + name if parent_name else name try: - nested_port = port[name] + nested_port = cast(Union[InputPort, PortNamespace], port[name]) if port else None except (KeyError, TypeError): nested_port = None sub_items = self._flatten_inputs( - port=nested_port, port_value=value, parent_name=prefixed_key, separator=separator) + port=nested_port, port_value=value, parent_name=prefixed_key, separator=separator + ) items.extend(sub_items) return items assert (port is None) or (isinstance(port, InputPort) and port.non_db) return [] - def _flatten_outputs(self, port, port_value, parent_name='', separator=PORT_NAMESPACE_SEPARATOR): + def _flatten_outputs( + self, + port: Union[None, OutputPort, PortNamespace], + port_value: Any, + parent_name: str = '', + separator: str = PORT_NAMESPACE_SEPARATOR + ) -> List[Tuple[str, Any]]: """ Function that will recursively flatten the outputs dictionary. :param port: port against which to map the port value, can be OutputPort or PortNamespace - :type port: :class:`plumpy.ports.Port` - :param port_value: value for the current port, can be a Mapping - :type parent_name: str - :param parent_name: the parent key with which to prefix the keys - :type parent_name: str - :param separator: character to use for the concatenation of keys - :type separator: str :return: flat list of outputs - :rtype: list + """ if port is None and isinstance(port_value, orm.Node) or isinstance(port, OutputPort): return [(parent_name, port_value)] - if (port is None and isinstance(port_value, collections.Mapping) or isinstance(port, PortNamespace)): + if (port is None and isinstance(port_value, Mapping) or isinstance(port, PortNamespace)): items = [] for name, value in port_value.items(): prefixed_key = parent_name + separator + name if parent_name else name try: - nested_port = port[name] + nested_port = cast(Union[OutputPort, PortNamespace], port[name]) if port else None except (KeyError, TypeError): nested_port = None sub_items = self._flatten_outputs( - port=nested_port, port_value=value, parent_name=prefixed_key, separator=separator) + port=nested_port, port_value=value, parent_name=prefixed_key, separator=separator + ) items.extend(sub_items) return items assert port is None, port return [] - def exposed_inputs(self, process_class, namespace=None, agglomerate=True): - """ - Gather a dictionary of the inputs that were exposed for a given Process class under an optional namespace. + def exposed_inputs( + self, + process_class: Type['Process'], + namespace: Optional[str] = None, + agglomerate: bool = True + ) -> AttributeDict: + """Gather a dictionary of the inputs that were exposed for a given Process class under an optional namespace. :param process_class: Process class whose inputs to try and retrieve - :type process_class: :class:`aiida.engine.Process` - :param namespace: PortNamespace in which to look for the inputs - :type namespace: str - :param agglomerate: If set to true, all parent namespaces of the given ``namespace`` will also be searched for inputs. Inputs in lower-lying namespaces take precedence. - :type agglomerate: bool :returns: exposed inputs - :rtype: dict + """ exposed_inputs = {} @@ -805,9 +845,9 @@ def exposed_inputs(self, process_class, namespace=None, agglomerate=True): else: inputs = self.inputs for part in sub_namespace.split('.'): - inputs = inputs[part] + inputs = inputs[part] # type: ignore[index] try: - port_namespace = self.spec().inputs.get_port(sub_namespace) + port_namespace = self.spec().inputs.get_port(sub_namespace) # type: ignore[assignment] except KeyError: raise ValueError(f'this process does not contain the "{sub_namespace}" input namespace') @@ -815,26 +855,26 @@ def exposed_inputs(self, process_class, namespace=None, agglomerate=True): exposed_inputs_list = self.spec()._exposed_inputs[sub_namespace][process_class] # pylint: disable=protected-access for name in port_namespace.ports.keys(): - if name in inputs and name in exposed_inputs_list: + if inputs and name in inputs and name in exposed_inputs_list: exposed_inputs[name] = inputs[name] return AttributeDict(exposed_inputs) - def exposed_outputs(self, node, process_class, namespace=None, agglomerate=True): + def exposed_outputs( + self, + node: orm.ProcessNode, + process_class: Type['Process'], + namespace: Optional[str] = None, + agglomerate: bool = True + ) -> AttributeDict: """Return the outputs which were exposed from the ``process_class`` and emitted by the specific ``node`` :param node: process node whose outputs to try and retrieve - :type node: :class:`aiida.orm.nodes.process.ProcessNode` - :param namespace: Namespace in which to search for exposed outputs. - :type namespace: str - :param agglomerate: If set to true, all parent namespaces of the given ``namespace`` will also be searched for outputs. Outputs in lower-lying namespaces take precedence. - :type agglomerate: bool :returns: exposed outputs - :rtype: dict """ namespace_separator = self.spec().namespace_separator @@ -843,9 +883,7 @@ def exposed_outputs(self, node, process_class, namespace=None, agglomerate=True) # maps the exposed name to all outputs that belong to it top_namespace_map = collections.defaultdict(list) link_types = (LinkType.CREATE, LinkType.RETURN) - process_outputs_dict = { - entry.link_label: entry.node for entry in node.get_outgoing(link_type=link_types) - } + process_outputs_dict = {entry.link_label: entry.node for entry in node.get_outgoing(link_type=link_types)} for port_name in process_outputs_dict: top_namespace = port_name.split(namespace_separator)[0] @@ -870,30 +908,27 @@ def exposed_outputs(self, node, process_class, namespace=None, agglomerate=True) return AttributeDict(result) @staticmethod - def _get_namespace_list(namespace=None, agglomerate=True): + def _get_namespace_list(namespace: Optional[str] = None, agglomerate: bool = True) -> List[Optional[str]]: """Get the list of namespaces in a given namespace. :param namespace: name space - :type namespace: str - :param agglomerate: If set to true, all parent namespaces of the given ``namespace`` will also be searched. - :type agglomerate: bool :returns: namespace list - :rtype: list + """ if not agglomerate: return [namespace] - namespace_list = [None] + namespace_list: List[Optional[str]] = [None] if namespace is not None: split_ns = namespace.split('.') namespace_list.extend(['.'.join(split_ns[:i]) for i in range(1, len(split_ns) + 1)]) return namespace_list @classmethod - def is_valid_cache(cls, node): + def is_valid_cache(cls, node: orm.ProcessNode) -> bool: """Check if the given node can be cached from. .. warning :: When overriding this method, make sure to call @@ -909,7 +944,7 @@ def is_valid_cache(cls, node): return True -def get_query_string_from_process_type_string(process_type_string): # pylint: disable=invalid-name +def get_query_string_from_process_type_string(process_type_string: str) -> str: # pylint: disable=invalid-name """ Take the process type string of a Node and create the queryable type string. diff --git a/aiida/engine/processes/process_spec.py b/aiida/engine/processes/process_spec.py index 334a8e0794..4e73005f2a 100644 --- a/aiida/engine/processes/process_spec.py +++ b/aiida/engine/processes/process_spec.py @@ -8,7 +8,11 @@ # For further information please visit http://www.aiida.net # ########################################################################### """AiiDA specific implementation of plumpy's ProcessSpec.""" -import plumpy +from typing import Optional + +import plumpy.process_spec + +from aiida.orm import Dict from .exit_code import ExitCode, ExitCodesNamespace from .ports import InputPort, PortNamespace, CalcJobOutputPort @@ -16,32 +20,32 @@ __all__ = ('ProcessSpec', 'CalcJobProcessSpec') -class ProcessSpec(plumpy.ProcessSpec): +class ProcessSpec(plumpy.process_spec.ProcessSpec): """Default process spec for process classes defined in `aiida-core`. This sub class defines custom classes for input ports and port namespaces. It also adds support for the definition of exit codes and retrieving them subsequently. """ - METADATA_KEY = 'metadata' - METADATA_OPTIONS_KEY = 'options' + METADATA_KEY: str = 'metadata' + METADATA_OPTIONS_KEY: str = 'options' INPUT_PORT_TYPE = InputPort PORT_NAMESPACE_TYPE = PortNamespace - def __init__(self): + def __init__(self) -> None: super().__init__() self._exit_codes = ExitCodesNamespace() @property - def metadata_key(self): + def metadata_key(self) -> str: return self.METADATA_KEY @property - def options_key(self): + def options_key(self) -> str: return self.METADATA_OPTIONS_KEY @property - def exit_codes(self): + def exit_codes(self) -> ExitCodesNamespace: """ Return the namespace of exit codes defined for this ProcessSpec @@ -49,7 +53,7 @@ def exit_codes(self): """ return self._exit_codes - def exit_code(self, status, label, message, invalidates_cache=False): + def exit_code(self, status: int, label: str, message: str, invalidates_cache: bool = False) -> None: """ Add an exit code to the ProcessSpec @@ -76,24 +80,36 @@ def exit_code(self, status, label, message, invalidates_cache=False): self._exit_codes[label] = ExitCode(status, message, invalidates_cache=invalidates_cache) + # override return type to aiida's PortNamespace subclass + + @property + def ports(self) -> PortNamespace: + return super().ports # type: ignore[return-value] + + @property + def inputs(self) -> PortNamespace: + return super().inputs # type: ignore[return-value] + + @property + def outputs(self) -> PortNamespace: + return super().outputs # type: ignore[return-value] + class CalcJobProcessSpec(ProcessSpec): """Process spec intended for the `CalcJob` process class.""" OUTPUT_PORT_TYPE = CalcJobOutputPort - def __init__(self): + def __init__(self) -> None: super().__init__() - self._default_output_node = None + self._default_output_node: Optional[str] = None @property - def default_output_node(self): + def default_output_node(self) -> Optional[str]: return self._default_output_node @default_output_node.setter - def default_output_node(self, port_name): - from aiida.orm import Dict - + def default_output_node(self, port_name: str) -> None: if port_name not in self.outputs: raise ValueError(f'{port_name} is not a registered output port') diff --git a/aiida/engine/processes/workchains/__init__.py b/aiida/engine/processes/workchains/__init__.py index bea66d0a5b..9b0cf508c9 100644 --- a/aiida/engine/processes/workchains/__init__.py +++ b/aiida/engine/processes/workchains/__init__.py @@ -14,4 +14,4 @@ from .utils import * from .workchain import * -__all__ = (context.__all__ + restart.__all__ + utils.__all__ + workchain.__all__) +__all__ = (context.__all__ + restart.__all__ + utils.__all__ + workchain.__all__) # type: ignore[name-defined] diff --git a/aiida/engine/processes/workchains/awaitable.py b/aiida/engine/processes/workchains/awaitable.py index fee97be995..ea8954ae92 100644 --- a/aiida/engine/processes/workchains/awaitable.py +++ b/aiida/engine/processes/workchains/awaitable.py @@ -9,6 +9,7 @@ ########################################################################### """Enums and function for the awaitables of Processes.""" from enum import Enum +from typing import Union from plumpy.utils import AttributesDict from aiida.orm import ProcessNode @@ -31,7 +32,7 @@ class AwaitableAction(Enum): APPEND = 'append' -def construct_awaitable(target): +def construct_awaitable(target: Union[Awaitable, ProcessNode]) -> Awaitable: """ Construct an instance of the Awaitable class that will contain the information related to the action to be taken with respect to the context once the awaitable diff --git a/aiida/engine/processes/workchains/context.py b/aiida/engine/processes/workchains/context.py index c0c9f31bb4..a22bc0cc02 100644 --- a/aiida/engine/processes/workchains/context.py +++ b/aiida/engine/processes/workchains/context.py @@ -8,14 +8,17 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Convenience functions to add awaitables to the Context of a WorkChain.""" -from .awaitable import construct_awaitable, AwaitableAction +from typing import Union + +from aiida.orm import ProcessNode +from .awaitable import construct_awaitable, Awaitable, AwaitableAction __all__ = ('ToContext', 'assign_', 'append_') ToContext = dict -def assign_(target): +def assign_(target: Union[Awaitable, ProcessNode]) -> Awaitable: """ Convenience function that will construct an Awaitable for a given class instance with the context action set to ASSIGN. When the awaitable target is completed @@ -24,14 +27,14 @@ def assign_(target): :param target: an instance of a Process or Awaitable :returns: the awaitable - :rtype: Awaitable + """ awaitable = construct_awaitable(target) awaitable.action = AwaitableAction.ASSIGN return awaitable -def append_(target): +def append_(target: Union[Awaitable, ProcessNode]) -> Awaitable: """ Convenience function that will construct an Awaitable for a given class instance with the context action set to APPEND. When the awaitable target is completed @@ -40,7 +43,7 @@ def append_(target): :param target: an instance of a Process or Awaitable :returns: the awaitable - :rtype: Awaitable + """ awaitable = construct_awaitable(target) awaitable.action = AwaitableAction.APPEND diff --git a/aiida/engine/processes/workchains/restart.py b/aiida/engine/processes/workchains/restart.py index 7bf5d368bd..5719e1496f 100644 --- a/aiida/engine/processes/workchains/restart.py +++ b/aiida/engine/processes/workchains/restart.py @@ -9,6 +9,9 @@ ########################################################################### """Base implementation of `WorkChain` class that implements a simple automated restart mechanism for sub processes.""" import functools +from inspect import getmembers +from types import FunctionType +from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING from aiida import orm from aiida.common import AttributeDict @@ -17,10 +20,17 @@ from .workchain import WorkChain from .utils import ProcessHandlerReport, process_handler +if TYPE_CHECKING: + from aiida.engine.processes import ExitCode, PortNamespace, Process, ProcessSpec + __all__ = ('BaseRestartWorkChain',) -def validate_handler_overrides(process_class, handler_overrides, ctx): # pylint: disable=unused-argument +def validate_handler_overrides( + process_class: 'BaseRestartWorkChain', + handler_overrides: Optional[orm.Dict], + ctx: 'PortNamespace' # pylint: disable=unused-argument +) -> Optional[str]: """Validator for the `handler_overrides` input port of the `BaseRestartWorkChain. The `handler_overrides` should be a dictionary where keys are strings that are the name of a process handler, i.e. a @@ -36,7 +46,7 @@ def validate_handler_overrides(process_class, handler_overrides, ctx): # pylint :param ctx: the `PortNamespace` in which the port is embedded """ if not handler_overrides: - return + return None for handler, override in handler_overrides.get_dict().items(): if not isinstance(handler, str): @@ -48,6 +58,8 @@ def validate_handler_overrides(process_class, handler_overrides, ctx): # pylint if not isinstance(override, bool): return f'The value of key `{handler}` is not a boolean.' + return None + class BaseRestartWorkChain(WorkChain): """Base restart work chain. @@ -101,11 +113,19 @@ def handle_problem(self, node): `inspect_process`. Refer to their respective documentation for details. """ - _process_class = None + _process_class: Optional[Type['Process']] = None _considered_handlers_extra = 'considered_handlers' + @property + def process_class(self) -> Type['Process']: + """Return the process class to run in the loop.""" + from ..process import Process # pylint: disable=cyclic-import + if self._process_class is None or not issubclass(self._process_class, Process): + raise ValueError('no valid Process class defined for `_process_class` attribute') + return self._process_class + @classmethod - def define(cls, spec): + def define(cls, spec: 'ProcessSpec') -> None: # type: ignore[override] """Define the process specification.""" # yapf: disable super().define(spec) @@ -126,25 +146,28 @@ def define(cls, spec): message='The maximum number of iterations was exceeded.') spec.exit_code(402, 'ERROR_SECOND_CONSECUTIVE_UNHANDLED_FAILURE', message='The process failed for an unknown reason, twice in a row.') + # yapf: enable - def setup(self): + def setup(self) -> None: """Initialize context variables that are used during the logical flow of the `BaseRestartWorkChain`.""" - overrides = self.inputs.handler_overrides.get_dict() if 'handler_overrides' in self.inputs else {} + overrides = self.inputs.handler_overrides.get_dict() if (self.inputs and + 'handler_overrides' in self.inputs) else {} self.ctx.handler_overrides = overrides - self.ctx.process_name = self._process_class.__name__ + self.ctx.process_name = self.process_class.__name__ self.ctx.unhandled_failure = False self.ctx.is_finished = False self.ctx.iteration = 0 - def should_run_process(self): + def should_run_process(self) -> bool: """Return whether a new process should be run. This is the case as long as the last process has not finished successfully and the maximum number of restarts has not yet been exceeded. """ - return not self.ctx.is_finished and self.ctx.iteration < self.inputs.max_iterations.value + max_iterations = self.inputs.max_iterations.value # type: ignore[union-attr] + return not self.ctx.is_finished and self.ctx.iteration < max_iterations - def run_process(self): + def run_process(self) -> ToContext: """Run the next process, taking the input dictionary from the context at `self.ctx.inputs`.""" self.ctx.iteration += 1 @@ -156,8 +179,8 @@ def run_process(self): # Set the `CALL` link label unwrapped_inputs.setdefault('metadata', {})['call_link_label'] = f'iteration_{self.ctx.iteration:02d}' - inputs = self._wrap_bare_dict_inputs(self._process_class.spec().inputs, unwrapped_inputs) - node = self.submit(self._process_class, **inputs) + inputs = self._wrap_bare_dict_inputs(self.process_class.spec().inputs, unwrapped_inputs) + node = self.submit(self.process_class, **inputs) # Add a new empty list to the `BaseRestartWorkChain._considered_handlers_extra` extra. This will contain the # name and return value of all class methods, decorated with `process_handler`, that are called during @@ -170,7 +193,7 @@ def run_process(self): return ToContext(children=append_(node)) - def inspect_process(self): # pylint: disable=too-many-branches + def inspect_process(self) -> Optional['ExitCode']: # pylint: disable=too-many-branches """Analyse the results of the previous process and call the handlers when necessary. If the process is excepted or killed, the work chain will abort. Otherwise any attached handlers will be called @@ -202,10 +225,11 @@ def inspect_process(self): # pylint: disable=too-many-branches last_report = None # Sort the handlers with a priority defined, based on their priority in reverse order - for handler in sorted(self.get_process_handlers(), key=lambda handler: handler.priority, reverse=True): + get_priority = lambda handler: handler.priority + for handler in sorted(self.get_process_handlers(), key=get_priority, reverse=True): # Skip if the handler is enabled, either explicitly through `handler_overrides` or by default - if not self.ctx.handler_overrides.get(handler.__name__, handler.enabled): + if not self.ctx.handler_overrides.get(handler.__name__, handler.enabled): # type: ignore[attr-defined] continue # Even though the `handler` is an instance method, the `get_process_handlers` method returns unbound methods @@ -236,7 +260,7 @@ def inspect_process(self): # pylint: disable=too-many-branches self.ctx.unhandled_failure = True self.report('{}<{}> failed and error was not handled, restarting once more'.format(*report_args)) - return + return None # Here either the process finished successful or at least one handler returned a report so it can no longer be # considered to be an unhandled failed process and therefore we reset the flag @@ -260,16 +284,21 @@ def inspect_process(self): # pylint: disable=too-many-branches # Otherwise the process was successful and no handler returned anything so we consider the work done self.ctx.is_finished = True - def results(self): + return None + + def results(self) -> Optional['ExitCode']: """Attach the outputs specified in the output specification from the last completed process.""" node = self.ctx.children[self.ctx.iteration - 1] # We check the `is_finished` attribute of the work chain and not the successfulness of the last process # because the error handlers in the last iteration can have qualified a "failed" process as satisfactory # for the outcome of the work chain and so have marked it as `is_finished=True`. - if not self.ctx.is_finished and self.ctx.iteration >= self.inputs.max_iterations.value: - self.report('reached the maximum number of iterations {}: last ran {}<{}>'.format( - self.inputs.max_iterations.value, self.ctx.process_name, node.pk)) + max_iterations = self.inputs.max_iterations.value # type: ignore[union-attr] + if not self.ctx.is_finished and self.ctx.iteration >= max_iterations: + self.report( + f'reached the maximum number of iterations {max_iterations}: ' + f'last ran {self.ctx.process_name}<{node.pk}>' + ) return self.exit_codes.ERROR_MAXIMUM_ITERATIONS_EXCEEDED # pylint: disable=no-member self.report(f'work chain completed after {self.ctx.iteration} iterations') @@ -284,16 +313,17 @@ def results(self): else: self.out(name, output) - def __init__(self, *args, **kwargs): + return None + + def __init__(self, *args, **kwargs) -> None: """Construct the instance.""" - from ..process import Process # pylint: disable=cyclic-import super().__init__(*args, **kwargs) - if self._process_class is None or not issubclass(self._process_class, Process): - raise ValueError('no valid Process class defined for `_process_class` attribute') + # try retrieving process class + self.process_class # pylint: disable=pointless-statement @classmethod - def is_process_handler(cls, process_handler_name): + def is_process_handler(cls, process_handler_name: Union[str, FunctionType]) -> bool: """Return whether the given method name corresponds to a process handler of this class. :param process_handler_name: string name of the instance method @@ -308,15 +338,14 @@ def is_process_handler(cls, process_handler_name): return getattr(handler, 'decorator', None) == process_handler @classmethod - def get_process_handlers(cls): - from inspect import getmembers + def get_process_handlers(cls) -> List[FunctionType]: return [method[1] for method in getmembers(cls) if cls.is_process_handler(method[1])] def on_terminated(self): """Clean the working directories of all child calculation jobs if `clean_workdir=True` in the inputs.""" super().on_terminated() - if self.inputs.clean_workdir.value is False: + if self.inputs.clean_workdir.value is False: # type: ignore[union-attr] self.report('remote folders will not be cleaned') return @@ -333,7 +362,7 @@ def on_terminated(self): if cleaned_calcs: self.report(f"cleaned remote folders of calculations: {' '.join(cleaned_calcs)}") - def _wrap_bare_dict_inputs(self, port_namespace, inputs): + def _wrap_bare_dict_inputs(self, port_namespace: 'PortNamespace', inputs: Dict[str, Any]) -> AttributeDict: """Wrap bare dictionaries in `inputs` in a `Dict` node if dictated by the corresponding inputs portnamespace. :param port_namespace: a `PortNamespace` diff --git a/aiida/engine/processes/workchains/utils.py b/aiida/engine/processes/workchains/utils.py index 53dceb3a60..e5cfdc6cc3 100644 --- a/aiida/engine/processes/workchains/utils.py +++ b/aiida/engine/processes/workchains/utils.py @@ -8,35 +8,45 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Utilities for `WorkChain` implementations.""" -from collections import namedtuple from functools import partial from inspect import getfullargspec from types import FunctionType # pylint: disable=no-name-in-module +from typing import List, Optional, Union, NamedTuple from wrapt import decorator from ..exit_code import ExitCode __all__ = ('ProcessHandlerReport', 'process_handler') -ProcessHandlerReport = namedtuple('ProcessHandlerReport', 'do_break exit_code') -ProcessHandlerReport.__new__.__defaults__ = (False, ExitCode()) -"""A namedtuple to define a process handler report for a :class:`aiida.engine.BaseRestartWorkChain`. -This namedtuple should be returned by a process handler of a work chain instance if the condition of the handler was -met by the completed process. If no further handling should be performed after this method the `do_break` field should -be set to `True`. If the handler encountered a fatal error and the work chain needs to be terminated, an `ExitCode` with -non-zero exit status can be set. This exit code is what will be set on the work chain itself. This works because the -value of the `exit_code` field returned by the handler, will in turn be returned by the `inspect_process` step and -returning a non-zero exit code from any work chain step will instruct the engine to abort the work chain. +class ProcessHandlerReport(NamedTuple): + """A namedtuple to define a process handler report for a :class:`aiida.engine.BaseRestartWorkChain`. -:param do_break: boolean, set to `True` if no further process handlers should be called, default is `False` -:param exit_code: an instance of the :class:`~aiida.engine.processes.exit_code.ExitCode` tuple. If not explicitly set, - the default `ExitCode` will be instantiated which has status `0` meaning that the work chain step will be considered - successful and the work chain will continue to the next step. -""" + This namedtuple should be returned by a process handler of a work chain instance if the condition of the handler was + met by the completed process. If no further handling should be performed after this method the `do_break` field + should be set to `True`. + If the handler encountered a fatal error and the work chain needs to be terminated, an `ExitCode` with + non-zero exit status can be set. This exit code is what will be set on the work chain itself. This works because the + value of the `exit_code` field returned by the handler, will in turn be returned by the `inspect_process` step and + returning a non-zero exit code from any work chain step will instruct the engine to abort the work chain. + + :param do_break: boolean, set to `True` if no further process handlers should be called, default is `False` + :param exit_code: an instance of the :class:`~aiida.engine.processes.exit_code.ExitCode` tuple. + If not explicitly set, the default `ExitCode` will be instantiated, + which has status `0` meaning that the work chain step will be considered + successful and the work chain will continue to the next step. + """ + do_break: bool = False + exit_code: ExitCode = ExitCode() -def process_handler(wrapped=None, *, priority=0, exit_codes=None, enabled=True): +def process_handler( + wrapped: Optional[FunctionType] = None, + *, + priority: int = 0, + exit_codes: Union[None, ExitCode, List[ExitCode]] = None, + enabled: bool = True +) -> FunctionType: """Decorator to register a :class:`~aiida.engine.BaseRestartWorkChain` instance method as a process handler. The decorator will validate the `priority` and `exit_codes` optional keyword arguments and then add itself as an @@ -55,7 +65,7 @@ def process_handler(wrapped=None, *, priority=0, exit_codes=None, enabled=True): `do_break` attribute should be set to `True`. If the work chain is to be aborted entirely, the `exit_code` of the report can be set to an `ExitCode` instance with a non-zero status. - :param cls: the work chain class to register the process handler with + :param wrapped: the work chain method to register the process handler with :param priority: optional integer that defines the order in which registered handlers will be called during the handling of a finished process. Higher priorities will be handled first. Default value is `0`. Multiple handlers with the same priority is allowed, but the order of those is not well defined. @@ -67,7 +77,9 @@ def process_handler(wrapped=None, *, priority=0, exit_codes=None, enabled=True): basis through the input `handler_overrides`. """ if wrapped is None: - return partial(process_handler, priority=priority, exit_codes=exit_codes, enabled=enabled) + return partial( + process_handler, priority=priority, exit_codes=exit_codes, enabled=enabled + ) # type: ignore[return-value] if not isinstance(wrapped, FunctionType): raise TypeError('first argument can only be an instance method, use keywords for decorator arguments.') @@ -89,9 +101,9 @@ def process_handler(wrapped=None, *, priority=0, exit_codes=None, enabled=True): if len(handler_args) != 2: raise TypeError(f'process handler `{wrapped.__name__}` has invalid signature: should be (self, node)') - wrapped.decorator = process_handler - wrapped.priority = priority - wrapped.enabled = enabled + wrapped.decorator = process_handler # type: ignore[attr-defined] + wrapped.priority = priority # type: ignore[attr-defined] + wrapped.enabled = enabled # type: ignore[attr-defined] @decorator def wrapper(wrapped, instance, args, kwargs): @@ -99,7 +111,9 @@ def wrapper(wrapped, instance, args, kwargs): # When the handler will be called by the `BaseRestartWorkChain` it will pass the node as the only argument node = args[0] - if exit_codes is not None and node.exit_status not in [exit_code.status for exit_code in exit_codes]: + if exit_codes is not None and node.exit_status not in [ + exit_code.status for exit_code in exit_codes # type: ignore[union-attr] + ]: result = None else: result = wrapped(*args, **kwargs) diff --git a/aiida/engine/processes/workchains/workchain.py b/aiida/engine/processes/workchains/workchain.py index f0c1f96541..4978e3594f 100644 --- a/aiida/engine/processes/workchains/workchain.py +++ b/aiida/engine/processes/workchains/workchain.py @@ -8,17 +8,19 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Components for the WorkChain concept of the workflow engine.""" -import collections +import collections.abc import functools +import logging +from typing import Any, List, Optional, Sequence, Union, TYPE_CHECKING -import plumpy -from plumpy import auto_persist, Wait, Continue -from plumpy.workchains import if_, while_, return_, _PropagateReturn +from plumpy.persistence import auto_persist +from plumpy.process_states import Wait, Continue +from plumpy.workchains import if_, while_, return_, _PropagateReturn, Stepper, WorkChainSpec as PlumpyWorkChainSpec from aiida.common import exceptions from aiida.common.extendeddicts import AttributeDict from aiida.common.lang import override -from aiida.orm import Node, WorkChainNode +from aiida.orm import Node, ProcessNode, WorkChainNode from aiida.orm.utils import load_node from ..exit_code import ExitCode @@ -26,10 +28,13 @@ from ..process import Process, ProcessState from .awaitable import Awaitable, AwaitableTarget, AwaitableAction, construct_awaitable +if TYPE_CHECKING: + from aiida.engine.runners import Runner + __all__ = ('WorkChain', 'if_', 'while_', 'return_') -class WorkChainSpec(ProcessSpec, plumpy.WorkChainSpec): +class WorkChainSpec(ProcessSpec, PlumpyWorkChainSpec): pass @@ -42,22 +47,21 @@ class WorkChain(Process): _STEPPER_STATE = 'stepper_state' _CONTEXT = 'CONTEXT' - def __init__(self, inputs=None, logger=None, runner=None, enable_persistence=True): + def __init__( + self, + inputs: Optional[dict] = None, + logger: Optional[logging.Logger] = None, + runner: Optional['Runner'] = None, + enable_persistence: bool = True + ) -> None: """Construct a WorkChain instance. Construct the instance only if it is a sub class of `WorkChain`, otherwise raise `InvalidOperation`. :param inputs: work chain inputs - :type inputs: dict - :param logger: aiida logger - :type logger: :class:`logging.Logger` - :param runner: work chain runner - :type: :class:`aiida.engine.runners.Runner` - :param enable_persistence: whether to persist this work chain - :type enable_persistence: bool """ if self.__class__ == WorkChain: @@ -65,21 +69,22 @@ def __init__(self, inputs=None, logger=None, runner=None, enable_persistence=Tru super().__init__(inputs, logger, runner, enable_persistence=enable_persistence) - self._stepper = None - self._awaitables = [] + self._stepper: Optional[Stepper] = None + self._awaitables: List[Awaitable] = [] self._context = AttributeDict() - @property - def ctx(self): - """Get context. + @classmethod + def spec(cls) -> WorkChainSpec: + return super().spec() # type: ignore[return-value] - :rtype: :class:`aiida.common.extendeddicts.AttributeDict` - """ + @property + def ctx(self) -> AttributeDict: + """Get the context.""" return self._context @override def save_instance_state(self, out_state, save_context): - """Save instance stace. + """Save instance state. :param out_state: state to save in @@ -105,7 +110,7 @@ def load_instance_state(self, saved_state, load_context): self._stepper = None stepper_state = saved_state.get(self._STEPPER_STATE, None) if stepper_state is not None: - self._stepper = self.spec().get_outline().recreate_stepper(stepper_state, self) + self._stepper = self.spec().get_outline().recreate_stepper(stepper_state, self) # type: ignore[arg-type] self.set_logger(self.node.logger) @@ -116,7 +121,7 @@ def on_run(self): super().on_run() self.node.set_stepper_state_info(str(self._stepper)) - def insert_awaitable(self, awaitable): + def insert_awaitable(self, awaitable: Awaitable) -> None: """Insert an awaitable that should be terminated before before continuing to the next step. :param awaitable: the thing to await @@ -137,7 +142,7 @@ def insert_awaitable(self, awaitable): self._update_process_status() - def resolve_awaitable(self, awaitable, value): + def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None: """Resolve an awaitable. Precondition: must be an awaitable that was previously inserted. @@ -162,9 +167,12 @@ def resolve_awaitable(self, awaitable, value): awaitable.resolved = True - self._update_process_status() + if not self.has_terminated(): + # the process may be terminated, for example, if the process was killed or excepted + # then we should not try to update it + self._update_process_status() - def to_context(self, **kwargs): + def to_context(self, **kwargs: Union[Awaitable, ProcessNode]) -> None: """Add a dictionary of awaitables to the context. This is a convenience method that provides syntactic sugar, for a user to add multiple intersteps that will @@ -175,7 +183,7 @@ def to_context(self, **kwargs): awaitable.key = key self.insert_awaitable(awaitable) - def _update_process_status(self): + def _update_process_status(self) -> None: """Set the process status with a message accounting the current sub processes that we are waiting for.""" if self._awaitables: status = f"Waiting for child processes: {', '.join([str(_.pk) for _ in self._awaitables])}" @@ -184,11 +192,11 @@ def _update_process_status(self): self.node.set_process_status(None) @override - def run(self): - self._stepper = self.spec().get_outline().create_stepper(self) + def run(self) -> Any: + self._stepper = self.spec().get_outline().create_stepper(self) # type: ignore[arg-type] return self._do_step() - def _do_step(self): + def _do_step(self) -> Any: """Execute the next step in the outline and return the result. If the stepper returns a non-finished status and the return value is of type ToContext, the contents of the @@ -199,9 +207,10 @@ def _do_step(self): from .context import ToContext self._awaitables = [] - result = None + result: Any = None try: + assert self._stepper is not None finished, stepper_result = self._stepper.step() except _PropagateReturn as exception: finished, result = True, exception.exit_code @@ -226,22 +235,22 @@ def _do_step(self): return Continue(self._do_step) - def _store_nodes(self, data): + def _store_nodes(self, data: Any) -> None: """Recurse through a data structure and store any unstored nodes that are found along the way :param data: a data structure potentially containing unstored nodes """ if isinstance(data, Node) and not data.is_stored: data.store() - elif isinstance(data, collections.Mapping): + elif isinstance(data, collections.abc.Mapping): for _, value in data.items(): self._store_nodes(value) - elif isinstance(data, collections.Sequence) and not isinstance(data, str): + elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str): for value in data: self._store_nodes(value) @override - def on_exiting(self): + def on_exiting(self) -> None: """Ensure that any unstored nodes in the context are stored, before the state is exited After the state is exited the next state will be entered and if persistence is enabled, a checkpoint will @@ -254,14 +263,15 @@ def on_exiting(self): # An uncaught exception here will have bizarre and disastrous consequences self.logger.exception('exception in _store_nodes called in on_exiting') - def on_wait(self, awaitables): + def on_wait(self, awaitables: Sequence[Awaitable]): + """Entering the WAITING state.""" super().on_wait(awaitables) if self._awaitables: self.action_awaitables() else: self.call_soon(self.resume) - def action_awaitables(self): + def action_awaitables(self) -> None: """Handle the awaitables that are currently registered with the work chain. Depending on the class type of the awaitable's target a different callback @@ -270,12 +280,12 @@ def action_awaitables(self): """ for awaitable in self._awaitables: if awaitable.target == AwaitableTarget.PROCESS: - callback = functools.partial(self._run_task, self.on_process_finished, awaitable) + callback = functools.partial(self.call_soon, self.on_process_finished, awaitable) self.runner.call_on_process_finish(awaitable.pk, callback) else: assert f"invalid awaitable target '{awaitable.target}'" - def on_process_finished(self, awaitable): + def on_process_finished(self, awaitable: Awaitable) -> None: """Callback function called by the runner when the process instance identified by pk is completed. The awaitable will be effectuated on the context of the work chain and removed from the internal list. If all diff --git a/aiida/engine/runners.py b/aiida/engine/runners.py index a0c43ed6d2..93752c3bec 100644 --- a/aiida/engine/runners.py +++ b/aiida/engine/runners.py @@ -9,22 +9,25 @@ ########################################################################### # pylint: disable=global-statement """Runners that can run and submit processes.""" -import collections +import asyncio import functools import logging import signal import threading +from typing import Any, Callable, Dict, NamedTuple, Optional, Tuple, Type, Union import uuid import kiwipy -import plumpy -import tornado.ioloop +from plumpy.persistence import Persister +from plumpy.process_comms import RemoteProcessThreadController +from plumpy.events import set_event_loop_policy, reset_event_loop_policy +from plumpy.communications import wrap_communicator from aiida.common import exceptions -from aiida.orm import load_node +from aiida.orm import load_node, ProcessNode from aiida.plugins.utils import PluginVersionProvider -from .processes import futures, ProcessState +from .processes import futures, Process, ProcessBuilder, ProcessState from .processes.calcjobs import manager from . import transports from . import utils @@ -33,41 +36,52 @@ LOGGER = logging.getLogger(__name__) -ResultAndNode = collections.namedtuple('ResultAndNode', ['result', 'node']) -ResultAndPk = collections.namedtuple('ResultAndPk', ['result', 'pk']) + +class ResultAndNode(NamedTuple): + result: Dict[str, Any] + node: ProcessNode + + +class ResultAndPk(NamedTuple): + result: Dict[str, Any] + pk: int + + +TYPE_RUN_PROCESS = Union[Process, Type[Process], ProcessBuilder] # pylint: disable=invalid-name +# run can also be process function, but it is not clear what type this should be +TYPE_SUBMIT_PROCESS = Union[Process, Type[Process], ProcessBuilder] # pylint: disable=invalid-name class Runner: # pylint: disable=too-many-public-methods """Class that can launch processes by running in the current interpreter or by submitting them to the daemon.""" - _persister = None - _communicator = None - _controller = None - _closed = False - - def __init__(self, poll_interval=0, loop=None, communicator=None, rmq_submit=False, persister=None): + _persister: Optional[Persister] = None + _communicator: Optional[kiwipy.Communicator] = None + _controller: Optional[RemoteProcessThreadController] = None + _closed: bool = False + + def __init__( + self, + poll_interval: Union[int, float] = 0, + loop: Optional[asyncio.AbstractEventLoop] = None, + communicator: Optional[kiwipy.Communicator] = None, + rmq_submit: bool = False, + persister: Optional[Persister] = None + ): """Construct a new runner. :param poll_interval: interval in seconds between polling for status of active sub processes - :param loop: an event loop to use, if none is suppled a new one will be created - :type loop: :class:`tornado.ioloop.IOLoop` + :param loop: an asyncio event loop, if none is suppled a new one will be created :param communicator: the communicator to use - :type communicator: :class:`kiwipy.Communicator` :param rmq_submit: if True, processes will be submitted to RabbitMQ, otherwise they will be scheduled here :param persister: the persister to use to persist processes - :type persister: :class:`plumpy.Persister` + """ assert not (rmq_submit and persister is None), \ 'Must supply a persister if you want to submit using communicator' - # Runner take responsibility to clear up loop only if the loop was created by Runner - self._do_close_loop = False - if loop is not None: - self._loop = loop - else: - self._loop = tornado.ioloop.IOLoop() - self._do_close_loop = True - + set_event_loop_policy() + self._loop = loop if loop is not None else asyncio.get_event_loop() self._poll_interval = poll_interval self._rmq_submit = rmq_submit self._transport = transports.TransportQueue(self._loop) @@ -76,96 +90,86 @@ def __init__(self, poll_interval=0, loop=None, communicator=None, rmq_submit=Fal self._plugin_version_provider = PluginVersionProvider() if communicator is not None: - self._communicator = plumpy.wrap_communicator(communicator, self._loop) - self._controller = plumpy.RemoteProcessThreadController(communicator) + self._communicator = wrap_communicator(communicator, self._loop) + self._controller = RemoteProcessThreadController(communicator) elif self._rmq_submit: LOGGER.warning('Disabling RabbitMQ submission, no communicator provided') self._rmq_submit = False - def __enter__(self): + def __enter__(self) -> 'Runner': return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() @property - def loop(self): - """ - Get the event loop of this runner - - :return: the event loop - :rtype: :class:`tornado.ioloop.IOLoop` - """ + def loop(self) -> asyncio.AbstractEventLoop: + """Get the event loop of this runner.""" return self._loop @property - def transport(self): + def transport(self) -> transports.TransportQueue: return self._transport @property - def persister(self): + def persister(self) -> Optional[Persister]: + """Get the persister used by this runner.""" return self._persister @property - def communicator(self): - """ - Get the communicator used by this runner - - :return: the communicator - :rtype: :class:`kiwipy.Communicator` - """ + def communicator(self) -> Optional[kiwipy.Communicator]: + """Get the communicator used by this runner.""" return self._communicator @property - def plugin_version_provider(self): + def plugin_version_provider(self) -> PluginVersionProvider: return self._plugin_version_provider @property - def job_manager(self): + def job_manager(self) -> manager.JobManager: return self._job_manager @property - def controller(self): + def controller(self) -> Optional[RemoteProcessThreadController]: + """Get the controller used by this runner.""" return self._controller @property - def is_daemon_runner(self): + def is_daemon_runner(self) -> bool: """Return whether the runner is a daemon runner, which means it submits processes over RabbitMQ. :return: True if the runner is a daemon runner - :rtype: bool """ return self._rmq_submit - def is_closed(self): + def is_closed(self) -> bool: return self._closed - def start(self): + def start(self) -> None: """Start the internal event loop.""" - self._loop.start() + self._loop.run_forever() - def stop(self): + def stop(self) -> None: """Stop the internal event loop.""" self._loop.stop() - def run_until_complete(self, future): + def run_until_complete(self, future: asyncio.Future) -> Any: """Run the loop until the future has finished and return the result.""" with utils.loop_scope(self._loop): - return self._loop.run_sync(lambda: future) + return self._loop.run_until_complete(future) - def close(self): + def close(self) -> None: """Close the runner by stopping the loop.""" assert not self._closed self.stop() - if self._do_close_loop: - self._loop.close() + reset_event_loop_policy() self._closed = True - def instantiate_process(self, process, *args, **inputs): + def instantiate_process(self, process: TYPE_RUN_PROCESS, *args, **inputs): from .utils import instantiate_process return instantiate_process(self, process, *args, **inputs) - def submit(self, process, *args, **inputs): + def submit(self, process: TYPE_SUBMIT_PROCESS, *args: Any, **inputs: Any): """ Submit the process with the supplied inputs to this runner immediately returning control to the interpreter. The return value will be the calculation node of the submitted process @@ -177,24 +181,26 @@ def submit(self, process, *args, **inputs): assert not utils.is_process_function(process), 'Cannot submit a process function' assert not self._closed - process = self.instantiate_process(process, *args, **inputs) + process_inited = self.instantiate_process(process, *args, **inputs) - if not process.metadata.store_provenance: + if not process_inited.metadata.store_provenance: raise exceptions.InvalidOperation('cannot submit a process with `store_provenance=False`') - if process.metadata.get('dry_run', False): + if process_inited.metadata.get('dry_run', False): raise exceptions.InvalidOperation('cannot submit a process from within another with `dry_run=True`') if self._rmq_submit: - self.persister.save_checkpoint(process) - process.close() - self.controller.continue_process(process.pid, nowait=False, no_reply=True) + assert self.persister is not None, 'runner does not have a persister' + assert self.controller is not None, 'runner does not have a controller' + self.persister.save_checkpoint(process_inited) + process_inited.close() + self.controller.continue_process(process_inited.pid, nowait=False, no_reply=True) else: - self.loop.add_callback(process.step_until_terminated) + self.loop.create_task(process_inited.step_until_terminated()) - return process.node + return process_inited.node - def schedule(self, process, *args, **inputs): + def schedule(self, process: TYPE_SUBMIT_PROCESS, *args: Any, **inputs: Any) -> ProcessNode: """ Schedule a process to be executed by this runner @@ -205,11 +211,11 @@ def schedule(self, process, *args, **inputs): assert not utils.is_process_function(process), 'Cannot submit a process function' assert not self._closed - process = self.instantiate_process(process, *args, **inputs) - self.loop.add_callback(process.step_until_terminated) - return process.node + process_inited = self.instantiate_process(process, *args, **inputs) + self.loop.create_task(process_inited.step_until_terminated()) + return process_inited.node - def _run(self, process, *args, **inputs): + def _run(self, process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Tuple[Dict[str, Any], ProcessNode]: """ Run the process with the supplied inputs in this runner that will block until the process is completed. The return value will be the results of the completed process @@ -221,24 +227,34 @@ def _run(self, process, *args, **inputs): assert not self._closed if utils.is_process_function(process): - result, node = process.run_get_node(*args, **inputs) + result, node = process.run_get_node(*args, **inputs) # type: ignore[union-attr] return result, node with utils.loop_scope(self.loop): - process = self.instantiate_process(process, *args, **inputs) + process_inited = self.instantiate_process(process, *args, **inputs) def kill_process(_num, _frame): """Send the kill signal to the process in the current scope.""" - LOGGER.critical('runner received interrupt, killing process %s', process.pid) - process.kill(msg='Process was killed because the runner received an interrupt') + if process_inited.is_killing: + LOGGER.warning('runner received interrupt, process %s already being killed', process_inited.pid) + return + LOGGER.critical('runner received interrupt, killing process %s', process_inited.pid) + process_inited.kill(msg='Process was killed because the runner received an interrupt') - signal.signal(signal.SIGINT, kill_process) - signal.signal(signal.SIGTERM, kill_process) + original_handler_int = signal.getsignal(signal.SIGINT) + original_handler_term = signal.getsignal(signal.SIGTERM) - process.execute() - return process.outputs, process.node + try: + signal.signal(signal.SIGINT, kill_process) + signal.signal(signal.SIGTERM, kill_process) + process_inited.execute() + finally: + signal.signal(signal.SIGINT, original_handler_int) + signal.signal(signal.SIGTERM, original_handler_term) - def run(self, process, *args, **inputs): + return process_inited.outputs, process_inited.node + + def run(self, process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Dict[str, Any]: """ Run the process with the supplied inputs in this runner that will block until the process is completed. The return value will be the results of the completed process @@ -250,7 +266,7 @@ def run(self, process, *args, **inputs): result, _ = self._run(process, *args, **inputs) return result - def run_get_node(self, process, *args, **inputs): + def run_get_node(self, process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> ResultAndNode: """ Run the process with the supplied inputs in this runner that will block until the process is completed. The return value will be the results of the completed process @@ -262,7 +278,7 @@ def run_get_node(self, process, *args, **inputs): result, node = self._run(process, *args, **inputs) return ResultAndNode(result, node) - def run_get_pk(self, process, *args, **inputs): + def run_get_pk(self, process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> ResultAndPk: """ Run the process with the supplied inputs in this runner that will block until the process is completed. The return value will be the results of the completed process @@ -274,7 +290,7 @@ def run_get_pk(self, process, *args, **inputs): result, node = self._run(process, *args, **inputs) return ResultAndPk(result, node.pk) - def call_on_process_finish(self, pk, callback): + def call_on_process_finish(self, pk: int, callback: Callable[[], Any]) -> None: """Schedule a callback when the process of the given pk is terminated. This method will add a broadcast subscriber that will listen for state changes of the target process to be @@ -284,6 +300,8 @@ def call_on_process_finish(self, pk, callback): :param pk: pk of the process :param callback: function to be called upon process termination """ + assert self.communicator is not None, 'communicator not set for runner' + node = load_node(pk=pk) subscriber_identifier = str(uuid.uuid4()) event = threading.Event() @@ -301,17 +319,17 @@ def inline_callback(event, *args, **kwargs): # pylint: disable=unused-argument callback() finally: event.set() - self._communicator.remove_broadcast_subscriber(subscriber_identifier) + self.communicator.remove_broadcast_subscriber(subscriber_identifier) # type: ignore[union-attr] broadcast_filter = kiwipy.BroadcastFilter(functools.partial(inline_callback, event), sender=pk) for state in [ProcessState.FINISHED, ProcessState.KILLED, ProcessState.EXCEPTED]: broadcast_filter.add_subject_filter(f'state_changed.*.{state.value}') LOGGER.info('adding subscriber for broadcasts of %d', pk) - self._communicator.add_broadcast_subscriber(broadcast_filter, subscriber_identifier) + self.communicator.add_broadcast_subscriber(broadcast_filter, subscriber_identifier) self._poll_process(node, functools.partial(inline_callback, event)) - def get_process_future(self, pk): + def get_process_future(self, pk: int) -> futures.ProcessFuture: """Return a future for a process. The future will have the process node as the result when finished. @@ -329,6 +347,6 @@ def _poll_process(self, node, callback): if node.is_terminated: args = [node.__class__.__name__, node.pk] LOGGER.info('%s<%d> confirmed to be terminated by backup polling mechanism', *args) - self._loop.add_callback(callback) + self._loop.call_soon(callback) else: self._loop.call_later(self._poll_interval, self._poll_process, node, callback) diff --git a/aiida/engine/transports.py b/aiida/engine/transports.py index eb8cae8e0a..d301235e27 100644 --- a/aiida/engine/transports.py +++ b/aiida/engine/transports.py @@ -8,11 +8,15 @@ # For further information please visit http://www.aiida.net # ########################################################################### """A transport queue to batch process multiple tasks that require a Transport.""" -from collections import namedtuple import contextlib import logging import traceback -from tornado import concurrent, gen, ioloop +from typing import Awaitable, Dict, Hashable, Iterator, Optional +import asyncio +import contextvars + +from aiida.orm import AuthInfo +from aiida.transports import Transport _LOGGER = logging.getLogger(__name__) @@ -22,7 +26,7 @@ class TransportRequest: def __init__(self): super().__init__() - self.future = concurrent.Future() + self.future: asyncio.Future = asyncio.Future() self.count = 0 @@ -37,31 +41,29 @@ class TransportQueue: up to that point. This way opening of transports (a costly operation) can be minimised. """ - AuthInfoEntry = namedtuple('AuthInfoEntry', ['authinfo', 'transport', 'callbacks', 'callback_handle']) - def __init__(self, loop=None): + def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None): """ - :param loop: The event loop to use, will use `tornado.ioloop.IOLoop.current()` if not supplied - :type loop: :class:`tornado.ioloop.IOLoop` + :param loop: An asyncio event, will use `asyncio.get_event_loop()` if not supplied """ - self._loop = loop if loop is not None else ioloop.IOLoop.current() - self._transport_requests = {} + self._loop = loop if loop is not None else asyncio.get_event_loop() + self._transport_requests: Dict[Hashable, TransportRequest] = {} - def loop(self): + @property + def loop(self) -> asyncio.AbstractEventLoop: """ Get the loop being used by this transport queue """ return self._loop @contextlib.contextmanager - def request_transport(self, authinfo): + def request_transport(self, authinfo: AuthInfo) -> Iterator[Awaitable[Transport]]: """ Request a transport from an authinfo. Because the client is not allowed to request a transport immediately they will instead be given back a future - that can be yielded to get the transport:: + that can be awaited to get the transport:: - @tornado.gen.coroutine - def transport_task(transport_queue, authinfo): + async def transport_task(transport_queue, authinfo): with transport_queue.request_transport(authinfo) as request: - transport = yield request + transport = await request # Do some work with the transport :param authinfo: The authinfo to be used to get transport @@ -80,7 +82,7 @@ def transport_task(transport_queue, authinfo): def do_open(): """ Actually open the transport """ - if transport_request.count > 0: + if transport_request and transport_request.count > 0: # The user still wants the transport so open it _LOGGER.debug('Transport request opening transport for %s', authinfo) try: @@ -95,13 +97,21 @@ def do_open(): transport_request.future.set_result(transport) # Save the handle so that we can cancel the callback if the user no longer wants it - open_callback_handle = self._loop.call_later(safe_open_interval, do_open) + # Note: Don't pass the Process context, since (a) it is not needed by `do_open` and (b) the transport is + # passed around to many places, including outside aiida-core (e.g. paramiko). Anyone keeping a reference + # to this handle would otherwise keep the Process context (and thus the process itself) in memory. + # See https://github.com/aiidateam/aiida-core/issues/4698 + open_callback_handle = self._loop.call_later( + safe_open_interval, do_open, context=contextvars.Context() + ) # type: ignore[call-arg] try: transport_request.count += 1 yield transport_request.future - except gen.Return: - # Have to have this special case so tornado returns are propagated up to the loop + except asyncio.CancelledError: # pylint: disable=try-except-raise + # note this is only required in python<=3.7, + # where asyncio.CancelledError inherits from Exception + _LOGGER.debug('Transport task cancelled') raise except Exception: _LOGGER.error('Exception whilst using transport:\n%s', traceback.format_exc()) @@ -115,6 +125,6 @@ def do_open(): _LOGGER.debug('Transport request closing transport for %s', authinfo) transport_request.future.result().close() elif open_callback_handle is not None: - self._loop.remove_timeout(open_callback_handle) + open_callback_handle.cancel() self._transport_requests.pop(authinfo.id, None) diff --git a/aiida/engine/utils.py b/aiida/engine/utils.py index 8f96f703aa..3cbef87015 100644 --- a/aiida/engine/utils.py +++ b/aiida/engine/utils.py @@ -10,11 +10,15 @@ # pylint: disable=invalid-name """Utilities for the workflow engine.""" +import asyncio import contextlib +from datetime import datetime import logging +from typing import Any, Awaitable, Callable, Iterator, List, Optional, Tuple, Type, Union, TYPE_CHECKING -import tornado.ioloop -from tornado import concurrent, gen +if TYPE_CHECKING: + from .processes import Process, ProcessBuilder + from .runners import Runner __all__ = ('interruptable_task', 'InterruptableFuture', 'is_process_function') @@ -23,7 +27,9 @@ PROCESS_STATE_CHANGE_DESCRIPTION = 'The last time a process of type {}, changed state' -def instantiate_process(runner, process, *args, **inputs): +def instantiate_process( + runner: 'Runner', process: Union['Process', Type['Process'], 'ProcessBuilder'], *args, **inputs +) -> 'Process': """ Return an instance of the process with the given inputs. The function can deal with various types of the `process`: @@ -50,7 +56,7 @@ def instantiate_process(runner, process, *args, **inputs): process_class = builder.process_class inputs.update(**builder._inputs(prune=True)) # pylint: disable=protected-access elif is_process_function(process): - process_class = process.process_class + process_class = process.process_class # type: ignore[attr-defined] elif issubclass(process, Process): process_class = process else: @@ -61,73 +67,77 @@ def instantiate_process(runner, process, *args, **inputs): return process -class InterruptableFuture(concurrent.Future): +class InterruptableFuture(asyncio.Future): """A future that can be interrupted by calling `interrupt`.""" - def interrupt(self, reason): + def interrupt(self, reason: Exception) -> None: """This method should be called to interrupt the coroutine represented by this InterruptableFuture.""" self.set_exception(reason) - @gen.coroutine - def with_interrupt(self, yieldable): + async def with_interrupt(self, coro: Awaitable[Any]) -> Any: """ - Yield a yieldable which will be interrupted if this future is interrupted :: + return result of a coroutine which will be interrupted if this future is interrupted :: - from tornado import ioloop, gen - loop = ioloop.IOLoop.current() + import asyncio + loop = asyncio.get_event_loop() interruptable = InterutableFuture() - loop.add_callback(interruptable.interrupt, RuntimeError("STOP")) - loop.run_sync(lambda: interruptable.with_interrupt(gen.sleep(2))) + loop.call_soon(interruptable.interrupt, RuntimeError("STOP")) + loop.run_until_complete(interruptable.with_interrupt(asyncio.sleep(2.))) >>> RuntimeError: STOP - :param yieldable: The yieldable - :return: The result of the yieldable + :param coro: The coroutine that can be interrupted + :return: The result of the coroutine """ - # Wait for one of the two to finish, if it's us that finishes we expect that it was - # because of an exception that will have been raised automatically - wait_iterator = gen.WaitIterator(yieldable, self) - result = yield wait_iterator.next() # pylint: disable=stop-iteration-return - if not wait_iterator.current_index == 0: - raise RuntimeError(f"This interruptible future had it's result set unexpectedly to {result}") + task = asyncio.ensure_future(coro) + wait_iter = asyncio.as_completed({self, task}) + result = await next(wait_iter) + if self.done(): + raise RuntimeError(f"This interruptible future had it's result set unexpectedly to '{result}'") - result = yield [yieldable, self][0] - raise gen.Return(result) + return result -def interruptable_task(coro, loop=None): +def interruptable_task( + coro: Callable[[InterruptableFuture], Awaitable[Any]], + loop: Optional[asyncio.AbstractEventLoop] = None +) -> InterruptableFuture: """ Turn the given coroutine into an interruptable task by turning it into an InterruptableFuture and returning it. - :param coro: the coroutine that should be made interruptable - :param loop: the event loop in which to run the coroutine, by default uses tornado.ioloop.IOLoop.current() + :param coro: the coroutine that should be made interruptable with object of InterutableFuture as last paramenter + :param loop: the event loop in which to run the coroutine, by default uses asyncio.get_event_loop() :return: an InterruptableFuture """ - loop = loop or tornado.ioloop.IOLoop.current() + loop = loop or asyncio.get_event_loop() future = InterruptableFuture() - @gen.coroutine - def execute_coroutine(): + async def execute_coroutine(): """Coroutine that wraps the original coroutine and sets it result on the future only if not already set.""" try: - result = yield coro(future) + result = await coro(future) except Exception as exception: # pylint: disable=broad-except if not future.done(): future.set_exception(exception) + else: + LOGGER.warning( + 'Interruptable future set to %s before its coro %s is done. %s', future.result(), coro.__name__, + str(exception) + ) else: # If the future has not been set elsewhere, i.e. by the interrupt call, by the time that the coroutine # is executed, set the future's result to the result of the coroutine if not future.done(): future.set_result(result) - loop.add_callback(execute_coroutine) + loop.create_task(execute_coroutine()) return future -def ensure_coroutine(fct): +def ensure_coroutine(fct: Callable[..., Any]) -> Callable[..., Awaitable[Any]]: """ Ensure that the given function ``fct`` is a coroutine @@ -136,42 +146,46 @@ def ensure_coroutine(fct): :param fct: the function :returns: the coroutine """ - if tornado.gen.is_coroutine_function(fct): + if asyncio.iscoroutinefunction(fct): return fct - @tornado.gen.coroutine - def wrapper(*args, **kwargs): - raise tornado.gen.Return(fct(*args, **kwargs)) + async def wrapper(*args, **kwargs): + return fct(*args, **kwargs) return wrapper -@gen.coroutine -def exponential_backoff_retry(fct, initial_interval=10.0, max_attempts=5, logger=None, ignore_exceptions=None): +async def exponential_backoff_retry( + fct: Callable[..., Any], + initial_interval: Union[int, float] = 10.0, + max_attempts: int = 5, + logger: Optional[logging.Logger] = None, + ignore_exceptions: Union[None, Type[Exception], Tuple[Type[Exception], ...]] = None +) -> Any: """ Coroutine to call a function, recalling it with an exponential backoff in the case of an exception This coroutine will loop ``max_attempts`` times, calling the ``fct`` function, breaking immediately when the call - finished without raising an exception, at which point the returned result will be raised, wrapped in a - ``tornado.gen.Result`` instance. If an exception is caught, the function will yield a ``tornado.gen.sleep`` with a - time interval equal to the ``initial_interval`` multiplied by ``2*N`` where ``N`` is the number of excepted calls. + finished without raising an exception, at which point the result will be returned. If an exception is caught, the + function will await a ``asyncio.sleep`` with a time interval equal to the ``initial_interval`` multiplied by + ``2 ** (N - 1)`` where ``N`` is the number of excepted calls. :param fct: the function to call, which will be turned into a coroutine first if it is not already :param initial_interval: the time to wait after the first caught exception before calling the coroutine again :param max_attempts: the maximum number of times to call the coroutine before re-raising the exception - :param ignore_exceptions: list or tuple of exceptions to ignore, i.e. when caught do nothing and simply re-raise - :raises: ``tornado.gen.Result`` if the ``coro`` call completes within ``max_attempts`` retries without raising + :param ignore_exceptions: exceptions to ignore, i.e. when caught do nothing and simply re-raise + :return: result if the ``coro`` call completes within ``max_attempts`` retries without raising """ if logger is None: logger = LOGGER - result = None + result: Any = None coro = ensure_coroutine(fct) interval = initial_interval for iteration in range(max_attempts): try: - result = yield coro() + result = await coro() break # Finished successfully except Exception as exception: # pylint: disable=broad-except @@ -188,13 +202,13 @@ def exponential_backoff_retry(fct, initial_interval=10.0, max_attempts=5, logger raise else: logger.exception('iteration %d of %s excepted, retrying after %d seconds', count, coro_name, interval) - yield gen.sleep(interval) + await asyncio.sleep(interval) interval *= 2 - raise gen.Return(result) + return result -def is_process_function(function): +def is_process_function(function: Any) -> bool: """Return whether the given function is a process function :param function: a function @@ -206,7 +220,7 @@ def is_process_function(function): return False -def is_process_scoped(): +def is_process_scoped() -> bool: """Return whether the current scope is within a process. :returns: True if the current scope is within a nested process, False otherwise @@ -216,23 +230,23 @@ def is_process_scoped(): @contextlib.contextmanager -def loop_scope(loop): +def loop_scope(loop) -> Iterator[None]: """ Make an event loop current for the scope of the context :param loop: The event loop to make current for the duration of the scope - :type loop: :class:`tornado.ioloop.IOLoop` + :type loop: asyncio event loop """ - current = tornado.ioloop.IOLoop.current() + current = asyncio.get_event_loop() try: - loop.make_current() + asyncio.set_event_loop(loop) yield finally: - current.make_current() + asyncio.set_event_loop(current) -def set_process_state_change_timestamp(process): +def set_process_state_change_timestamp(process: 'Process') -> None: """ Set the global setting that reflects the last time a process changed state, for the process type of the given process, to the current timestamp. The process type will be determined based on @@ -266,7 +280,7 @@ def set_process_state_change_timestamp(process): process.logger.debug(f'could not update the {key} setting because of a UniquenessError: {exception}') -def get_process_state_change_timestamp(process_type=None): +def get_process_state_change_timestamp(process_type: Optional[str] = None) -> Optional[datetime]: """ Get the global setting that reflects the last time a process of the given process type changed its state. The returned value will be the corresponding timestamp or None if the setting does not exist. @@ -291,7 +305,7 @@ def get_process_state_change_timestamp(process_type=None): else: process_types = [process_type] - timestamps = [] + timestamps: List[datetime] = [] for process_type_key in process_types: key = PROCESS_STATE_CHANGE_KEY.format(process_type_key) diff --git a/aiida/manage/caching.py b/aiida/manage/caching.py index 6e81cf4bff..d387b9eb15 100644 --- a/aiida/manage/caching.py +++ b/aiida/manage/caching.py @@ -8,19 +8,15 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Definition of caching mechanism and configuration for calculations.""" -import os import re -import copy import keyword from enum import Enum from collections import namedtuple from contextlib import contextmanager, suppress -import yaml -from wrapt import decorator - from aiida.common import exceptions from aiida.common.lang import type_check +from aiida.manage.configuration import get_config_option from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR, ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP @@ -30,91 +26,117 @@ class ConfigKeys(Enum): """Valid keys for caching configuration.""" - DEFAULT = 'default' - ENABLED = 'enabled' - DISABLED = 'disabled' + DEFAULT = 'caching.default_enabled' + ENABLED = 'caching.enabled_for' + DISABLED = 'caching.disabled_for' -DEFAULT_CONFIG = { - ConfigKeys.DEFAULT.value: False, - ConfigKeys.ENABLED.value: [], - ConfigKeys.DISABLED.value: [], -} +class _ContextCache: + """Cache options, accounting for when in enable_caching or disable_caching contexts.""" + def __init__(self): + self._default_all = None + self._enable = [] + self._disable = [] -def _get_config(config_file): - """Return the caching configuration. + def clear(self): + """Clear caching overrides.""" + self.__init__() - :param config_file: the absolute path to the caching configuration file - :return: the configuration dictionary - """ - from aiida.manage.configuration import get_profile + def enable_all(self): + self._default_all = 'enable' + + def disable_all(self): + self._default_all = 'disable' + + def enable(self, identifier): + self._enable.append(identifier) + with suppress(ValueError): + self._disable.remove(identifier) + + def disable(self, identifier): + self._disable.append(identifier) + with suppress(ValueError): + self._enable.remove(identifier) + + def get_options(self): + """Return the options, applying any context overrides.""" + + if self._default_all == 'disable': + return False, [], [] + + if self._default_all == 'enable': + return True, [], [] - profile = get_profile() + default = get_config_option(ConfigKeys.DEFAULT.value) + enabled = get_config_option(ConfigKeys.ENABLED.value)[:] + disabled = get_config_option(ConfigKeys.DISABLED.value)[:] - if profile is None: - exceptions.ConfigurationError('no profile has been loaded') + for ident in self._disable: + disabled.append(ident) + with suppress(ValueError): + enabled.remove(ident) + + for ident in self._enable: + enabled.append(ident) + with suppress(ValueError): + disabled.remove(ident) - try: - with open(config_file, 'r', encoding='utf8') as handle: - config = yaml.safe_load(handle)[profile.name] - except (OSError, IOError, KeyError): - # No config file, or no config for this profile - return DEFAULT_CONFIG + # Check validity of enabled and disabled entries + try: + for identifier in enabled + disabled: + _validate_identifier_pattern(identifier=identifier) + except ValueError as exc: + raise exceptions.ConfigurationError('Invalid identifier pattern in enable or disable list.') from exc - # Validate configuration - for key in config: - if key not in DEFAULT_CONFIG: - raise exceptions.ConfigurationError(f"Configuration error: Invalid key '{key}' in cache_config.yml") + return default, enabled, disabled - # Add defaults where key is either completely missing or specifies no values in which case it will be `None` - for key, default_config in DEFAULT_CONFIG.items(): - if key not in config or config[key] is None: - config[key] = default_config - try: - type_check(config[ConfigKeys.DEFAULT.value], bool) - type_check(config[ConfigKeys.ENABLED.value], list) - type_check(config[ConfigKeys.DISABLED.value], list) - except TypeError as exc: - raise exceptions.ConfigurationError('Invalid type in caching configuration file.') from exc +_CONTEXT_CACHE = _ContextCache() - # Check validity of enabled and disabled entries - try: - for identifier in config[ConfigKeys.ENABLED.value] + config[ConfigKeys.DISABLED.value]: - _validate_identifier_pattern(identifier=identifier) - except ValueError as exc: - raise exceptions.ConfigurationError('Invalid identifier pattern in enable or disable list.') from exc - return config +@contextmanager +def enable_caching(*, identifier=None): + """Context manager to enable caching, either for a specific node class, or globally. + .. warning:: this does not affect the behavior of the daemon, only the local Python interpreter. -_CONFIG = {} + :param identifier: Process type string of the node, or a pattern with '*' wildcard that matches it. + If not provided, caching is enabled for all classes. + :type identifier: str + """ + type_check(identifier, str, allow_none=True) + if identifier is None: + _CONTEXT_CACHE.enable_all() + else: + _validate_identifier_pattern(identifier=identifier) + _CONTEXT_CACHE.enable(identifier) + yield + _CONTEXT_CACHE.clear() -def configure(config_file=None): - """Reads the caching configuration file and sets the _CONFIG variable.""" - # pylint: disable=global-statement - if config_file is None: - from aiida.manage.configuration import get_config - config = get_config() - config_file = os.path.join(config.dirpath, 'cache_config.yml') +@contextmanager +def disable_caching(*, identifier=None): + """Context manager to disable caching, either for a specific node class, or globally. - global _CONFIG - _CONFIG.clear() - _CONFIG.update(_get_config(config_file=config_file)) + .. warning:: this does not affect the behavior of the daemon, only the local Python interpreter. + :param identifier: Process type string of the node, or a pattern with '*' wildcard that matches it. + If not provided, caching is disabled for all classes. + :type identifier: str + """ + type_check(identifier, str, allow_none=True) -@decorator -def _with_config(wrapped, _, args, kwargs): - """Function decorator to load the caching configuration for the scope of the wrapped function.""" - if not _CONFIG: - configure() - return wrapped(*args, **kwargs) + if identifier is None: + _CONTEXT_CACHE.disable_all() + else: + _validate_identifier_pattern(identifier=identifier) + _CONTEXT_CACHE.disable(identifier) + yield + _CONTEXT_CACHE.clear() -@_with_config def get_use_cache(*, identifier=None): """Return whether the caching mechanism should be used for the given process type according to the configuration. @@ -126,17 +148,13 @@ def get_use_cache(*, identifier=None): """ type_check(identifier, str, allow_none=True) + default, enabled, disabled = _CONTEXT_CACHE.get_options() + if identifier is not None: type_check(identifier, str) - enable_matches = [ - pattern for pattern in _CONFIG[ConfigKeys.ENABLED.value] - if _match_wildcard(string=identifier, pattern=pattern) - ] - disable_matches = [ - pattern for pattern in _CONFIG[ConfigKeys.DISABLED.value] - if _match_wildcard(string=identifier, pattern=pattern) - ] + enable_matches = [pattern for pattern in enabled if _match_wildcard(string=identifier, pattern=pattern)] + disable_matches = [pattern for pattern in disabled if _match_wildcard(string=identifier, pattern=pattern)] if enable_matches and disable_matches: # If both enable and disable have matching identifier, we search for @@ -172,65 +190,7 @@ def get_use_cache(*, identifier=None): return True if disable_matches: return False - return _CONFIG[ConfigKeys.DEFAULT.value] - - -@contextmanager -@_with_config -def _reset_config(): - """Reset the configuration by clearing the contents of the global config variable.""" - # pylint: disable=global-statement - global _CONFIG - config_copy = copy.deepcopy(_CONFIG) - yield - _CONFIG.clear() - _CONFIG.update(config_copy) - - -@contextmanager -def enable_caching(*, identifier=None): - """Context manager to enable caching, either for a specific node class, or globally. - - .. warning:: this does not affect the behavior of the daemon, only the local Python interpreter. - - :param identifier: Process type string of the node, or a pattern with '*' wildcard that matches it. - :type identifier: str - """ - - type_check(identifier, str, allow_none=True) - with _reset_config(): - if identifier is None: - _CONFIG[ConfigKeys.DEFAULT.value] = True - _CONFIG[ConfigKeys.DISABLED.value] = [] - else: - _validate_identifier_pattern(identifier=identifier) - _CONFIG[ConfigKeys.ENABLED.value].append(identifier) - with suppress(ValueError): - _CONFIG[ConfigKeys.DISABLED.value].remove(identifier) - yield - - -@contextmanager -def disable_caching(*, identifier=None): - """Context manager to disable caching, either for a specific node class, or globally. - - .. warning:: this does not affect the behavior of the daemon, only the local Python interpreter. - - :param identifier: Process type string of the node, or a pattern with '*' wildcard that matches it. - :type identifier: str - """ - type_check(identifier, str, allow_none=True) - - with _reset_config(): - if identifier is None: - _CONFIG[ConfigKeys.DEFAULT.value] = False - _CONFIG[ConfigKeys.ENABLED.value] = [] - else: - _validate_identifier_pattern(identifier=identifier) - _CONFIG[ConfigKeys.DISABLED.value].append(identifier) - with suppress(ValueError): - _CONFIG[ConfigKeys.ENABLED.value].remove(identifier) - yield + return default def _match_wildcard(*, string, pattern): diff --git a/aiida/manage/configuration/__init__.py b/aiida/manage/configuration/__init__.py index 942b9ac5c2..568fa9992e 100644 --- a/aiida/manage/configuration/__init__.py +++ b/aiida/manage/configuration/__init__.py @@ -9,6 +9,8 @@ ########################################################################### # pylint: disable=undefined-variable,wildcard-import,global-statement,redefined-outer-name,cyclic-import """Modules related to the configuration of an AiiDA instance.""" +import os +import shutil import warnings from aiida.common.warnings import AiidaDeprecationWarning @@ -68,7 +70,6 @@ def load_profile(profile=None): def get_config_path(): """Returns path to .aiida configuration directory.""" - import os from .settings import AIIDA_CONFIG_FOLDER, DEFAULT_CONFIG_FILE_NAME return os.path.join(AIIDA_CONFIG_FOLDER, DEFAULT_CONFIG_FILE_NAME) @@ -87,7 +88,6 @@ def load_config(create=False): :rtype: :class:`~aiida.manage.configuration.config.Config` :raises aiida.common.MissingConfigurationError: if the configuration file could not be found and create=False """ - import os from aiida.common import exceptions from .config import Config @@ -101,9 +101,45 @@ def load_config(create=False): except ValueError: raise exceptions.ConfigurationError(f'configuration file {filepath} contains invalid JSON') + _merge_deprecated_cache_yaml(config, filepath) + return config +def _merge_deprecated_cache_yaml(config, filepath): + """Merge the deprecated cache_config.yml into the config.""" + from aiida.common import timezone + cache_path = os.path.join(os.path.dirname(filepath), 'cache_config.yml') + if not os.path.exists(cache_path): + return + + cache_path_backup = None + # Keep generating a new backup filename based on the current time until it does not exist + while not cache_path_backup or os.path.isfile(cache_path_backup): + cache_path_backup = f"{cache_path}.{timezone.now().strftime('%Y%m%d-%H%M%S.%f')}" + + warnings.warn( + f'cache_config.yml use is deprecated, merging into config.json and moving to: {cache_path_backup}', + AiidaDeprecationWarning + ) + import yaml + with open(cache_path, 'r', encoding='utf8') as handle: + cache_config = yaml.safe_load(handle) + for profile_name, data in cache_config.items(): + if profile_name not in config.profile_names: + warnings.warn(f"Profile '{profile_name}' from cache_config.yml not in config.json, skipping", UserWarning) + continue + for key, option_name in [('default', 'caching.default_enabled'), ('enabled', 'caching.enabled_for'), + ('disabled', 'caching.disabled_for')]: + if key in data: + value = data[key] + # in case of empty key + value = [] if value is None and key != 'default' else value + config.set_option(option_name, value, scope=profile_name) + config.store() + shutil.move(cache_path, cache_path_backup) + + def get_profile(): """Return the currently loaded profile. diff --git a/aiida/manage/configuration/config.py b/aiida/manage/configuration/config.py index f8ac82fc1c..04f281b0a0 100644 --- a/aiida/manage/configuration/config.py +++ b/aiida/manage/configuration/config.py @@ -8,16 +8,50 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module that defines the configuration file of an AiiDA instance and functions to create and load it.""" +from functools import lru_cache +from importlib import resources import os import shutil import tempfile +from typing import Any, Dict, Optional, Sequence, Tuple + +import jsonschema from aiida.common import json +from aiida.common.exceptions import ConfigurationError -from .options import get_option, parse_option, NO_DEFAULT +from . import schema as schema_module +from .options import get_option, get_option_names, Option, parse_option, NO_DEFAULT from .profile import Profile -__all__ = ('Config',) +__all__ = ('Config', 'config_schema', 'ConfigValidationError') + +SCHEMA_FILE = 'config-v5.schema.json' + + +@lru_cache(1) +def config_schema() -> Dict[str, Any]: + """Return the configuration schema.""" + return json.loads(resources.read_text(schema_module, SCHEMA_FILE, encoding='utf8')) + + +class ConfigValidationError(ConfigurationError): + """Configuration error raised when the file contents fails validation.""" + + def __init__( + self, message: str, keypath: Sequence[Any] = (), schema: Optional[dict] = None, filepath: Optional[str] = None + ): + super().__init__(message) + self._message = message + self._keypath = keypath + self._filepath = filepath + self._schema = schema + + def __str__(self) -> str: + prefix = f'{self._filepath}:' if self._filepath else '' + path = '/' + '/'.join(str(k) for k in self._keypath) + ': ' if self._keypath else '' + schema = f'\n schema:\n {self._schema}' if self._schema else '' + return f'Validation Error: {prefix}{path}{self._message}{schema}' class Config: # pylint: disable=too-many-public-methods @@ -29,6 +63,7 @@ class Config: # pylint: disable=too-many-public-methods KEY_DEFAULT_PROFILE = 'default_profile' KEY_PROFILES = 'profiles' KEY_OPTIONS = 'options' + KEY_SCHEMA = '$schema' @classmethod def from_file(cls, filepath): @@ -86,26 +121,40 @@ def _backup(cls, filepath): return filepath_backup - def __init__(self, filepath, config): + @staticmethod + def validate(config: dict, filepath: Optional[str] = None): + """Validate a configuration dictionary.""" + try: + jsonschema.validate(instance=config, schema=config_schema()) + except jsonschema.ValidationError as error: + raise ConfigValidationError( + message=error.message, keypath=error.path, schema=error.schema, filepath=filepath + ) + + def __init__(self, filepath: str, config: dict, validate: bool = True): """Instantiate a configuration object from a configuration dictionary and its filepath. If an empty dictionary is passed, the constructor will create the skeleton configuration dictionary. :param filepath: the absolute filepath of the configuration file :param config: the content of the configuration file in dictionary form + :param validate: validate the dictionary against the schema """ from .migrations import CURRENT_CONFIG_VERSION, OLDEST_COMPATIBLE_CONFIG_VERSION - version = config.get(self.KEY_VERSION, {}) - current_version = version.get(self.KEY_VERSION_CURRENT, CURRENT_CONFIG_VERSION) - compatible_version = version.get(self.KEY_VERSION_OLDEST_COMPATIBLE, OLDEST_COMPATIBLE_CONFIG_VERSION) + if validate: + self.validate(config, filepath) self._filepath = filepath - self._current_version = current_version - self._oldest_compatible_version = compatible_version + self._schema = config.get(self.KEY_SCHEMA, None) + version = config.get(self.KEY_VERSION, {}) + self._current_version = version.get(self.KEY_VERSION_CURRENT, CURRENT_CONFIG_VERSION) + self._oldest_compatible_version = version.get( + self.KEY_VERSION_OLDEST_COMPATIBLE, OLDEST_COMPATIBLE_CONFIG_VERSION + ) self._profiles = {} - known_keys = [self.KEY_VERSION, self.KEY_PROFILES, self.KEY_OPTIONS, self.KEY_DEFAULT_PROFILE] + known_keys = [self.KEY_SCHEMA, self.KEY_VERSION, self.KEY_PROFILES, self.KEY_OPTIONS, self.KEY_DEFAULT_PROFILE] unknown_keys = set(config.keys()) - set(known_keys) if unknown_keys: @@ -148,15 +197,17 @@ def handle_invalid(self, message): echo.echo_warning(f'backup of the original config file written to: `{filepath_backup}`') @property - def dictionary(self): + def dictionary(self) -> dict: """Return the dictionary representation of the config as it would be written to file. :return: dictionary representation of config as it should be written to file """ - config = { - self.KEY_VERSION: self.version_settings, - self.KEY_PROFILES: {name: profile.dictionary for name, profile in self._profiles.items()} - } + config = {} + if self._schema: + config[self.KEY_SCHEMA] = self._schema + + config[self.KEY_VERSION] = self.version_settings + config[self.KEY_PROFILES] = {name: profile.dictionary for name, profile in self._profiles.items()} if self._default_profile: config[self.KEY_DEFAULT_PROFILE] = self._default_profile @@ -321,6 +372,8 @@ def set_option(self, option_name, option_value, scope=None, override=True): :param option_value: the option value :param scope: set the option for this profile or globally if not specified :param override: boolean, if False, will not override the option if it already exists + + :returns: the parsed value (potentially cast to a valid type) """ option, parsed_value = parse_option(option_name, option_value) @@ -332,12 +385,14 @@ def set_option(self, option_name, option_value, scope=None, override=True): return if not option.global_only and scope is not None: - self.get_profile(scope).set_option(option.key, value, override=override) + self.get_profile(scope).set_option(option.name, value, override=override) else: - if option.key not in self.options or override: - self.options[option.key] = value + if option.name not in self.options or override: + self.options[option.name] = value - def unset_option(self, option_name, scope=None): + return value + + def unset_option(self, option_name: str, scope=None): """Unset a configuration option for a certain scope. :param option_name: the name of the configuration option @@ -346,9 +401,9 @@ def unset_option(self, option_name, scope=None): option = get_option(option_name) if scope is not None: - self.get_profile(scope).unset_option(option.key) + self.get_profile(scope).unset_option(option.name) else: - self.options.pop(option.key, None) + self.options.pop(option.name, None) def get_option(self, option_name, scope=None, default=True): """Get a configuration option for a certain scope. @@ -364,12 +419,36 @@ def get_option(self, option_name, scope=None, default=True): default_value = option.default if default and option.default is not NO_DEFAULT else None if scope is not None: - value = self.get_profile(scope).get_option(option.key, default_value) + value = self.get_profile(scope).get_option(option.name, default_value) else: - value = self.options.get(option.key, default_value) + value = self.options.get(option.name, default_value) return value + def get_options(self, scope: Optional[str] = None) -> Dict[str, Tuple[Option, str, Any]]: + """Return a dictionary of all option values and their source ('profile', 'global', or 'default'). + + :param scope: the profile name or globally if not specified + :returns: (option, source, value) + """ + profile = self.get_profile(scope) if scope else None + output = {} + for name in get_option_names(): + option = get_option(name) + if profile and name in profile.options: + value = profile.options.get(name) + source = 'profile' + elif name in self.options: + value = self.options.get(name) + source = 'global' + elif 'default' in option.schema: + value = option.default + source = 'default' + else: + continue + output[name] = (option, source, value) + return output + def store(self): """Write the current config to file. diff --git a/aiida/manage/configuration/migrations/migrations.py b/aiida/manage/configuration/migrations/migrations.py index 7d094e74d3..54bd123e5a 100644 --- a/aiida/manage/configuration/migrations/migrations.py +++ b/aiida/manage/configuration/migrations/migrations.py @@ -15,8 +15,8 @@ # If the configuration file format is changed, the current version number should be upped and a migration added. # When the configuration file format is changed in a backwards-incompatible way, the oldest compatible version should # be set to the new current version. -CURRENT_CONFIG_VERSION = 4 -OLDEST_COMPATIBLE_CONFIG_VERSION = 3 +CURRENT_CONFIG_VERSION = 5 +OLDEST_COMPATIBLE_CONFIG_VERSION = 5 class ConfigMigration: @@ -98,10 +98,46 @@ def _3_add_message_broker(config): return config +def _4_simplify_options(config): + """Remove unnecessary difference between file/internal representation of options""" + conversions = { + 'runner_poll_interval': 'runner.poll.interval', + 'daemon_default_workers': 'daemon.default_workers', + 'daemon_timeout': 'daemon.timeout', + 'daemon_worker_process_slots': 'daemon.worker_process_slots', + 'db_batch_size': 'db.batch_size', + 'verdi_shell_auto_import': 'verdi.shell.auto_import', + 'logging_aiida_log_level': 'logging.aiida_loglevel', + 'logging_db_log_level': 'logging.db_loglevel', + 'logging_plumpy_log_level': 'logging.plumpy_loglevel', + 'logging_kiwipy_log_level': 'logging.kiwipy_loglevel', + 'logging_paramiko_log_level': 'logging.paramiko_loglevel', + 'logging_alembic_log_level': 'logging.alembic_loglevel', + 'logging_sqlalchemy_loglevel': 'logging.sqlalchemy_loglevel', + 'logging_circus_log_level': 'logging.circus_loglevel', + 'user_email': 'autofill.user.email', + 'user_first_name': 'autofill.user.first_name', + 'user_last_name': 'autofill.user.last_name', + 'user_institution': 'autofill.user.institution', + 'show_deprecations': 'warnings.showdeprecations', + 'task_retry_initial_interval': 'transport.task_retry_initial_interval', + 'task_maximum_attempts': 'transport.task_maximum_attempts' + } + for current, new in conversions.items(): + for profile in config.get('profiles', {}).values(): + if current in profile.get('options', {}): + profile['options'][new] = profile['options'].pop(current) + if current in config.get('options', {}): + config['options'][new] = config['options'].pop(current) + + return config + + # Maps the initial config version to the ConfigMigration which updates it. _MIGRATION_LOOKUP = { 0: ConfigMigration(migrate_function=lambda x: x, version=1, version_oldest_compatible=0), 1: ConfigMigration(migrate_function=_1_add_profile_uuid, version=2, version_oldest_compatible=0), 2: ConfigMigration(migrate_function=_2_simplify_default_profiles, version=3, version_oldest_compatible=3), - 3: ConfigMigration(migrate_function=_3_add_message_broker, version=4, version_oldest_compatible=3) + 3: ConfigMigration(migrate_function=_3_add_message_broker, version=4, version_oldest_compatible=3), + 4: ConfigMigration(migrate_function=_4_simplify_options, version=5, version_oldest_compatible=5) } diff --git a/aiida/manage/configuration/options.py b/aiida/manage/configuration/options.py index 176bb5c713..281880eb9f 100644 --- a/aiida/manage/configuration/options.py +++ b/aiida/manage/configuration/options.py @@ -8,244 +8,131 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Definition of known configuration options and methods to parse and get option values.""" -import collections +from typing import Any, Dict, List, Tuple -__all__ = ('get_option', 'get_option_names', 'parse_option') - -NO_DEFAULT = () -DEFAULT_DAEMON_WORKERS = 1 -DEFAULT_DAEMON_TIMEOUT = 20 # Default timeout in seconds for circus client calls -DEFAULT_DAEMON_WORKER_PROCESS_SLOTS = 200 -VALID_LOG_LEVELS = ['CRITICAL', 'ERROR', 'WARNING', 'REPORT', 'INFO', 'DEBUG'] - -Option = collections.namedtuple( - 'Option', ['name', 'key', 'valid_type', 'valid_values', 'default', 'description', 'global_only'] -) - -CONFIG_OPTIONS = { - 'runner.poll.interval': { - 'key': 'runner_poll_interval', - 'valid_type': 'int', - 'valid_values': None, - 'default': 60, - 'description': 'The polling interval in seconds to be used by process runners', - 'global_only': False, - }, - 'daemon.default_workers': { - 'key': 'daemon_default_workers', - 'valid_type': 'int', - 'valid_values': None, - 'default': DEFAULT_DAEMON_WORKERS, - 'description': 'The default number of workers to be launched by `verdi daemon start`', - 'global_only': False, - }, - 'daemon.timeout': { - 'key': 'daemon_timeout', - 'valid_type': 'int', - 'valid_values': None, - 'default': DEFAULT_DAEMON_TIMEOUT, - 'description': 'The timeout in seconds for calls to the circus client', - 'global_only': False, - }, - 'daemon.worker_process_slots': { - 'key': 'daemon_worker_process_slots', - 'valid_type': 'int', - 'valid_values': None, - 'default': DEFAULT_DAEMON_WORKER_PROCESS_SLOTS, - 'description': 'The maximum number of concurrent process tasks that each daemon worker can handle', - 'global_only': False, - }, - 'db.batch_size': { - 'key': 'db_batch_size', - 'valid_type': 'int', - 'valid_values': None, - 'default': 100000, - 'description': - 'Batch size for bulk CREATE operations in the database. Avoids hitting MaxAllocSize of PostgreSQL' - '(1GB) when creating large numbers of database records in one go.', - 'global_only': False, - }, - 'verdi.shell.auto_import': { - 'key': 'verdi_shell_auto_import', - 'valid_type': 'string', - 'valid_values': None, - 'default': '', - 'description': 'Additional modules/functions/classes to be automatically loaded in `verdi shell`', - 'global_only': False, - }, - 'logging.aiida_loglevel': { - 'key': 'logging_aiida_log_level', - 'valid_type': 'string', - 'valid_values': VALID_LOG_LEVELS, - 'default': 'REPORT', - 'description': 'Minimum level to log to daemon log and the `DbLog` table for the `aiida` logger', - 'global_only': False, - }, - 'logging.db_loglevel': { - 'key': 'logging_db_log_level', - 'valid_type': 'string', - 'valid_values': VALID_LOG_LEVELS, - 'default': 'REPORT', - 'description': 'Minimum level to log to the DbLog table', - 'global_only': False, - }, - 'logging.tornado_loglevel': { - 'key': 'logging_tornado_log_level', - 'valid_type': 'string', - 'valid_values': VALID_LOG_LEVELS, - 'default': 'WARNING', - 'description': 'Minimum level to log to daemon log and the `DbLog` table for the `tornado` logger', - 'global_only': False, - }, - 'logging.plumpy_loglevel': { - 'key': 'logging_plumpy_log_level', - 'valid_type': 'string', - 'valid_values': VALID_LOG_LEVELS, - 'default': 'WARNING', - 'description': 'Minimum level to log to daemon log and the `DbLog` table for the `plumpy` logger', - 'global_only': False, - }, - 'logging.kiwipy_loglevel': { - 'key': 'logging_kiwipy_log_level', - 'valid_type': 'string', - 'valid_values': VALID_LOG_LEVELS, - 'default': 'WARNING', - 'description': 'Minimum level to log to daemon log and the `DbLog` table for the `kiwipy` logger', - 'global_only': False, - }, - 'logging.paramiko_loglevel': { - 'key': 'logging_paramiko_log_level', - 'valid_type': 'string', - 'valid_values': VALID_LOG_LEVELS, - 'default': 'WARNING', - 'description': 'Minimum level to log to daemon log and the `DbLog` table for the `paramiko` logger', - 'global_only': False, - }, - 'logging.alembic_loglevel': { - 'key': 'logging_alembic_log_level', - 'valid_type': 'string', - 'valid_values': VALID_LOG_LEVELS, - 'default': 'WARNING', - 'description': 'Minimum level to log to daemon log and the `DbLog` table for the `alembic` logger', - 'global_only': False, - }, - 'logging.sqlalchemy_loglevel': { - 'key': 'logging_sqlalchemy_loglevel', - 'valid_type': 'string', - 'valid_values': VALID_LOG_LEVELS, - 'default': 'WARNING', - 'description': 'Minimum level to log to daemon log and the `DbLog` table for the `sqlalchemy` logger', - 'global_only': False, - }, - 'logging.circus_loglevel': { - 'key': 'logging_circus_log_level', - 'valid_type': 'string', - 'valid_values': VALID_LOG_LEVELS, - 'default': 'INFO', - 'description': 'Minimum level to log to daemon log and the `DbLog` table for the `circus` logger', - 'global_only': False, - }, - 'user.email': { - 'key': 'user_email', - 'valid_type': 'string', - 'valid_values': None, - 'default': NO_DEFAULT, - 'description': 'Default user email to use when creating new profiles.', - 'global_only': True, - }, - 'user.first_name': { - 'key': 'user_first_name', - 'valid_type': 'string', - 'valid_values': None, - 'default': NO_DEFAULT, - 'description': 'Default user first name to use when creating new profiles.', - 'global_only': True, - }, - 'user.last_name': { - 'key': 'user_last_name', - 'valid_type': 'string', - 'valid_values': None, - 'default': NO_DEFAULT, - 'description': 'Default user last name to use when creating new profiles.', - 'global_only': True, - }, - 'user.institution': { - 'key': 'user_institution', - 'valid_type': 'string', - 'valid_values': None, - 'default': NO_DEFAULT, - 'description': 'Default user institution to use when creating new profiles.', - 'global_only': True, - }, - 'warnings.showdeprecations': { - 'key': 'show_deprecations', - 'valid_type': 'bool', - 'valid_values': None, - 'default': True, - 'description': 'Boolean whether to print AiiDA deprecation warnings', - 'global_only': False, - }, -} - - -def get_option(option_name): - """Return a configuration option.configuration - - :param option_name: the name of the configuration option - :return: the configuration option - :raises ValueError: if the configuration option does not exist - """ - try: - option = Option(option_name, **CONFIG_OPTIONS[option_name]) - except KeyError: - raise ValueError(f'the option {option_name} does not exist') - else: - return option +import jsonschema +from aiida.common.exceptions import ConfigurationError -def get_option_names(): - """Return a list of available option names. +__all__ = ('get_option', 'get_option_names', 'parse_option', 'Option') - :return: list of available option names - """ - return CONFIG_OPTIONS.keys() +NO_DEFAULT = () -def parse_option(option_name, option_value): +class Option: + """Represent a configuration option schema.""" + + def __init__(self, name: str, schema: Dict[str, Any]): + self._name = name + self._schema = schema + + def __str__(self) -> str: + return f'Option(name={self._name})' + + @property + def name(self) -> str: + return self._name + + @property + def schema(self) -> Dict[str, Any]: + return self._schema + + @property + def valid_type(self) -> Any: + return self._schema.get('type', None) + + @property + def default(self) -> Any: + return self._schema.get('default', NO_DEFAULT) + + @property + def description(self) -> str: + return self._schema.get('description', '') + + @property + def global_only(self) -> bool: + return self._schema.get('global_only', False) + + def validate(self, value: Any, cast: bool = True) -> Any: + """Validate a value + + :param value: The input value + :param cast: Attempt to cast the value to the required type + + :return: The output value + :raise: ConfigValidationError + + """ + # pylint: disable=too-many-branches + from .config import ConfigValidationError + from aiida.manage.caching import _validate_identifier_pattern + + if cast: + try: + if self.valid_type == 'boolean': + if isinstance(value, str): + if value.strip().lower() in ['0', 'false', 'f']: + value = False + elif value.strip().lower() in ['1', 'true', 't']: + value = True + else: + value = bool(value) + elif self.valid_type == 'string': + value = str(value) + elif self.valid_type == 'integer': + value = int(value) + elif self.valid_type == 'number': + value = float(value) + elif self.valid_type == 'array' and isinstance(value, str): + value = value.split() + except ValueError: + pass + + try: + jsonschema.validate(instance=value, schema=self.schema) + except jsonschema.ValidationError as exc: + raise ConfigValidationError(message=exc.message, keypath=[self.name, *(exc.path or [])], schema=exc.schema) + + # special caching validation + if self.name in ('caching.enabled_for', 'caching.disabled_for'): + for i, identifier in enumerate(value): + try: + _validate_identifier_pattern(identifier=identifier) + except ValueError as exc: + raise ConfigValidationError(message=str(exc), keypath=[self.name, str(i)]) + + return value + + +def get_schema_options() -> Dict[str, Dict[str, Any]]: + """Return schema for options.""" + from .config import config_schema + schema = config_schema() + return schema['definitions']['options']['properties'] + + +def get_option_names() -> List[str]: + """Return a list of available option names.""" + return list(get_schema_options()) + + +def get_option(name: str) -> Option: + """Return option.""" + options = get_schema_options() + if name not in options: + raise ConfigurationError(f'the option {name} does not exist') + return Option(name, options[name]) + + +def parse_option(option_name: str, option_value: Any) -> Tuple[Option, Any]: """Parse and validate a value for a configuration option. :param option_name: the name of the configuration option :param option_value: the option value :return: a tuple of the option and the parsed value + """ option = get_option(option_name) - - value = False - - if option.valid_type == 'bool': - if isinstance(option_value, str): - if option_value.strip().lower() in ['0', 'false', 'f']: - value = False - elif option_value.strip().lower() in ['1', 'true', 't']: - value = True - else: - raise ValueError(f'option {option.name} expects a boolean value') - else: - value = bool(option_value) - elif option.valid_type == 'string': - value = str(option_value) - elif option.valid_type == 'int': - value = int(option_value) - elif option.valid_type == 'list_of_str': - value = option_value.split() - else: - raise NotImplementedError(f'Type string {option.valid_type} not implemented yet') - - if option.valid_values is not None: - if value not in option.valid_values: - raise ValueError( - '{} is not among the list of accepted values for option {}.\nThe valid values are: ' - '{}'.format(value, option.name, ', '.join(option.valid_values)) - ) + value = option.validate(option_value, cast=True) return option, value diff --git a/aiida/manage/configuration/profile.py b/aiida/manage/configuration/profile.py index e8276d3f06..593302116a 100644 --- a/aiida/manage/configuration/profile.py +++ b/aiida/manage/configuration/profile.py @@ -12,6 +12,8 @@ import os from aiida.common import exceptions + +from .options import parse_option from .settings import DAEMON_DIR, DAEMON_LOG_DIR __all__ = ('Profile',) @@ -270,8 +272,9 @@ def set_option(self, option_key, value, override=True): :param option_value: the option value :param override: boolean, if False, will not override the option if it already exists """ + _, parsed_value = parse_option(option_key, value) # ensure the value is validated if option_key not in self.options or override: - self.options[option_key] = value + self.options[option_key] = parsed_value def unset_option(self, option_key): self.options.pop(option_key, None) diff --git a/.ci/polish/polish_workchains/__init__.py b/aiida/manage/configuration/schema/__init__.py similarity index 100% rename from .ci/polish/polish_workchains/__init__.py rename to aiida/manage/configuration/schema/__init__.py diff --git a/aiida/manage/configuration/schema/config-v5.schema.json b/aiida/manage/configuration/schema/config-v5.schema.json new file mode 100644 index 0000000000..43ff1e87fb --- /dev/null +++ b/aiida/manage/configuration/schema/config-v5.schema.json @@ -0,0 +1,391 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema", + "description": "Schema for AiiDA configuration files, format version 5", + "type": "object", + "definitions": { + "options": { + "type": "object", + "properties": { + "runner.poll.interval": { + "type": "integer", + "default": 60, + "minimum": 0, + "description": "Polling interval in seconds to be used by process runners" + }, + "daemon.default_workers": { + "type": "integer", + "default": 1, + "minimum": 1, + "description": "Default number of workers to be launched by `verdi daemon start`" + }, + "daemon.timeout": { + "type": "integer", + "default": 20, + "minimum": 0, + "description": "Timeout in seconds for calls to the circus client" + }, + "daemon.worker_process_slots": { + "type": "integer", + "default": 200, + "minimum": 1, + "description": "Maximum number of concurrent process tasks that each daemon worker can handle" + }, + "db.batch_size": { + "type": "integer", + "default": 100000, + "minimum": 1, + "description": "Batch size for bulk CREATE operations in the database. Avoids hitting MaxAllocSize of PostgreSQL (1GB) when creating large numbers of database records in one go." + }, + "verdi.shell.auto_import": { + "type": "string", + "default": "", + "description": "Additional modules/functions/classes to be automatically loaded in `verdi shell`, split by ':'" + }, + "logging.aiida_loglevel": { + "type": "string", + "enum": [ + "CRITICAL", + "ERROR", + "WARNING", + "REPORT", + "INFO", + "DEBUG" + ], + "default": "REPORT", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `aiida` logger" + }, + "logging.db_loglevel": { + "type": "string", + "enum": [ + "CRITICAL", + "ERROR", + "WARNING", + "REPORT", + "INFO", + "DEBUG" + ], + "default": "REPORT", + "description": "Minimum level to log to the DbLog table" + }, + "logging.plumpy_loglevel": { + "type": "string", + "enum": [ + "CRITICAL", + "ERROR", + "WARNING", + "REPORT", + "INFO", + "DEBUG" + ], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `plumpy` logger" + }, + "logging.kiwipy_loglevel": { + "type": "string", + "enum": [ + "CRITICAL", + "ERROR", + "WARNING", + "REPORT", + "INFO", + "DEBUG" + ], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `kiwipy` logger" + }, + "logging.paramiko_loglevel": { + "key": "logging_paramiko_log_level", + "type": "string", + "enum": [ + "CRITICAL", + "ERROR", + "WARNING", + "REPORT", + "INFO", + "DEBUG" + ], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `paramiko` logger" + }, + "logging.alembic_loglevel": { + "type": "string", + "enum": [ + "CRITICAL", + "ERROR", + "WARNING", + "REPORT", + "INFO", + "DEBUG" + ], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `alembic` logger" + }, + "logging.sqlalchemy_loglevel": { + "type": "string", + "enum": [ + "CRITICAL", + "ERROR", + "WARNING", + "REPORT", + "INFO", + "DEBUG" + ], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `sqlalchemy` logger" + }, + "logging.circus_loglevel": { + "type": "string", + "enum": [ + "CRITICAL", + "ERROR", + "WARNING", + "REPORT", + "INFO", + "DEBUG" + ], + "default": "INFO", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `circus` logger" + }, + "logging.aiopika_loglevel": { + "type": "string", + "enum": [ + "CRITICAL", + "ERROR", + "WARNING", + "REPORT", + "INFO", + "DEBUG" + ], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `aio_pika` logger" + }, + "warnings.showdeprecations": { + "type": "boolean", + "default": true, + "description": "Whether to print AiiDA deprecation warnings" + }, + "transport.task_retry_initial_interval": { + "type": "integer", + "default": 20, + "minimum": 1, + "description": "Initial time interval for the exponential backoff mechanism." + }, + "transport.task_maximum_attempts": { + "type": "integer", + "default": 5, + "minimum": 1, + "description": "Maximum number of transport task attempts before a Process is Paused." + }, + "rmq.task_timeout": { + "type": "integer", + "default": 10, + "minimum": 1, + "description": "Timeout in seconds for communications with RabbitMQ" + }, + "caching.default_enabled": { + "type": "boolean", + "default": false, + "description": "Enable calculation caching by default" + }, + "caching.enabled_for": { + "description": "Calculation entry points to enable caching on", + "type": "array", + "default": [], + "items": { + "type": "string" + } + }, + "caching.disabled_for": { + "description": "Calculation entry points to disable caching on", + "type": "array", + "default": [], + "items": { + "type": "string" + } + }, + "autofill.user.email": { + "type": "string", + "global_only": true, + "description": "Default user email to use when creating new profiles." + }, + "autofill.user.first_name": { + "type": "string", + "global_only": true, + "description": "Default user first name to use when creating new profiles." + }, + "autofill.user.last_name": { + "type": "string", + "global_only": true, + "description": "Default user last name to use when creating new profiles." + }, + "autofill.user.institution": { + "type": "string", + "global_only": true, + "description": "Default user institution to use when creating new profiles." + } + } + }, + "profile": { + "type": "object", + "required": [ + "AIIDADB_REPOSITORY_URI", + "AIIDADB_BACKEND", + "AIIDADB_ENGINE", + "AIIDADB_HOST", + "AIIDADB_NAME", + "AIIDADB_PASS", + "AIIDADB_PORT", + "AIIDADB_USER" + ], + "properties": { + "PROFILE_UUID": { + "description": "The profile's unique key", + "type": "string" + }, + "AIIDADB_REPOSITORY_URI": { + "type": "string", + "description": "URI to the AiiDA object store" + }, + "AIIDADB_ENGINE": { + "type": "string", + "default": "postgresql_psycopg2" + }, + "AIIDADB_BACKEND": { + "type": "string", + "enum": [ + "django", + "sqlalchemy" + ], + "default": "django" + }, + "AIIDADB_NAME": { + "type": "string" + }, + "AIIDADB_PORT": { + "type": ["integer", "string"], + "minimum": 1, + "pattern": "\\d+", + "default": 5432 + }, + "AIIDADB_HOST": { + "type": [ + "string", + "null" + ], + "default": null + }, + "AIIDADB_USER": { + "type": "string" + }, + "AIIDADB_PASS": { + "type": [ + "string", + "null" + ], + "default": null + }, + "broker_protocol": { + "description": "Protocol for connecting to the RabbitMQ server", + "type": "string", + "enum": [ + "amqp", + "amqps" + ], + "default": "amqp" + }, + "broker_username": { + "description": "Username for RabbitMQ authentication", + "type": "string", + "default": "guest" + }, + "broker_password": { + "description": "Password for RabbitMQ authentication", + "type": "string", + "default": "guest" + }, + "broker_host": { + "description": "Hostname of the RabbitMQ server", + "type": "string", + "default": "127.0.0.1" + }, + "broker_port": { + "description": "Port of the RabbitMQ server", + "type": "integer", + "minimum": 1, + "default": 5672 + }, + "broker_virtual_host": { + "description": "RabbitMQ virtual host to connect to", + "type": "string", + "default": "" + }, + "broker_parameters": { + "description": "RabbitMQ arguments that will be encoded as query parameters", + "type": "object", + "default": { + "heartbeat": 600 + }, + "properties": { + "heartbeat": { + "description": "After how many seconds the peer TCP connection should be considered unreachable", + "type": "integer", + "default": 600, + "minimum": 0 + } + } + }, + "default_user_email": { + "type": [ + "string", + "null" + ], + "default": null + }, + "options": { + "description": "Profile specific options", + "$ref": "#/definitions/options" + } + } + } + }, + "required": [], + "properties": { + "CONFIG_VERSION": { + "description": "The configuration version", + "type": "object", + "required": [ + "CURRENT", + "OLDEST_COMPATIBLE" + ], + "properties": { + "CURRENT": { + "description": "Version number of configuration file format", + "type": "integer", + "const": 5 + }, + "OLDEST_COMPATIBLE": { + "description": "Version number of oldest configuration file format this file is compatible with", + "type": "integer", + "const": 5 + } + } + }, + "profiles": { + "description": "Configured profiles", + "type": "object", + "patternProperties": { + ".+": { + "$ref": "#/definitions/profile" + } + } + }, + "default_profile": { + "description": "Default profile to use", + "type": "string" + }, + "options": { + "description": "Global options", + "$ref": "#/definitions/options" + } + } +} diff --git a/aiida/manage/database/delete/nodes.py b/aiida/manage/database/delete/nodes.py index 47860be84b..03a7edc47f 100644 --- a/aiida/manage/database/delete/nodes.py +++ b/aiida/manage/database/delete/nodes.py @@ -7,120 +7,31 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Function to delete nodes from the database.""" +"""Functions to delete nodes from the database, preserving provenance integrity.""" +from typing import Callable, Iterable, Optional, Set, Tuple, Union +import warnings -import click -from aiida.cmdline.utils import echo +def delete_nodes( + pks: Iterable[int], + verbosity: Optional[int] = None, + dry_run: Union[bool, Callable[[Set[int]], bool]] = True, + force: Optional[bool] = None, + **traversal_rules: bool +) -> Tuple[Set[int], bool]: + """Delete nodes given a list of "starting" PKs. -def delete_nodes(pks, verbosity=0, dry_run=False, force=False, **kwargs): - """Delete nodes by a list of pks. + .. deprecated:: 1.6.0 + This function has been moved and will be removed in `v2.0.0`. + It should now be imported using `from aiida.tools import delete_nodes` - This command will delete not only the specified nodes, but also the ones that are - linked to these and should be also deleted in order to keep a consistent provenance - according to the rules explained in the concepts section of the documentation. - In summary: - - 1. If a DATA node is deleted, any process nodes linked to it will also be deleted. - - 2. If a CALC node is deleted, any incoming WORK node (callers) will be deleted as - well whereas any incoming DATA node (inputs) will be kept. Outgoing DATA nodes - (outputs) will be deleted by default but this can be disabled. - - 3. If a WORK node is deleted, any incoming WORK node (callers) will be deleted as - well, but all DATA nodes will be kept. Outgoing WORK or CALC nodes will be kept by - default, but deletion of either of both kind of connected nodes can be enabled. - - These rules are 'recursive', so if a CALC node is deleted, then its output DATA - nodes will be deleted as well, and then any CALC node that may have those as - inputs, and so on. - - :param pks: a list of the PKs of the nodes to delete - :param bool force: do not ask for confirmation to delete nodes. - :param int verbosity: 0 prints nothing, - 1 prints just sums and total, - 2 prints individual nodes. - - :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names - are toggleable and what the defaults are. - :param bool dry_run: - Just perform a dry run and do not delete anything. Print statistics according - to the verbosity level set. - :param bool force: - Do not ask for confirmation to delete nodes. """ - # pylint: disable=too-many-arguments,too-many-branches,too-many-locals,too-many-statements - from aiida.backends.utils import delete_nodes_and_connections - from aiida.common import exceptions - from aiida.orm import Node, QueryBuilder, load_node - from aiida.tools.graph.graph_traversers import get_nodes_delete - - starting_pks = [] - for pk in pks: - try: - load_node(pk) - except exceptions.NotExistent: - echo.echo_warning(f'warning: node with pk<{pk}> does not exist, skipping') - else: - starting_pks.append(pk) - - # An empty set might be problematic for the queries done below. - if not starting_pks: - if verbosity: - echo.echo('Nothing to delete') - return - - pks_set_to_delete = get_nodes_delete(starting_pks, **kwargs)['nodes'] - - if verbosity > 0: - echo.echo( - 'I {} delete {} node{}'.format( - 'would' if dry_run else 'will', len(pks_set_to_delete), 's' if len(pks_set_to_delete) > 1 else '' - ) - ) - if verbosity > 1: - builder = QueryBuilder().append( - Node, filters={'id': { - 'in': pks_set_to_delete - }}, project=('uuid', 'id', 'node_type', 'label') - ) - echo.echo(f"The nodes I {'would' if dry_run else 'will'} delete:") - for uuid, pk, type_string, label in builder.iterall(): - try: - short_type_string = type_string.split('.')[-2] - except IndexError: - short_type_string = type_string - echo.echo(f' {uuid} {pk} {short_type_string} {label}') - - if dry_run: - if verbosity > 0: - echo.echo('\nThis was a dry run, exiting without deleting anything') - return - - # Asking for user confirmation here - if force: - pass - else: - echo.echo_warning(f'YOU ARE ABOUT TO DELETE {len(pks_set_to_delete)} NODES! THIS CANNOT BE UNDONE!') - if not click.confirm('Shall I continue?'): - echo.echo('Exiting without deleting') - return - - # Recover the list of folders to delete before actually deleting the nodes. I will delete the folders only later, - # so that if there is a problem during the deletion of the nodes in the DB, I don't delete the folders - repositories = [load_node(pk)._repository for pk in pks_set_to_delete] # pylint: disable=protected-access - - if verbosity > 0: - echo.echo('Starting node deletion...') - delete_nodes_and_connections(pks_set_to_delete) - - if verbosity > 0: - echo.echo('Nodes deleted from database, deleting files from the repository now...') + from aiida.common.warnings import AiidaDeprecationWarning + from aiida.tools import delete_nodes as _delete - # If we are here, we managed to delete the entries from the DB. - # I can now delete the folders - for repository in repositories: - repository.erase(force=True) + warnings.warn( + 'This function has been moved and will be removed in `v2.0.0`.' + 'It should now be imported using `from aiida.tools import delete_nodes`', AiidaDeprecationWarning + ) # pylint: disable=no-member - if verbosity > 0: - echo.echo('Deletion completed.') + return _delete(pks, verbosity, dry_run, force, **traversal_rules) diff --git a/aiida/manage/external/postgres.py b/aiida/manage/external/postgres.py index c4faff90c1..0a6ff8f937 100644 --- a/aiida/manage/external/postgres.py +++ b/aiida/manage/external/postgres.py @@ -116,6 +116,8 @@ def check_dbuser(self, dbuser): :param str dbuser: Name of the user to be created or reused. :returns: tuple (dbuser, created) """ + if not self.interactive: + return dbuser, not self.dbuser_exists(dbuser) create = True while create and self.dbuser_exists(dbuser): echo.echo_info(f'Database user "{dbuser}" already exists!') @@ -163,6 +165,8 @@ def check_db(self, dbname): :param str dbname: Name of the database to be created or reused. :returns: tuple (dbname, created) """ + if not self.interactive: + return dbname, not self.db_exists(dbname) create = True while create and self.db_exists(dbname): echo.echo_info(f'database {dbname} already exists!') diff --git a/aiida/manage/external/rmq.py b/aiida/manage/external/rmq.py index 6b66c8d9c8..c7cccfd149 100644 --- a/aiida/manage/external/rmq.py +++ b/aiida/manage/external/rmq.py @@ -9,17 +9,25 @@ ########################################################################### # pylint: disable=cyclic-import """Components to communicate tasks to RabbitMQ.""" -import collections +import asyncio +from collections.abc import Mapping import logging +import traceback -from tornado import gen from kiwipy import communications, Future +import pamqp.encode import plumpy from aiida.common.extendeddicts import AttributeDict __all__ = ('RemoteException', 'CommunicationTimeout', 'DeliveryFailed', 'ProcessLauncher', 'BROKER_DEFAULTS') +# The following statement enables support for RabbitMQ 3.5 because without it, connections established by `aiormq` will +# fail because the interpretation of the types of integers passed in connection parameters has changed after that +# version. Once RabbitMQ 3.5 is no longer supported (it has been EOL since October 2016) this can be removed. This +# should also allow to remove the direct dependency on `pamqp` entirely. +pamqp.encode.support_deprecated_rabbitmq() + LOGGER = logging.getLogger(__name__) RemoteException = plumpy.RemoteException @@ -120,7 +128,7 @@ def _store_inputs(inputs): try: node.store() except AttributeError: - if isinstance(node, collections.Mapping): + if isinstance(node, Mapping): _store_inputs(node) @@ -146,14 +154,13 @@ def handle_continue_exception(node, exception, message): """ from aiida.engine import ProcessState - if not node.is_excepted: + if not node.is_excepted and not node.is_sealed: node.logger.exception(message) - node.set_exception(str(exception)) + node.set_exception(''.join(traceback.format_exception(type(exception), exception, None)).rstrip()) node.set_process_state(ProcessState.EXCEPTED) node.seal() - @gen.coroutine - def _continue(self, communicator, pid, nowait, tag=None): + async def _continue(self, communicator, pid, nowait, tag=None): """Continue the task. Note that the task may already have been completed, as indicated from the corresponding the node, in which @@ -180,7 +187,7 @@ def _continue(self, communicator, pid, nowait, tag=None): # we raise `Return` instead of `TaskRejected` because the latter would cause the task to be resent and start # to ping-pong between RabbitMQ and the daemon workers. LOGGER.exception('Cannot continue process<%d>', pid) - raise gen.Return(False) from exception + return False if node.is_terminated: @@ -195,14 +202,18 @@ def _continue(self, communicator, pid, nowait, tag=None): elif node.is_killed: future.set_exception(plumpy.KilledError()) - raise gen.Return(future.result()) + return future.result() try: - result = yield super()._continue(communicator, pid, nowait, tag) + result = await super()._continue(communicator, pid, nowait, tag) except ImportError as exception: message = 'the class of the process could not be imported.' self.handle_continue_exception(node, exception, message) raise + except asyncio.CancelledError: # pylint: disable=try-except-raise + # note this is only required in python<=3.7, + # where asyncio.CancelledError inherits from Exception + raise except Exception as exception: message = 'failed to recreate the process instance in order to continue it.' self.handle_continue_exception(node, exception, message) @@ -215,4 +226,4 @@ def _continue(self, communicator, pid, nowait, tag=None): LOGGER.exception('failed to serialize the result for process<%d>', pid) raise - raise gen.Return(serialized) + return serialized diff --git a/aiida/manage/manager.py b/aiida/manage/manager.py index ed79afd026..8f8bdfd1f1 100644 --- a/aiida/manage/manager.py +++ b/aiida/manage/manager.py @@ -9,11 +9,23 @@ ########################################################################### # pylint: disable=cyclic-import """AiiDA manager for global settings""" +import asyncio import functools +from typing import Any, Optional, TYPE_CHECKING -__all__ = ('get_manager', 'reset_manager') +if TYPE_CHECKING: + from kiwipy.rmq import RmqThreadCommunicator + from plumpy.process_comms import RemoteProcessThreadController + + from aiida.backends.manager import BackendManager + from aiida.engine.daemon.client import DaemonClient + from aiida.engine.runners import Runner + from aiida.manage.configuration.config import Config + from aiida.manage.configuration.profile import Profile + from aiida.orm.implementation import Backend + from aiida.engine.persistence import AiiDAPersister -MANAGER = None +__all__ = ('get_manager', 'reset_manager') class Manager: @@ -32,34 +44,62 @@ class Manager: * reset manager cache when loading a new profile """ + def __init__(self) -> None: + self._backend: Optional['Backend'] = None + self._backend_manager: Optional['BackendManager'] = None + self._config: Optional['Config'] = None + self._daemon_client: Optional['DaemonClient'] = None + self._profile: Optional['Profile'] = None + self._communicator: Optional['RmqThreadCommunicator'] = None + self._process_controller: Optional['RemoteProcessThreadController'] = None + self._persister: Optional['AiiDAPersister'] = None + self._runner: Optional['Runner'] = None + + def close(self) -> None: + """Reset the global settings entirely and release any global objects.""" + if self._communicator is not None: + self._communicator.close() + if self._runner is not None: + self._runner.stop() + + self._backend = None + self._backend_manager = None + self._config = None + self._profile = None + self._communicator = None + self._daemon_client = None + self._process_controller = None + self._persister = None + self._runner = None + @staticmethod - def get_config(): + def get_config() -> 'Config': """Return the current config. :return: current loaded config instance - :rtype: :class:`~aiida.manage.configuration.config.Config` :raises aiida.common.ConfigurationError: if the configuration file could not be found, read or deserialized + """ from .configuration import get_config return get_config() @staticmethod - def get_profile(): + def get_profile() -> Optional['Profile']: """Return the current loaded profile, if any :return: current loaded profile instance - :rtype: :class:`~aiida.manage.configuration.profile.Profile` or None + """ from .configuration import get_profile return get_profile() - def unload_backend(self): + def unload_backend(self) -> None: """Unload the current backend and its corresponding database environment.""" manager = self.get_backend_manager() manager.reset_backend_environment() self._backend = None - def _load_backend(self, schema_check=True): + def _load_backend(self, schema_check: bool = True) -> 'Backend': """Load the backend for the currently configured profile and return it. .. note:: this will reconstruct the `Backend` instance in `self._backend` so the preferred method to load the @@ -67,7 +107,7 @@ def _load_backend(self, schema_check=True): :param schema_check: force a database schema check if the database environment has not yet been loaded :return: the database backend - :rtype: :class:`aiida.orm.implementation.Backend` + """ from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA from aiida.common import ConfigurationError, InvalidOperation @@ -86,8 +126,9 @@ def _load_backend(self, schema_check=True): # Do NOT reload the backend environment if already loaded, simply reload the backend instance after if configuration.BACKEND_UUID is None: - manager = self.get_backend_manager() - manager.load_backend_environment(profile, validate_schema=schema_check) + from aiida.backends import get_backend_manager + backend_manager = get_backend_manager(profile.database_backend) + backend_manager.load_backend_environment(profile, validate_schema=schema_check) configuration.BACKEND_UUID = profile.uuid backend_type = profile.database_backend @@ -107,44 +148,52 @@ def _load_backend(self, schema_check=True): return self._backend @property - def backend_loaded(self): + def backend_loaded(self) -> bool: """Return whether a database backend has been loaded. :return: boolean, True if database backend is currently loaded, False otherwise """ return self._backend is not None - def get_backend_manager(self): + def get_backend_manager(self) -> 'BackendManager': """Return the database backend manager. .. note:: this is not the actual backend, but a manager class that is necessary for database operations that go around the actual ORM. For example when the schema version has not yet been validated. :return: the database backend manager - :rtype: :class:`aiida.backend.manager.BackendManager` + """ + from aiida.backends import get_backend_manager + from aiida.common import ConfigurationError + if self._backend_manager is None: - from aiida.backends import get_backend_manager - self._backend_manager = get_backend_manager(self.get_profile().database_backend) + self._load_backend() + profile = self.get_profile() + if profile is None: + raise ConfigurationError( + 'Could not determine the current profile. Consider loading a profile using `aiida.load_profile()`.' + ) + self._backend_manager = get_backend_manager(profile.database_backend) return self._backend_manager - def get_backend(self): + def get_backend(self) -> 'Backend': """Return the database backend :return: the database backend - :rtype: :class:`aiida.orm.implementation.Backend` + """ if self._backend is None: self._load_backend() return self._backend - def get_persister(self): + def get_persister(self) -> 'AiiDAPersister': """Return the persister :return: the current persister instance - :rtype: :class:`plumpy.Persister` + """ from aiida.engine import persistence @@ -153,18 +202,20 @@ def get_persister(self): return self._persister - def get_communicator(self): + def get_communicator(self) -> 'RmqThreadCommunicator': """Return the communicator :return: a global communicator instance - :rtype: :class:`kiwipy.Communicator` + """ if self._communicator is None: self._communicator = self.create_communicator() return self._communicator - def create_communicator(self, task_prefetch_count=None, with_orm=True): + def create_communicator( + self, task_prefetch_count: Optional[int] = None, with_orm: bool = True + ) -> 'RmqThreadCommunicator': """Create a Communicator. :param task_prefetch_count: optional specify how many tasks this communicator take simultaneously @@ -172,12 +223,17 @@ def create_communicator(self, task_prefetch_count=None, with_orm=True): This is used by verdi status to get a communicator without needing to load the dbenv. :return: the communicator instance - :rtype: :class:`~kiwipy.rmq.communicator.RmqThreadCommunicator` + """ + from aiida.common import ConfigurationError from aiida.manage.external import rmq import kiwipy.rmq profile = self.get_profile() + if profile is None: + raise ConfigurationError( + 'Could not determine the current profile. Consider loading a profile using `aiida.load_profile()`.' + ) if task_prefetch_count is None: task_prefetch_count = self.get_config().get_option('daemon.worker_process_slots', profile.name) @@ -202,16 +258,17 @@ def create_communicator(self, task_prefetch_count=None, with_orm=True): task_exchange=rmq.get_task_exchange_name(prefix), task_queue=rmq.get_launch_queue_name(prefix), task_prefetch_count=task_prefetch_count, + async_task_timeout=self.get_config().get_option('rmq.task_timeout', profile.name), # This is needed because the verdi commands will call this function and when called in unit tests the # testing_mode cannot be set. testing_mode=profile.is_test_profile, ) - def get_daemon_client(self): + def get_daemon_client(self) -> 'DaemonClient': """Return the daemon client for the current profile. :return: the daemon client - :rtype: :class:`aiida.daemon.client.DaemonClient` + :raises aiida.common.MissingConfigurationError: if the configuration file cannot be found :raises aiida.common.ProfileConfigurationError: if the given profile does not exist """ @@ -222,52 +279,57 @@ def get_daemon_client(self): return self._daemon_client - def get_process_controller(self): + def get_process_controller(self) -> 'RemoteProcessThreadController': """Return the process controller :return: the process controller instance - :rtype: :class:`plumpy.RemoteProcessThreadController` + """ - import plumpy + from plumpy.process_comms import RemoteProcessThreadController if self._process_controller is None: - self._process_controller = plumpy.RemoteProcessThreadController(self.get_communicator()) + self._process_controller = RemoteProcessThreadController(self.get_communicator()) return self._process_controller - def get_runner(self): + def get_runner(self, **kwargs) -> 'Runner': """Return a runner that is based on the current profile settings and can be used globally by the code. :return: the global runner - :rtype: :class:`aiida.engine.runners.Runner` + """ if self._runner is None: - self._runner = self.create_runner() + self._runner = self.create_runner(**kwargs) return self._runner - def set_runner(self, new_runner): + def set_runner(self, new_runner: 'Runner') -> None: """Set the currently used runner :param new_runner: the new runner to use - :type new_runner: :class:`aiida.engine.runners.Runner` + """ if self._runner is not None: self._runner.close() self._runner = new_runner - def create_runner(self, with_persistence=True, **kwargs): + def create_runner(self, with_persistence: bool = True, **kwargs: Any) -> 'Runner': """Create and return a new runner :param with_persistence: create a runner with persistence enabled - :type with_persistence: bool + :return: a new runner instance - :rtype: :class:`aiida.engine.runners.Runner` + """ + from aiida.common import ConfigurationError from aiida.engine import runners config = self.get_config() profile = self.get_profile() + if profile is None: + raise ConfigurationError( + 'Could not determine the current profile. Consider loading a profile using `aiida.load_profile()`.' + ) poll_interval = 0.0 if profile.is_test_profile else config.get_option('runner.poll.interval', profile.name) settings = {'rmq_submit': False, 'poll_interval': poll_interval} @@ -282,17 +344,17 @@ def create_runner(self, with_persistence=True, **kwargs): return runners.Runner(**settings) - def create_daemon_runner(self, loop=None): + def create_daemon_runner(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Runner': """Create and return a new daemon runner. This is used by workers when the daemon is running and in testing. - :param loop: the (optional) tornado event loop to use - :type loop: :class:`tornado.ioloop.IOLoop` + :param loop: the (optional) asyncio event loop to use + :return: a runner configured to work in the daemon configuration - :rtype: :class:`aiida.engine.runners.Runner` + """ - import plumpy + from plumpy.persistence import LoadSaveContext from aiida.engine import persistence from aiida.manage.external import rmq @@ -303,52 +365,27 @@ def create_daemon_runner(self, loop=None): task_receiver = rmq.ProcessLauncher( loop=runner_loop, persister=self.get_persister(), - load_context=plumpy.LoadSaveContext(runner=runner), + load_context=LoadSaveContext(runner=runner), loader=persistence.get_object_loader() ) + assert runner.communicator is not None, 'communicator not set for runner' runner.communicator.add_task_subscriber(task_receiver) return runner - def close(self): - """Reset the global settings entirely and release any global objects.""" - if self._communicator is not None: - self._communicator.stop() - if self._runner is not None: - self._runner.stop() - - self._backend = None - self._backend_manager = None - self._config = None - self._profile = None - self._communicator = None - self._daemon_client = None - self._process_controller = None - self._persister = None - self._runner = None - def __init__(self): - super().__init__() - self._backend = None # type: aiida.orm.implementation.Backend - self._backend_manager = None # type: aiida.backend.manager.BackendManager - self._config = None # type: aiida.manage.configuration.config.Config - self._daemon_client = None # type: aiida.daemon.client.DaemonClient - self._profile = None # type: aiida.manage.configuration.profile.Profile - self._communicator = None # type: kiwipy.rmq.RmqThreadCommunicator - self._process_controller = None # type: plumpy.RemoteProcessThreadController - self._persister = None # type: aiida.engine.persistence.AiiDAPersister - self._runner = None # type: aiida.engine.runners.Runner +MANAGER: Optional[Manager] = None -def get_manager(): +def get_manager() -> Manager: global MANAGER # pylint: disable=global-statement if MANAGER is None: MANAGER = Manager() return MANAGER -def reset_manager(): +def reset_manager() -> None: global MANAGER # pylint: disable=global-statement if MANAGER is not None: MANAGER.close() diff --git a/aiida/manage/tests/pytest_fixtures.py b/aiida/manage/tests/pytest_fixtures.py index 2512546bc4..08c203b358 100644 --- a/aiida/manage/tests/pytest_fixtures.py +++ b/aiida/manage/tests/pytest_fixtures.py @@ -17,13 +17,24 @@ * aiida_local_code_factory """ +import asyncio import shutil import tempfile import pytest +from aiida.common.log import AIIDA_LOGGER from aiida.manage.tests import test_manager, get_test_backend_name, get_test_profile_name +@pytest.fixture(scope='function') +def aiida_caplog(caplog): + """A copy of pytest's caplog fixture, which allows ``AIIDA_LOGGER`` to propagate.""" + propogate = AIIDA_LOGGER.propagate + AIIDA_LOGGER.propagate = True + yield caplog + AIIDA_LOGGER.propagate = propogate + + @pytest.fixture(scope='session', autouse=True) def aiida_profile(): """Set up AiiDA test profile for the duration of the tests. @@ -59,6 +70,19 @@ def clear_database_before_test(aiida_profile): yield +@pytest.fixture(scope='function') +def temporary_event_loop(): + """Create a temporary loop for independent test case""" + current = asyncio.get_event_loop() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + yield + finally: + loop.close() + asyncio.set_event_loop(current) + + @pytest.fixture(scope='function') def temp_dir(): """Get a temporary directory. @@ -84,7 +108,7 @@ def aiida_localhost(temp_dir): Usage:: def test_1(aiida_localhost): - name = aiida_localhost.get_name() + label = aiida_localhost.get_label() # proceed to set up code or use 'aiida_local_code_factory' instead diff --git a/aiida/orm/computers.py b/aiida/orm/computers.py index 4e83ac05c4..358163fd6f 100644 --- a/aiida/orm/computers.py +++ b/aiida/orm/computers.py @@ -26,17 +26,7 @@ class Computer(entities.Entity): """ - Base class to map a node in the DB + its permanent repository counterpart. - - Stores attributes starting with an underscore. - - Caches files and attributes before the first save, and saves everything only on store(). - After the call to store(), attributes cannot be changed. - - Only after storing (or upon loading from uuid) metadata can be modified - and in this case they are directly set on the db. - - In the plugin, also set the _plugin_type_string, to be set in the DB in the 'type' field. + Computer entity. """ # pylint: disable=too-many-public-methods @@ -68,6 +58,26 @@ def get(self, **filters): return super().get(**filters) + def get_or_create(self, label=None, **kwargs): + """ + Try to retrieve a Computer from the DB with the given arguments; + create (and store) a new Computer if such a Computer was not present yet. + + :param label: computer label + :type label: str + + :return: (computer, created) where computer is the computer (new or existing, + in any case already stored) and created is a boolean saying + :rtype: (:class:`aiida.orm.Computer`, bool) + """ + if not label: + raise ValueError('Computer label must be provided') + + try: + return False, self.get(label=label) + except exceptions.NotExistent: + return True, Computer(backend=self.backend, label=label, **kwargs) + def list_names(self): """Return a list with all the names of the computers in the DB. @@ -721,7 +731,8 @@ def get_configuration(self, user=None): config = {} try: - authinfo = backend.authinfos.get(self, user) + # Need to pass the backend entity here, not just self + authinfo = backend.authinfos.get(self._backend_entity, user) config = authinfo.get_auth_params() except exceptions.NotExistent: pass diff --git a/aiida/orm/implementation/django/comments.py b/aiida/orm/implementation/django/comments.py index abdcf798ab..1e6f2b0521 100644 --- a/aiida/orm/implementation/django/comments.py +++ b/aiida/orm/implementation/django/comments.py @@ -67,9 +67,7 @@ def store(self): if self._dbmodel.dbnode.id is None or self._dbmodel.user.id is None: raise exceptions.ModificationNotAllowed('The corresponding node and/or user are not stored') - # `contextlib.suppress` provides empty context and can be replaced with `contextlib.nullcontext` after we drop - # support for python 3.6 - with suppress_auto_now([(models.DbComment, ['mtime'])]) if self.mtime else contextlib.suppress(): + with suppress_auto_now([(models.DbComment, ['mtime'])]) if self.mtime else contextlib.nullcontext(): super().store() @property diff --git a/aiida/orm/implementation/django/nodes.py b/aiida/orm/implementation/django/nodes.py index d8f527e5fd..af47942246 100644 --- a/aiida/orm/implementation/django/nodes.py +++ b/aiida/orm/implementation/django/nodes.py @@ -201,10 +201,8 @@ def store(self, links=None, with_transaction=True, clean=True): # pylint: disab if clean: self.clean_values() - # `contextlib.suppress` provides empty context and can be replaced with `contextlib.nullcontext` after we drop - # support for python 3.6 - with transaction.atomic() if with_transaction else contextlib.suppress(): - with suppress_auto_now([(models.DbNode, ['mtime'])]) if self.mtime else contextlib.suppress(): + with transaction.atomic() if with_transaction else contextlib.nullcontext(): + with suppress_auto_now([(models.DbNode, ['mtime'])]) if self.mtime else contextlib.nullcontext(): # We need to save the node model instance itself first such that it has a pk # that can be used in the foreign keys that will be needed for setting the # attributes and links diff --git a/aiida/orm/implementation/sqlalchemy/groups.py b/aiida/orm/implementation/sqlalchemy/groups.py index 8cc831b850..5284720f0d 100644 --- a/aiida/orm/implementation/sqlalchemy/groups.py +++ b/aiida/orm/implementation/sqlalchemy/groups.py @@ -9,7 +9,7 @@ ########################################################################### """SQLA groups""" -import collections +from collections.abc import Iterable import logging from aiida.backends import sqlalchemy as sa @@ -317,7 +317,7 @@ def query( if past_days is not None: filters.append(DbGroup.time >= past_days) if nodes: - if not isinstance(nodes, collections.Iterable): + if not isinstance(nodes, Iterable): nodes = [nodes] if not all(isinstance(n, (SqlaNode, DbNode)) for n in nodes): diff --git a/aiida/orm/nodes/data/__init__.py b/aiida/orm/nodes/data/__init__.py index 0023f8a107..8ed0d10aa4 100644 --- a/aiida/orm/nodes/data/__init__.py +++ b/aiida/orm/nodes/data/__init__.py @@ -8,9 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with `Node` sub classes for data structures.""" - from .array import ArrayData, BandsData, KpointsData, ProjectionData, TrajectoryData, XyData -from .base import BaseType +from .base import BaseType, to_aiida_type from .bool import Bool from .cif import CifData from .code import Code @@ -22,7 +21,7 @@ from .list import List from .numeric import NumericType from .orbital import OrbitalData -from .remote import RemoteData +from .remote import RemoteData, RemoteStashData, RemoteStashFolderData from .singlefile import SinglefileData from .str import Str from .structure import StructureData @@ -30,6 +29,6 @@ __all__ = ( 'Data', 'BaseType', 'ArrayData', 'BandsData', 'KpointsData', 'ProjectionData', 'TrajectoryData', 'XyData', 'Bool', - 'CifData', 'Code', 'Float', 'FolderData', 'Int', 'List', 'OrbitalData', 'Dict', 'RemoteData', 'SinglefileData', - 'Str', 'StructureData', 'UpfData', 'NumericType' + 'CifData', 'Code', 'Float', 'FolderData', 'Int', 'List', 'OrbitalData', 'Dict', 'RemoteData', 'RemoteStashData', + 'RemoteStashFolderData', 'SinglefileData', 'Str', 'StructureData', 'UpfData', 'NumericType', 'to_aiida_type' ) diff --git a/aiida/orm/nodes/data/array/kpoints.py b/aiida/orm/nodes/data/array/kpoints.py index e0709a3692..a3aa1630d0 100644 --- a/aiida/orm/nodes/data/array/kpoints.py +++ b/aiida/orm/nodes/data/array/kpoints.py @@ -372,7 +372,7 @@ def _validate_kpoints_weights(self, kpoints, weights): else: raise ValueError(f'kpoints must be a list of lists in {self._dimension}D case') - if kpoints.dtype != numpy.dtype(numpy.float): + if kpoints.dtype != numpy.dtype(float): raise ValueError(f'kpoints must be an array of type floats. Found instead {kpoints.dtype}') if kpoints.shape[1] < self._dimension: @@ -385,7 +385,7 @@ def _validate_kpoints_weights(self, kpoints, weights): weights = numpy.array(weights) if weights.shape[0] != kpoints.shape[0]: raise ValueError(f'Found {weights.shape[0]} weights but {kpoints.shape[0]} kpoints') - if weights.dtype != numpy.dtype(numpy.float): + if weights.dtype != numpy.dtype(float): raise ValueError(f'weights must be an array of type floats. Found instead {weights.dtype}') return kpoints, weights diff --git a/aiida/orm/nodes/data/array/trajectory.py b/aiida/orm/nodes/data/array/trajectory.py index ce43a4452e..5bc4d90e17 100644 --- a/aiida/orm/nodes/data/array/trajectory.py +++ b/aiida/orm/nodes/data/array/trajectory.py @@ -11,7 +11,7 @@ AiiDA class to deal with crystal structure trajectories. """ -import collections +import collections.abc from .array import ArrayData @@ -35,7 +35,7 @@ def _internal_validate(self, stepids, cells, symbols, positions, times, velociti """ import numpy - if not isinstance(symbols, collections.Iterable): + if not isinstance(symbols, collections.abc.Iterable): raise TypeError('TrajectoryData.symbols must be of type list') if any([not isinstance(i, str) for i in symbols]): raise TypeError('TrajectoryData.symbols must be a 1d list of strings') diff --git a/aiida/orm/nodes/data/remote/__init__.py b/aiida/orm/nodes/data/remote/__init__.py new file mode 100644 index 0000000000..2f88d7edbc --- /dev/null +++ b/aiida/orm/nodes/data/remote/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +"""Module with data plugins that represent remote resources and so effectively are symbolic links.""" +from .base import RemoteData +from .stash import RemoteStashData, RemoteStashFolderData + +__all__ = ('RemoteData', 'RemoteStashData', 'RemoteStashFolderData') diff --git a/aiida/orm/nodes/data/remote.py b/aiida/orm/nodes/data/remote/base.py similarity index 95% rename from aiida/orm/nodes/data/remote.py rename to aiida/orm/nodes/data/remote/base.py index ba8b8e52e0..b293e2e6b9 100644 --- a/aiida/orm/nodes/data/remote.py +++ b/aiida/orm/nodes/data/remote/base.py @@ -11,7 +11,7 @@ import os from aiida.orm import AuthInfo -from .data import Data +from ..data import Data __all__ = ('RemoteData',) @@ -79,7 +79,7 @@ def getfile(self, relpath, destpath): full_path, self.computer.label # pylint: disable=no-member ) - ) + ) from exception raise def listdir(self, relpath='.'): @@ -102,7 +102,7 @@ def listdir(self, relpath='.'): format(full_path, self.computer.label) # pylint: disable=no-member ) exc.errno = exception.errno - raise exc + raise exc from exception else: raise @@ -115,7 +115,7 @@ def listdir(self, relpath='.'): format(full_path, self.computer.label) # pylint: disable=no-member ) exc.errno = exception.errno - raise exc + raise exc from exception else: raise @@ -139,7 +139,7 @@ def listdir_withattributes(self, path='.'): format(full_path, self.computer.label) # pylint: disable=no-member ) exc.errno = exception.errno - raise exc + raise exc from exception else: raise @@ -152,7 +152,7 @@ def listdir_withattributes(self, path='.'): format(full_path, self.computer.label) # pylint: disable=no-member ) exc.errno = exception.errno - raise exc + raise exc from exception else: raise @@ -176,8 +176,8 @@ def _validate(self): try: self.get_remote_path() - except AttributeError: - raise ValidationError("attribute 'remote_path' not set.") + except AttributeError as exception: + raise ValidationError("attribute 'remote_path' not set.") from exception computer = self.computer if computer is None: diff --git a/aiida/orm/nodes/data/remote/stash/__init__.py b/aiida/orm/nodes/data/remote/stash/__init__.py new file mode 100644 index 0000000000..f744240cfc --- /dev/null +++ b/aiida/orm/nodes/data/remote/stash/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +"""Module with data plugins that represent files of completed calculations jobs that have been stashed.""" +from .base import RemoteStashData +from .folder import RemoteStashFolderData + +__all__ = ('RemoteStashData', 'RemoteStashFolderData') diff --git a/aiida/orm/nodes/data/remote/stash/base.py b/aiida/orm/nodes/data/remote/stash/base.py new file mode 100644 index 0000000000..f904643bab --- /dev/null +++ b/aiida/orm/nodes/data/remote/stash/base.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +"""Data plugin that models an archived folder on a remote computer.""" +from aiida.common.datastructures import StashMode +from aiida.common.lang import type_check +from ...data import Data + +__all__ = ('RemoteStashData',) + + +class RemoteStashData(Data): + """Data plugin that models an archived folder on a remote computer. + + A stashed folder is essentially an instance of ``RemoteData`` that has been archived. Archiving in this context can + simply mean copying the content of the folder to another location on the same or another filesystem as long as it is + on the same machine. In addition, the folder may have been compressed into a single file for efficiency or even + written to tape. The ``stash_mode`` attribute will distinguish how the folder was stashed which will allow the + implementation to also `unstash` it and transform it back into a ``RemoteData`` such that it can be used as an input + for new ``CalcJobs``. + + This class is a non-storable base class that merely registers the ``stash_mode`` attribute. Only its subclasses, + that actually implement a certain stash mode, can be instantiated and therefore stored. The reason for this design + is that because the behavior of the class can change significantly based on the mode employed to stash the files and + implementing all these variants in the same class will lead to an unintuitive interface where certain properties or + methods of the class will only be available or function properly based on the ``stash_mode``. + """ + + _storable = False + + def __init__(self, stash_mode: StashMode, **kwargs): + """Construct a new instance + + :param stash_mode: the stashing mode with which the data was stashed on the remote. + """ + super().__init__(**kwargs) + self.stash_mode = stash_mode + + @property + def stash_mode(self) -> StashMode: + """Return the mode with which the data was stashed on the remote. + + :return: the stash mode. + """ + return StashMode(self.get_attribute('stash_mode')) + + @stash_mode.setter + def stash_mode(self, value: StashMode): + """Set the mode with which the data was stashed on the remote. + + :param value: the stash mode. + """ + type_check(value, StashMode) + self.set_attribute('stash_mode', value.value) diff --git a/aiida/orm/nodes/data/remote/stash/folder.py b/aiida/orm/nodes/data/remote/stash/folder.py new file mode 100644 index 0000000000..7d7c00b2fc --- /dev/null +++ b/aiida/orm/nodes/data/remote/stash/folder.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +"""Data plugin that models a stashed folder on a remote computer.""" +import typing + +from aiida.common.datastructures import StashMode +from aiida.common.lang import type_check +from .base import RemoteStashData + +__all__ = ('RemoteStashFolderData',) + + +class RemoteStashFolderData(RemoteStashData): + """Data plugin that models a folder with files of a completed calculation job that has been stashed through a copy. + + This data plugin can and should be used to stash files if and only if the stash mode is `StashMode.COPY`. + """ + + _storable = True + + def __init__(self, stash_mode: StashMode, target_basepath: str, source_list: typing.List, **kwargs): + """Construct a new instance + + :param stash_mode: the stashing mode with which the data was stashed on the remote. + :param target_basepath: the target basepath. + :param source_list: the list of source files. + """ + super().__init__(stash_mode, **kwargs) + self.target_basepath = target_basepath + self.source_list = source_list + + if stash_mode != StashMode.COPY: + raise ValueError('`RemoteStashFolderData` can only be used with `stash_mode == StashMode.COPY`.') + + @property + def target_basepath(self) -> str: + """Return the target basepath. + + :return: the target basepath. + """ + return self.get_attribute('target_basepath') + + @target_basepath.setter + def target_basepath(self, value: str): + """Set the target basepath. + + :param value: the target basepath. + """ + type_check(value, str) + self.set_attribute('target_basepath', value) + + @property + def source_list(self) -> typing.Union[typing.List, typing.Tuple]: + """Return the list of source files that were stashed. + + :return: the list of source files. + """ + return self.get_attribute('source_list') + + @source_list.setter + def source_list(self, value: typing.Union[typing.List, typing.Tuple]): + """Set the list of source files that were stashed. + + :param value: the list of source files. + """ + type_check(value, (list, tuple)) + self.set_attribute('source_list', value) diff --git a/aiida/orm/nodes/data/singlefile.py b/aiida/orm/nodes/data/singlefile.py index 17c03663ed..eecc0484d3 100644 --- a/aiida/orm/nodes/data/singlefile.py +++ b/aiida/orm/nodes/data/singlefile.py @@ -11,6 +11,7 @@ import inspect import os import warnings +import pathlib from aiida.common import exceptions from aiida.common.warnings import AiidaDeprecationWarning @@ -102,7 +103,7 @@ def set_file(self, file, filename=None): """ # pylint: disable=redefined-builtin - if isinstance(file, str): + if isinstance(file, (str, pathlib.Path)): is_filelike = False key = os.path.basename(file) diff --git a/aiida/orm/nodes/data/structure.py b/aiida/orm/nodes/data/structure.py index 8b2827d5e0..f6187c4485 100644 --- a/aiida/orm/nodes/data/structure.py +++ b/aiida/orm/nodes/data/structure.py @@ -115,8 +115,12 @@ def get_pymatgen_version(): """ if not has_pymatgen(): return None - import pymatgen - return pymatgen.__version__ + try: + from pymatgen import __version__ + except ImportError: + # this was changed in version 2022.0.3 + from pymatgen.core import __version__ + return __version__ def has_spglib(): @@ -1852,7 +1856,7 @@ def _get_object_pymatgen_structure(self, **kwargs): .. note:: Requires the pymatgen module (version >= 3.0.13, usage of earlier versions may cause errors) """ - from pymatgen import Structure + from pymatgen.core.structure import Structure if self.pbc != (True, True, True): raise ValueError('Periodic boundary conditions must apply in all three dimensions of real space') @@ -1862,7 +1866,7 @@ def _get_object_pymatgen_structure(self, **kwargs): if (kwargs.pop('add_spin', False) and any([n.endswith('1') or n.endswith('2') for n in self.get_kind_names()])): # case when spins are defined -> no partial occupancy allowed - from pymatgen import Specie + from pymatgen.core.periodic_table import Specie oxidation_state = 0 # now I always set the oxidation_state to zero for site in self.sites: kind = self.get_kind(site.kind_name) @@ -1907,7 +1911,7 @@ def _get_object_pymatgen_molecule(self, **kwargs): .. note:: Requires the pymatgen module (version >= 3.0.13, usage of earlier versions may cause errors) """ - from pymatgen import Molecule + from pymatgen.core.structure import Molecule if kwargs: raise ValueError(f'Unrecognized parameters passed to pymatgen converter: {kwargs.keys()}') diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index a1664d1f9a..a30a1d1135 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -9,9 +9,14 @@ ########################################################################### # pylint: disable=too-many-lines,too-many-arguments """Package for node ORM classes.""" +import datetime import importlib +from logging import Logger import warnings import traceback +from typing import Any, Dict, IO, Iterator, List, Optional, Sequence, Tuple, Type, Union +from typing import TYPE_CHECKING +from uuid import UUID from aiida.common import exceptions from aiida.common.escaping import sql_string_match @@ -32,20 +37,25 @@ from ..querybuilder import QueryBuilder from ..users import User +if TYPE_CHECKING: + from aiida.repository import File + from ..implementation import Backend + from ..implementation.nodes import BackendNode + __all__ = ('Node',) -_NO_DEFAULT = tuple() +_NO_DEFAULT = tuple() # type: ignore[var-annotated] class WarnWhenNotEntered: """Temporary wrapper to warn when `Node.open` is called outside of a context manager.""" - def __init__(self, fileobj, name): - self._fileobj = fileobj + def __init__(self, fileobj: Union[IO[str], IO[bytes]], name: str) -> None: + self._fileobj: Union[IO[str], IO[bytes]] = fileobj self._name = name self._was_entered = False - def _warn_if_not_entered(self, method): + def _warn_if_not_entered(self, method) -> None: """Fire a warning if the object wrapper has not yet been entered.""" if not self._was_entered: msg = f'\nThe method `{method}` was called on the return value of `{self._name}.open()`' + \ @@ -62,34 +72,34 @@ def _warn_if_not_entered(self, method): warnings.warn(msg, AiidaDeprecationWarning) # pylint: disable=no-member - def __enter__(self): + def __enter__(self) -> Union[IO[str], IO[bytes]]: self._was_entered = True return self._fileobj.__enter__() - def __exit__(self, *args): + def __exit__(self, *args: Any) -> None: self._fileobj.__exit__(*args) - def __getattr__(self, key): + def __getattr__(self, key: str): if key == '_fileobj': return self._fileobj return getattr(self._fileobj, key) - def __del__(self): + def __del__(self) -> None: self._warn_if_not_entered('del') - def __iter__(self): + def __iter__(self) -> Iterator[Union[str, bytes]]: return self._fileobj.__iter__() - def __next__(self): + def __next__(self) -> Union[str, bytes]: return self._fileobj.__next__() - def read(self, *args, **kwargs): + def read(self, *args: Any, **kwargs: Any) -> Union[str, bytes]: self._warn_if_not_entered('read') return self._fileobj.read(*args, **kwargs) - def close(self, *args, **kwargs): + def close(self, *args: Any, **kwargs: Any) -> None: self._warn_if_not_entered('close') - return self._fileobj.close(*args, **kwargs) + return self._fileobj.close(*args, **kwargs) # type: ignore[call-arg] class Node(Entity, EntityAttributesMixin, EntityExtrasMixin, metaclass=AbstractNodeMeta): @@ -113,7 +123,7 @@ class Node(Entity, EntityAttributesMixin, EntityExtrasMixin, metaclass=AbstractN class Collection(EntityCollection): """The collection of nodes.""" - def delete(self, node_id): + def delete(self, node_id: int) -> None: """Delete a `Node` from the collection with the given id :param node_id: the node id @@ -134,14 +144,14 @@ def delete(self, node_id): repository.erase(force=True) # This will be set by the metaclass call - _logger = None + _logger: Optional[Logger] = None # A tuple of attribute names that can be updated even after node is stored # Requires Sealable mixin, but needs empty tuple for base class - _updatable_attributes = tuple() + _updatable_attributes: Tuple[str, ...] = tuple() # A tuple of attribute names that will be ignored when creating the hash. - _hash_ignored_attributes = tuple() + _hash_ignored_attributes: Tuple[str, ...] = tuple() # Flag that determines whether the class can be cached. _cachable = False @@ -154,15 +164,21 @@ def delete(self, node_id): _unstorable_message = 'only Data, WorkflowNode, CalculationNode or their subclasses can be stored' # These are to be initialized in the `initialization` method - _incoming_cache = None - _repository = None + _incoming_cache: Optional[List[LinkTriple]] = None + _repository: Optional[Repository] = None @classmethod - def from_backend_entity(cls, backend_entity): + def from_backend_entity(cls, backend_entity: 'BackendNode') -> 'Node': entity = super().from_backend_entity(backend_entity) return entity - def __init__(self, backend=None, user=None, computer=None, **kwargs): + def __init__( + self, + backend: Optional['Backend'] = None, + user: Optional[User] = None, + computer: Optional[Computer] = None, + **kwargs: Any + ) -> None: backend = backend or get_manager().get_backend() if computer and not computer.is_stored: @@ -179,10 +195,24 @@ def __init__(self, backend=None, user=None, computer=None, **kwargs): ) super().__init__(backend_entity) - def __repr__(self): + @property + def backend_entity(self) -> 'BackendNode': + return super().backend_entity + + def __eq__(self, other: Any) -> bool: + """Fallback equality comparison by uuid (can be overwritten by specific types)""" + if isinstance(other, Node) and self.uuid == other.uuid: + return True + return super().__eq__(other) + + def __hash__(self) -> int: + """Python-Hash: Implementation that is compatible with __eq__""" + return UUID(self.uuid).int + + def __repr__(self) -> str: return f'<{self.__class__.__name__}: {str(self)}>' - def __str__(self): + def __str__(self) -> str: if not self.is_stored: return f'uuid: {self.uuid} (unstored)' @@ -196,7 +226,7 @@ def __deepcopy__(self, memo): """Deep copying a Node is not supported in general, but only for the Data sub class.""" raise exceptions.InvalidOperation('deep copying a base Node is not supported') - def initialize(self): + def initialize(self) -> None: """ Initialize internal variables for the backend node @@ -210,7 +240,7 @@ def initialize(self): # Calls the initialisation from the RepositoryMixin self._repository = Repository(uuid=self.uuid, is_stored=self.is_stored, base_path=self._repository_base_path) - def _validate(self): + def _validate(self) -> bool: """Check if the attributes and files retrieved from the database are valid. Must be able to work even before storing: therefore, use the `get_attr` and similar methods that automatically @@ -222,7 +252,7 @@ def _validate(self): # pylint: disable=no-self-use return True - def validate_storability(self): + def validate_storability(self) -> None: """Verify that the current node is allowed to be stored. :raises `aiida.common.exceptions.StoringNotAllowed`: if the node does not match all requirements for storing @@ -237,13 +267,13 @@ def validate_storability(self): raise exceptions.StoringNotAllowed(msg) @classproperty - def class_node_type(cls): + def class_node_type(cls) -> str: """Returns the node type of this node (sub) class.""" # pylint: disable=no-self-argument,no-member return cls._plugin_type_string @property - def logger(self): + def logger(self) -> Optional[Logger]: """Return the logger configured for this Node. :return: Logger object @@ -251,16 +281,16 @@ def logger(self): return self._logger @property - def uuid(self): + def uuid(self) -> str: """Return the node UUID. :return: the string representation of the UUID - :rtype: str + """ return self.backend_entity.uuid @property - def node_type(self): + def node_type(self) -> str: """Return the node type. :return: the node type @@ -268,7 +298,7 @@ def node_type(self): return self.backend_entity.node_type @property - def process_type(self): + def process_type(self) -> Optional[str]: """Return the node process type. :return: the process type @@ -276,7 +306,7 @@ def process_type(self): return self.backend_entity.process_type @process_type.setter - def process_type(self, value): + def process_type(self, value: str) -> None: """Set the node process type. :param value: the new value to set @@ -284,7 +314,7 @@ def process_type(self, value): self.backend_entity.process_type = value @property - def label(self): + def label(self) -> str: """Return the node label. :return: the label @@ -292,7 +322,7 @@ def label(self): return self.backend_entity.label @label.setter - def label(self, value): + def label(self, value: str) -> None: """Set the label. :param value: the new value to set @@ -300,7 +330,7 @@ def label(self, value): self.backend_entity.label = value @property - def description(self): + def description(self) -> str: """Return the node description. :return: the description @@ -308,7 +338,7 @@ def description(self): return self.backend_entity.description @description.setter - def description(self, value): + def description(self, value: str) -> None: """Set the description. :param value: the new value to set @@ -316,7 +346,7 @@ def description(self, value): self.backend_entity.description = value @property - def computer(self): + def computer(self) -> Optional[Computer]: """Return the computer of this node. :return: the computer or None @@ -328,7 +358,7 @@ def computer(self): return None @computer.setter - def computer(self, computer): + def computer(self, computer: Optional[Computer]) -> None: """Set the computer of this node. :param computer: a `Computer` @@ -344,7 +374,7 @@ def computer(self, computer): self.backend_entity.computer = computer @property - def user(self): + def user(self) -> User: """Return the user of this node. :return: the user @@ -353,7 +383,7 @@ def user(self): return User.from_backend_entity(self.backend_entity.user) @user.setter - def user(self, user): + def user(self, user: User) -> None: """Set the user of this node. :param user: a `User` @@ -365,7 +395,7 @@ def user(self, user): self.backend_entity.user = user.backend_entity @property - def ctime(self): + def ctime(self) -> datetime.datetime: """Return the node ctime. :return: the ctime @@ -373,14 +403,14 @@ def ctime(self): return self.backend_entity.ctime @property - def mtime(self): + def mtime(self) -> datetime.datetime: """Return the node mtime. :return: the mtime """ return self.backend_entity.mtime - def list_objects(self, path=None, key=None): + def list_objects(self, path: Optional[str] = None, key: Optional[str] = None) -> List['File']: """Return a list of the objects contained in this repository, optionally in the given sub directory. .. deprecated:: 1.4.0 @@ -391,6 +421,8 @@ def list_objects(self, path=None, key=None): :return: a list of `File` named tuples representing the objects present in directory with the given path :raises FileNotFoundError: if the `path` does not exist in the repository of this node """ + assert self._repository is not None, 'repository not initialised' + if key is not None: if path is not None: raise ValueError('cannot specify both `path` and `key`.') @@ -402,7 +434,7 @@ def list_objects(self, path=None, key=None): return self._repository.list_objects(path) - def list_object_names(self, path=None, key=None): + def list_object_names(self, path: Optional[str] = None, key: Optional[str] = None) -> List[str]: """Return a list of the object names contained in this repository, optionally in the given sub directory. .. deprecated:: 1.4.0 @@ -410,8 +442,10 @@ def list_object_names(self, path=None, key=None): :param path: the relative path of the object within the repository. :param key: fully qualified identifier for the object within the repository - :return: a list of `File` named tuples representing the objects present in directory with the given path + """ + assert self._repository is not None, 'repository not initialised' + if key is not None: if path is not None: raise ValueError('cannot specify both `path` and `key`.') @@ -423,7 +457,7 @@ def list_object_names(self, path=None, key=None): return self._repository.list_object_names(path) - def open(self, path=None, mode='r', key=None): + def open(self, path: Optional[str] = None, mode: str = 'r', key: Optional[str] = None) -> WarnWhenNotEntered: """Open a file handle to the object with the given path. .. deprecated:: 1.4.0 @@ -436,6 +470,8 @@ def open(self, path=None, mode='r', key=None): :param key: fully qualified identifier for the object within the repository :param mode: the mode under which to open the handle """ + assert self._repository is not None, 'repository not initialised' + if key is not None: if path is not None: raise ValueError('cannot specify both `path` and `key`.') @@ -453,7 +489,7 @@ def open(self, path=None, mode='r', key=None): return WarnWhenNotEntered(self._repository.open(path, mode), repr(self)) - def get_object(self, path=None, key=None): + def get_object(self, path: Optional[str] = None, key: Optional[str] = None) -> 'File': """Return the object with the given path. .. deprecated:: 1.4.0 @@ -463,6 +499,8 @@ def get_object(self, path=None, key=None): :param key: fully qualified identifier for the object within the repository :return: a `File` named tuple """ + assert self._repository is not None, 'repository not initialised' + if key is not None: if path is not None: raise ValueError('cannot specify both `path` and `key`.') @@ -477,7 +515,10 @@ def get_object(self, path=None, key=None): return self._repository.get_object(path) - def get_object_content(self, path=None, mode='r', key=None): + def get_object_content(self, + path: Optional[str] = None, + mode: str = 'r', + key: Optional[str] = None) -> Union[str, bytes]: """Return the content of a object with the given path. .. deprecated:: 1.4.0 @@ -486,6 +527,8 @@ def get_object_content(self, path=None, mode='r', key=None): :param path: the relative path of the object within the repository. :param key: fully qualified identifier for the object within the repository """ + assert self._repository is not None, 'repository not initialised' + if key is not None: if path is not None: raise ValueError('cannot specify both `path` and `key`.') @@ -503,7 +546,14 @@ def get_object_content(self, path=None, mode='r', key=None): return self._repository.get_object_content(path, mode) - def put_object_from_tree(self, filepath, path=None, contents_only=True, force=False, key=None): + def put_object_from_tree( + self, + filepath: str, + path: Optional[str] = None, + contents_only: bool = True, + force: bool = False, + key: Optional[str] = None + ) -> None: """Store a new object under `path` with the contents of the directory located at `filepath` on this file system. .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. @@ -528,6 +578,8 @@ def put_object_from_tree(self, filepath, path=None, contents_only=True, force=Fa :param force: boolean, if True, will skip the mutability check :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` """ + assert self._repository is not None, 'repository not initialised' + if force: warnings.warn('the `force` keyword is deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member @@ -547,7 +599,15 @@ def put_object_from_tree(self, filepath, path=None, contents_only=True, force=Fa self._repository.put_object_from_tree(filepath, path, contents_only, force) - def put_object_from_file(self, filepath, path=None, mode=None, encoding=None, force=False, key=None): + def put_object_from_file( + self, + filepath: str, + path: Optional[str] = None, + mode: Optional[str] = None, + encoding: Optional[str] = None, + force: bool = False, + key: Optional[str] = None + ) -> None: """Store a new object under `path` with contents of the file located at `filepath` on this file system. .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. @@ -572,6 +632,8 @@ def put_object_from_file(self, filepath, path=None, mode=None, encoding=None, fo :param force: boolean, if True, will skip the mutability check :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` """ + assert self._repository is not None, 'repository not initialised' + # Note that the defaults of `mode` and `encoding` had to be change to `None` from `w` and `utf-8` resptively, in # order to detect when they were being passed such that the deprecation warning can be emitted. The defaults did # not make sense and so ignoring them is justified, since the side-effect of this function, a file being copied, @@ -601,7 +663,15 @@ def put_object_from_file(self, filepath, path=None, mode=None, encoding=None, fo self._repository.put_object_from_file(filepath, path, mode, encoding, force) - def put_object_from_filelike(self, handle, path=None, mode='w', encoding='utf8', force=False, key=None): + def put_object_from_filelike( + self, + handle: IO[Any], + path: Optional[str] = None, + mode: str = 'w', + encoding: str = 'utf8', + force: bool = False, + key: Optional[str] = None + ) -> None: """Store a new object under `path` with contents of filelike object `handle`. .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. @@ -621,6 +691,8 @@ def put_object_from_filelike(self, handle, path=None, mode='w', encoding='utf8', :param force: boolean, if True, will skip the mutability check :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` """ + assert self._repository is not None, 'repository not initialised' + if force: warnings.warn('the `force` keyword is deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member @@ -638,7 +710,7 @@ def put_object_from_filelike(self, handle, path=None, mode='w', encoding='utf8', self._repository.put_object_from_filelike(handle, path, mode, encoding, force) - def delete_object(self, path=None, force=False, key=None): + def delete_object(self, path: Optional[str] = None, force: bool = False, key: Optional[str] = None) -> None: """Delete the object from the repository. .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. @@ -654,6 +726,8 @@ def delete_object(self, path=None, force=False, key=None): :param force: boolean, if True, will skip the mutability check :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` """ + assert self._repository is not None, 'repository not initialised' + if force: warnings.warn('the `force` keyword is deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member @@ -671,7 +745,7 @@ def delete_object(self, path=None, force=False, key=None): self._repository.delete_object(path, force) - def add_comment(self, content, user=None): + def add_comment(self, content: str, user: Optional[User] = None) -> Comment: """Add a new comment. :param content: string with comment @@ -681,7 +755,7 @@ def add_comment(self, content, user=None): user = user or User.objects.get_default() return Comment(node=self, user=user, content=content).store() - def get_comment(self, identifier): + def get_comment(self, identifier: int) -> Comment: """Return a comment corresponding to the given identifier. :param identifier: the comment pk @@ -689,16 +763,16 @@ def get_comment(self, identifier): :raise aiida.common.MultipleObjectsError: if the id cannot be uniquely resolved to a comment :return: the comment """ - return Comment.objects.get(dbnode_id=self.pk, pk=identifier) + return Comment.objects.get(dbnode_id=self.pk, id=identifier) - def get_comments(self): + def get_comments(self) -> List[Comment]: """Return a sorted list of comments for this node. :return: the list of comments, sorted by pk """ return Comment.objects.find(filters={'dbnode_id': self.pk}, order_by=[{'id': 'asc'}]) - def update_comment(self, identifier, content): + def update_comment(self, identifier: int, content: str) -> None: """Update the content of an existing comment. :param identifier: the comment pk @@ -706,17 +780,17 @@ def update_comment(self, identifier, content): :raise aiida.common.NotExistent: if the comment with the given id does not exist :raise aiida.common.MultipleObjectsError: if the id cannot be uniquely resolved to a comment """ - comment = Comment.objects.get(dbnode_id=self.pk, pk=identifier) + comment = Comment.objects.get(dbnode_id=self.pk, id=identifier) comment.set_content(content) - def remove_comment(self, identifier): + def remove_comment(self, identifier: int) -> None: # pylint: disable=no-self-use """Delete an existing comment. :param identifier: the comment pk """ - Comment.objects.delete(dbnode_id=self.pk, comment=identifier) + Comment.objects.delete(identifier) - def add_incoming(self, source, link_type, link_label): + def add_incoming(self, source: 'Node', link_type: LinkType, link_label: str) -> None: """Add a link of the given type from a given node to ourself. :param source: the node from which the link is coming @@ -733,7 +807,7 @@ def add_incoming(self, source, link_type, link_label): else: self._add_incoming_cache(source, link_type, link_label) - def validate_incoming(self, source, link_type, link_label): + def validate_incoming(self, source: 'Node', link_type: LinkType, link_label: str) -> None: """Validate adding a link of the given type from a given node to ourself. This function will first validate the types of the inputs, followed by the node and link types and validate @@ -760,7 +834,7 @@ def validate_incoming(self, source, link_type, link_label): if builder.count() > 0: raise ValueError('the link you are attempting to create would generate a cycle in the graph') - def validate_outgoing(self, target, link_type, link_label): # pylint: disable=unused-argument,no-self-use + def validate_outgoing(self, target: 'Node', link_type: LinkType, link_label: str) -> None: # pylint: disable=unused-argument,no-self-use """Validate adding a link of the given type from ourself to a given node. The validity of the triple (source, link, target) should be validated in the `validate_incoming` call. @@ -776,7 +850,7 @@ def validate_outgoing(self, target, link_type, link_label): # pylint: disable=u type_check(link_type, LinkType, f'link_type should be a LinkType enum but got: {type(link_type)}') type_check(target, Node, f'target should be a `Node` instance but got: {type(target)}') - def _add_incoming_cache(self, source, link_type, link_label): + def _add_incoming_cache(self, source: 'Node', link_type: LinkType, link_label: str) -> None: """Add an incoming link to the cache. .. note: the proposed link is not validated in this function, so this should not be called directly @@ -787,6 +861,8 @@ def _add_incoming_cache(self, source, link_type, link_label): :param link_label: the link label :raise aiida.common.UniquenessError: if the given link triple already exists in the cache """ + assert self._incoming_cache is not None, 'incoming_cache not initialised' + link_triple = LinkTriple(source, link_type, link_label) if link_triple in self._incoming_cache: @@ -795,8 +871,13 @@ def _add_incoming_cache(self, source, link_type, link_label): self._incoming_cache.append(link_triple) def get_stored_link_triples( - self, node_class=None, link_type=(), link_label_filter=None, link_direction='incoming', only_uuid=False - ): + self, + node_class: Type['Node'] = None, + link_type: Union[LinkType, Sequence[LinkType]] = (), + link_label_filter: Optional[str] = None, + link_direction: str = 'incoming', + only_uuid: bool = False + ) -> List[LinkTriple]: """Return the list of stored link triples directly incoming to or outgoing of this node. Note this will only return link triples that are stored in the database. Anything in the cache is ignored. @@ -804,7 +885,7 @@ def get_stored_link_triples( :param node_class: If specified, should be a class, and it filters only elements of that (subclass of) type :param link_type: Only get inputs of this link type, if empty tuple then returns all inputs of all link types. :param link_label_filter: filters the incoming nodes by its link label. This should be a regex statement as - one would pass directly to a QuerBuilder filter statement with the 'like' operation. + one would pass directly to a QueryBuilder filter statement with the 'like' operation. :param link_direction: `incoming` or `outgoing` to get the incoming or outgoing links, respectively. :param only_uuid: project only the node UUID instead of the instance onto the `NodeTriple.node` entries """ @@ -815,8 +896,8 @@ def get_stored_link_triples( raise TypeError(f'link_type should be a LinkType or tuple of LinkType: got {link_type}') node_class = node_class or Node - node_filters = {'id': {'==': self.id}} - edge_filters = {} + node_filters: Dict[str, Any] = {'id': {'==': self.id}} + edge_filters: Dict[str, Any] = {} if link_type: edge_filters['type'] = {'in': [t.value for t in link_type]} @@ -847,7 +928,13 @@ def get_stored_link_triples( return [LinkTriple(entry[0], LinkType(entry[1]), entry[2]) for entry in builder.all()] - def get_incoming(self, node_class=None, link_type=(), link_label_filter=None, only_uuid=False): + def get_incoming( + self, + node_class: Type['Node'] = None, + link_type: Union[LinkType, Sequence[LinkType]] = (), + link_label_filter: Optional[str] = None, + only_uuid: bool = False + ) -> LinkManager: """Return a list of link triples that are (directly) incoming into this node. :param node_class: If specified, should be a class or tuple of classes, and it filters only @@ -858,6 +945,8 @@ def get_incoming(self, node_class=None, link_type=(), link_label_filter=None, on Here wildcards (% and _) can be passed in link label filter as we are using "like" in QB. :param only_uuid: project only the node UUID instead of the instance onto the `NodeTriple.node` entries """ + assert self._incoming_cache is not None, 'incoming_cache not initialised' + if not isinstance(link_type, tuple): link_type = (link_type,) @@ -888,7 +977,13 @@ def get_incoming(self, node_class=None, link_type=(), link_label_filter=None, on return LinkManager(link_triples) - def get_outgoing(self, node_class=None, link_type=(), link_label_filter=None, only_uuid=False): + def get_outgoing( + self, + node_class: Type['Node'] = None, + link_type: Union[LinkType, Sequence[LinkType]] = (), + link_label_filter: Optional[str] = None, + only_uuid: bool = False + ) -> LinkManager: """Return a list of link triples that are (directly) outgoing of this node. :param node_class: If specified, should be a class or tuple of classes, and it filters only @@ -902,20 +997,23 @@ def get_outgoing(self, node_class=None, link_type=(), link_label_filter=None, on link_triples = self.get_stored_link_triples(node_class, link_type, link_label_filter, 'outgoing', only_uuid) return LinkManager(link_triples) - def has_cached_links(self): + def has_cached_links(self) -> bool: """Feturn whether there are unstored incoming links in the cache. :return: boolean, True when there are links in the incoming cache, False otherwise """ + assert self._incoming_cache is not None, 'incoming_cache not initialised' return bool(self._incoming_cache) - def store_all(self, with_transaction=True, use_cache=None): + def store_all(self, with_transaction: bool = True, use_cache=None) -> 'Node': """Store the node, together with all input links. Unstored nodes from cached incoming linkswill also be stored. :parameter with_transaction: if False, do not use a transaction because the caller will already have opened one. """ + assert self._incoming_cache is not None, 'incoming_cache not initialised' + if use_cache is not None: warnings.warn( # pylint: disable=no-member 'the `use_cache` argument is deprecated and will be removed in `v2.0.0`', AiidaDeprecationWarning @@ -934,7 +1032,7 @@ def store_all(self, with_transaction=True, use_cache=None): return self.store(with_transaction) - def store(self, with_transaction=True, use_cache=None): # pylint: disable=arguments-differ + def store(self, with_transaction: bool = True, use_cache=None) -> 'Node': # pylint: disable=arguments-differ """Store the node in the database while saving its attributes and repository directory. After being called attributes cannot be changed anymore! Instead, extras can be changed only AFTER calling @@ -983,12 +1081,14 @@ def store(self, with_transaction=True, use_cache=None): # pylint: disable=argum return self - def _store(self, with_transaction=True, clean=True): + def _store(self, with_transaction: bool = True, clean: bool = True) -> 'Node': """Store the node in the database while saving its attributes and repository directory. :param with_transaction: if False, do not use a transaction because the caller will already have opened one. :param clean: boolean, if True, will clean the attributes and extras before attempting to store """ + assert self._repository is not None, 'repository not initialised' + # First store the repository folder such that if this fails, there won't be an incomplete node in the database. # On the flipside, in the case that storing the node does fail, the repository will now have an orphaned node # directory which will have to be cleaned manually sometime. @@ -1007,19 +1107,24 @@ def _store(self, with_transaction=True, clean=True): return self - def verify_are_parents_stored(self): + def verify_are_parents_stored(self) -> None: """Verify that all `parent` nodes are already stored. :raise aiida.common.ModificationNotAllowed: if one of the source nodes of incoming links is not stored. """ + assert self._incoming_cache is not None, 'incoming_cache not initialised' + for link_triple in self._incoming_cache: if not link_triple.node.is_stored: raise exceptions.ModificationNotAllowed( f'Cannot store because source node of link triple {link_triple} is not stored' ) - def _store_from_cache(self, cache_node, with_transaction): + def _store_from_cache(self, cache_node: 'Node', with_transaction: bool) -> None: """Store this node from an existing cache node.""" + assert self._repository is not None, 'repository not initialised' + assert cache_node._repository is not None, 'cache repository not initialised' # pylint: disable=protected-access + from aiida.orm.utils.mixins import Sealable assert self.node_type == cache_node.node_type @@ -1045,36 +1150,50 @@ def _store_from_cache(self, cache_node, with_transaction): self._add_outputs_from_cache(cache_node) self.set_extra('_aiida_cached_from', cache_node.uuid) - def _add_outputs_from_cache(self, cache_node): + def _add_outputs_from_cache(self, cache_node: 'Node') -> None: """Replicate the output links and nodes from the cached node onto this node.""" for entry in cache_node.get_outgoing(link_type=LinkType.CREATE): new_node = entry.node.clone() new_node.add_incoming(self, link_type=LinkType.CREATE, link_label=entry.link_label) new_node.store() - def get_hash(self, ignore_errors=True, **kwargs): - """Return the hash for this node based on its attributes.""" + def get_hash(self, ignore_errors: bool = True, **kwargs: Any) -> Optional[str]: + """Return the hash for this node based on its attributes. + + :param ignore_errors: return ``None`` on ``aiida.common.exceptions.HashingError`` (logging the exception) + """ if not self.is_stored: raise exceptions.InvalidOperation('You can get the hash only after having stored the node') return self._get_hash(ignore_errors=ignore_errors, **kwargs) - def _get_hash(self, ignore_errors=True, **kwargs): + def _get_hash(self, ignore_errors: bool = True, **kwargs: Any) -> Optional[str]: """ Return the hash for this node based on its attributes. This will always work, even before storing. + + :param ignore_errors: return ``None`` on ``aiida.common.exceptions.HashingError`` (logging the exception) """ try: return make_hash(self._get_objects_to_hash(), **kwargs) - except Exception: # pylint: disable=broad-except + except exceptions.HashingError: if not ignore_errors: raise + if self.logger: + self.logger.exception('Node hashing failed') + return None - def _get_objects_to_hash(self): + def _get_objects_to_hash(self) -> List[Any]: """Return a list of objects which should be included in the hash.""" + assert self._repository is not None, 'repository not initialised' + top_level_module = self.__module__.split('.', 1)[0] + try: + version = importlib.import_module(top_level_module).__version__ # type: ignore[attr-defined] + except (ImportError, AttributeError) as exc: + raise exceptions.HashingError("The node's package version could not be determined") from exc objects = [ - importlib.import_module(self.__module__.split('.', 1)[0]).__version__, + version, { key: val for key, val in self.attributes_items() @@ -1085,15 +1204,15 @@ def _get_objects_to_hash(self): ] return objects - def rehash(self): + def rehash(self) -> None: """Regenerate the stored hash of the Node.""" self.set_extra(_HASH_EXTRA_KEY, self.get_hash()) - def clear_hash(self): + def clear_hash(self) -> None: """Sets the stored hash of the Node to None.""" self.set_extra(_HASH_EXTRA_KEY, None) - def get_cache_source(self): + def get_cache_source(self) -> Optional[str]: """Return the UUID of the node that was used in creating this node from the cache, or None if it was not cached. :return: source node UUID or None @@ -1101,14 +1220,14 @@ def get_cache_source(self): return self.get_extra('_aiida_cached_from', None) @property - def is_created_from_cache(self): + def is_created_from_cache(self) -> bool: """Return whether this node was created from a cached node. :return: boolean, True if the node was created by cloning a cached node, False otherwise """ return self.get_cache_source() is not None - def _get_same_node(self): + def _get_same_node(self) -> Optional['Node']: """Returns a stored node from which the current Node can be cached or None if it does not exist If a node is returned it is a valid cache, meaning its `_aiida_hash` extra matches `self.get_hash()`. @@ -1125,7 +1244,7 @@ def _get_same_node(self): except StopIteration: return None - def get_all_same_nodes(self): + def get_all_same_nodes(self) -> List['Node']: """Return a list of stored nodes which match the type and hash of the current node. All returned nodes are valid caches, meaning their `_aiida_hash` extra matches `self.get_hash()`. @@ -1135,7 +1254,7 @@ def get_all_same_nodes(self): """ return list(self._iter_all_same_nodes()) - def _iter_all_same_nodes(self, allow_before_store=False): + def _iter_all_same_nodes(self, allow_before_store=False) -> Iterator['Node']: """ Returns an iterator of all same nodes. @@ -1156,22 +1275,21 @@ def _iter_all_same_nodes(self, allow_before_store=False): return (node for node in nodes_identical if node.is_valid_cache) @property - def is_valid_cache(self): + def is_valid_cache(self) -> bool: """Hook to exclude certain `Node` instances from being considered a valid cache.""" # pylint: disable=no-self-use return True - def get_description(self): + def get_description(self) -> str: """Return a string with a description of the node. :return: a description string - :rtype: str """ # pylint: disable=no-self-use return '' @staticmethod - def get_schema(): + def get_schema() -> Dict[str, Any]: """ Every node property contains: - display_name: display name of the property diff --git a/aiida/orm/nodes/process/__init__.py b/aiida/orm/nodes/process/__init__.py index 15e3dd6f03..4a84f892b0 100644 --- a/aiida/orm/nodes/process/__init__.py +++ b/aiida/orm/nodes/process/__init__.py @@ -14,4 +14,4 @@ from .process import * from .workflow import * -__all__ = (calculation.__all__ + process.__all__ + workflow.__all__) +__all__ = (calculation.__all__ + process.__all__ + workflow.__all__) # type: ignore[name-defined] diff --git a/aiida/orm/nodes/process/calculation/calcfunction.py b/aiida/orm/nodes/process/calculation/calcfunction.py index 9749283d5a..bc3bbf64c2 100644 --- a/aiida/orm/nodes/process/calculation/calcfunction.py +++ b/aiida/orm/nodes/process/calculation/calcfunction.py @@ -8,19 +8,23 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with `Node` sub class for calculation function processes.""" +from typing import TYPE_CHECKING from aiida.common.links import LinkType from aiida.orm.utils.mixins import FunctionCalculationMixin from .calculation import CalculationNode +if TYPE_CHECKING: + from aiida.orm import Node + __all__ = ('CalcFunctionNode',) class CalcFunctionNode(FunctionCalculationMixin, CalculationNode): """ORM class for all nodes representing the execution of a calcfunction.""" - def validate_outgoing(self, target, link_type, link_label): + def validate_outgoing(self, target: 'Node', link_type: LinkType, link_label: str) -> None: """ Validate adding a link of the given type from ourself to a given node. diff --git a/aiida/orm/nodes/process/calculation/calcjob.py b/aiida/orm/nodes/process/calculation/calcjob.py index 311ccf8a6b..ccfa5d921a 100644 --- a/aiida/orm/nodes/process/calculation/calcjob.py +++ b/aiida/orm/nodes/process/calculation/calcjob.py @@ -8,17 +8,30 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with `Node` sub class for calculation job processes.""" - +import datetime +from typing import Any, AnyStr, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import TYPE_CHECKING import warnings from aiida.common import exceptions from aiida.common.datastructures import CalcJobState from aiida.common.lang import classproperty from aiida.common.links import LinkType +from aiida.common.folders import Folder from aiida.common.warnings import AiidaDeprecationWarning from .calculation import CalculationNode +if TYPE_CHECKING: + from aiida.engine.processes.builder import ProcessBuilder + from aiida.orm import FolderData + from aiida.orm.authinfos import AuthInfo + from aiida.orm.utils.calcjob import CalcJobResultManager + from aiida.parsers import Parser + from aiida.schedulers.datastructures import JobInfo, JobState + from aiida.tools.calculations import CalculationTools + from aiida.transports import Transport + __all__ = ('CalcJobNode',) @@ -45,7 +58,7 @@ class CalcJobNode(CalculationNode): _tools = None @property - def tools(self): + def tools(self) -> 'CalculationTools': """Return the calculation tools that are registered for the process type associated with this calculation. If the entry point name stored in the `process_type` of the CalcJobNode has an accompanying entry point in the @@ -76,7 +89,7 @@ def tools(self): return self._tools @classproperty - def _updatable_attributes(cls): # pylint: disable=no-self-argument + def _updatable_attributes(cls) -> Tuple[str, ...]: # pylint: disable=no-self-argument return super()._updatable_attributes + ( cls.CALC_JOB_STATE_KEY, cls.REMOTE_WORKDIR_KEY, @@ -91,7 +104,7 @@ def _updatable_attributes(cls): # pylint: disable=no-self-argument ) @classproperty - def _hash_ignored_attributes(cls): # pylint: disable=no-self-argument + def _hash_ignored_attributes(cls) -> Tuple[str, ...]: # pylint: disable=no-self-argument return super()._hash_ignored_attributes + ( 'queue_name', 'account', @@ -101,7 +114,7 @@ def _hash_ignored_attributes(cls): # pylint: disable=no-self-argument 'max_memory_kb', ) - def _get_objects_to_hash(self): + def _get_objects_to_hash(self) -> List[Any]: """Return a list of objects which should be included in the hash. This method is purposefully overridden from the base `Node` class, because we do not want to include the @@ -112,7 +125,7 @@ def _get_objects_to_hash(self): """ from importlib import import_module objects = [ - import_module(self.__module__.split('.', 1)[0]).__version__, + import_module(self.__module__.split('.', 1)[0]).__version__, # type: ignore[attr-defined] { key: val for key, val in self.attributes_items() @@ -127,7 +140,7 @@ def _get_objects_to_hash(self): ] return objects - def get_builder_restart(self): + def get_builder_restart(self) -> 'ProcessBuilder': """Return a `ProcessBuilder` that is ready to relaunch the same `CalcJob` that created this node. The process class will be set based on the `process_type` of this node and the inputs of the builder will be @@ -137,15 +150,14 @@ def get_builder_restart(self): In addition to prepopulating the input nodes, which is implemented by the base `ProcessNode` class, here we also add the `options` that were passed in the `metadata` input of the `CalcJob` process. - :return: `~aiida.engine.processes.builder.ProcessBuilder` instance """ builder = super().get_builder_restart() - builder.metadata.options = self.get_options() + builder.metadata.options = self.get_options() # type: ignore[attr-defined] return builder @property - def _raw_input_folder(self): + def _raw_input_folder(self) -> Folder: """ Get the input folder object. @@ -154,13 +166,15 @@ def _raw_input_folder(self): """ from aiida.common.exceptions import NotExistent + assert self._repository is not None, 'repository not initialised' + return_folder = self._repository._get_base_folder() # pylint: disable=protected-access if return_folder.exists(): return return_folder raise NotExistent('the `_raw_input_folder` has not yet been created') - def get_option(self, name): + def get_option(self, name: str) -> Optional[Any]: """ Retun the value of an option that was set for this CalcJobNode @@ -170,7 +184,7 @@ def get_option(self, name): """ return self.get_attribute(name, None) - def set_option(self, name, value): + def set_option(self, name: str, value: Any) -> None: """ Set an option to the given value @@ -181,21 +195,21 @@ def set_option(self, name, value): """ self.set_attribute(name, value) - def get_options(self): + def get_options(self) -> Dict[str, Any]: """ Return the dictionary of options set for this CalcJobNode :return: dictionary of the options and their values """ options = {} - for name in self.process_class.spec_options.keys(): + for name in self.process_class.spec_options.keys(): # type: ignore[attr-defined] value = self.get_option(name) if value is not None: options[name] = value return options - def set_options(self, options): + def set_options(self, options: Dict[str, Any]) -> None: """ Set the options for this CalcJobNode @@ -204,7 +218,7 @@ def set_options(self, options): for name, value in options.items(): self.set_option(name, value) - def get_state(self): + def get_state(self) -> Optional[CalcJobState]: """Return the calculation job active sub state. The calculation job state serves to give more granular state information to `CalcJobs`, in addition to the @@ -223,10 +237,9 @@ def get_state(self): return state - def set_state(self, state): + def set_state(self, state: CalcJobState) -> None: """Set the calculation active job state. - :param state: a string with the state from ``aiida.common.datastructures.CalcJobState``. :raise: ValueError if state is invalid """ if not isinstance(state, CalcJobState): @@ -234,21 +247,21 @@ def set_state(self, state): self.set_attribute(self.CALC_JOB_STATE_KEY, state.value) - def delete_state(self): + def delete_state(self) -> None: """Delete the calculation job state attribute if it exists.""" try: self.delete_attribute(self.CALC_JOB_STATE_KEY) except AttributeError: pass - def set_remote_workdir(self, remote_workdir): + def set_remote_workdir(self, remote_workdir: str) -> None: """Set the absolute path to the working directory on the remote computer where the calculation is run. :param remote_workdir: absolute filepath to the remote working directory """ self.set_attribute(self.REMOTE_WORKDIR_KEY, remote_workdir) - def get_remote_workdir(self): + def get_remote_workdir(self) -> Optional[str]: """Return the path to the remote (on cluster) scratch folder of the calculation. :return: a string with the remote path @@ -256,10 +269,10 @@ def get_remote_workdir(self): return self.get_attribute(self.REMOTE_WORKDIR_KEY, None) @staticmethod - def _validate_retrieval_directive(directives): + def _validate_retrieval_directive(directives: Sequence[Union[str, Tuple[str, str, str]]]) -> None: """Validate a list or tuple of file retrieval directives. - :param directives: a list or tuple of file retrieveal directives + :param directives: a list or tuple of file retrieval directives :raise ValueError: if the format of the directives is invalid """ if not isinstance(directives, (tuple, list)): @@ -284,7 +297,7 @@ def _validate_retrieval_directive(directives): if not isinstance(directive[2], int): raise ValueError('invalid directive, three element has to be an integer representing the depth') - def set_retrieve_list(self, retrieve_list): + def set_retrieve_list(self, retrieve_list: Sequence[Union[str, Tuple[str, str, str]]]) -> None: """Set the retrieve list. This list of directives will instruct the daemon what files to retrieve after the calculation has completed. @@ -295,14 +308,14 @@ def set_retrieve_list(self, retrieve_list): self._validate_retrieval_directive(retrieve_list) self.set_attribute(self.RETRIEVE_LIST_KEY, retrieve_list) - def get_retrieve_list(self): + def get_retrieve_list(self) -> Optional[Sequence[Union[str, Tuple[str, str, str]]]]: """Return the list of files/directories to be retrieved on the cluster after the calculation has completed. :return: a list of file directives """ return self.get_attribute(self.RETRIEVE_LIST_KEY, None) - def set_retrieve_temporary_list(self, retrieve_temporary_list): + def set_retrieve_temporary_list(self, retrieve_temporary_list: Sequence[Union[str, Tuple[str, str, str]]]) -> None: """Set the retrieve temporary list. The retrieve temporary list stores files that are retrieved after completion and made available during parsing @@ -313,7 +326,7 @@ def set_retrieve_temporary_list(self, retrieve_temporary_list): self._validate_retrieval_directive(retrieve_temporary_list) self.set_attribute(self.RETRIEVE_TEMPORARY_LIST_KEY, retrieve_temporary_list) - def get_retrieve_temporary_list(self): + def get_retrieve_temporary_list(self) -> Optional[Sequence[Union[str, Tuple[str, str, str]]]]: """Return list of files to be retrieved from the cluster which will be available during parsing. :return: a list of file directives @@ -360,7 +373,7 @@ def get_retrieve_singlefile_list(self): """ return self.get_attribute(self.RETRIEVE_SINGLE_FILE_LIST_KEY, None) - def set_job_id(self, job_id): + def set_job_id(self, job_id: Union[int, str]) -> None: """Set the job id that was assigned to the calculation by the scheduler. .. note:: the id will always be stored as a string @@ -369,14 +382,14 @@ def set_job_id(self, job_id): """ return self.set_attribute(self.SCHEDULER_JOB_ID_KEY, str(job_id)) - def get_job_id(self): + def get_job_id(self) -> Optional[str]: """Return job id that was assigned to the calculation by the scheduler. :return: the string representation of the scheduler job id """ return self.get_attribute(self.SCHEDULER_JOB_ID_KEY, None) - def set_scheduler_state(self, state): + def set_scheduler_state(self, state: 'JobState') -> None: """Set the scheduler state. :param state: an instance of `JobState` @@ -390,7 +403,7 @@ def set_scheduler_state(self, state): self.set_attribute(self.SCHEDULER_STATE_KEY, state.value) self.set_attribute(self.SCHEDULER_LAST_CHECK_TIME_KEY, timezone.datetime_to_isoformat(timezone.now())) - def get_scheduler_state(self): + def get_scheduler_state(self) -> Optional['JobState']: """Return the status of the calculation according to the cluster scheduler. :return: a JobState enum instance. @@ -404,7 +417,7 @@ def get_scheduler_state(self): return JobState(state) - def get_scheduler_lastchecktime(self): + def get_scheduler_lastchecktime(self) -> Optional[datetime.datetime]: """Return the time of the last update of the scheduler state by the daemon or None if it was never set. :return: a datetime object or None @@ -417,30 +430,38 @@ def get_scheduler_lastchecktime(self): return value - def set_detailed_job_info(self, detailed_job_info): + def set_detailed_job_info(self, detailed_job_info: Optional[dict]) -> None: """Set the detailed job info dictionary. :param detailed_job_info: a dictionary with metadata with the accounting of a completed job """ self.set_attribute(self.SCHEDULER_DETAILED_JOB_INFO_KEY, detailed_job_info) - def get_detailed_job_info(self): + def get_detailed_job_info(self) -> Optional[dict]: """Return the detailed job info dictionary. + The scheduler is polled for the detailed job info after the job is completed and ready to be retrieved. + :return: the dictionary with detailed job info if defined or None """ return self.get_attribute(self.SCHEDULER_DETAILED_JOB_INFO_KEY, None) - def set_last_job_info(self, last_job_info): + def set_last_job_info(self, last_job_info: 'JobInfo') -> None: """Set the last job info. :param last_job_info: a `JobInfo` object """ self.set_attribute(self.SCHEDULER_LAST_JOB_INFO_KEY, last_job_info.get_dict()) - def get_last_job_info(self): + def get_last_job_info(self) -> Optional['JobInfo']: """Return the last information asked to the scheduler about the status of the job. + The last job info is updated on every poll of the scheduler, except for the final poll when the job drops from + the scheduler's job queue. + For completed jobs, the last job info therefore contains the "second-to-last" job info that still shows the job + as running. Please use :meth:`~aiida.orm.nodes.process.calculation.calcjob.CalcJobNode.get_detailed_job_info` + instead. + :return: a `JobInfo` object (that closely resembles a dictionary) or None. """ from aiida.schedulers.datastructures import JobInfo @@ -454,28 +475,26 @@ def get_last_job_info(self): return job_info - def get_authinfo(self): + def get_authinfo(self) -> 'AuthInfo': """Return the `AuthInfo` that is configured for the `Computer` set for this node. :return: `AuthInfo` """ - from aiida.orm.authinfos import AuthInfo - computer = self.computer if computer is None: raise exceptions.NotExistent('No computer has been set for this calculation') - return AuthInfo.from_backend_entity(self.backend.authinfos.get(computer=computer, user=self.user)) + return computer.get_authinfo(self.user) - def get_transport(self): + def get_transport(self) -> 'Transport': """Return the transport for this calculation. :return: `Transport` configured with the `AuthInfo` associated to the computer of this node """ return self.get_authinfo().get_transport() - def get_parser_class(self): + def get_parser_class(self) -> Optional[Type['Parser']]: """Return the output parser object for this calculation or None if no parser is set. :return: a `Parser` class. @@ -491,11 +510,11 @@ def get_parser_class(self): return None @property - def link_label_retrieved(self): + def link_label_retrieved(self) -> str: """Return the link label used for the retrieved FolderData node.""" return 'retrieved' - def get_retrieved_node(self): + def get_retrieved_node(self) -> Optional['FolderData']: """Return the retrieved data folder. :return: the retrieved FolderData node or None if not found @@ -507,7 +526,7 @@ def get_retrieved_node(self): return None @property - def res(self): + def res(self) -> 'CalcJobResultManager': """ To be used to get direct access to the parsed parameters. @@ -520,7 +539,7 @@ def res(self): from aiida.orm.utils.calcjob import CalcJobResultManager return CalcJobResultManager(self) - def get_scheduler_stdout(self): + def get_scheduler_stdout(self) -> Optional[AnyStr]: """Return the scheduler stderr output if the calculation has finished and been retrieved, None otherwise. :return: scheduler stderr output or None @@ -538,7 +557,7 @@ def get_scheduler_stdout(self): return stdout - def get_scheduler_stderr(self): + def get_scheduler_stderr(self) -> Optional[AnyStr]: """Return the scheduler stdout output if the calculation has finished and been retrieved, None otherwise. :return: scheduler stdout output or None @@ -556,6 +575,9 @@ def get_scheduler_stderr(self): return stderr - def get_description(self): - """Return a string with a description of the node based on its properties.""" - return self.get_state() + def get_description(self) -> str: + """Return a description of the node based on its properties.""" + state = self.get_state() + if not state: + return '' + return state.value diff --git a/aiida/orm/nodes/process/calculation/calculation.py b/aiida/orm/nodes/process/calculation/calculation.py index 51a1c7d4d9..4dd8b9bf23 100644 --- a/aiida/orm/nodes/process/calculation/calculation.py +++ b/aiida/orm/nodes/process/calculation/calculation.py @@ -25,25 +25,23 @@ class CalculationNode(ProcessNode): _unstorable_message = 'storing for this node has been disabled' @property - def inputs(self): + def inputs(self) -> NodeLinksManager: """Return an instance of `NodeLinksManager` to manage incoming INPUT_CALC links The returned Manager allows you to easily explore the nodes connected to this node via an incoming INPUT_CALC link. The incoming nodes are reachable by their link labels which are attributes of the manager. - :return: `NodeLinksManager` """ return NodeLinksManager(node=self, link_type=LinkType.INPUT_CALC, incoming=True) @property - def outputs(self): + def outputs(self) -> NodeLinksManager: """Return an instance of `NodeLinksManager` to manage outgoing CREATE links The returned Manager allows you to easily explore the nodes connected to this node via an outgoing CREATE link. The outgoing nodes are reachable by their link labels which are attributes of the manager. - :return: `NodeLinksManager` """ return NodeLinksManager(node=self, link_type=LinkType.CREATE, incoming=False) diff --git a/aiida/orm/nodes/process/process.py b/aiida/orm/nodes/process/process.py index e78b6a4b8d..63409b4857 100644 --- a/aiida/orm/nodes/process/process.py +++ b/aiida/orm/nodes/process/process.py @@ -10,8 +10,10 @@ """Module with `Node` sub class for processes.""" import enum +from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING -from plumpy import ProcessState +from plumpy.process_states import ProcessState from aiida.common.links import LinkType from aiida.common.lang import classproperty @@ -19,6 +21,10 @@ from ..node import Node +if TYPE_CHECKING: + from aiida.engine.processes import Process + from aiida.engine.processes.builder import ProcessBuilder + __all__ = ('ProcessNode',) @@ -48,7 +54,7 @@ class ProcessNode(Sealable, Node): _unstorable_message = 'only Data, WorkflowNode, CalculationNode or their subclasses can be stored' - def __str__(self): + def __str__(self) -> str: base = super().__str__() if self.process_type: return f'{base} ({self.process_type})' @@ -56,7 +62,7 @@ def __str__(self): return f'{base}' @classproperty - def _updatable_attributes(cls): + def _updatable_attributes(cls) -> Tuple[str, ...]: # pylint: disable=no-self-argument return super()._updatable_attributes + ( cls.PROCESS_PAUSED_KEY, @@ -79,7 +85,7 @@ def logger(self): from aiida.orm.utils.log import create_logger_adapter return create_logger_adapter(self._logger, self) - def get_builder_restart(self): + def get_builder_restart(self) -> 'ProcessBuilder': """Return a `ProcessBuilder` that is ready to relaunch the process that created this node. The process class will be set based on the `process_type` of this node and the inputs of the builder will be @@ -94,7 +100,7 @@ def get_builder_restart(self): return builder @property - def process_class(self): + def process_class(self) -> Type['Process']: """Return the process class that was used to create this node. :return: `Process` class @@ -128,7 +134,7 @@ def process_class(self): return process_class - def set_process_type(self, process_type_string): + def set_process_type(self, process_type_string: str) -> None: """ Set the process type string. @@ -137,7 +143,7 @@ def set_process_type(self, process_type_string): self.process_type = process_type_string @property - def process_label(self): + def process_label(self) -> Optional[str]: """ Return the process label @@ -145,7 +151,7 @@ def process_label(self): """ return self.get_attribute(self.PROCESS_LABEL_KEY, None) - def set_process_label(self, label): + def set_process_label(self, label: str) -> None: """ Set the process label @@ -154,7 +160,7 @@ def set_process_label(self, label): self.set_attribute(self.PROCESS_LABEL_KEY, label) @property - def process_state(self): + def process_state(self) -> Optional[ProcessState]: """ Return the process state @@ -167,7 +173,7 @@ def process_state(self): return ProcessState(state) - def set_process_state(self, state): + def set_process_state(self, state: Union[str, ProcessState]): """ Set the process state @@ -178,7 +184,7 @@ def set_process_state(self, state): return self.set_attribute(self.PROCESS_STATE_KEY, state) @property - def process_status(self): + def process_status(self) -> Optional[str]: """ Return the process status @@ -188,7 +194,7 @@ def process_status(self): """ return self.get_attribute(self.PROCESS_STATUS_KEY, None) - def set_process_status(self, status): + def set_process_status(self, status: Optional[str]) -> None: """ Set the process status @@ -210,7 +216,7 @@ def set_process_status(self, status): return self.set_attribute(self.PROCESS_STATUS_KEY, status) @property - def is_terminated(self): + def is_terminated(self) -> bool: """ Return whether the process has terminated @@ -222,7 +228,7 @@ def is_terminated(self): return self.is_excepted or self.is_finished or self.is_killed @property - def is_excepted(self): + def is_excepted(self) -> bool: """ Return whether the process has excepted @@ -234,7 +240,7 @@ def is_excepted(self): return self.process_state == ProcessState.EXCEPTED @property - def is_killed(self): + def is_killed(self) -> bool: """ Return whether the process was killed @@ -246,7 +252,7 @@ def is_killed(self): return self.process_state == ProcessState.KILLED @property - def is_finished(self): + def is_finished(self) -> bool: """ Return whether the process has finished @@ -259,7 +265,7 @@ def is_finished(self): return self.process_state == ProcessState.FINISHED @property - def is_finished_ok(self): + def is_finished_ok(self) -> bool: """ Return whether the process has finished successfully @@ -271,7 +277,7 @@ def is_finished_ok(self): return self.is_finished and self.exit_status == 0 @property - def is_failed(self): + def is_failed(self) -> bool: """ Return whether the process has failed @@ -283,7 +289,7 @@ def is_failed(self): return self.is_finished and self.exit_status != 0 @property - def exit_status(self): + def exit_status(self) -> Optional[int]: """ Return the exit status of the process @@ -291,7 +297,7 @@ def exit_status(self): """ return self.get_attribute(self.EXIT_STATUS_KEY, None) - def set_exit_status(self, status): + def set_exit_status(self, status: Union[None, enum.Enum, int]) -> None: """ Set the exit status of the process @@ -309,7 +315,7 @@ def set_exit_status(self, status): return self.set_attribute(self.EXIT_STATUS_KEY, status) @property - def exit_message(self): + def exit_message(self) -> Optional[str]: """ Return the exit message of the process @@ -317,7 +323,7 @@ def exit_message(self): """ return self.get_attribute(self.EXIT_MESSAGE_KEY, None) - def set_exit_message(self, message): + def set_exit_message(self, message: Optional[str]) -> None: """ Set the exit message of the process, if None nothing will be done @@ -332,7 +338,7 @@ def set_exit_message(self, message): return self.set_attribute(self.EXIT_MESSAGE_KEY, message) @property - def exception(self): + def exception(self) -> Optional[str]: """ Return the exception of the process or None if the process is not excepted. @@ -345,7 +351,7 @@ def exception(self): return None - def set_exception(self, exception): + def set_exception(self, exception: str) -> None: """ Set the exception of the process @@ -357,7 +363,7 @@ def set_exception(self, exception): return self.set_attribute(self.EXCEPTION_KEY, exception) @property - def checkpoint(self): + def checkpoint(self) -> Optional[Dict[str, Any]]: """ Return the checkpoint bundle set for the process @@ -365,7 +371,7 @@ def checkpoint(self): """ return self.get_attribute(self.CHECKPOINT_KEY, None) - def set_checkpoint(self, checkpoint): + def set_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """ Set the checkpoint bundle set for the process @@ -373,7 +379,7 @@ def set_checkpoint(self, checkpoint): """ return self.set_attribute(self.CHECKPOINT_KEY, checkpoint) - def delete_checkpoint(self): + def delete_checkpoint(self) -> None: """ Delete the checkpoint bundle set for the process """ @@ -383,7 +389,7 @@ def delete_checkpoint(self): pass @property - def paused(self): + def paused(self) -> bool: """ Return whether the process is paused @@ -391,7 +397,7 @@ def paused(self): """ return self.get_attribute(self.PROCESS_PAUSED_KEY, False) - def pause(self): + def pause(self) -> None: """ Mark the process as paused by setting the corresponding attribute. @@ -400,7 +406,7 @@ def pause(self): """ return self.set_attribute(self.PROCESS_PAUSED_KEY, True) - def unpause(self): + def unpause(self) -> None: """ Mark the process as unpaused by removing the corresponding attribute. @@ -413,7 +419,7 @@ def unpause(self): pass @property - def called(self): + def called(self) -> List['ProcessNode']: """ Return a list of nodes that the process called @@ -422,7 +428,7 @@ def called(self): return self.get_outgoing(link_type=(LinkType.CALL_CALC, LinkType.CALL_WORK)).all_nodes() @property - def called_descendants(self): + def called_descendants(self) -> List['ProcessNode']: """ Return a list of all nodes that have been called downstream of this process @@ -437,7 +443,7 @@ def called_descendants(self): return descendants @property - def caller(self): + def caller(self) -> Optional['ProcessNode']: """ Return the process node that called this process node, or None if it does not have a caller @@ -450,7 +456,7 @@ def caller(self): else: return caller - def validate_incoming(self, source, link_type, link_label): + def validate_incoming(self, source: Node, link_type: LinkType, link_label: str) -> None: """Validate adding a link of the given type from a given node to ourself. Adding an input link to a `ProcessNode` once it is stored is illegal because this should be taken care of @@ -468,7 +474,7 @@ def validate_incoming(self, source, link_type, link_label): raise ValueError('attempted to add an input link after the process node was already stored.') @property - def is_valid_cache(self): + def is_valid_cache(self) -> bool: """ Return whether the node is valid for caching @@ -490,7 +496,7 @@ def is_valid_cache(self): return is_valid_cache_func(self) - def _get_objects_to_hash(self): + def _get_objects_to_hash(self) -> List[Any]: """ Return a list of objects which should be included in the hash. """ diff --git a/aiida/orm/nodes/process/workflow/workchain.py b/aiida/orm/nodes/process/workflow/workchain.py index 1c4e7a01be..07f0f8a0b3 100644 --- a/aiida/orm/nodes/process/workflow/workchain.py +++ b/aiida/orm/nodes/process/workflow/workchain.py @@ -8,6 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with `Node` sub class for workchain processes.""" +from typing import Optional, Tuple from aiida.common.lang import classproperty @@ -22,12 +23,12 @@ class WorkChainNode(WorkflowNode): STEPPER_STATE_INFO_KEY = 'stepper_state_info' @classproperty - def _updatable_attributes(cls): + def _updatable_attributes(cls) -> Tuple[str, ...]: # pylint: disable=no-self-argument return super()._updatable_attributes + (cls.STEPPER_STATE_INFO_KEY,) @property - def stepper_state_info(self): + def stepper_state_info(self) -> Optional[str]: """ Return the stepper state info @@ -35,7 +36,7 @@ def stepper_state_info(self): """ return self.get_attribute(self.STEPPER_STATE_INFO_KEY, None) - def set_stepper_state_info(self, stepper_state_info): + def set_stepper_state_info(self, stepper_state_info: str) -> None: """ Set the stepper state info diff --git a/aiida/orm/nodes/process/workflow/workflow.py b/aiida/orm/nodes/process/workflow/workflow.py index 1ccd20141a..8a48d1ba06 100644 --- a/aiida/orm/nodes/process/workflow/workflow.py +++ b/aiida/orm/nodes/process/workflow/workflow.py @@ -8,12 +8,16 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with `Node` sub class for workflow processes.""" +from typing import TYPE_CHECKING from aiida.common.links import LinkType from aiida.orm.utils.managers import NodeLinksManager from ..process import ProcessNode +if TYPE_CHECKING: + from aiida.orm import Node + __all__ = ('WorkflowNode',) @@ -25,7 +29,7 @@ class WorkflowNode(ProcessNode): _unstorable_message = 'storing for this node has been disabled' @property - def inputs(self): + def inputs(self) -> NodeLinksManager: """Return an instance of `NodeLinksManager` to manage incoming INPUT_WORK links The returned Manager allows you to easily explore the nodes connected to this node @@ -37,7 +41,7 @@ def inputs(self): return NodeLinksManager(node=self, link_type=LinkType.INPUT_WORK, incoming=True) @property - def outputs(self): + def outputs(self) -> NodeLinksManager: """Return an instance of `NodeLinksManager` to manage outgoing RETURN links The returned Manager allows you to easily explore the nodes connected to this node @@ -48,7 +52,7 @@ def outputs(self): """ return NodeLinksManager(node=self, link_type=LinkType.RETURN, incoming=False) - def validate_outgoing(self, target, link_type, link_label): + def validate_outgoing(self, target: 'Node', link_type: LinkType, link_label: str) -> None: """Validate adding a link of the given type from ourself to a given node. A workflow cannot 'create' Data, so if we receive an outgoing link to an unstored Data node, that means diff --git a/aiida/orm/nodes/process/workflow/workfunction.py b/aiida/orm/nodes/process/workflow/workfunction.py index 11de0f144d..c97dc1095b 100644 --- a/aiida/orm/nodes/process/workflow/workfunction.py +++ b/aiida/orm/nodes/process/workflow/workfunction.py @@ -8,19 +8,23 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with `Node` sub class for workflow function processes.""" +from typing import TYPE_CHECKING from aiida.common.links import LinkType from aiida.orm.utils.mixins import FunctionCalculationMixin from .workflow import WorkflowNode +if TYPE_CHECKING: + from aiida.orm import Node + __all__ = ('WorkFunctionNode',) class WorkFunctionNode(FunctionCalculationMixin, WorkflowNode): """ORM class for all nodes representing the execution of a workfunction.""" - def validate_outgoing(self, target, link_type, link_label): + def validate_outgoing(self, target: 'Node', link_type: LinkType, link_label: str) -> None: """ Validate adding a link of the given type from ourself to a given node. diff --git a/aiida/restapi/api.py b/aiida/restapi/api.py index 796e9a074f..586e84c74c 100644 --- a/aiida/restapi/api.py +++ b/aiida/restapi/api.py @@ -25,13 +25,8 @@ class App(Flask): def __init__(self, *args, **kwargs): - # Decide whether or not to catch the internal server exceptions ( - # default is True) - catch_internal_server = True - try: - catch_internal_server = kwargs.pop('catch_internal_server') - except KeyError: - pass + # Decide whether or not to catch the internal server exceptions (default is True) + catch_internal_server = kwargs.pop('catch_internal_server', True) # Basic initialization super().__init__(*args, **kwargs) @@ -95,12 +90,17 @@ def __init__(self, app=None, **kwargs): configuration and PREFIX """ - from aiida.restapi.resources import ProcessNode, CalcJobNode, Computer, User, Group, Node, ServerInfo + from aiida.restapi.common.config import CLI_DEFAULTS + from aiida.restapi.resources import ( + ProcessNode, CalcJobNode, Computer, User, Group, Node, ServerInfo, QueryBuilder + ) self.app = app super().__init__(app=app, prefix=kwargs['PREFIX'], catch_all_404s=True) + posting = kwargs.pop('posting', CLI_DEFAULTS['POSTING']) + self.add_resource( ServerInfo, '/', @@ -111,6 +111,15 @@ def __init__(self, app=None, **kwargs): resource_class_kwargs=kwargs ) + if posting: + self.add_resource( + QueryBuilder, + '/querybuilder/', + endpoint='querybuilder', + strict_slashes=False, + resource_class_kwargs=kwargs, + ) + ## Add resources and endpoints to the api self.add_resource( Computer, diff --git a/aiida/restapi/common/config.py b/aiida/restapi/common/config.py index 117cc95db4..0569640824 100644 --- a/aiida/restapi/common/config.py +++ b/aiida/restapi/common/config.py @@ -47,4 +47,5 @@ 'WSGI_PROFILE': False, 'HOOKUP_APP': True, 'CATCH_INTERNAL_SERVER': False, + 'POSTING': True, # Include POST endpoints (currently only /querybuilder) } diff --git a/aiida/restapi/common/identifiers.py b/aiida/restapi/common/identifiers.py index eb7ea85207..870a904065 100644 --- a/aiida/restapi/common/identifiers.py +++ b/aiida/restapi/common/identifiers.py @@ -69,7 +69,7 @@ def construct_full_type(node_type, process_type): :return: the full type, which is a unique identifier """ if node_type is None: - process_type = '' + node_type = '' if process_type is None: process_type = '' diff --git a/aiida/restapi/resources.py b/aiida/restapi/resources.py index 18572264e9..b4f9083a57 100644 --- a/aiida/restapi/resources.py +++ b/aiida/restapi/resources.py @@ -207,6 +207,133 @@ def get(self, id=None, page=None): # pylint: disable=redefined-builtin,invalid- return self.utils.build_response(status=200, headers=headers, data=data) +class QueryBuilder(BaseResource): + """ + Representation of a QueryBuilder REST API resource (instantiated with a queryhelp JSON). + + It supports POST requests taking in JSON :py:func:`~aiida.orm.querybuilder.QueryBuilder.queryhelp` + objects and returning the :py:class:`~aiida.orm.querybuilder.QueryBuilder` result accordingly. + """ + from aiida.restapi.translator.nodes.node import NodeTranslator + + _translator_class = NodeTranslator + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # HTTP Request method decorators + if 'get_decorators' in kwargs and isinstance(kwargs['get_decorators'], (tuple, list, set)): + self.method_decorators.update({'post': list(kwargs['get_decorators'])}) + + def get(self): # pylint: disable=arguments-differ + """Static return to state information about this endpoint.""" + data = { + 'message': ( + 'Method Not Allowed. Use HTTP POST requests to use the AiiDA QueryBuilder. ' + 'POST JSON data, which MUST be a valid QueryBuilder.queryhelp dictionary as a JSON object. ' + 'See the documentation at https://aiida.readthedocs.io/projects/aiida-core/en/latest/topics/' + 'database.html?highlight=QueryBuilder#the-queryhelp for more information.' + ), + } + + headers = self.utils.build_headers(url=request.url, total_count=1) + return self.utils.build_response( + status=405, # Method Not Allowed + headers=headers, + data={ + 'method': request.method, + 'url': unquote(request.url), + 'url_root': unquote(request.url_root), + 'path': unquote(request.path), + 'query_string': request.query_string.decode('utf-8'), + 'resource_type': self.__class__.__name__, + 'data': data, + }, + ) + + def post(self): # pylint: disable=too-many-branches + """ + POST method to pass query help JSON. + + If the posted JSON is not a valid QueryBuilder queryhelp, the request will fail with an internal server error. + + This uses the NodeTranslator in order to best return Nodes according to the general AiiDA + REST API data format, while still allowing the return of other AiiDA entities. + + :return: QueryBuilder result of AiiDA entities in "standard" REST API format. + """ + # pylint: disable=protected-access + self.trans._query_help = request.get_json(force=True) + # While the data may be correct JSON, it MUST be a single JSON Object, + # equivalent of a QuieryBuilder.queryhelp dictionary. + assert isinstance(self.trans._query_help, dict), ( + 'POSTed data MUST be a valid QueryBuilder.queryhelp dictionary. ' + f'Got instead (type: {type(self.trans._query_help)}): {self.trans._query_help}' + ) + self.trans.__label__ = self.trans._result_type = self.trans._query_help['path'][-1]['tag'] + + # Handle empty list projections + number_projections = len(self.trans._query_help['project']) + empty_projections_counter = 0 + skip_tags = [] + for tag, projections in tuple(self.trans._query_help['project'].items()): + if projections == [{'*': {}}]: + self.trans._query_help['project'][tag] = self.trans._default + elif not projections: + empty_projections_counter += 1 + skip_tags.append(tag) + else: + # Use projections as given, no need to "correct" them. + pass + + if empty_projections_counter == number_projections: + # No projections have been specified in the queryhelp. + # To be true to the QueryBuilder response, the last entry in path + # is the only entry to be returned, all without edges/links. + self.trans._query_help['project'][self.trans.__label__] = self.trans._default + + self.trans.init_qb() + + data = {} + if self.trans.get_total_count(): + if empty_projections_counter == number_projections: + # "Normal" REST API retrieval can be used. + data = self.trans.get_results() + else: + # Since the "normal" REST API retrieval relies on single-tag retrieval, + # we must instead be more creative with how we retrieve the results here. + # So we opt for a dictionary, with the tags being the keys. + for tag in self.trans._query_help['project']: + if tag in skip_tags: + continue + self.trans.__label__ = tag + data.update(self.trans.get_formatted_result(tag)) + + # Remove 'full_type's when they're `None` + for tag, entities in list(data.items()): + updated_entities = [] + for entity in entities: + if entity.get('full_type') is None: + entity.pop('full_type', None) + updated_entities.append(entity) + data[tag] = updated_entities + + headers = self.utils.build_headers(url=request.url, total_count=self.trans.get_total_count()) + return self.utils.build_response( + status=200, + headers=headers, + data={ + 'method': request.method, + 'url': unquote(request.url), + 'url_root': unquote(request.url_root), + 'path': unquote(request.path), + 'query_string': request.query_string.decode('utf-8'), + 'resource_type': self.__class__.__name__, + 'data': data, + }, + ) + + class Node(BaseResource): """ Differs from BaseResource in trans.set_query() mostly because it takes diff --git a/aiida/restapi/run_api.py b/aiida/restapi/run_api.py index b551d12769..dde845de70 100755 --- a/aiida/restapi/run_api.py +++ b/aiida/restapi/run_api.py @@ -39,6 +39,7 @@ def run_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **kwargs) :param wsgi_profile: use WSGI profiler middleware for finding bottlenecks in web application :param hookup: If true, hook up application to built-in server, else just return it. This parameter is deprecated as of AiiDA 1.2.1. If you don't intend to run the API (hookup=False) use `configure_api` instead. + :param posting: Whether or not to include POST-enabled endpoints (currently only `/querybuilder`). :returns: tuple (app, api) if hookup==False or runs app if hookup==True """ @@ -80,6 +81,7 @@ def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **k :param catch_internal_server: If true, catch and print internal server errors with full python traceback. Useful during app development. :param wsgi_profile: use WSGI profiler middleware for finding bottlenecks in the web application + :param posting: Whether or not to include POST-enabled endpoints (currently only `/querybuilder`). :returns: Flask RESTful API :rtype: :py:class:`flask_restful.Api` @@ -89,6 +91,7 @@ def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **k config = kwargs.pop('config', CLI_DEFAULTS['CONFIG_DIR']) catch_internal_server = kwargs.pop('catch_internal_server', CLI_DEFAULTS['CATCH_INTERNAL_SERVER']) wsgi_profile = kwargs.pop('wsgi_profile', CLI_DEFAULTS['WSGI_PROFILE']) + posting = kwargs.pop('posting', CLI_DEFAULTS['POSTING']) if kwargs: raise ValueError(f'Unknown keyword arguments: {kwargs}') @@ -121,4 +124,4 @@ def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **k app.wsgi_app = ProfilerMiddleware(app.wsgi_app, restrictions=[30]) # Instantiate and return a Flask RESTful API by associating its app - return flask_api(app, **API_CONFIG) + return flask_api(app, posting=posting, **config_module.API_CONFIG) diff --git a/aiida/restapi/translator/nodes/node.py b/aiida/restapi/translator/nodes/node.py index 69eaedb0dd..8b8e4d3d2c 100644 --- a/aiida/restapi/translator/nodes/node.py +++ b/aiida/restapi/translator/nodes/node.py @@ -563,10 +563,10 @@ def get_formatted_result(self, label): for node_entry in results[result_name]: # construct full_type and add it to every node - try: - node_entry['full_type'] = construct_full_type(node_entry['node_type'], node_entry['process_type']) - except KeyError: - node_entry['full_type'] = None + node_entry['full_type'] = ( + construct_full_type(node_entry.get('node_type'), node_entry.get('process_type')) + if node_entry.get('node_type') or node_entry.get('process_type') else None + ) return results diff --git a/aiida/sphinxext/process.py b/aiida/sphinxext/process.py index 4c80cefbeb..077b49af1c 100644 --- a/aiida/sphinxext/process.py +++ b/aiida/sphinxext/process.py @@ -117,9 +117,12 @@ def build_content(self): content += self.build_doctree(title='Outputs:', port_namespace=self.process_spec.outputs) if hasattr(self.process_spec, 'get_outline'): - outline = self.process_spec.get_outline() - if outline is not None: - content += self.build_outline_doctree(outline=outline) + try: + outline = self.process_spec.get_outline() + if outline is not None: + content += self.build_outline_doctree(outline=outline) + except AssertionError: + pass return content def build_doctree(self, title, port_namespace): diff --git a/aiida/tools/__init__.py b/aiida/tools/__init__.py index fe4146ea57..ffdf77d6e5 100644 --- a/aiida/tools/__init__.py +++ b/aiida/tools/__init__.py @@ -25,5 +25,8 @@ from .data.array.kpoints import * from .data.structure import * from .dbimporters import * +from .graph import * -__all__ = (calculations.__all__ + data.array.kpoints.__all__ + data.structure.__all__ + dbimporters.__all__) +__all__ = ( + calculations.__all__ + data.array.kpoints.__all__ + data.structure.__all__ + dbimporters.__all__ + graph.__all__ +) diff --git a/aiida/tools/dbimporters/plugins/materialsproject.py b/aiida/tools/dbimporters/plugins/materialsproject.py index f4d1ced5c3..be16390e2f 100644 --- a/aiida/tools/dbimporters/plugins/materialsproject.py +++ b/aiida/tools/dbimporters/plugins/materialsproject.py @@ -12,7 +12,7 @@ import os import requests -from pymatgen import MPRester +from pymatgen.ext.matproj import MPRester from aiida.tools.dbimporters.baseclasses import CifEntry, DbImporter, DbSearchResults diff --git a/aiida/tools/graph/__init__.py b/aiida/tools/graph/__init__.py index 2776a55f97..c095d1619a 100644 --- a/aiida/tools/graph/__init__.py +++ b/aiida/tools/graph/__init__.py @@ -7,3 +7,8 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=wildcard-import,undefined-variable +"""Provides tools for traversing the provenance graph.""" +from .deletions import * + +__all__ = deletions.__all__ diff --git a/aiida/tools/graph/deletions.py b/aiida/tools/graph/deletions.py new file mode 100644 index 0000000000..b151f7d3c8 --- /dev/null +++ b/aiida/tools/graph/deletions.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Functions to delete entities from the database, preserving provenance integrity.""" +import logging +from typing import Callable, Iterable, Optional, Set, Tuple, Union +import warnings + +from aiida.backends.utils import delete_nodes_and_connections +from aiida.common.log import AIIDA_LOGGER +from aiida.common.warnings import AiidaDeprecationWarning +from aiida.orm import Group, Node, QueryBuilder, load_node +from aiida.tools.graph.graph_traversers import get_nodes_delete + +__all__ = ('DELETE_LOGGER', 'delete_nodes', 'delete_group_nodes') + +DELETE_LOGGER = AIIDA_LOGGER.getChild('delete') + + +def delete_nodes( + pks: Iterable[int], + verbosity: Optional[int] = None, + dry_run: Union[bool, Callable[[Set[int]], bool]] = True, + force: Optional[bool] = None, + **traversal_rules: bool +) -> Tuple[Set[int], bool]: + """Delete nodes given a list of "starting" PKs. + + This command will delete not only the specified nodes, but also the ones that are + linked to these and should be also deleted in order to keep a consistent provenance + according to the rules explained in the Topics - Provenance section of the documentation. + In summary: + + 1. If a DATA node is deleted, any process nodes linked to it will also be deleted. + + 2. If a CALC node is deleted, any incoming WORK node (callers) will be deleted as + well whereas any incoming DATA node (inputs) will be kept. Outgoing DATA nodes + (outputs) will be deleted by default but this can be disabled. + + 3. If a WORK node is deleted, any incoming WORK node (callers) will be deleted as + well, but all DATA nodes will be kept. Outgoing WORK or CALC nodes will be kept by + default, but deletion of either of both kind of connected nodes can be enabled. + + These rules are 'recursive', so if a CALC node is deleted, then its output DATA + nodes will be deleted as well, and then any CALC node that may have those as + inputs, and so on. + + .. deprecated:: 1.6.0 + The `verbosity` keyword will be removed in `v2.0.0`, set the level of `DELETE_LOGGER` instead. + + .. deprecated:: 1.6.0 + The `force` keyword will be removed in `v2.0.0`, use the `dry_run` option instead. + + :param pks: a list of starting PKs of the nodes to delete + (the full set will be based on the traversal rules) + + :param dry_run: + If True, return the pks to delete without deleting anything. + If False, delete the pks without confirmation + If callable, a function that return True/False, based on the pks, e.g. ``dry_run=lambda pks: True`` + + :param traversal_rules: graph traversal rules. + See :const:`aiida.common.links.GraphTraversalRules` for what rule names + are toggleable and what the defaults are. + + :returns: (pks to delete, whether they were deleted) + + """ + # pylint: disable=too-many-arguments,too-many-branches,too-many-locals,too-many-statements + + if verbosity is not None: + warnings.warn( + 'The verbosity option is deprecated and will be removed in `aiida-core==2.0.0`. ' + 'Set the level of DELETE_LOGGER instead', AiidaDeprecationWarning + ) # pylint: disable=no-member + + if force is not None: + warnings.warn( + 'The force option is deprecated and will be removed in `aiida-core==2.0.0`. ' + 'Use dry_run instead', AiidaDeprecationWarning + ) # pylint: disable=no-member + if force is True: + dry_run = False + + def _missing_callback(_pks: Iterable[int]): + for _pk in _pks: + DELETE_LOGGER.warning(f'warning: node with pk<{_pk}> does not exist, skipping') + + pks_set_to_delete = get_nodes_delete(pks, get_links=False, missing_callback=_missing_callback, + **traversal_rules)['nodes'] + + DELETE_LOGGER.info('%s Node(s) marked for deletion', len(pks_set_to_delete)) + + if pks_set_to_delete and DELETE_LOGGER.level == logging.DEBUG: + builder = QueryBuilder().append( + Node, filters={'id': { + 'in': pks_set_to_delete + }}, project=('uuid', 'id', 'node_type', 'label') + ) + DELETE_LOGGER.debug('Node(s) to delete:') + for uuid, pk, type_string, label in builder.iterall(): + try: + short_type_string = type_string.split('.')[-2] + except IndexError: + short_type_string = type_string + DELETE_LOGGER.debug(f' {uuid} {pk} {short_type_string} {label}') + + if dry_run is True: + DELETE_LOGGER.info('This was a dry run, exiting without deleting anything') + return (pks_set_to_delete, False) + + # confirm deletion + if callable(dry_run) and dry_run(pks_set_to_delete): + DELETE_LOGGER.info('This was a dry run, exiting without deleting anything') + return (pks_set_to_delete, False) + + if not pks_set_to_delete: + return (pks_set_to_delete, True) + + # Recover the list of folders to delete before actually deleting the nodes. I will delete the folders only later, + # so that if there is a problem during the deletion of the nodes in the DB, I don't delete the folders + repositories = [load_node(pk)._repository for pk in pks_set_to_delete] # pylint: disable=protected-access + + DELETE_LOGGER.info('Starting node deletion...') + delete_nodes_and_connections(pks_set_to_delete) + + DELETE_LOGGER.info('Nodes deleted from database, deleting files from the repository now...') + + # If we are here, we managed to delete the entries from the DB. + # I can now delete the folders + for repository in repositories: + repository.erase(force=True) + + DELETE_LOGGER.info('Deletion of nodes completed.') + + return (pks_set_to_delete, True) + + +def delete_group_nodes( + pks: Iterable[int], + dry_run: Union[bool, Callable[[Set[int]], bool]] = True, + **traversal_rules: bool +) -> Tuple[Set[int], bool]: + """Delete nodes contained in a list of groups (not the groups themselves!). + + This command will delete not only the nodes, but also the ones that are + linked to these and should be also deleted in order to keep a consistent provenance + according to the rules explained in the concepts section of the documentation. + In summary: + + 1. If a DATA node is deleted, any process nodes linked to it will also be deleted. + + 2. If a CALC node is deleted, any incoming WORK node (callers) will be deleted as + well whereas any incoming DATA node (inputs) will be kept. Outgoing DATA nodes + (outputs) will be deleted by default but this can be disabled. + + 3. If a WORK node is deleted, any incoming WORK node (callers) will be deleted as + well, but all DATA nodes will be kept. Outgoing WORK or CALC nodes will be kept by + default, but deletion of either of both kind of connected nodes can be enabled. + + These rules are 'recursive', so if a CALC node is deleted, then its output DATA + nodes will be deleted as well, and then any CALC node that may have those as + inputs, and so on. + + :param pks: a list of the groups + + :param dry_run: + If True, return the pks to delete without deleting anything. + If False, delete the pks without confirmation + If callable, a function that return True/False, based on the pks, e.g. ``dry_run=lambda pks: True`` + + :param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names + are toggleable and what the defaults are. + + :returns: (node pks to delete, whether they were deleted) + + """ + group_node_query = QueryBuilder().append( + Group, + filters={ + 'id': { + 'in': list(pks) + } + }, + tag='groups', + ).append(Node, project='id', with_group='groups') + group_node_query.distinct() + node_pks = group_node_query.all(flat=True) + return delete_nodes(node_pks, dry_run=dry_run, **traversal_rules) diff --git a/aiida/tools/graph/graph_traversers.py b/aiida/tools/graph/graph_traversers.py index cee4e9e52a..c731a4a672 100644 --- a/aiida/tools/graph/graph_traversers.py +++ b/aiida/tools/graph/graph_traversers.py @@ -8,34 +8,60 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for functions to traverse AiiDA graphs.""" +import sys +from typing import Any, Callable, cast, Dict, Iterable, List, Mapping, Optional, Set from numpy import inf -from aiida.common.links import GraphTraversalRules, LinkType - -def get_nodes_delete(starting_pks, get_links=False, **kwargs): +from aiida import orm +from aiida.common import exceptions +from aiida.common.links import GraphTraversalRules, LinkType +from aiida.orm.utils.links import LinkQuadruple +from aiida.tools.graph.age_entities import Basket +from aiida.tools.graph.age_rules import UpdateRule, RuleSequence, RuleSaveWalkers, RuleSetWalkers + +if sys.version_info >= (3, 8): + from typing import TypedDict + + class TraverseGraphOutput(TypedDict, total=False): + nodes: Set[int] + links: Optional[Set[LinkQuadruple]] + rules: Dict[str, bool] +else: + TraverseGraphOutput = Mapping[str, Any] + + +def get_nodes_delete( + starting_pks: Iterable[int], + get_links: bool = False, + missing_callback: Optional[Callable[[Iterable[int]], None]] = None, + **traversal_rules: bool +) -> TraverseGraphOutput: """ This function will return the set of all nodes that can be connected to a list of initial nodes through any sequence of specified authorized links and directions for deletion. - :type starting_pks: list or tuple or set :param starting_pks: Contains the (valid) pks of the starting nodes. - :param bool get_links: + :param get_links: Pass True to also return the links between all nodes (found + initial). - :param bool create_forward: will traverse CREATE links in the forward direction. - :param bool call_calc_forward: will traverse CALL_CALC links in the forward direction. - :param bool call_work_forward: will traverse CALL_WORK links in the forward direction. + :param missing_callback: A callback to handle missing starting_pks or if None raise NotExistent + For example to ignore them: ``missing_callback=lambda missing_pks: None`` + + :param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names + are toggleable and what the defaults are. + """ - traverse_links = validate_traversal_rules(GraphTraversalRules.DELETE, **kwargs) + traverse_links = validate_traversal_rules(GraphTraversalRules.DELETE, **traversal_rules) traverse_output = traverse_graph( starting_pks, get_links=get_links, links_forward=traverse_links['forward'], - links_backward=traverse_links['backward'] + links_backward=traverse_links['backward'], + missing_callback=missing_callback ) function_output = { @@ -44,30 +70,31 @@ def get_nodes_delete(starting_pks, get_links=False, **kwargs): 'rules': traverse_links['rules_applied'] } - return function_output + return cast(TraverseGraphOutput, function_output) -def get_nodes_export(starting_pks, get_links=False, **kwargs): +def get_nodes_export( + starting_pks: Iterable[int], get_links: bool = False, **traversal_rules: bool +) -> TraverseGraphOutput: """ This function will return the set of all nodes that can be connected to a list of initial nodes through any sequence of specified authorized links and directions for export. This will also return the links and the traversal rules parsed. - :type starting_pks: list or tuple or set :param starting_pks: Contains the (valid) pks of the starting nodes. - :param bool get_links: + :param get_links: Pass True to also return the links between all nodes (found + initial). - :param bool input_calc_forward: will traverse INPUT_CALC links in the forward direction. - :param bool create_backward: will traverse CREATE links in the backward direction. - :param bool return_backward: will traverse RETURN links in the backward direction. - :param bool input_work_forward: will traverse INPUT_WORK links in the forward direction. - :param bool call_calc_backward: will traverse CALL_CALC links in the backward direction. - :param bool call_work_backward: will traverse CALL_WORK links in the backward direction. + :param input_calc_forward: will traverse INPUT_CALC links in the forward direction. + :param create_backward: will traverse CREATE links in the backward direction. + :param return_backward: will traverse RETURN links in the backward direction. + :param input_work_forward: will traverse INPUT_WORK links in the forward direction. + :param call_calc_backward: will traverse CALL_CALC links in the backward direction. + :param call_work_backward: will traverse CALL_WORK links in the backward direction. """ - traverse_links = validate_traversal_rules(GraphTraversalRules.EXPORT, **kwargs) + traverse_links = validate_traversal_rules(GraphTraversalRules.EXPORT, **traversal_rules) traverse_output = traverse_graph( starting_pks, @@ -82,50 +109,49 @@ def get_nodes_export(starting_pks, get_links=False, **kwargs): 'rules': traverse_links['rules_applied'] } - return function_output + return cast(TraverseGraphOutput, function_output) -def validate_traversal_rules(ruleset=GraphTraversalRules.DEFAULT, **kwargs): +def validate_traversal_rules( + ruleset: GraphTraversalRules = GraphTraversalRules.DEFAULT, **traversal_rules: bool +) -> dict: """ Validates the keywords with a ruleset template and returns a parsed dictionary ready to be used. - :type ruleset: :py:class:`aiida.common.links.GraphTraversalRules` :param ruleset: Ruleset template used to validate the set of rules. - :param bool input_calc_forward: will traverse INPUT_CALC links in the forward direction. - :param bool input_calc_backward: will traverse INPUT_CALC links in the backward direction. - :param bool create_forward: will traverse CREATE links in the forward direction. - :param bool create_backward: will traverse CREATE links in the backward direction. - :param bool return_forward: will traverse RETURN links in the forward direction. - :param bool return_backward: will traverse RETURN links in the backward direction. - :param bool input_work_forward: will traverse INPUT_WORK links in the forward direction. - :param bool input_work_backward: will traverse INPUT_WORK links in the backward direction. - :param bool call_calc_forward: will traverse CALL_CALC links in the forward direction. - :param bool call_calc_backward: will traverse CALL_CALC links in the backward direction. - :param bool call_work_forward: will traverse CALL_WORK links in the forward direction. - :param bool call_work_backward: will traverse CALL_WORK links in the backward direction. + :param input_calc_forward: will traverse INPUT_CALC links in the forward direction. + :param input_calc_backward: will traverse INPUT_CALC links in the backward direction. + :param create_forward: will traverse CREATE links in the forward direction. + :param create_backward: will traverse CREATE links in the backward direction. + :param return_forward: will traverse RETURN links in the forward direction. + :param return_backward: will traverse RETURN links in the backward direction. + :param input_work_forward: will traverse INPUT_WORK links in the forward direction. + :param input_work_backward: will traverse INPUT_WORK links in the backward direction. + :param call_calc_forward: will traverse CALL_CALC links in the forward direction. + :param call_calc_backward: will traverse CALL_CALC links in the backward direction. + :param call_work_forward: will traverse CALL_WORK links in the forward direction. + :param call_work_backward: will traverse CALL_WORK links in the backward direction. """ - from aiida.common import exceptions - if not isinstance(ruleset, GraphTraversalRules): raise TypeError( f'ruleset input must be of type aiida.common.links.GraphTraversalRules\ninstead, it is: {type(ruleset)}' ) - rules_applied = {} - links_forward = [] - links_backward = [] + rules_applied: Dict[str, bool] = {} + links_forward: List[LinkType] = [] + links_backward: List[LinkType] = [] for name, rule in ruleset.value.items(): follow = rule.default - if name in kwargs: + if name in traversal_rules: if not rule.toggleable: raise ValueError(f'input rule {name} is not toggleable for ruleset {ruleset}') - follow = kwargs.pop(name) + follow = traversal_rules.pop(name) if not isinstance(follow, bool): raise ValueError(f'the value of rule {name} must be boolean, but it is: {follow}') @@ -141,8 +167,8 @@ def validate_traversal_rules(ruleset=GraphTraversalRules.DEFAULT, **kwargs): rules_applied[name] = follow - if kwargs: - error_message = f"unrecognized keywords: {', '.join(kwargs.keys())}" + if traversal_rules: + error_message = f"unrecognized keywords: {', '.join(traversal_rules.keys())}" raise exceptions.ValidationError(error_message) valid_output = { @@ -154,39 +180,36 @@ def validate_traversal_rules(ruleset=GraphTraversalRules.DEFAULT, **kwargs): return valid_output -def traverse_graph(starting_pks, max_iterations=None, get_links=False, links_forward=(), links_backward=()): +def traverse_graph( + starting_pks: Iterable[int], + max_iterations: Optional[int] = None, + get_links: bool = False, + links_forward: Iterable[LinkType] = (), + links_backward: Iterable[LinkType] = (), + missing_callback: Optional[Callable[[Iterable[int]], None]] = None +) -> TraverseGraphOutput: """ This function will return the set of all nodes that can be connected to a list of initial nodes through any sequence of specified links. Optionally, it may also return the links that connect these nodes. - :type starting_pks: list or tuple or set :param starting_pks: Contains the (valid) pks of the starting nodes. - :type max_iterations: int or None :param max_iterations: The number of iterations to apply the set of rules (a value of 'None' will iterate until no new nodes are added). - :param bool get_links: - Pass True to also return the links between all nodes (found + initial). + :param get_links: Pass True to also return the links between all nodes (found + initial). - :type links_forward: aiida.common.links.LinkType - :param links_forward: - List with all the links that should be traversed in the forward direction. + :param links_forward: List with all the links that should be traversed in the forward direction. + :param links_backward: List with all the links that should be traversed in the backward direction. - :type links_backward: aiida.common.links.LinkType - :param links_backward: - List with all the links that should be traversed in the backward direction. + :param missing_callback: A callback to handle missing starting_pks or if None raise NotExistent """ # pylint: disable=too-many-locals,too-many-statements,too-many-branches - from aiida import orm - from aiida.tools.graph.age_entities import Basket - from aiida.tools.graph.age_rules import UpdateRule, RuleSequence, RuleSaveWalkers, RuleSetWalkers - from aiida.common import exceptions if max_iterations is None: - max_iterations = inf + max_iterations = cast(int, inf) elif not (isinstance(max_iterations, int) or max_iterations is inf): raise TypeError('Max_iterations has to be an integer or infinity') @@ -204,31 +227,31 @@ def traverse_graph(starting_pks, max_iterations=None, get_links=False, links_for linktype_list.append(linktype.value) filters_backwards = {'type': {'in': linktype_list}} - if not isinstance(starting_pks, (list, set, tuple)): - raise TypeError(f'starting_pks must be of type list, set or tuple\ninstead, it is {type(starting_pks)}') - - if not starting_pks: - if get_links: - output = {'nodes': set(), 'links': set()} - else: - output = {'nodes': set(), 'links': None} - return output + if not isinstance(starting_pks, Iterable): # pylint: disable=isinstance-second-argument-not-valid-type + raise TypeError(f'starting_pks must be an iterable\ninstead, it is {type(starting_pks)}') if any([not isinstance(pk, int) for pk in starting_pks]): raise TypeError(f'one of the starting_pks is not of type int:\n {starting_pks}') operational_set = set(starting_pks) + if not operational_set: + if get_links: + return {'nodes': set(), 'links': set()} + return {'nodes': set(), 'links': None} + query_nodes = orm.QueryBuilder() query_nodes.append(orm.Node, project=['id'], filters={'id': {'in': operational_set}}) existing_pks = set(query_nodes.all(flat=True)) missing_pks = operational_set.difference(existing_pks) - if missing_pks: + if missing_pks and missing_callback is None: raise exceptions.NotExistent( - f'The following pks are not in the database and must be pruned before this call: {missing_pks}' + f'The following pks are not in the database and must be pruned before this call: {missing_pks}' ) + elif missing_pks and missing_callback is not None: + missing_callback(missing_pks) rules = [] - basket = Basket(nodes=operational_set) + basket = Basket(nodes=existing_pks) # When max_iterations is finite, the order of traversal may affect the result # (its not the same to first go backwards and then forwards than vice-versa) @@ -269,4 +292,4 @@ def traverse_graph(starting_pks, max_iterations=None, get_links=False, links_for if get_links: output['links'] = results['nodes_nodes'].keyset - return output + return cast(TraverseGraphOutput, output) diff --git a/aiida/tools/importexport/dbexport/__init__.py b/aiida/tools/importexport/dbexport/__init__.py index b94f4a307d..6ceb6480a0 100644 --- a/aiida/tools/importexport/dbexport/__init__.py +++ b/aiida/tools/importexport/dbexport/__init__.py @@ -280,7 +280,7 @@ def export( _check_node_licenses(node_ids_to_be_exported, allowed_licenses, forbidden_licenses) # write the link data - if traverse_output['links']: + if traverse_output['links'] is not None: with get_progress_reporter()(total=len(traverse_output['links']), desc='Writing links') as progress: for link in traverse_output['links']: progress.update() diff --git a/aiida/tools/ipython/aiida_magic_register.py b/aiida/tools/ipython/aiida_magic_register.py index b059e19384..f458ddc46a 100644 --- a/aiida/tools/ipython/aiida_magic_register.py +++ b/aiida/tools/ipython/aiida_magic_register.py @@ -15,20 +15,15 @@ The start up folder is usually at ``.ipython/profile_default/startup/`` """ +# DOCUMENTATION MARKER if __name__ == '__main__': try: import aiida del aiida except ImportError: + # AiiDA is not installed in this Python environment pass else: - import IPython - # pylint: disable=ungrouped-imports - from aiida.tools.ipython.ipython_magics import load_ipython_extension - - # Get the current Ipython session - IPYSESSION = IPython.get_ipython() - - # Register the line magic - load_ipython_extension(IPYSESSION) + from aiida.tools.ipython.ipython_magics import register_ipython_extension + register_ipython_extension() diff --git a/aiida/tools/ipython/ipython_magics.py b/aiida/tools/ipython/ipython_magics.py index 8ae158a726..114a1b1ae2 100644 --- a/aiida/tools/ipython/ipython_magics.py +++ b/aiida/tools/ipython/ipython_magics.py @@ -34,10 +34,8 @@ In [2]: %aiida """ -from IPython import version_info # pylint: disable=no-name-in-module -from IPython.core import magic # pylint: disable=no-name-in-module,import-error - -from aiida.common import json +from IPython import version_info, get_ipython +from IPython.core import magic def add_to_ns(local_ns, name, obj): @@ -99,6 +97,8 @@ def _repr_json_(self): """ Output in JSON format. """ + from aiida.common import json + obj = {'current_state': self.current_state} if version_info[0] >= 3: return obj @@ -130,11 +130,10 @@ def _repr_latex_(self): return latex - def _repr_pretty_(self, pretty_print, cycle): + def _repr_pretty_(self, pretty_print, cycle): # pylint: disable=unused-argument """ Output in text format. """ - # pylint: disable=unused-argument if self.is_warning: warning_str = '** ' else: @@ -146,6 +145,24 @@ def _repr_pretty_(self, pretty_print, cycle): def load_ipython_extension(ipython): """ - Triggers the load of all the AiiDA magic commands. + Registers the %aiida IPython extension. + + .. deprecated:: v3.0.0 + Use :py:func:`~aiida.tools.ipython.ipython_magics.register_ipython_extension` instead. + """ + register_ipython_extension(ipython) + + +def register_ipython_extension(ipython=None): """ + Registers the %aiida IPython extension. + + The %aiida IPython extension provides the same environment as the `verdi shell`. + + :param ipython: InteractiveShell instance. If omitted, the global InteractiveShell is used. + + """ + if ipython is None: + ipython = get_ipython() + ipython.register_magics(AiiDALoaderMagics) diff --git a/aiida/transports/plugins/ssh.py b/aiida/transports/plugins/ssh.py index f5e1ca021f..5f1a765608 100644 --- a/aiida/transports/plugins/ssh.py +++ b/aiida/transports/plugins/ssh.py @@ -36,7 +36,8 @@ def parse_sshconfig(computername): import paramiko config = paramiko.SSHConfig() try: - config.parse(open(os.path.expanduser('~/.ssh/config'), encoding='utf8')) + with open(os.path.expanduser('~/.ssh/config'), encoding='utf8') as fhandle: + config.parse(fhandle) except IOError: # No file found, so empty configuration pass @@ -1264,7 +1265,7 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1): # if self.getcwd() is not None: escaped_folder = escape_for_bash(self.getcwd()) - command_to_execute = (f'cd {escaped_folder} && {command}') + command_to_execute = (f'cd {escaped_folder} && ( {command} )') else: command_to_execute = command diff --git a/docs/source/conf.py b/docs/source/conf.py index 633078f48f..f99ab1eb34 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -97,7 +97,6 @@ 'get_started/**', 'howto/installation_more/index.rst', 'import_export/**', - 'internals/engine.rst', 'internals/global_design.rst', 'internals/orm.rst', 'scheduler/index.rst', diff --git a/docs/source/developer_guide/core/caching.rst b/docs/source/developer_guide/core/caching.rst deleted file mode 100644 index 3885222ae6..0000000000 --- a/docs/source/developer_guide/core/caching.rst +++ /dev/null @@ -1,58 +0,0 @@ -Caching: implementation details -+++++++++++++++++++++++++++++++ - -This section covers some details of the caching mechanism which are not discussed in the :ref:`user guide `. -If you are developing plugins and want to modify the caching behavior of your classes, we recommend you read :ref:`this section ` first. - -.. _devel_controlling_hashing: - -Controlling hashing -------------------- - -Below are some methods you can use to control how the hashes of calculation and data classes are computed: - -* To ignore specific attributes, a :py:class:`~aiida.orm.nodes.Node` subclass can have a ``_hash_ignored_attributes`` attribute. - This is a list of attribute names, which are ignored when creating the hash. -* For calculations, the ``_hash_ignored_inputs`` attribute lists inputs that should be ignored when creating the hash. -* To add things which should be considered in the hash, you can override the :meth:`~aiida.orm.nodes.Node._get_objects_to_hash` method. Note that doing so overrides the behavior described above, so you should make sure to use the ``super()`` method. -* Pass a keyword argument to :meth:`~aiida.orm.nodes.Node.get_hash`. - These are passed on to :meth:`~aiida.common.hashing.make_hash`. - -.. _devel_controlling_caching: - -Controlling caching -------------------- - -There are several methods you can use to disable caching for particular nodes: - -On the level of generic :class:`aiida.orm.nodes.Node`: - -* The :meth:`~aiida.orm.nodes.Node.is_valid_cache` property determines whether a particular node can be used as a cache. This is used for example to disable caching from failed calculations. -* Node classes have a ``_cachable`` attribute, which can be set to ``False`` to completely switch off caching for nodes of that class. This avoids performing queries for the hash altogether. - -On the level of :class:`aiida.engine.processes.process.Process` and :class:`aiida.orm.nodes.process.ProcessNode`: - -* The :meth:`ProcessNode.is_valid_cache ` calls :meth:`Process.is_valid_cache `, passing the node itself. This can be used in :class:`~aiida.engine.processes.process.Process` subclasses (e.g. in calculation plugins) to implement custom ways of invalidating the cache. -* The ``spec.exit_code`` has a keyword argument ``invalidates_cache``. If this is set to ``True``, returning that exit code means the process is no longer considered a valid cache. This is implemented in :meth:`Process.is_valid_cache `. - - -The ``WorkflowNode`` example -............................ - -As discussed in the :ref:`user guide `, nodes which can have ``RETURN`` links cannot be cached. -This is enforced on two levels: - -* The ``_cachable`` property is set to ``False`` in the :class:`~aiida.orm.nodes.Node`, and only re-enabled in :class:`~aiida.orm.nodes.process.calculation.calculation.CalculationNode` (which affects CalcJobs and calcfunctions). - This means that a :class:`~aiida.orm.nodes.process.workflow.workflow.WorkflowNode` will not be cached. -* The ``_store_from_cache`` method, which is used to "clone" an existing node, will raise an error if the existing node has any ``RETURN`` links. - This extra safe-guard prevents cases where a user might incorrectly override the ``_cachable`` property on a ``WorkflowNode`` subclass. - -Design guidelines ------------------ - -When modifying the hashing/caching behaviour of your classes, keep in mind that cache matches can go wrong in two ways: - -* False negatives, where two nodes *should* have the same hash but do not -* False positives, where two different nodes get the same hash by mistake - -False negatives are **highly preferrable** because they only increase the runtime of your calculations, while false positives can lead to wrong results. diff --git a/docs/source/developer_guide/core/internals.rst b/docs/source/developer_guide/core/internals.rst index ce32443512..efb9168270 100644 --- a/docs/source/developer_guide/core/internals.rst +++ b/docs/source/developer_guide/core/internals.rst @@ -405,7 +405,7 @@ In case a method is renamed or removed, this is the procedure to follow: - Our ``AiidaDeprecationWarning`` does not inherit from ``DeprecationWarning``, so it will not be "hidden" by python - User can disable our warnings (and only those) by using AiiDA properties with:: - verdi config warnings.showdeprecations False + verdi config set warnings.showdeprecations False Changing the config.json structure ++++++++++++++++++++++++++++++++++ diff --git a/docs/source/howto/data.rst b/docs/source/howto/data.rst index 5f86102074..f0bfdd18eb 100644 --- a/docs/source/howto/data.rst +++ b/docs/source/howto/data.rst @@ -10,8 +10,8 @@ How to work with data Importing data ============== -AiiDA allows users to export data from their database into an export archive file, which can be imported in any other AiiDA database. -If you have an AiiDA export archive that you would like to import, you can use the ``verdi import`` command (see :ref:`the reference section` for details). +AiiDA allows users to export data from their database into an export archive file, which can be imported into any other AiiDA database. +If you have an AiiDA export archive that you would like to import, you can use the ``verdi archive import`` command (see :ref:`the reference section` for details). .. note:: For information on exporting and importing data via AiiDA archives, see :ref:`"How to share data"`. @@ -71,7 +71,7 @@ Then we just construct an instance of that class, passing the file of interest a Note that after construction, you will get an *unstored* node. This means that at this point your data is not yet stored in the database and you can first inspect it and optionally modify it. If you are happy with the results, you can store the new data permanently by calling the :py:meth:`~aiida.orm.nodes.node.Node.store` method. -Every node is assigned a Universal Unique Identifer (UUID) upon creation and once stored it is also assigned a primary key (PK), which can be retrieved through the ``node.uuid`` and ``node.pk`` properties, respectively. +Every node is assigned a Universal Unique Identifier (UUID) upon creation and once stored it is also assigned a primary key (PK), which can be retrieved through the ``node.uuid`` and ``node.pk`` properties, respectively. You can use these identifiers to reference and or retrieve a node. Ways to find and retrieve data that have previously been imported are described in section :ref:`"How to find data"`. @@ -129,7 +129,7 @@ However, they have to be of the same ORM-type (e.g. all have to be subclasses of .. code-block:: python qb = QueryBuilder() # Instantiating instance. One instance -> one query - qb.append([CalcJobNode, WorkChainNode]) # Setting first vertice of path, either WorkChainNode or Job. + qb.append([CalcJobNode, WorkChainNode]) # Setting first vertices of path, either WorkChainNode or Job. .. note:: @@ -148,7 +148,7 @@ There are several ways to obtain data from a query: .. code-block:: python qb = QueryBuilder() # Instantiating instance - qb.append(CalcJobNode) # Setting first vertice of path + qb.append(CalcJobNode) # Setting first vertices of path first_row = qb.first() # Returns a list (!) of the results of the first row @@ -507,10 +507,15 @@ From the command line interface: Are you sure to delete Group? [y/N]: y Success: Group deleted. -.. important:: - Any deletion operation related to groups won't affect the nodes themselves. - For example if you delete a group, the nodes that belonged to the group will remain in the database. - The same happens if you remove nodes from the group -- they will remain in the database but won't belong to the group anymore. +Any deletion operation related to groups, by default, will not affect the nodes themselves. +For example if you delete a group, the nodes that belonged to the group will remain in the database. +The same happens if you remove nodes from the group -- they will remain in the database but won't belong to the group anymore. + +If you also wish to delete the nodes, when deleting the group, use the ``--delete-nodes`` option: + +.. code-block:: console + + $ verdi group delete another_group --delete-nodes Copy one group into another ^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -764,7 +769,7 @@ Deleting data By default, every time you run or submit a new calculation, AiiDA will create for you new nodes in the database, and will never replace or delete data. There are cases, however, when it might be useful to delete nodes that are not useful anymore, for instance test runs or incorrect/wrong data and calculations. -For this case, AiiDA provides the ``verdi node delete`` command to remove the nodes from the provenance graph. +For this case, AiiDA provides the ``verdi node delete`` command and the :py:func:`~aiida.tools.graph.deletions.delete_nodes` function, to remove the nodes from the provenance graph. .. caution:: Once the data is deleted, there is no way to recover it (unless you made a backup). @@ -780,6 +785,13 @@ In addition, there are a number of additional rules that are not mandatory to en For instance, you can set ``--create-forward`` if, when deleting a calculation, you want to delete also the data it produced (using instead ``--no-create-forward`` will delete the calculation only, keeping the output data: note that this effectively strips out the provenance information of the output data). The full list of these flags is available from the help command ``verdi node delete -h``. +.. code-block:: python + + from aiida.tools import delete_nodes + pks_to_be_deleted = delete_nodes( + [1, 2, 3], dry_run=True, create_forward=True, call_calc_forward=True, call_work_forward=True + ) + Deleting computers ------------------ To delete a computer, you can use ``verdi computer delete``. @@ -805,3 +817,79 @@ This command will delete both the file repository and the database. .. danger:: It is not possible to restore a deleted profile unless it was previously backed up! + +.. _how-to:data:transfer: + +Transferring data +================= + +.. versionadded:: 1.6.0 + +.. danger:: + + This feature is still in beta version and its API might change in the near future. + It is therefore not recommended that you rely on it for your public/production workflows. + + Moreover, feedback on its implementation is much appreciated (at https://github.com/aiidateam/aiida-core/issues/4811). + +When a calculation job is launched, AiiDA will create a :py:class:`~aiida.orm.nodes.data.remote.RemoteData` node that is attached as an output node to the calculation node with the label ``remote_folder``. +The input files generated by the ``CalcJob`` plugin are copied to this remote folder and, since the job is executed there as well, the code will produce its output files in that same remote folder also. +Since the :py:class:`~aiida.orm.nodes.data.remote.RemoteData` node only explicitly stores the filepath on the remote computer, and not its actual contents, it functions more or less like a symbolic link. +That means that if the remote folder gets deleted, there will be no way to retrieve its contents. +The ``CalcJob`` plugin can for that reason specify some files that should be :ref:`retrieved` and stored locally in a :py:class:`~aiida.orm.nodes.data.folder.FolderData` node for safekeeing, which is attached to the calculation node as an output with the label ``retrieved_folder``. + +Although the :ref:`retrieve_list` allows to specify what output files are to be retrieved locally, this has to be done *before* the calculation is submitted. +In order to provide more flexibility in deciding what files of completed calculation jobs are to be stored locally, even after it has terminated, AiiDA ships with a the :py:class:`~aiida.calculations.transfer.TransferCalculation` plugin. +This calculation plugin enables to retrieve files from a remote machine and save them in a local :py:class:`~aiida.orm.nodes.data.folder.FolderData`. +The specifications of what to copy are provided through an input of type + +.. code-block:: ipython + + In [1]: instructions_cont = {} + ... instructions_cont['retrieve_files'] = True + ... instructions_cont['symlink_files'] = [ + ... ('node_keyname', 'source/path/filename', 'target/path/filename'), + ... ] + ... instructions_node = orm.Dict(dict=instructions_cont) + +The ``'source/path/filename'`` and ``'target/path/filename'`` are both relative paths (to their respective folders). +The ``node_keyname`` is a string that will be used when providing the source :py:class:`~aiida.orm.nodes.data.remote.RemoteData` node to the calculation. +You also need to provide the computer between which the transfer will occur: + +.. code-block:: ipython + + In [2]: transfer_builder = CalculationFactory('core.transfer').get_builder() + ... transfer_builder.instructions = instructions_node + ... transfer_builder.source_nodes = {'node_keyname': source_node} + ... transfer_builder.metadata.computer = source_node.computer + +The variable ``source_node`` here corresponds to the ``RemoteData`` node whose contents need to be retrieved. +Finally, you just run or submit the calculation as you would do with any other: + +.. code-block:: ipython + + In [2]: from aiida.engine import submit + ... submit(transfer_builder) + +You can also use this to copy local files into a new :py:class:`~aiida.orm.nodes.data.remote.RemoteData` folder. +For this you first have to adapt the instructions to set ``'retrieve_files'`` to ``False`` and use a ``'local_files'`` list instead of the ``'symlink_files'``: + +.. code-block:: ipython + + In [1]: instructions_cont = {} + ... instructions_cont['retrieve_files'] = False + ... instructions_cont['local_files'] = [ + ... ('node_keyname', 'source/path/filename', 'target/path/filename'), + ... ] + ... instructions_node = orm.Dict(dict=instructions_cont) + +It is also relevant to note that, in this case, the ``source_node`` will be of type :py:class:`~aiida.orm.nodes.data.folder.FolderData` so you will have to manually select the computer to where you want to copy the files. +You can do this by looking at your available computers running ``verdi computer list`` and using the label shown to load it with :py:func:`~aiida.orm.utils.load_computer`: + +.. code-block:: ipython + + In [2]: transfer_builder.metadata.computer = load_computer('some-computer-label') + +Both when uploading or retrieving, you can copy multiple files by appending them to the list of the ``local_files`` or ``symlink_files`` keys in the instructions input, respectively. +It is also possible to copy files from any number of nodes by providing several ``source_node`` s, each with a different ``'node_keyname'``. +The target node will always be one (so you can *"gather"* files in a single call, but not *"distribute"* them). diff --git a/docs/source/howto/exploring.rst b/docs/source/howto/exploring.rst index 13a6bba8c3..5eff89a5c1 100644 --- a/docs/source/howto/exploring.rst +++ b/docs/source/howto/exploring.rst @@ -10,11 +10,11 @@ Incoming and outgoing links =========================== The provenance graph in AiiDA is a :ref:`directed graph `. -The vertices of the graph are the *nodes* and the edges that connect them are called *links*. +The vertices of the graph are the *nodes*, and the edges that connect them are called *links*. Since the graph is directed, any node can have *incoming* and *outgoing* links that connect it to neighboring nodes. To discover the neighbors of a given node, you can use the methods :meth:`~aiida.orm.nodes.node.Node.get_incoming` and :meth:`~aiida.orm.nodes.node.Node.get_outgoing`. -They have the exact same interface but will return the neighbors connected to the current node with link coming into it, or with links going out of it, respectively. +They have the exact same interface but will return the neighbors connected to the current node with a link coming into it or with links going out of it, respectively. For example, for a given ``node``, to inspect all the neighboring nodes from which a link is incoming to the ``node``: .. code-block:: python @@ -22,7 +22,7 @@ For example, for a given ``node``, to inspect all the neighboring nodes from whi node.get_incoming() This will return an instance of the :class:`~aiida.orm.utils.links.LinkManager`. -From that manager you can request the results in a specific format. +From that manager, you can request the results in a specific format. If you are only interested in the neighboring nodes themselves, you can call the :class:`~aiida.orm.utils.links.LinkManager.all_nodes` method: .. code-block:: python diff --git a/docs/source/howto/faq.rst b/docs/source/howto/faq.rst index 4c6ae14ad8..1481a2e26c 100644 --- a/docs/source/howto/faq.rst +++ b/docs/source/howto/faq.rst @@ -39,11 +39,11 @@ Simply reloading your shell will solve the problem. Why are calculation jobs taking very long to run on remote machines even though the actual computation time should be fast? =========================================================================================================================== -First make sure that the calculation is not actually waiting in the queue of the scheduler, but it is actually running or has already completed. +First, make sure that the calculation is not actually waiting in the queue of the scheduler, but it is actually running or has already completed. If it then still takes seemingly a lot of time for AiiDA to update your calculations, there are a couple of explanations. First, if you are running many processes, your daemon workers may simply be busy managing other calculations and workflows. -If that is not the case, you may be witnessing the effects of the built in throttling mechanisms of AiiDA's engine. -To ensure that the AiiDA daemon does not overload remote computers or their schedulers, there are built in limits to how often the daemon workers are allowed to open an SSH connection, or poll the scheduler. +If that is not the case, you may be witnessing the effects of the built-in throttling mechanisms of AiiDA's engine. +To ensure that the AiiDA daemon does not overload remote computers or their schedulers, there are built-in limits to how often the daemon workers are allowed to open an SSH connection, or poll the scheduler. To determine the minimum transport and job polling interval, use ``verdi computer configure show `` and ``computer.get_minimum_job_poll_interval()``, respectively. You can lower these values using: @@ -71,10 +71,21 @@ To determine exactly what might be going wrong, first :ref:`set the loglevel `_. Make sure that the PYTHONPATH is correctly defined automatically when starting your shell, so for example if you are using bash, add it to your ``.bashrc``. + +.. _how-to:faq:caching-not-enabled: + +Why is caching not enabled by default? +====================================== + +Caching is designed to work in an unobtrusive way and simply save time and valuable computational resources. +However, this design is a double-egded sword, in that a user that might not be aware of this functionality, can be caught off guard by the results of their calculations. + +The caching mechanism comes with some limitations and caveats that are important to understand. +Refer to the :ref:`topics:provenance:caching:limitations` section for more details. diff --git a/docs/source/howto/include/images/caching.png b/docs/source/howto/include/images/caching.png index 02563f5933..2c90688a3f 100644 Binary files a/docs/source/howto/include/images/caching.png and b/docs/source/howto/include/images/caching.png differ diff --git a/docs/source/howto/include/images/caching.svg b/docs/source/howto/include/images/caching.svg index a1b5acf03b..011145622b 100644 --- a/docs/source/howto/include/images/caching.svg +++ b/docs/source/howto/include/images/caching.svg @@ -7,782 +7,783 @@ xmlns="http://www.w3.org/2000/svg" xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd" xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" - version="1.1" - id="svg4700" - viewBox="0 0 247.80323 507.58128" - height="143.25072mm" - width="69.935577mm" - inkscape:version="0.92.4 (5da689c313, 2019-01-14)" - sodipodi:docname="caching.svg" - inkscape:export-filename="caching.png" + inkscape:export-ydpi="96" inkscape:export-xdpi="96" - inkscape:export-ydpi="96"> + inkscape:export-filename="caching.png" + sodipodi:docname="caching.svg" + inkscape:version="1.0 (4035a4f, 2020-05-01)" + width="69.935577mm" + height="143.25072mm" + viewBox="0 0 247.80323 507.58128" + id="svg4700" + version="1.1"> + fit-margin-left="0" + fit-margin-top="0" + inkscape:current-layer="layer2" + inkscape:window-maximized="0" + inkscape:window-y="363" + inkscape:window-x="2042" + inkscape:cy="275.12116" + inkscape:cx="258.29552" + inkscape:zoom="0.64572445" + showgrid="false" + id="namedview162" + inkscape:window-height="785" + inkscape:window-width="1436" + inkscape:pageshadow="2" + inkscape:pageopacity="0" + guidetolerance="10" + gridtolerance="10" + objecttolerance="10" + borderopacity="1" + bordercolor="#666666" + pagecolor="#ffffff" /> - - - + style="fill:#ffffff;fill-rule:evenodd;stroke:#000000;stroke-width:1.00000003pt" + transform="matrix(-0.4,0,0,-0.4,1.8,0)" /> + id="path8345" /> + refX="0" + id="marker5157" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker14860" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker10321" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker10257" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5624" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499-4" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499-4-0" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499-4-9" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker14860-6" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker14860-6-5" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker4530-3" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0" + style="overflow:visible"> + + + + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-6-4" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker14860-6-4" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker14860-6-4-2" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499-4-0-8" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499-4-0-8-2" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499-4-0-8-2-4" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-5" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-8" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-7" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-2" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-61" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-4" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499-4-0-8-2-9" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499-4-0-8-2-9-7" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499-4-0-8-2-9-5" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499-4-0-8-2-9-76" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499-4-0-8-2-9-7-6" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499-4-0-8-2-9-7-3" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499-4-0-8-2-9-7-8" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499-4-0-8-2-9-7-9" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker14860-6-4-0" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker14860-6-4-5" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker14860-6-4-8" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker14860-6-4-1" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker14860-6-4-81" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499-4-0-8-2-9-7-6-4" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker11499-4-0-8-2-9-7-6-47" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-2-3" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-2-3-5" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-2-3-5-6" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-2-3-5-1" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-2-3-9" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-2-3-5-62" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-2-3-9-8" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-2-3-5-62-2" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-2-3-5-62-3" + style="overflow:visible"> + d="M 5.77,0 -2.88,5 V -5 Z" + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" /> + refX="0" + id="marker5157-4-0-2-3-2" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" + inkscape:connector-curvature="0" /> + refX="0" + id="marker5157-4-0-2-3-92" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" + inkscape:connector-curvature="0" /> + refX="0" + id="marker5157-4-0-2-3-5-62-2-3" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" + inkscape:connector-curvature="0" /> + refX="0" + id="marker5157-4-0-2-3-5-62-2-31" + style="overflow:visible"> + style="fill:#ffffff;fill-rule:evenodd;stroke:#555555;stroke-width:1.00000003pt;stroke-opacity:1" + transform="matrix(0.4,0,0,0.4,-1.8,0)" + inkscape:connector-curvature="0" /> image/svg+xml - + + transform="translate(-150.7275,-31.76242)"> + transform="translate(187.14431,727.24035)"> + style="display:inline" /> + style="display:inline;fill:none;fill-rule:evenodd;stroke:#555555;stroke-width:2;stroke-linecap:butt;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1;marker-end:url(#marker5157-4-0-2-3)" + d="m 201.28329,137.263 -0.59122,107.40039" + id="path8629-6-28-0-1-2-0-3-3-8-10-7" /> + transform="translate(0,2.9037231)" + id="g1826"> + width="100" + height="100" + x="150.72751" + y="249.64442" /> C + height="160" + x="380" + y="269.50507" />C + + id="g1757"> + height="100" + width="100" + id="rect1745" + style="display:inline;opacity:0.75;fill:#de707f;fill-opacity:1;stroke:none;stroke-width:1.41695988;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" /> C' C' + CREATE - - - - D2 - - + height="54.285713" + x="763.57141" + y="281.42496" />CREATE + transform="translate(143.33879,238.57849)" + id="g5405"> + id="g5222" + transform="translate(-341.77464,94.52841)"> + id="g6556" + transform="translate(313.0534,-594.45179)"> + cx="86.109955" + id="path7355-6-6-92" + style="display:inline;fill:#8cd499;fill-opacity:1;stroke:none;stroke-width:2.5;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" /> D3 + height="160" + width="140" + id="rect5320-6-5" />D2 + + transform="translate(2.5085837,66.620272)" + id="g5331"> + style="display:inline"> + cx="86.109955" + id="path7355-6-6-92-7" + style="display:inline;fill:#8cd499;fill-opacity:1;stroke:none;stroke-width:2.5;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" /> D1 + height="160" + width="140" + id="rect5320-6-5-36" />D1 + + transform="translate(332.47541,-301.6828)" + id="g6472"> + id="g6459" + transform="translate(0,-50)"> + transform="translate(586.3525,-579.82948)"> + cx="65.977531" + id="path7355-6-6-92-6" + style="display:inline;fill:#8cd499;fill-opacity:1;stroke:none;stroke-width:2.5;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" /> D4 + height="160" + width="140" + id="rect5320-6-5-3" />D4 + RESULT + height="46.904217" + width="146.88426" + id="rect11760-2-6-5-6" />RESULT + + id="g6446" + transform="translate(6.19864,-50)"> + transform="translate(580.15386,-469.10115)"> + cx="65.977531" + id="path7355-6-6-92-6-2" + style="display:inline;fill:#8cd499;fill-opacity:1;stroke:none;stroke-width:2.5;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" /> D5 + height="160" + width="140" + id="rect5320-6-5-3-9" />D5 + RETRIEVED + height="46.904217" + width="146.88426" + id="rect11760-2-6-5-6-2" />RETRIEVED + CREATE + height="54.285713" + x="763.57141" + y="281.42496" />CREATE - + id="path1830" /> + x="367.76187" + style="font-style:normal;font-variant:normal;font-weight:normal;font-stretch:normal;font-size:12.5px;line-height:1.25;font-family:Arial;-inkscape-font-specification:Arial;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.9375" + xml:space="preserve" /> + id="g1890"> + transform="translate(197.39203,238.57849)" + id="g1850"> + style="opacity:1"> + id="g1846" + transform="translate(313.0534,-594.45179)"> + cx="86.109955" + id="circle1834" + style="display:inline;fill:#8cd499;fill-opacity:1;stroke:none;stroke-width:2.5;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" /> D3 + height="160" + width="140" + id="rect1836" />D2 + ' - + height="160" + x="380" + y="269.50507" />' + + + style="display:inline;fill:none;fill-rule:evenodd;stroke:#555555;stroke-width:2;stroke-linecap:butt;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1;marker-end:url(#marker5157-4-0-2-3)" + inkscape:connector-curvature="0" + sodipodi:nodetypes="cc" /> diff --git a/docs/source/howto/include/images/workflow_error_handling_basic.svg b/docs/source/howto/include/images/workflow_error_handling_basic.svg new file mode 100644 index 0000000000..13c3b7c833 --- /dev/null +++ b/docs/source/howto/include/images/workflow_error_handling_basic.svg @@ -0,0 +1,1558 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + W1 + PHONONS + + + + C1 pw.x + + + C2 ph.x + + + C3 q2r.x + + + C4 matdyn.x + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/howto/include/images/workflow_error_handling_basic_failed.png b/docs/source/howto/include/images/workflow_error_handling_basic_failed.png new file mode 100644 index 0000000000..13406b8fdd Binary files /dev/null and b/docs/source/howto/include/images/workflow_error_handling_basic_failed.png differ diff --git a/docs/source/howto/include/images/workflow_error_handling_basic_success.png b/docs/source/howto/include/images/workflow_error_handling_basic_success.png new file mode 100644 index 0000000000..3114bdc4d7 Binary files /dev/null and b/docs/source/howto/include/images/workflow_error_handling_basic_success.png differ diff --git a/docs/source/howto/include/images/workflow_error_handling_flow_base.png b/docs/source/howto/include/images/workflow_error_handling_flow_base.png new file mode 100644 index 0000000000..7dfc939249 Binary files /dev/null and b/docs/source/howto/include/images/workflow_error_handling_flow_base.png differ diff --git a/docs/source/howto/include/images/workflow_error_handling_flow_base.svg b/docs/source/howto/include/images/workflow_error_handling_flow_base.svg new file mode 100644 index 0000000000..9293811e6d --- /dev/null +++ b/docs/source/howto/include/images/workflow_error_handling_flow_base.svg @@ -0,0 +1,349 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + + + + + + SUCCESS? + NO + + + YES + + + + FINALIZE + + + HANDLEERRORS + + LAUNCHPROCESS + START + + diff --git a/docs/source/howto/include/images/workflow_error_handling_flow_loop.png b/docs/source/howto/include/images/workflow_error_handling_flow_loop.png new file mode 100644 index 0000000000..22934a6bd6 Binary files /dev/null and b/docs/source/howto/include/images/workflow_error_handling_flow_loop.png differ diff --git a/docs/source/howto/index.rst b/docs/source/howto/index.rst index cfb28f8784..070e2ba600 100644 --- a/docs/source/howto/index.rst +++ b/docs/source/howto/index.rst @@ -11,6 +11,7 @@ How-To Guides ssh plugin_codes workflows + workchains_restart data exploring share_data diff --git a/docs/source/howto/installation.rst b/docs/source/howto/installation.rst index f741ff6f16..f9a2f75774 100644 --- a/docs/source/howto/installation.rst +++ b/docs/source/howto/installation.rst @@ -102,53 +102,132 @@ This file is shell specific, but likely one of the following: Configuring profile options --------------------------- -AiiDA provides various configurational options for profiles, which can be controlled with the :ref:`verdi config` command. -To set a configurational option, simply pass the name of the option and the value to set ``verdi config OPTION_NAME OPTION_VALUE``. -The available options are tab-completed, so simply type ``verdi config`` and thit twice to list them. - -For example, if you want to change the default number of workers that are created when you start the daemon, you can run: - -.. code:: bash - - $ verdi config daemon.default_workers 5 - Success: daemon.default_workers set to 5 for profile-one - -You can check the currently defined value of any option by simply calling the command without specifying a value, for example: -.. code:: bash +AiiDA provides various configurational options for profiles, which can be controlled with the :ref:`verdi config` command. - $ verdi config daemon.default_workers +To view all configuration options set for the current profile: + +.. code:: console + + $ verdi config list + name source value + ------------------------------------- -------- ------------ + autofill.user.email global abc@test.com + autofill.user.first_name global chris + autofill.user.institution global epfl + autofill.user.last_name global sewell + caching.default_enabled default False + caching.disabled_for default + caching.enabled_for default + daemon.default_workers default 1 + daemon.timeout profile 20 + daemon.worker_process_slots default 200 + db.batch_size default 100000 + logging.aiida_loglevel default REPORT + logging.alembic_loglevel default WARNING + logging.circus_loglevel default INFO + logging.db_loglevel default REPORT + logging.kiwipy_loglevel default WARNING + logging.paramiko_loglevel default WARNING + logging.plumpy_loglevel default WARNING + logging.sqlalchemy_loglevel default WARNING + rmq.task_timeout default 10 + runner.poll.interval profile 50 + transport.task_maximum_attempts global 6 + transport.task_retry_initial_interval default 20 + verdi.shell.auto_import default + warnings.showdeprecations default True + +Configuration option values are taken, in order of priority, from either the profile specific setting, the global setting (applies to all profiles), or the default value. + +You can also filter by a prefix: + +.. code:: console + + $ verdi config list transport + name source value + ------------------------------------- -------- ------------ + transport.task_maximum_attempts global 6 + transport.task_retry_initial_interval default 20 + +To show the full information for a configuration option or get its current value: + +.. code:: console + + $ verdi config show transport.task_maximum_attempts + schema: + default: 5 + description: Maximum number of transport task attempts before a Process is Paused. + minimum: 1 + type: integer + values: + default: 5 + global: 6 + profile: + $ verdi config get transport.task_maximum_attempts + 6 + +You can also retrieve the value *via* the API: + +.. code-block:: ipython + + In [1]: from aiida import get_config_option + In [2]: get_config_option('transport.task_maximum_attempts') + Out[2]: 6 + +To set a value, at the profile or global level: + +.. code-block:: console + + $ verdi config set transport.task_maximum_attempts 10 + Success: 'transport.task_maximum_attempts' set to 10 for 'quicksetup' profile + $ verdi config set --global transport.task_maximum_attempts 20 + Success: 'transport.task_maximum_attempts' set to 20 globally + $ verdi config show transport.task_maximum_attempts + schema: + type: integer + default: 5 + minimum: 1 + description: Maximum number of transport task attempts before a Process is Paused. + values: + default: 5 + global: 20 + profile: 10 + $ verdi config get transport.task_maximum_attempts + 10 + +.. tip:: + + By default any option set through ``verdi config`` will be applied to the current default profile. + To change the profile you can use the :ref:`profile option`. + +Similarly to unset a value: + +.. code-block:: console + + $ verdi config unset transport.task_maximum_attempts + Success: 'transport.task_maximum_attempts' unset for 'quicksetup' profile + $ verdi config unset --global transport.task_maximum_attempts + Success: 'transport.task_maximum_attempts' unset globally + $ verdi config show transport.task_maximum_attempts + schema: + type: integer + default: 5 + minimum: 1 + description: Maximum number of transport task attempts before a Process is Paused. + values: + default: 5 + global: + profile: + $ verdi config get transport.task_maximum_attempts 5 -If no value is displayed, it means that no value has ever explicitly been set for this particular option and the default will always be used. -By default any option set through ``verdi config`` will be applied to the current default profile. -To change the profile you can use the :ref:`profile option`. - -To undo the configuration of a particular option and reset it so the default value is used, you can use the ``--unset`` option: - -.. code:: bash - - $ verdi config daemon.default_workers --unset - Success: daemon.default_workers unset for profile-one - -If you want to set a particular option that should be applied to all profiles, you can use the ``--global`` flag: - -.. code:: bash - - $ verdi config daemon.default_workers 5 --global - Success: daemon.default_workers set to 5 globally - -and just as on a per-profile basis, this can be undone with the ``--unset`` flag: - -.. code:: bash - - $ verdi config daemon.default_workers --unset --global - Success: daemon.default_workers unset globally - .. important:: Changes that affect the daemon (e.g. ``logging.aiida_loglevel``) will only take affect after restarting the daemon. +.. seealso:: :ref:`How-to configure caching ` + .. _how-to:installation:configure:instance-isolation: @@ -166,7 +245,7 @@ By default, each AiiDA instance (each installation) will store associated profil A best practice is to always separate the profiles together with the code to which they belong. The typical approach is to place the configuration folder in the virtual environment itself and have it automatically selected whenever the environment is activated. -The location of the AiiDA configuration folder, can be controlled with the ``AIIDA_PATH`` environment variable. +The location of the AiiDA configuration folder can be controlled with the ``AIIDA_PATH`` environment variable. This allows us to change the configuration folder automatically, by adding the following lines to the activation script of a virtual environment. For example, if the path of your virtual environment is ``/home/user/.virtualenvs/aiida``, add the following line: diff --git a/docs/source/howto/run_codes.rst b/docs/source/howto/run_codes.rst index 7610543fc1..784a62ffbd 100644 --- a/docs/source/howto/run_codes.rst +++ b/docs/source/howto/run_codes.rst @@ -4,14 +4,14 @@ How to run external codes ************************* -This how-to walks you through the steps of setting up a (possibly remote) compute resource, setting up a code on that computer and submitting a calculation through AiiDA (similar to the :ref:`introductory tutorial `, but in more detail). +This how-to walks you through the steps of setting up a (possibly remote) compute resource, setting up a code on that computer, and submitting a calculation through AiiDA (similar to the :ref:`introductory tutorial `, but in more detail). To run an external code with AiiDA, you need an appropriate :ref:`calculation plugin `. In the following, we assume that a plugin for your code is already available from the `aiida plugin registry `_ and installed on your machine. Refer to the :ref:`how-to:plugins-install` section for details on how to install an existing plugin. If a plugin for your code is not yet available, see :ref:`how-to:plugin-codes`. -Throughout the process you will be prompted for information on the computer and code. +Throughout the process, you will be prompted for information on the computer and code. In these prompts: * Type ``?`` followed by ```` to get help on what is being asked at any prompt. @@ -150,7 +150,7 @@ This command will perform various tests to make sure that AiiDA can connect to t Mitigating connection overloads ---------------------------------- -Some compute resources, particularly large supercomputing centres, may not tolerate submitting too many jobs at once, executing scheduler commands too frequently or opening too many SSH connections. +Some compute resources, particularly large supercomputing centers, may not tolerate submitting too many jobs at once, executing scheduler commands too frequently, or opening too many SSH connections. * Limit the number of jobs in the queue. @@ -257,7 +257,7 @@ At the end of these steps, you will be prompted to edit a script, where you can * *before* running the submission script (after the 'Pre execution script' lines), and * *after* running the submission script (after the 'Post execution script' separator). -Use this for instance to load modules or set variables that are needed by the code, such as: +Use this, for instance, to load modules or set variables that are needed by the code, such as: .. code-block:: bash @@ -399,51 +399,77 @@ See :ref:`topics:processes:usage:launching` and :ref:`topics:processes:usage:mon .. _how-to:run-codes:caching: -How to save computational resources using caching -================================================= +How to save compute time with caching +===================================== + +Over the course of a project, you may end up re-running the same calculations multiple times - be it because two workflows include the same calculation or because one needs to restart a workflow that failed due to some infrastructure problem. -There are numerous reasons why you might need to re-run calculations you have already run before. -Maybe you run a great number of complex workflows in high-throughput that each may repeat the same calculation, or you may have to restart an entire workflow that failed somewhere half-way through. Since AiiDA stores the full provenance of each calculation, it can detect whether a calculation has been run before and, instead of running it again, simply reuse its outputs, thereby saving valuable computational resources. This is what we mean by **caching** in AiiDA. +With caching enabled, AiiDA searches the database for a calculation of the same :ref:`hash`. +If found, AiiDA creates a copy of the calculation node and its results, thus ensuring that the resulting provenance graph is independent of whether caching is enabled or not (see :numref:`fig_caching`). + +.. _fig_caching: +.. figure:: include/images/caching.png + :align: center + :height: 350px + + When reusing the results of a calculation **C** for a new calculation **C'**, AiiDA simply makes a copy of the result nodes and links them up as usual. + This diagram depicts the same input node **D1** being used for both calculations, but an input node **D1'** with the same *hash* as **D1** would trigger the cache as well. + +Caching happens on the *calculation* level (no caching at the workflow level, see :ref:`topics:provenance:caching:limitations`). +By default, both successful and failed calculations enter the cache (more details in :ref:`topics:provenance:caching:control-caching`). + .. _how-to:run-codes:caching:enable: How to enable caching --------------------- -Caching is **not enabled by default**. -The reason is that it is designed to work in an unobtrusive way and simply save time and valuable computational resources. -However, this design is a double-egded sword, in that a user that might not be aware of this functionality, can be caught off guard by the results of their calculations. +.. important:: Caching is **not** enabled by default, see :ref:`the faq `. -.. important:: +Caching is controlled on a per-profile level via the :ref:`verdi config cli `. - The caching mechanism comes with some limitations and caveats that are important to understand. - Refer to the :ref:`topics:provenance:caching:limitations` section for more details. +View your current caching configuration: -In order to enable caching for your profile (here called ``aiida_profile``), place the following ``cache_config.yml`` file in your ``.aiida`` configuration folder: +.. code-block:: console -.. code-block:: yaml + $ verdi config list caching + name source value + ----------------------- -------- ------- + caching.default_enabled default False + caching.disabled_for default + caching.enabled_for default - aiida_profile: - default: True +Enable caching for your current profile or globally (for all profiles): -From this point onwards, when you launch a new calculation, AiiDA will compare its hash (depending both on the type of calculation and its inputs, see :ref:`topics:provenance:caching:hashing`) against other calculations already present in your database. -If another calculation with the same hash is found, AiiDA will reuse its results without repeating the actual calculation. +.. code-block:: console -In order to ensure that the provenance graph with and without caching is the same, AiiDA creates both a new calculation node and a copy of the output data nodes as shown in :numref:`fig_caching`. + $ verdi config set caching.default_enabled True + Success: 'caching.default_enabled' set to True for 'quicksetup' profile -.. _fig_caching: -.. figure:: include/images/caching.png - :align: center - :height: 350px + $ verdi config set -g caching.default_enabled True + Success: 'caching.default_enabled' set to True globally - When reusing the results of a calculation **C** for a new calculation **C'**, AiiDA simply makes a copy of the result nodes and links them up as usual. + $ verdi config list caching + name source value + ----------------------- -------- ------- + caching.default_enabled profile True + caching.disabled_for default + caching.enabled_for default + +.. versionchanged:: 1.6.0 + + Configuring caching via the ``cache_config.yml`` is deprecated as of AiiDA 1.6.0. + Existing ``cache_config.yml`` files will be migrated to the central ``config.json`` file automatically. + + +From this point onwards, when you launch a new calculation, AiiDA will compare its hash (a fixed size string, unique for a calulation's type and inputs, see :ref:`topics:provenance:caching:hashing`) against other calculations already present in your database. +If another calculation with the same hash is found, AiiDA will reuse its results without repeating the actual calculation. .. note:: - AiiDA uses the *hashes* of the input nodes **D1** and **D2** when searching the calculation cache. - That is to say, if the input of **C'** were new nodes **D1'** and **D2'** with the same content (hash) as **D1**, **D2**, the cache would trigger as well. + In contrast to caching, hashing **is** enabled by default, i.e. hashes for all your calculations will already have been computed. .. _how-to:run-codes:caching:configure: @@ -455,48 +481,80 @@ The caching mechanism can be configured on a process class level, meaning the ru Class level ........... -Besides an on/off switch per profile, the ``.aiida/cache_config.yml`` provides control over caching at the level of specific calculations using their corresponding entry point strings (see the output of ``verdi plugin list aiida.calculations``): +Besides the on/off switch set by ``caching.default_enabled``, caching can be controlled at the level of specific calculations using their corresponding entry point strings (see the output of ``verdi plugin list aiida.calculations``): + +.. code-block:: console + + $ verdi config set caching.disabled_for aiida.calculations:templatereplacer + Success: 'caching.disabled_for' set to ['aiida.calculations:templatereplacer'] for 'quicksetup' profile + $ verdi config set caching.enabled_for aiida.calculations:quantumespresso.pw + Success: 'caching.enabled_for' set to ['aiida.calculations:quantumespresso.pw'] for 'quicksetup' profile + $ verdi config set --append caching.enabled_for aiida.calculations:other + Success: 'caching.enabled_for' set to ['aiida.calculations:quantumespresso.pw', 'aiida.calculations:other'] for 'quicksetup' profile + $ verdi config list caching + name source value + ----------------------- -------- ------------------------------------- + caching.default_enabled profile True + caching.disabled_for profile aiida.calculations:templatereplacer + caching.enabled_for profile aiida.calculations:quantumespresso.pw + aiida.calculations:other + +In this example, caching is enabled by default, but explicitly disabled for calculations of the ``TemplatereplacerCalculation`` class, identified by its corresponding ``aiida.calculations:templatereplacer`` entry point string. +It also shows how to enable caching for particular calculations (which has no effect here due to the profile-wide default). + +.. tip:: To set multiple entry-points at once, use a ``,`` delimiter. -.. code-block:: yaml +For the available entry-points in your environment, you can list which are enabled/disabled using: - aiida_profile: - default: False - enabled: - - aiida.calculations:quantumespresso.pw - disabled: - - aiida.calculations:templatereplacer +.. code-block:: console -In this example, where ``aiida_profile`` is the name of the profile, caching is disabled by default, but explicitly enabled for calculaions of the ``PwCalculation`` class, identified by its corresponding ``aiida.calculations:quantumespresso.pw`` entry point string. -It also shows how to disable caching for particular calculations (which has no effect here due to the profile-wide default). + $ verdi config caching + aiida.calculations:arithmetic.add + aiida.calculations:core.transfer + aiida.workflows:arithmetic.add_multiply + aiida.workflows:arithmetic.multiply_add + $ verdi config caching --disabled + aiida.calculations:templatereplacer For calculations which do not have an entry point, you need to specify the fully qualified Python name instead. -For example, the ``seekpath_structure_analysis`` calcfunction defined in ``aiida_quantumespresso.workflows.functions.seekpath_structure_analysis`` is labeled as ``aiida_quantumespresso.workflows.functions.seekpath_structure_analysis.seekpath_structure_analysis``. +For example, the ``seekpath_structure_analysis`` calcfunction defined in ``aiida_quantumespresso.workflows.functions.seekpath_structure_analysis`` is labelled as ``aiida_quantumespresso.workflows.functions.seekpath_structure_analysis.seekpath_structure_analysis``. From an existing :class:`~aiida.orm.nodes.process.calculation.CalculationNode`, you can get the identifier string through the ``process_type`` attribute. The caching configuration also accepts ``*`` wildcards. -For example, the following configuration enables caching for all calculation entry points defined by ``aiida-quantumespresso``, and the ``seekpath_structure_analysis`` calcfunction. -Note that the ``*.seekpath_structure_analysis`` entry needs to be quoted, because it starts with ``*`` which is a special character in YAML. +For example, the following configuration disables caching for all calculation entry points. -.. code-block:: yaml +.. code-block:: console - aiida_profile: - default: False - enabled: - - aiida.calculations:quantumespresso.* - - '*.seekpath_structure_analysis' + $ verdi config set caching.disabled_for 'aiida.calculations:*' + Success: 'caching.disabled_for' set to ['aiida.calculations:*'] for 'quicksetup' profile + $ verdi config caching + aiida.workflows:arithmetic.add_multiply + aiida.workflows:arithmetic.multiply_add + $ verdi config caching --disabled + aiida.calculations:arithmetic.add + aiida.calculations:core.transfer + aiida.calculations:templatereplacer Any entry with a wildcard is overridden by a more specific entry. -The following configuration enables caching for all ``aiida.calculation`` entry points, except those of ``aiida-quantumespresso``: +The following configuration disables caching for all ``aiida.calculation`` entry points, except those of ``arithmetic``: -.. code-block:: yaml - - aiida_profile: - default: False - enabled: - - aiida.calculations:* - disabled: - - aiida.calculations:quantumespresso.* +.. code-block:: console + $ verdi config set caching.enabled_for 'aiida.calculations:arithmetic.*' + Success: 'caching.enabled_for' set to ['aiida.calculations:arithmetic.*'] for 'quicksetup' profile + $ verdi config list caching + name source value + ----------------------- -------- ------------------------------- + caching.default_enabled profile True + caching.disabled_for profile aiida.calculations:* + caching.enabled_for profile aiida.calculations:arithmetic.* + $ verdi config caching + aiida.calculations:arithmetic.add + aiida.workflows:arithmetic.add_multiply + aiida.workflows:arithmetic.multiply_add + $ verdi config caching --disabled + aiida.calculations:core.transfer + aiida.calculations:templatereplacer Instance level .............. diff --git a/docs/source/howto/share_data.rst b/docs/source/howto/share_data.rst index 6c8319023b..37174415e8 100644 --- a/docs/source/howto/share_data.rst +++ b/docs/source/howto/share_data.rst @@ -24,13 +24,13 @@ Exporting those results together with their provenance is as easy as: .. code-block:: console - $ verdi export create my-calculations.aiida --nodes 12 123 1234 + $ verdi archive create my-calculations.aiida --nodes 12 123 1234 As usual, you can use any identifier (label, PK or UUID) to specify the nodes to be exported. The resulting archive file ``my-calculations.aiida`` contains all information pertaining to the exported nodes. The default traversal rules make sure to include the complete provenance of any node specified and should be sufficient for most cases. -See ``verdi export create --help`` for ways to modify the traversal rules. +See ``verdi archive create --help`` for ways to modify the traversal rules. .. tip:: @@ -53,7 +53,7 @@ Then export the group: .. code-block:: console - $ verdi export create my-calculations.aiida --groups my-results + $ verdi archive create my-calculations.aiida --groups my-results Publishing AiiDA archive files ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -69,27 +69,27 @@ When publishing AiiDA archives on the `Materials Cloud Archive`_, you also get a Importing an archive ^^^^^^^^^^^^^^^^^^^^ -Use ``verdi import`` to import AiiDA archives into your current AiiDA profile. -``verdi import`` accepts URLs, e.g.: +Use ``verdi archive import`` to import AiiDA archives into your current AiiDA profile. +``verdi archive import`` accepts URLs, e.g.: .. code-block:: console - $ verdi import "https://archive.materialscloud.org/record/file?file_id=2a59c9e7-9752-47a8-8f0e-79bcdb06842c&filename=SSSP_1.1_PBE_efficiency.aiida&record_id=23" + $ verdi archive import "https://archive.materialscloud.org/record/file?file_id=2a59c9e7-9752-47a8-8f0e-79bcdb06842c&filename=SSSP_1.1_PBE_efficiency.aiida&record_id=23" During import, AiiDA will avoid identifier collisions and node duplication based on UUIDs (and email comparisons for :py:class:`~aiida.orm.users.User` entries). By default, existing entities will be updated with the most recent changes. -Node extras and comments have special modes for determining how to import them - for more details, see ``verdi import --help``. +Node extras and comments have special modes for determining how to import them - for more details, see ``verdi archive import --help``. .. tip:: The AiiDA archive format has evolved over time, but you can still import archives created with previous AiiDA versions. If an outdated archive version is detected during import, the archive file will be automatically migrated to the newest version (within a temporary folder) and the import retried. - You can also use ``verdi export migrate`` to create updated archive files from existing archive files (or update them in place). + You can also use ``verdi archive migrate`` to create updated archive files from existing archive files (or update them in place). -.. tip:: In order to get a quick overview of an archive file *without* importing it into your AiiDA profile, use ``verdi export inspect``: +.. tip:: In order to get a quick overview of an archive file *without* importing it into your AiiDA profile, use ``verdi archive inspect``: .. code-block:: console - $ verdi export inspect sssp-efficiency.aiida + $ verdi archive inspect sssp-efficiency.aiida -------------- ----- Version aiida 1.2.1 Version format 0.9 diff --git a/docs/source/howto/ssh.rst b/docs/source/howto/ssh.rst index ad590e53c1..94e3b2e87c 100644 --- a/docs/source/howto/ssh.rst +++ b/docs/source/howto/ssh.rst @@ -191,26 +191,6 @@ This section explains how to use the ``proxy_command`` feature of ``ssh`` in ord This method can also be used to automatically tunnel into virtual private networks, if you have an account on a proxy/jumphost server with access to the network. -Requirements -^^^^^^^^^^^^ - -The ``netcat`` tool needs to be present on the *PROXY* server (executable may be named ``netcat`` or ``nc``). -``netcat`` simply takes the standard input and redirects it to a given TCP port. - -.. dropdown:: Installing netcat - - If neither ``netcat`` or ``nc`` are available, you will need to install it on your own. - You can download a `netcat distribution `_, unzip the downloaded package, ``cd`` into the folder and execute something like: - - .. code-block:: console - - $ ./configure --prefix=. - $ make - $ make install - - This usually creates a subfolder ``bin``, containing the ``netcat`` and ``nc`` executables. - Write down the full path to ``nc`` which we will need later. - SSH configuration @@ -222,14 +202,9 @@ Edit the ``~/.ssh/config`` file on the computer on which you installed AiiDA (or Hostname FULLHOSTNAME_TARGET User USER_TARGET IdentityFile ~/.ssh/aiida - ProxyCommand ssh USER_PROXY@FULLHOSTNAME_PROXY ABSPATH_NETCAT %h %p - -replacing the ``..._TARGET`` and ``..._PROXY`` variables with the host/user names of the respective servers, and replacing ``ABSPATH_NETCAT`` with the result of ``which netcat`` (or ``which nc``). - -.. note:: - - If desired/necessary for your netcat implementation, hide warnings and errors that may occur during the proxying/tunneling by redirecting stdout and stderr, e.g. by appending ``2> /dev/null`` to the ``ProxyCommand``. + ProxyCommand ssh -W %h:%p USER_PROXY@FULLHOSTNAME_PROXY +replacing the ``..._TARGET`` and ``..._PROXY`` variables with the host/user names of the respective servers. This should allow you to directly connect to the *TARGET* server using @@ -240,15 +215,6 @@ This should allow you to directly connect to the *TARGET* server using For a *passwordless* connection, you need to follow the instructions :ref:`how-to:ssh:passwordless` *twice*: once for the connection from your computer to the *PROXY* server, and once for the connection from the *PROXY* server to the *TARGET* server. -.. warning:: - - There are occasionally ``netcat`` implementations, which keep running after you close your SSH connection, resulting in a growing number of open SSH connections between the *PROXY* server and the *TARGET* server. - If you suspect an issue, it may be worth connecting to the *PROXY* server and checking how many ``netcat`` processes are running, e.g. via: - - .. code-block:: console - - $ ps -aux | grep netcat - AiiDA configuration ^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/howto/visualising_graphs/visualising_graphs.rst b/docs/source/howto/visualising_graphs/visualising_graphs.rst index 3de6facea6..26163a4b9b 100644 --- a/docs/source/howto/visualising_graphs/visualising_graphs.rst +++ b/docs/source/howto/visualising_graphs/visualising_graphs.rst @@ -41,7 +41,7 @@ It can then be imported into the database: .. code:: ipython - !verdi import -n graph1.aiida + !verdi archive import -n graph1.aiida .. code:: python diff --git a/docs/source/howto/workchains_restart.rst b/docs/source/howto/workchains_restart.rst new file mode 100644 index 0000000000..12c3d7e9d9 --- /dev/null +++ b/docs/source/howto/workchains_restart.rst @@ -0,0 +1,402 @@ +.. _how-to:restart_workchain: + +************************************** +How to write error-resistant workflows +************************************** + +.. admonition:: Overview + + This how-to introduces the :py:class:`~aiida.engine.processes.workchains.restart.BaseRestartWorkChain`, and how it can be sub-classed to handle known failure modes of processes and calculations. + +In the :ref:`multi-step workflows how-to ` we discussed how to write a simple multi-step workflow using work chains. +However, there is one thing that we did not consider there: + + What if a calculation step fails? + +For example with the :py:class:`~aiida.workflows.arithmetic.multiply_add.MultiplyAddWorkChain`; it launches a :py:class:`~aiida.calculations.arithmetic.add.ArithmeticAddCalculation`. +If that were to fail, the work chain would except because the line ``self.ctx.addition.outputs.sum`` will raise an ``AttributeError``. +In this case, where the work chain just runs a single calculation, that is not such a big deal but for real-life work chains that run a number of calculations in sequence, having the work chain except will cause all the work up to that point to be lost. +Take as an example a workflow that computes the phonons of a crystal structure using Quantum ESPRESSO: + +.. figure:: include/images/workflow_error_handling_basic_success.png + + Schematic diagram of a workflow that computes the phonons of a crystal structure using Quantum ESPRESSO. + The workflow consists of four consecutive calculations using the ``pw.x``, ``ph.x``, ``q2r.x`` and ``matdyn.x`` code, respectively. + +If all calculations run without problems, the workflow itself will of course also run fine and produce the desired final result. +But, now imagine the third calculation actually fails. +If the workflow does not explicitly check for this failure, but instead blindly assumes that the calculation have produced the required results, it will fail itself, losing the progress it made with the first two calculations. + +.. figure:: include/images/workflow_error_handling_basic_failed.png + + Example execution of the Quantum ESPRESSO phonon workflow where the third step, the ``q2r.x`` code, failed, and because the workflow blindly assumed it would have finished without errors also fails. + +The solution seems simple then. +After each calculation, we simply add a check to verify that it finished successfully and produced the required outputs before continuing with the next calculation. +What do we do, though, when the calculation failed? +Depending on the cause of the failure, we might actually be able to fix the problem, and re-run the calculation, potentially with corrected inputs. +A common example is that the calculation ran out of wall time (requested time from the job scheduler) and was cancelled by the job scheduler. +In this case, simply restarting the calculation (if the code supports restarts), and optionally giving the job more wall time or resources, may fix the problem. + +You might be tempted to add this error handling directly into the workflow. +However, this requires implementing the same error-handling code many times in other workflows that just happen to run the same codes. +For example, we could add the error handling for the ``pw.x`` code directly in our phonon workflow, but a structure optimization workflow will also have to run ``pw.x`` and will have to implement the same error-handling logic. +Is there a way that we can implement this once and easily reuse it in various workflows? + +Yes! Instead of directly running a calculation in a workflow, one should rather run a work chain that is explicitly designed to run the calculation to completion. +This *base* work chain knows about the various failure modes of the calculation and can try to fix the problem and restart the calculation whenever it fails, until it finishes successfully. +This logic of such a base work chain is very generic and can be applied to any calculation, and actually any process: + +.. figure:: include/images/workflow_error_handling_flow_base.png + :align: center + :height: 500px + + Schematic flow diagram of the logic of a *base* work chain, whose job it is to run a subprocess repeatedly, fixing any potential errors, until it finishes successfully. + +The work chain runs the subprocess. +Once it has finished, it then inspects the status. +If the subprocess finished successfully, the work chain returns the results and its job is done. +If, instead, the subprocess failed, the work chain should inspect the cause of failure, and attempt to fix the problem and restart the subprocess. +This cycle is repeated until the subprocess finishes successfully. +Of course this runs the risk of entering into an infinite loop if the work chain never manages to fix the problem, so we want to build in a limit to the maximum number of calculations that can be re-run: + +.. _workflow-error-handling-flow-loop: +.. figure:: include/images/workflow_error_handling_flow_loop.png + :align: center + :height: 500px + + An improved flow diagram for the base work chain that limits the maximum number of iterations that the work chain can try and get the calculation to finish successfully. + +Since this is such a common logical flow for a base work chain that is to wrap another :py:class:`~aiida.engine.processes.process.Process` and restart it until it is finished successfully, we have implemented it as an abstract base class in ``aiida-core``. +The :py:class:`~aiida.engine.processes.workchains.restart.BaseRestartWorkChain` implements the logic of the flow diagram shown above. +Although the ``BaseRestartWorkChain`` is a subclass of :py:class:`~aiida.engine.processes.workchains.workchain.WorkChain` itself, you cannot launch it. +The reason is that it is completely general and so does not know which :py:class:`~aiida.engine.processes.process.Process` class it should run. +Instead, to make use of the base restart work chain, you should subclass it for the process class that you want to wrap. + + +Writing a base restart work chain +================================= + +In this how-to, we will show how to implement the ``BaseRestartWorkChain`` for the :py:class:`~aiida.calculations.arithmetic.add.ArithmeticAddCalculation`. +We start by importing the relevant base classes and create a subclass: + +.. code-block:: python + + from aiida.engine import BaseRestartWorkChain + from aiida.plugins import CalculationFactory + + ArithmeticAddCalculation = CalculationFactory('arithmetic.add') + + class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): + + _process_class = ArithmeticAddCalculation + + +As you can see, all we had to do is create a subclass of the ``BaseRestartWorkChain`` class, which we called ``ArithmeticAddBaseWorkChain``, and set the ``_process_class`` class attribute to ``ArithmeticAddCalculation``. +The latter instructs the work chain what type of process it should launch. +Next, as with all work chains, we should *define* its process specification: + +.. code-block:: python + + from aiida import orm + from aiida.engine import while_ + + @classmethod + def define(cls, spec): + """Define the process specification.""" + super().define(spec) + spec.input('x', valid_type=(orm.Int, orm.Float), help='The left operand.') + spec.input('y', valid_type=(orm.Int, orm.Float), help='The right operand.') + spec.input('code', valid_type=orm.Code, help='The code to use to perform the summation.') + spec.output('sum', valid_type=(orm.Int, orm.Float), help='The sum of the left and right operand.') + spec.outline( + cls.setup, + while_(cls.should_run_process)( + cls.run_process, + cls.inspect_process, + ), + cls.results, + ) + +The inputs and output that we define are essentially determined by the sub process that the work chain will be running. +Since the ``ArithmeticAddCalculation`` requires the inputs ``x`` and ``y``, and produces the ``sum`` as output, we `mirror` those in the specification of the work chain, otherwise we wouldn't be able to pass the necessary inputs. +Finally, we define the logical outline, which if you look closely, resembles the logical flow chart presented in :numref:`workflow-error-handling-flow-loop` a lot. +We start by *setting up* the work chain and then enter a loop: *while* the subprocess has not yet finished successfully *and* we haven't exceeded the maximum number of iterations, we *run* another instance of the process and then *inspect* the results. +The while conditions are implemented in the ``should_run_process`` outline step. +When the process finishes successfully or we have to abandon, we report the *results*. +Now unlike with normal work chain implementations, we *do not* have to implement these outline steps ourselves. +They have already been implemented by the ``BaseRestartWorkChain`` so that we don't have to. +This is why the base restart work chain is so useful, as it saves us from writing and repeating a lot of `boilerplate code `__. + +.. warning:: + + This minimal outline definition is required for the work chain to work properly. + If you change the logic, the names of the steps or omit some steps, the work chain will not run. + Adding extra outline steps to add custom functionality, however, is fine and actually encouraged if it makes sense. + +The last part of the puzzle is to define in the setup what inputs the work chain should pass to the subprocess. +You might wonder why this is necessary, because we already define the inputs in the specification, but those are not the only inputs that will be passed. +The ``BaseRestartWorkChain`` also defines some inputs of its own, such as ``max_iterations`` as you can see in its :py:meth:`~aiida.engine.processes.workchains.restart.BaseRestartWorkChain.define` method. +To make it absolutely clear what inputs are intended for the subprocess, we define them as a dictionary in the context under the key ``inputs``. +One way of doing this is to reuse the :py:meth:`~aiida.engine.processes.workchains.restart.BaseRestartWorkChain.setup` method: + +.. code-block:: python + + def setup(self): + """Call the `setup` of the `BaseRestartWorkChain` and then create the inputs dictionary in `self.ctx.inputs`. + + This `self.ctx.inputs` dictionary will be used by the `BaseRestartWorkChain` to submit the process in the + internal loop. + """ + super().setup() + self.ctx.inputs = {'x': self.inputs.x, 'y': self.inputs.y, 'code': self.inputs.code} + +Note that, as explained before, the ``setup`` step forms a crucial part of the logical outline of any base restart work chain. +Omitting it from the outline will break the work chain, but so will overriding it completely, except as long as we call the ``super``. + +This is all the code we have to write to have a functional work chain. +We can now launch it like any other work chain and the ``BaseRestartWorkChain`` will work its magic: + +.. code-block:: python + + submit(ArithmeticAddBaseWorkChain, x=Int(3), y=Int(4), code=load_code('add@tutor')) + +Once the work chain finished, we can inspect what has happened with, for example, ``verdi process status``: + +.. code-block:: console + + $ verdi process status 1909 + ArithmeticAddBaseWorkChain<1909> Finished [0] [2:results] + └── ArithmeticAddCalculation<1910> Finished [0] + +As you can see the work chain launched a single instance of the ``ArithmeticAddCalculation`` which finished successfully, so the job of the work chain was done as well. + +.. note:: + + If the work chain excepted, make sure the directory containing the WorkChain definition is in the ``PYTHONPATH``. + + You can add the folder in which you have your Python file defining the WorkChain to the ``PYTHONPATH`` through: + + .. code-block:: bash + + $ export PYTHONPATH=/path/to/workchain/directory/:$PYTHONPATH + + After this, it is **very important** to restart the daemon: + + .. code-block:: bash + + $ verdi daemon restart --reset + + Indeed, when updating an existing work chain file or adding a new one, it is **necessary** to restart the daemon **every time** after all changes have taken place. + +Exposing inputs and outputs +=========================== + +Any base restart work chain *needs* to *expose* the inputs of the subprocess it wraps, and most likely *wants* to do the same for the outputs it produces, although the latter is not necessary. +For the simple example presented in the previous section, simply copy-pasting the input and output port definitions of the subprocess ``ArithmeticAddCalculation`` was not too troublesome. +However, this quickly becomes tedious, and more importantly, error-prone once you start to wrap processes with quite a few more inputs. +To prevent the copy-pasting of input and output specifications, the :class:`~aiida.engine.processes.process_spec.ProcessSpec` class provides the :meth:`~plumpy.ProcessSpec.expose_inputs` and :meth:`~plumpy.ProcessSpec.expose_outputs` methods: + +.. code-block:: python + + @classmethod + def define(cls, spec): + """Define the process specification.""" + super().define(spec) + spec.expose_inputs(ArithmeticAddCalculation, namespace='add') + spec.expose_outputs(ArithmeticAddCalculation) + ... + +.. seealso:: + + For more detail on exposing inputs and outputs, see the basic :ref:`Workchain usage section `. + +That takes care of exposing the port specification of the wrapped process class in a very efficient way. +To efficiently retrieve the inputs that have been passed to the process, one can use the :meth:`~aiida.engine.processes.process.Process.exposed_inputs` method. +Note the past tense of the method name. +The method takes a process class and an optional namespace as arguments, and will return the inputs that have been passed into that namespace when it was launched. +This utility now allows us to simplify the ``setup`` outline step that we have shown before: + +.. code-block:: python + + def setup(self): + """Call the `setup` of the `BaseRestartWorkChain` and then create the inputs dictionary in `self.ctx.inputs`. + + This `self.ctx.inputs` dictionary will be used by the `BaseRestartWorkChain` to submit the process in the + internal loop. + """ + super().setup() + self.ctx.inputs = self.exposed_inputs(ArithmeticAddCalculation, 'add') + +This way we don't have to manually fish out all the individual inputs from the ``self.inputs`` but have to just call this single method, saving a lot of time and lines of code. + +When submitting or running the work chain using namespaced inputs (``add`` in the example above), it is important to use the namespace: + +.. code-block:: python + + inputs = { + 'add': { + 'x': Int(3), + 'y': Int(4), + 'code': load_code('add@tutor') + } + } + submit(ArithmeticAddBaseWorkChain, **inputs) + +.. important:: + + Every time you make changes to the ``ArithmeticAddBaseWorkChain``, don't forget to restart the daemon with: + + .. code-block:: bash + + $ verdi daemon restart --reset + +Error handling +============== + +So far you have seen how easy it is to get a work chain up and running that will run a subprocess using the ``BaseRestartWorkChain``. +However, the whole point of this exercise, as described in the introduction, was for the work chain to be able to deal with *failing* processes, yet in the previous example it finished without any problems. + + What would have happened if the subprocess had failed? + +If the computed sum of the inputs ``x`` and ``y`` is negative, the ``ArithmeticAddCalculation`` fails with exit code ``410`` which corresponds to ``ERROR_NEGATIVE_NUMBER``. + +.. seealso:: + + The :ref:`exit code usage section`, for a more detailed explanation of exit codes. + +Let's launch the work chain with inputs that will cause the calculation to fail, e.g. by making one of the operands negative, and see what happens: + +.. code-block:: python + + submit(ArithmeticAddBaseWorkChain, add={'x': Int(3), 'y': Int(-4), 'code': load_code('add@tutor')}) + +This time we will see that the work chain takes quite a different path: + +.. code-block:: console + + $ verdi process status 1930 + ArithmeticAddBaseWorkChain<1930> Finished [402] [1:while_(should_run_process)(1:inspect_process)] + ├── ArithmeticAddCalculation<1931> Finished [410] + └── ArithmeticAddCalculation<1934> Finished [410] + +As expected, the ``ArithmeticAddCalculation`` failed this time with a ``410``. +The work chain noticed the failure when inspecting the result of the subprocess in ``inspect_process``, and in keeping with its name and design, restarted the calculation. +However, since the inputs were not changed, the calculation inevitably and wholly expectedly failed once more with the exact same error code. +Unlike after the first iteration, however, the work chain did not restart again, but gave up and returned the exit code ``402`` itself, which stands for ``ERROR_SECOND_CONSECUTIVE_UNHANDLED_FAILURE``. +As the name suggests, the work chain tried to run the subprocess but it failed twice in a row without the problem being *handled*. +The obvious question now of course is: "How exactly can we instruct the base work chain to handle certain problems?" + +Since the problems are necessarily dependent on the subprocess that the work chain will run, it cannot be implemented by the ``BaseRestartWorkChain`` class itself, but rather will have to be implemented by the subclass. +If the subprocess fails, the ``BaseRestartWorkChain`` calls a set of *process handlers* in the ``inspect_process`` step. +Each process handler gets passed the node of the subprocess that was just run, such that it can inspect the results and potentially fix any problems that it finds. +To "register" a process handler for a base restart work chain implementation, you simply define a method that takes a node as its single argument and decorate it with the :func:`~aiida.engine.processes.workchains.utils.process_handler` decorator: + +.. code-block:: python + + from aiida.engine import process_handler, ProcessHandlerReport + + class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): + + _process_class = ArithmeticAddCalculation + + ... + + @process_handler + def handle_negative_sum(self, node): + """Check if the calculation failed with `ERROR_NEGATIVE_NUMBER`. + + If this is the case, simply make the inputs positive by taking the absolute value. + + :param node: the node of the subprocess that was ran in the current iteration. + :return: optional :class:`~aiida.engine.processes.workchains.utils.ProcessHandlerReport` instance to signal + that a problem was detected and potentially handled. + """ + if node.exit_status == ArithmeticAddCalculation.exit_codes.ERROR_NEGATIVE_NUMBER.status: + self.ctx.inputs['x'] = orm.Int(abs(node.inputs.x.value)) + self.ctx.inputs['y'] = orm.Int(abs(node.inputs.y.value)) + return ProcessHandlerReport() + +The method name can be anything as long as it is a valid Python method name and does not overlap with one of the base work chain's methods. +For better readability, it is, however, recommended to have the method name start with ``handle_``. +In this example, we want to specifically check for a particular failure mode of the ``ArithmeticAddCalculation``, so we compare the :meth:`~aiida.orm.nodes.process.process.ProcessNode.exit_status` of the node with that of the spec of the process. +If the exit code matches, we know that the problem was due to the sum being negative. +Fixing this fictitious problem for this example is as simple as making sure that the inputs are all positive, which we can do by taking the absolute value of them. +We assign the new values to the ``self.ctx.inputs`` just as where we defined the original inputs in the ``setup`` step. +Finally, to indicate that we have handled the problem, we return an instance of :class:`~aiida.engine.processes.workchains.utils.ProcessHandlerReport`. +This will instruct the work chain to restart the subprocess, taking the updated inputs from the context. +With this simple addition, we can now launch the work chain again: + +.. code-block:: console + + $ verdi process status 1941 + ArithmeticAddBaseWorkChain<1941> Finished [0] [2:results] + ├── ArithmeticAddCalculation<1942> Finished [410] + └── ArithmeticAddCalculation<1947> Finished [0] + +This time around, although the first subprocess fails again with a ``410``, the new process handler is called. +It "fixes" the inputs, and when the work chain restarts the subprocess with the new inputs it finishes successfully. +With this simple process you can add as many process handlers as you would like to deal with any potential problem that might occur for the specific subprocess type of the work chain implementation. +To make the code even more readable, the :func:`~aiida.engine.processes.workchains.utils.process_handler` decorator comes with various syntactic sugar. +Instead of having a conditional at the start of each handler to compare the exit status of the node to a particular exit code of the subprocess, you can define it through the ``exit_codes`` keyword argument of the decorator: + +.. code-block:: python + + @process_handler(exit_codes=ArithmeticAddCalculation.exit_codes.ERROR_NEGATIVE_NUMBER) + def handle_negative_sum(self, node): + """Handle the `ERROR_NEGATIVE_NUMBER` failure mode of the `ArithmeticAddCalculation`.""" + self.ctx.inputs['x'] = orm.Int(abs(node.inputs.x.value)) + self.ctx.inputs['y'] = orm.Int(abs(node.inputs.y.value)) + return ProcessHandlerReport() + +If the ``exit_codes`` keyword is defined, which can be either a single instance of :class:`~aiida.engine.processes.exit_code.ExitCode` or a list thereof, the process handler will only be called if the exit status of the node corresponds to one of those exit codes, otherwise it will simply be skipped. + +Multiple process handlers +========================= + +Since typically a base restart work chain implementation will have more than one process handler, one might want to control the order in which they are called. +This can be done through the ``priority`` keyword: + +.. code-block:: python + + @process_handler(priority=400, exit_codes=ArithmeticAddCalculation.exit_codes.ERROR_NEGATIVE_NUMBER) + def handle_negative_sum(self, node): + """Handle the `ERROR_NEGATIVE_NUMBER` failure mode of the `ArithmeticAddCalculation`.""" + self.ctx.inputs['x'] = orm.Int(abs(node.inputs.x.value)) + self.ctx.inputs['y'] = orm.Int(abs(node.inputs.y.value)) + return ProcessHandlerReport() + +The process handlers with a higher priority will be called first. +In this scenario, in addition to controlling the order with which the handlers are called, you may also want to stop the process handling once you have determined the problem. +This can be achieved by setting the ``do_break`` argument of the ``ProcessHandler`` to ``True``: + +.. code-block:: python + + @process_handler(priority=400, exit_codes=ArithmeticAddCalculation.exit_codes.ERROR_NEGATIVE_NUMBER) + def handle_negative_sum(self, node): + """Handle the `ERROR_NEGATIVE_NUMBER` failure mode of the `ArithmeticAddCalculation`.""" + self.ctx.inputs['x'] = orm.Int(abs(node.inputs.x.value)) + self.ctx.inputs['y'] = orm.Int(abs(node.inputs.y.value)) + return ProcessHandlerReport(do_break=True) + +Finally, sometimes one detects a problem that simply cannot or should not be corrected by the work chain. +In this case, the handler can signal that the work chain should abort by setting an :class:`~aiida.engine.processes.exit_code.ExitCode` instance on the ``exit_code`` argument of the ``ProcessHandler``: + +.. code-block:: python + + from aiida.engine import ExitCode + + @process_handler(priority=400, exit_codes=ArithmeticAddCalculation.exit_codes.ERROR_NEGATIVE_NUMBER) + def handle_negative_sum(self, node): + """Handle the `ERROR_NEGATIVE_NUMBER` failure mode of the `ArithmeticAddCalculation`.""" + return ProcessHandlerReport(exit_code=ExitCode(450, 'Inputs lead to a negative sum but I will not correct them')) + +The base restart work chain will detect this exit code and abort the work chain, setting the corresponding status and message on the node as usual: + +.. code-block:: console + + $ verdi process status 1951 + ArithmeticAddBaseWorkChain<1951> Finished [450] [1:while_(should_run_process)(1:inspect_process)] + └── ArithmeticAddCalculation<1952> Finished [410] + +With these basic tools, a broad range of use-cases can be addressed while preventing a lot of boilerplate code. diff --git a/docs/source/howto/workflows.rst b/docs/source/howto/workflows.rst index c9ccb0bd17..844fc7b2c1 100644 --- a/docs/source/howto/workflows.rst +++ b/docs/source/howto/workflows.rst @@ -21,6 +21,10 @@ Here we present a brief introduction on how to write both workflow types. For more details on the concept of a workflow, and the difference between a work function and a work chain, please see the corresponding :ref:`topics section`. +.. note:: + + Developing workflows may involve running several lengthy calculations. Consider :ref:`enabling caching ` to help avoid repeating long workflow steps. + Work function ------------- @@ -242,9 +246,9 @@ So, it is advisable to *submit* more complex or longer work chains to the daemon workchain_node = submit(MultiplyAddWorkChain, **inputs) -Note that when using ``submit`` the work chain is not run in the local interpreter but is sent off to the daemon and you get back control instantly. +Note that when using ``submit`` the work chain is not run in the local interpreter but is sent off to the daemon, and you get back control instantly. This allows you to submit multiple work chains at the same time and the daemon will start working on them in parallel. -Once the ``submit`` call returns, you will not get the result as with ``run``, but you will get the **node** that represents the work chain. +Once the ``submit`` call returns, you will not get the result as with ``run``, but you will get the **node** representing the work chain. Submitting a work chain instead of directly running it not only makes it easier to execute multiple work chains in parallel, but also ensures that the progress of a workchain is not lost when you restart your computer. .. important:: diff --git a/docs/source/index.rst b/docs/source/index.rst index bfc29c1d21..1a21108347 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -110,11 +110,11 @@ How to cite If you use AiiDA for your research, please cite the following work: -.. highlights:: **AiiDA >= 1.0:** Sebastiaan. P. Huber, Spyros Zoupanos, Martin Uhrin, Leopold Talirz, Leonid Kahle, Rico Häuselmann, Dominik Gresch, Tiziano Müller, Aliaksandr V. Yakutovich, Casper W. Andersen, Francisco F. Ramirez, Carl S. Adorf, Fernando Gargiulo, Snehal Kumbhar, Elsa Passaro, Conrad Johnston, Andrius Merkys, Andrea Cepellotti, Nicolas Mounet, Nicola Marzari, Boris Kozinsky, Giovanni Pizzi, *AiiDA 1.0, a scalable computational infrastructure for automated reproducible workflows and data provenance*, Scientific Data **7**, 300 (2020); DOI: [10.1038/s41597-020-00638-4](https://doi.org/10.1038/s41597-020-00638-4) +.. highlights:: **AiiDA >= 1.0:** Sebastiaan. P. Huber, Spyros Zoupanos, Martin Uhrin, Leopold Talirz, Leonid Kahle, Rico Häuselmann, Dominik Gresch, Tiziano Müller, Aliaksandr V. Yakutovich, Casper W. Andersen, Francisco F. Ramirez, Carl S. Adorf, Fernando Gargiulo, Snehal Kumbhar, Elsa Passaro, Conrad Johnston, Andrius Merkys, Andrea Cepellotti, Nicolas Mounet, Nicola Marzari, Boris Kozinsky, and Giovanni Pizzi, *AiiDA 1.0, a scalable computational infrastructure for automated reproducible workflows and data provenance*, Scientific Data **7**, 300 (2020); DOI: `10.1038/s41597-020-00638-4 `_ -.. highlights:: **AiiDA < 1.0:** Giovanni Pizzi, Andrea Cepellotti, Riccardo Sabatini, Nicola Marzari, - and Boris Kozinsky, *AiiDA: automated interactive infrastructure and database - for computational science*, Comp. Mat. Sci 111, 218-230 (2016); DOI: [10.1016/j.commatsci.2015.09.013](https://doi.org/10.1016/j.commatsci.2015.09.013) +.. highlights:: **AiiDA >= 1.0:** Martin Uhrin, Sebastiaan. P. Huber, Jusong Yu, Nicola Marzari, and Giovanni Pizzi, *Workflows in AiiDA: Engineering a high-throughput, event-based engine for robust and modular computational workflows*, Computational Materials Science **187**, 110086 (2021); DOI: `10.1016/j.commatsci.2020.110086 `_ + +.. highlights:: **AiiDA < 1.0:** Giovanni Pizzi, Andrea Cepellotti, Riccardo Sabatini, Nicola Marzari, and Boris Kozinsky, *AiiDA: automated interactive infrastructure and database for computational science*, Computational Materials Science **111**, 218-230 (2016); DOI: `10.1016/j.commatsci.2015.09.013 `_ **************** diff --git a/docs/source/internals/data_storage.rst b/docs/source/internals/data_storage.rst index a3ca0e0e19..d60caf6d03 100644 --- a/docs/source/internals/data_storage.rst +++ b/docs/source/internals/data_storage.rst @@ -72,7 +72,7 @@ The corresponding entity names appear nested next to the properties to show this .. note:: - If you supply an old archive file that the current AiiDA code does not support, ``verdi import`` will automatically try to migrate the archive by calling ``verdi export migrate``. + If you supply an old archive file that the current AiiDA code does not support, ``verdi archive import`` will automatically try to migrate the archive by calling ``verdi archive migrate``. .. _internal_architecture:orm:archive:data-json: diff --git a/docs/source/internals/engine.rst b/docs/source/internals/engine.rst index 6a1c93c07e..0c4fd84f4b 100644 --- a/docs/source/internals/engine.rst +++ b/docs/source/internals/engine.rst @@ -1,11 +1,49 @@ -.. todo:: +.. _internal_architecture:engine: - .. _internal_architecture:engine: +****** +Engine +****** - ****** - Engine - ****** - `#4038`_ + +.. _internal_architecture:engine:caching: + +Controlling caching +------------------- + +.. important:: + + This section covers some details of the caching mechanism which are not discussed in the :ref:`topics section `. + If you are developing plugins and want to modify the caching behavior of your classes, we recommend you read that section first. + +There are several methods which the internal classes of AiiDA use to control the caching mechanism: + +On the level of the generic :class:`orm.Node ` class: + +* The :meth:`~aiida.orm.nodes.Node.is_valid_cache` property determines whether a particular node can be used as a cache. + This is used for example to disable caching from failed calculations. +* Node classes have a ``_cachable`` attribute, which can be set to ``False`` to completely switch off caching for nodes of that class. + This avoids performing queries for the hash altogether. + +On the level of the :class:`Process ` and :class:`orm.ProcessNode ` classes: + +* The :meth:`ProcessNode.is_valid_cache ` calls :meth:`Process.is_valid_cache `, passing the node itself. + This can be used in :class:`~aiida.engine.processes.process.Process` subclasses (e.g. in calculation plugins) to implement custom ways of invalidating the cache. +* The :meth:`ProcessNode._hash_ignored_inputs ` attribute lists the inputs that should be ignored when creating the hash. + This is checked by the :meth:`ProcessNode._get_objects_to_hash ` method. +* The :meth:`Process.is_valid_cache ` is where the :meth:`exit_codes ` that have been marked by ``invalidates_cache`` are checked. + + +The ``WorkflowNode`` example +............................ + +As discussed in the :ref:`topic section `, nodes which can have ``RETURN`` links cannot be cached. +This is enforced on two levels: + +* The ``_cachable`` property is set to ``False`` in the :class:`~aiida.orm.nodes.Node`, and only re-enabled in :class:`~aiida.orm.nodes.process.calculation.calculation.CalculationNode` (which affects CalcJobs and calcfunctions). + This means that a :class:`~aiida.orm.nodes.process.workflow.workflow.WorkflowNode` will not be cached. +* The ``_store_from_cache`` method, which is used to "clone" an existing node, will raise an error if the existing node has any ``RETURN`` links. + This extra safe-guard prevents cases where a user might incorrectly override the ``_cachable`` property on a ``WorkflowNode`` subclass. + .. _#4038: https://github.com/aiidateam/aiida-core/issues/4038 diff --git a/docs/source/internals/index.rst b/docs/source/internals/index.rst index 14f900f266..1bd33c1690 100644 --- a/docs/source/internals/index.rst +++ b/docs/source/internals/index.rst @@ -7,10 +7,10 @@ Internal architecture data_storage plugin_system + engine rest_api .. todo:: global_design orm - engine diff --git a/docs/source/intro/install_system.rst b/docs/source/intro/install_system.rst index 7d11b8e6d8..5df3023a12 100644 --- a/docs/source/intro/install_system.rst +++ b/docs/source/intro/install_system.rst @@ -15,7 +15,7 @@ This is the *recommended* installation method to setup AiiDA on a personal lapto **Install prerequisite services** - AiiDA is designed to run on `Unix `_ operating systems and requires a `bash `_ or `zsh `_ shell, and Python >= 3.6. + AiiDA is designed to run on `Unix `_ operating systems and requires a `bash `_ or `zsh `_ shell, and Python >= 3.7. .. tabbed:: Ubuntu @@ -216,6 +216,23 @@ This is the *recommended* installation method to setup AiiDA on a personal lapto --- + **Setup profile** + + Next, set up an AiiDA configuration profile and related data storage, with the ``verdi quicksetup`` command. + + .. code-block:: console + + (aiida) $ verdi quicksetup + Info: enter "?" for help + Info: enter "!" to ignore the default and set no value + Profile name: me + Email Address (for sharing data): me@user.com + First name: my + Last name: name + Institution: where-i-work + + --- + **Start verdi daemons** Start the verdi daemon(s) that are used to run AiiDA workflows. @@ -234,23 +251,6 @@ This is the *recommended* installation method to setup AiiDA on a personal lapto --- - **Setup profile** - - Next, set up an AiiDA configuration profile and related data storage, with the ``verdi quicksetup`` command. - - .. code-block:: console - - (aiida) $ verdi quicksetup - Info: enter "?" for help - Info: enter "!" to ignore the default and set no value - Profile name: me - Email Address (for sharing data): me@user.com - First name: my - Last name: name - Institution: where-i-work - - --- - **Check setup** To check that everything is set up correctly, execute: diff --git a/docs/source/intro/installation.rst b/docs/source/intro/installation.rst index 04e7c4c4f2..2bc09e1add 100644 --- a/docs/source/intro/installation.rst +++ b/docs/source/intro/installation.rst @@ -6,7 +6,7 @@ Advanced configuration ********************** This chapter covers topics that go beyond the :ref:`standard setup of AiiDA `. -If you are new to AiiDA, we recommed you first go through the :ref:`Basic Tutorial `, +If you are new to AiiDA, we recommend you first go through the :ref:`Basic Tutorial `, or see our :ref:`Next steps guide `. .. _intro:install:database: @@ -246,57 +246,40 @@ The AiiDA daemon is controlled using three simple commands: Using AiiDA in Jupyter ---------------------- -`Jupyter `_ is an open-source web application that allows you to create in-browser notebooks containing live code, visualizations and formatted text. + 1. Install the AiiDA ``notebook`` extra **inside** the AiiDA python environment, e.g. by running ``pip install aiida-core[notebook]``. -Originally born out of the iPython project, it now supports code written in many languages and customized iPython kernels. + 2. (optional) Register the ``%aiida`` IPython magic for loading the same environment as in the ``verdi shell``: -If you didn't already install AiiDA with the ``[notebook]`` option (during ``pip install``), run ``pip install jupyter`` **inside** the virtualenv, and then run **from within the virtualenv**: + Copy the following code snippet into ``/.ipython/profile_default/startup/aiida_magic_register.py`` -.. code-block:: console - - $ jupyter notebook - -This will open a tab in your browser. Click on ``New -> Python`` and type: - -.. code-block:: python - - import aiida - -followed by ``Shift-Enter``. If no exception is thrown, you can use AiiDA in Jupyter. + .. literalinclude:: ../../../aiida/tools/ipython/aiida_magic_register.py + :start-after: # DOCUMENTATION MARKER -If you want to set the same environment as in a ``verdi shell``, -add the following code to a ``.py`` file (create one if there isn't any) in ``/.ipython/profile_default/startup/``: + .. note:: Use ``ipython locate profile`` if you're unsure about the location of your ipython profile folder. -.. code-block:: python - - try: - import aiida - except ImportError: - pass - else: - import IPython - from aiida.tools.ipython.ipython_magics import load_ipython_extension - # Get the current Ipython session - ipython = IPython.get_ipython() +With this setup, you're ready to use AiiDA in Jupyter notebeooks. - # Register the line magic - load_ipython_extension(ipython) - -This file will be executed when the ipython kernel starts up and enable the line magic ``%aiida``. -Alternatively, if you have a ``aiida-core`` repository checked out locally, -you can just copy the file ``/aiida/tools/ipython/aiida_magic_register.py`` to the same folder. -The current ipython profile folder can be located using: +Start a Jupyter notebook server: .. code-block:: console - $ ipython locate profile + $ jupyter notebook + +This will open a tab in your browser. Click on ``New -> Python``. -After this, if you open a Jupyter notebook as explained above and type in a cell: +If you registered the ``%aiida`` IPython magic, simply run: .. code-block:: ipython %aiida -followed by ``Shift-Enter``. You should receive the message "Loaded AiiDA DB environment." -This line magic should also be enabled in standard ipython shells. +After executing the cell by ``Shift-Enter``, you should receive the message "Loaded AiiDA DB environment." +Otherwise, you can load the profile manually as you would in a Python script: + +.. code-block:: python + + from aiida import load_profile, orm + load_profile() + qb = orm.QueryBuilder() + # ... diff --git a/docs/source/intro/troubleshooting.rst b/docs/source/intro/troubleshooting.rst index c6a23d5960..f01519e161 100644 --- a/docs/source/intro/troubleshooting.rst +++ b/docs/source/intro/troubleshooting.rst @@ -326,9 +326,9 @@ Use in ipython/jupyter ---------------------- In order to use the AiiDA objects and functions in Jupyter, this latter has to be instructed to use the iPython kernel installed in the AiiDA virtual environment. -This happens by default if you install AiiDA with ``pip`` including the ``notebook`` option and run Jupyter from the AiiDA virtual environment. +This happens by default if you install AiiDA with ``pip`` including the ``notebook`` option, and run Jupyter from the AiiDA virtual environment. -If, for any reason, you do not want to install Jupyter in the virtual environment, you might consider to install it out of the virtual environment, if not already done: +If for any reason, you do not want to install Jupyter in the virtual environment, you might consider to install it out of the virtual environment, if not already done: .. code-block:: console diff --git a/docs/source/intro/tutorial.rst b/docs/source/intro/tutorial.rst index d1b7a3fd66..8fffa94945 100644 --- a/docs/source/intro/tutorial.rst +++ b/docs/source/intro/tutorial.rst @@ -13,7 +13,7 @@ Basic tutorial Welcome to the AiiDA tutorial! The goal of this tutorial is to give you a basic idea of how AiiDA helps you in executing data-driven workflows. -At the end of this tutorial you will know how to: +At the end of this tutorial, you will know how to: * Store data in the database and subsequently retrieve it. * Decorate a Python function such that its inputs and outputs are automatically tracked. @@ -22,7 +22,7 @@ At the end of this tutorial you will know how to: .. important:: - If you are working on your own machine, note that the tutorial assumes that you have a working AiiDA installation, and have set up your AiiDA profile in the current Python environment. + If you are working on your own machine, note that the tutorial assumes that you have a working AiiDA installation and have set up your AiiDA profile in the current Python environment. If this is not the case, consult the :ref:`getting started page`. Provenance @@ -44,7 +44,7 @@ In the provenance graph, you can see different types of *nodes* represented by d The green ellipses are ``Data`` nodes, the blue ellipse is a ``Code`` node, and the rectangles represent *processes*, i.e. the calculations performed in your *workflow*. The provenance graph allows us to not only see what data we have, but also how it was produced. -During this tutorial we will be using AiiDA to generate the provenance graph in :numref:`fig_intro_workchain_graph` step by step. +During this tutorial, we will be using AiiDA to generate the provenance graph in :numref:`fig_intro_workchain_graph` step by step. Data nodes ========== @@ -94,7 +94,7 @@ Use the PK only if you are working within a single database, i.e. in an interact The PK numbers shown throughout this tutorial assume that you start from a completely empty database. It is possible that the nodes' PKs will be different for your database! - The UUIDs are generated randomly and are therefore **guaranteed** to be different. + The UUIDs are generated randomly and are, therefore, **guaranteed** to be different. Next, let's leave the IPython shell by typing ``exit()`` and then enter. Back in the terminal, use the ``verdi`` command line interface (CLI) to check the data node we have just created: diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 19cfad02a8..06635a95c4 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -19,19 +19,56 @@ py:class builtins.str py:class builtins.dict # typing -py:class traceback +py:class asyncio.events.AbstractEventLoop +py:class EntityType +py:class function py:class IO +py:class traceback ### AiiDA # issues with order of object processing and type hinting -py:class WorkChainSpec +py:class aiida.engine.runners.ResultAndNode +py:class aiida.engine.runners.ResultAndPk +py:class aiida.engine.processes.workchains.workchain.WorkChainSpec +py:class aiida.manage.manager.Manager +py:class aiida.orm.nodes.node.WarnWhenNotEntered +py:class aiida.orm.utils.links.LinkQuadruple py:class aiida.tools.importexport.dbexport.ExportReport py:class aiida.tools.importexport.dbexport.ArchiveData - -py:class EntityType py:class aiida.tools.groups.paths.WalkNodeResult +py:class Backend +py:class BackendEntity +py:class BackendNode +py:class AuthInfo +py:class CalcJob +py:class CalcJobNode +py:class Data +py:class ExitCode +py:class File +py:class FolderData +py:class JobInfo +py:class JobState +py:class Node +py:class Parser +py:class PersistenceError +py:class Process +py:class ProcessBuilder +py:class ProcessNode +py:class ProcessSpec +py:class Port +py:class PortNamespace +py:class Runner +py:class Transport +py:class TransportQueue +py:class WorkChainSpec + +py:class kiwipy.communications.Communicator +py:class plumpy.process_states.State +py:class plumpy.workchains._If +py:class plumpy.workchains._While + ### python packages # Note: These exceptions are needed if # * the objects are referenced e.g. as param/return types types in method docstrings (without intersphinx mapping) @@ -60,10 +97,7 @@ py:class paramiko.proxy.ProxyCommand py:class plumpy.ports.PortNamespace py:class plumpy.utils.AttributesDict -py:class topika.Connection - -py:class tornado.ioloop.IOLoop -py:class tornado.concurrent.Future +py:class _asyncio.Future py:class tqdm.std.tqdm @@ -120,3 +154,5 @@ py:class alembic.config.Config py:class pgsu.PGSU py:meth pgsu.PGSU.__init__ + +py:class jsonschema.exceptions._Error diff --git a/docs/source/redirects.txt b/docs/source/redirects.txt index 1068be6541..94f833917f 100644 --- a/docs/source/redirects.txt +++ b/docs/source/redirects.txt @@ -21,6 +21,7 @@ datatypes/index.rst topics/data_types.rst datatypes/functionality.rst topics/data_types.rst datatypes/kpoints.rst topics/data_types.rst datatypes/bands.rst topics/data_types.rst +backup/index.rst howto/installation.rst # fix https://www.materialscloud.org/dmp apidoc/aiida.orm.rst reference/apidoc/aiida.orm.rst diff --git a/docs/source/reference/api/public.rst b/docs/source/reference/api/public.rst index 66d30f5c87..0464d7a1ac 100644 --- a/docs/source/reference/api/public.rst +++ b/docs/source/reference/api/public.rst @@ -102,6 +102,7 @@ If a module is mentioned, then all the resources defined in its ``__all__`` are load_code load_computer load_group + to_aiida_type ``aiida.parsers`` diff --git a/docs/source/reference/command_line.rst b/docs/source/reference/command_line.rst index 84f6989dad..88065b134f 100644 --- a/docs/source/reference/command_line.rst +++ b/docs/source/reference/command_line.rst @@ -10,6 +10,27 @@ Commands ======== Below is a list with all available subcommands. +.. _reference:command-line:verdi-archive: + +``verdi archive`` +----------------- + +.. code:: console + + Usage: [OPTIONS] COMMAND [ARGS]... + + Create, inspect and import AiiDA archives. + + Options: + --help Show this message and exit. + + Commands: + create Export subsets of the provenance graph to file for sharing. + import Import data from an AiiDA archive file. + inspect Inspect contents of an archive without importing it. + migrate Migrate an export archive to a more recent format version. + + .. _reference:command-line:verdi-calcjob: ``verdi calcjob`` @@ -133,14 +154,20 @@ Below is a list with all available subcommands. .. code:: console - Usage: [OPTIONS] OPTION_NAME OPTION_VALUE + Usage: [OPTIONS] COMMAND [ARGS]... - Configure profile-specific or global AiiDA options. + Manage the AiiDA configuration. Options: - --global Apply the option configuration wide. - --unset Remove the line matching the option name from the config file. - --help Show this message and exit. + --help Show this message and exit. + + Commands: + caching List caching-enabled process types for the current profile. + get Get the value of an AiiDA option for the current profile. + list List AiiDA options for the current profile. + set Set an AiiDA option. + show Show details of an AiiDA option for the current profile. + unset Unset an AiiDA option. .. _reference:command-line:verdi-daemon: @@ -199,6 +226,8 @@ Below is a list with all available subcommands. Commands: integrity Check the integrity of the database and fix potential issues. migrate Migrate the database to the latest schema version. + summary Summarise the entities in the database. + version Show the version of the database. .. _reference:command-line:verdi-devel: @@ -233,7 +262,7 @@ Below is a list with all available subcommands. Usage: [OPTIONS] COMMAND [ARGS]... - Create and manage export archives. + Deprecated, use `verdi archive`. Options: --help Show this message and exit. @@ -277,10 +306,10 @@ Below is a list with all available subcommands. --help Show this message and exit. Commands: - add-nodes Add nodes to the a group. + add-nodes Add nodes to a group. copy Duplicate a group. create Create an empty group with a given name. - delete Delete a group. + delete Delete a group and (optionally) the nodes it contains. description Change the description of a group. list Show a list of existing groups. path Inspect groups of nodes, with delimited label paths. @@ -313,9 +342,7 @@ Below is a list with all available subcommands. Usage: [OPTIONS] [--] [ARCHIVES]... - Import data from an AiiDA archive file. - - The archive can be specified by its relative or absolute file path, or its HTTP URL. + Deprecated, use `verdi archive import`. Options: -w, --webpages TEXT... Discover all URL targets pointing to files with the diff --git a/docs/source/topics/calculations/usage.rst b/docs/source/topics/calculations/usage.rst index 5226424c3d..de7c77b1d7 100644 --- a/docs/source/topics/calculations/usage.rst +++ b/docs/source/topics/calculations/usage.rst @@ -312,38 +312,171 @@ Note that the source path can point to a directory, in which case its contents w Retrieve list ~~~~~~~~~~~~~ -The retrieve list supports various formats to define what files should be retrieved. -The simplest is retrieving a single file, whose filename you know before hand and you simply want to copy with the same name in the retrieved folder. -Imagine you want to retrieve the files ``output1.out`` and ``output_folder/output2.out`` you would simply add them as strings to the retrieve list: +The retrieve list is a list of instructions of what files and folders should be retrieved by the engine once a calculation job has terminated. +Each instruction should have one of two formats: -.. code:: python + * a string representing a relative filepath in the remote working directory + * a tuple of length three that allows to control the name of the retrieved file or folder in the retrieved folder - calc_info.retrieve_list = ['output1.out', 'output_folder/output2.out'] +The retrieve list can contain any number of instructions and can use both formats at the same time. +The first format is obviously the simplest, however, this requires one knows the exact name of the file or folder to be retrieved and in addition any subdirectories will be ignored when it is retrieved. +If the exact filename is not known and `glob patterns `_ should be used, or if the original folder structure should be (partially) kept, one should use the tuple format, which has the following format: -The retrieved files will be copied over keeping the exact names and hierarchy. -If you require more control over the hierarchy and nesting, you can use tuples of length three instead, with the following items: + * `source relative path`: the relative path, with respect to the working directory on the remote, of the file or directory to retrieve. + * `target relative path`: the relative path of the directory in the retrieved folder in to which the content of the source will be copied. The string ``'.'`` indicates the top level in the retrieved folder. + * `depth`: the number of levels of nesting in the source path to maintain when copying, starting from the deepest file. - * `source relative path`: the relative path, with respect to the working directory on the remote, of the file or directory to retrieve - * `target relative path`: the relative path where to copy the files locally in the retrieved folder. The string `'.'` indicates the top level in the retrieved folder. - * `depth`: the number of levels of nesting in the folder hierarchy to maintain when copying, starting from the deepest file +To illustrate the various possibilities, consider the following example file hierarchy in the remote working directory: -For example, imagine the calculation will have written a file in the remote working directory with the folder hierarchy ``some/remote/path/files/output.dat``. -If you want to copy the file, with the final resulting path ``path/files/output.dat``, you would specify: +.. code:: bash -.. code:: python + ├─ path + | ├── sub + │ │ ├─ file_c.txt + │ │ └─ file_d.txt + | └─ file_b.txt + └─ file_a.txt - calc_info.retrieve_list = [('some/remote/path/files/output.dat', '.', 2)] +Below, you will find examples for various use cases of files and folders to be retrieved. +Each example starts with the format of the ``retrieve_list``, followed by a schematic depiction of the final file hierarchy that would be created in the retrieved folder. -The depth of two, ensures that only two levels of nesting are copied. -If the output files have dynamic names that one cannot know beforehand, the ``'*'`` glob pattern can be used. -For example, if the code will generate a number of XML files in the folder ``relative/path/output`` with filenames that follow the pattern ``file_*[0-9].xml``, you can instruct to retrieve all of them as follows: +Explicit file or folder +....................... -.. code:: python +Retrieving a single toplevel file or folder (with all its contents) where the final folder structure is not important. + +.. code:: bash + + retrieve_list = ['file_a.txt'] + + └─ file_a.txt + +.. code:: bash + + retrieve_list = ['path'] + + ├── sub + │ ├─ file_c.txt + │ └─ file_d.txt + └─ file_b.txt + + +Explicit nested file or folder +.............................. + +Retrieving a single file or folder (with all its contents) that is located in a subdirectory in the remote working directory, where the final folder structure is not important. + +.. code:: bash + + retrieve_list = ['path/file_b.txt'] + + └─ file_b.txt + +.. code:: bash + + retrieve_list = ['path/sub'] + + ├─ file_c.txt + └─ file_d.txt + + +Explicit nested file or folder keeping (partial) hierarchy +.......................................................... + +The following examples show how the file hierarchy of the retrieved files can be controlled. +By changing the ``depth`` parameter of the tuple, one can control what part of the remote folder hierarchy is kept. +In the given example, the maximum depth of the remote folder hierarchy is ``3``. +The following example shows that by specifying ``3``, the exact folder structure is kept: + +.. code:: bash + + retrieve_list = [('path/sub/file_c.txt', '.', 3)] + + └─ path + └─ sub + └─ file_c.txt + +For ``depth=2``, only two levels of nesting are kept (including the file itself) and so the ``path`` folder is discarded. + +.. code:: bash + + retrieve_list = [('path/sub/file_c.txt', '.', 2)] + + └─ sub + └─ file_c.txt + +The same applies for directories. +By specifying a directory for the first element, all its contents will be retrieved. +With ``depth=1``, only the first level ``sub`` is kept of the folder hierarchy. + +.. code:: bash + + retrieve_list = [('path/sub', '.', 1)] + + └── sub + ├─ file_c.txt + └─ file_d.txt - calc_info.retrieve_list = [('relative/path/output/file_*[0-9].xml', '.', 1)] -The second item when using globbing *has* to be ``'.'`` and the depth works just as before. -In this example, all files matching the globbing pattern will be copied in the directory ``output`` in the retrieved folder data node. +Pattern matching +................ + +If the exact file or folder name is not known beforehand, glob patterns can be used. +In the following examples, all files that match ``*c.txt`` in the directory ``path/sub`` will be retrieved. +Since ``depth=0`` the files will be copied without the ``path/sub`` subdirectory. + +.. code:: bash + + retrieve_list = [('path/sub/*c.txt', '.', 0)] + + └─ file_c.txt + +To keep the subdirectory structure, one can set the depth parameter, just as in the previous examples. + +.. code:: bash + + retrieve_list = [('path/sub/*c.txt', '.', 2)] + + └── sub + └─ file_c.txt + + +Specific target directory +......................... + +The final folder hierarchy of the retrieved files in the retrieved folder is not only determined by the hierarchy of the remote working directory, but can also be controlled through the second and third elements of the instructions tuples. +The final ``depth`` element controls what level of hierarchy of the source is maintained, where the second element specifies the base path in the retrieved folder into which the remote files should be retrieved. +For example, to retrieve a nested file, maintaining the remote hierarchy and storing it locally in the ``target`` directory, one can do the following: + +.. code:: bash + + retrieve_list = [('path/sub/file_c.txt', 'target', 3)] + + └─ target + └─ path + └─ sub + └─ file_c.txt + +The same applies for folders that are to be retrieved: + +.. code:: bash + + retrieve_list = [('path/sub', 'target', 1)] + + └─ target + └── sub + ├─ file_c.txt + └─ file_d.txt + +Note that `target` here is not used to rename the retrieved file or folder, but indicates the path of the directory into which the source is copied. +The target relative path is also compatible with glob patterns in the source relative paths: + +.. code:: bash + + retrieve_list = [('path/sub/*c.txt', 'target', 0)] + + └─ target + └─ file_c.txt Retrieve temporary list @@ -360,6 +493,54 @@ The parser implementation can then parse these files and store the relevant info After the parser terminates, the engine will take care to automatically clean up the sandbox folder with the temporarily retrieved files. The contract of the 'retrieve temporary list' is essentially that the files will be available during parsing and will be destroyed immediately afterwards. +.. _topics:calculations:usage:calcjobs:stashing: + +Stashing files on the remote +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. versionadded:: 1.6.0 + +The ``stash`` option namespace allows a user to specify certain files that are created by the calculation job to be stashed somewhere on the remote. +This can be useful if those files need to be stored for a longer time than the scratch space where the job was run is typically not cleaned for, but need to be kept on the remote machine and not retrieved. +Examples are files that are necessary to restart a calculation but are too big to be retrieved and stored permanently in the local file repository. + +The files that are to be stashed are specified through their relative filepaths within the working directory in the ``stash.source_list`` option. +Using the ``COPY`` mode, the target path defines another location (on the same filesystem as the calculation) to copy these files to, and is set through the ``stash.target_base`` option, for example: + +.. code-block:: python + + from aiida.common.datastructures import StashMode + + inputs = { + 'code': ...., + ... + 'metadata': { + 'options': { + 'stash': { + 'source_list': ['aiida.out', 'output.txt'], + 'target_base': '/storage/project/stash_folder', + 'stash_mode': StashMode.COPY.value, + } + } + } + } + +.. note:: + + In the future, other methods for stashing may be implemented, such as placing all files in a (compressed) tarball or even stash files on tape. + +.. important:: + + If the ``stash`` option namespace is defined for a calculation job, the daemon will perform the stashing operations before the files are retrieved. + This means that the stashing happens before the parsing of the output files (which occurs after the retrieving step), such that that the files will be stashed independent of the final exit status that the parser will assign to the calculation job. + This may cause files to be stashed for calculations that will later be considered to have failed. + +The stashed files are represented by an output node that is attached to the calculation node through the label ``remote_stash``, as a ``RemoteStashFolderData`` node. +Just like the ``remote_folder`` node, this represents a location or files on a remote machine and so is equivalent to a "symbolic link". + +.. important:: + + AiiDA does not actually own the files in the remote stash, and so the contents may disappear at some point. .. _topics:calculations:usage:calcjobs:options: diff --git a/docs/source/topics/cli.rst b/docs/source/topics/cli.rst index f3743b4dbf..26d4159a64 100644 --- a/docs/source/topics/cli.rst +++ b/docs/source/topics/cli.rst @@ -25,11 +25,11 @@ Multi-value options Some ``verdi`` commands provide *options* that can take multiple values. This allows to avoid repetition and e.g. write:: - verdi export create -N 10 11 12 -- archive.aiida + verdi archive create -N 10 11 12 -- archive.aiida instead of the more lengthy:: - verdi export create -N 10 -N 11 -N 12 archive.aiida + verdi archive create -N 10 -N 11 -N 12 archive.aiida Note the use of the so-called 'endopts' marker ``--`` that is necessary to mark the end of the ``-N`` option and distinguish it from the ``archive.aiida`` argument. @@ -68,7 +68,7 @@ The ``Usage:`` line encodes information on the command's parameters, e.g.: Multi-value options are followed by ``...`` in the help string and the ``Usage:`` line of the corresponding command will contain the 'endopts' marker. For example:: - Usage: verdi export create [OPTIONS] [--] OUTPUT_FILE + Usage: verdi archive create [OPTIONS] [--] OUTPUT_FILE Export various entities, such as Codes, Computers, Groups and Nodes, to an archive file for backup or sharing purposes. diff --git a/docs/source/topics/processes/concepts.rst b/docs/source/topics/processes/concepts.rst index 166c601406..a0a2b50700 100644 --- a/docs/source/topics/processes/concepts.rst +++ b/docs/source/topics/processes/concepts.rst @@ -105,6 +105,7 @@ When you load a calculation node from the database, you can use these property m Process exit codes ================== + The previous section about the process state showed that a process that is ``Finished`` does not say anything about whether the result is 'successful' or 'failed'. The ``Finished`` state means nothing more than that the engine succeeded in running the process to the end of execution, without it encountering exceptions or being killed. To distinguish between a 'successful' and 'failed' process, an 'exit status' can be defined. @@ -112,8 +113,10 @@ The `exit status is a common concept in programming ` and :ref:`workflow` development sections. +.. seealso:: + + For how exit codes can be defined and returned see the :ref:`exit code usage section `. .. _topics:processes:concepts:lifetime: @@ -156,6 +159,10 @@ Processes, whose task is in the queue and not with any runner, though technicall While a process is not actually being run, i.e. it is not in memory with a runner, one cannot interact with it. Similarly, as soon as the task disappears, either because the process was intentionally terminated (or unintentionally), the process will never continue running again. +.. figure:: include/images/submit_sysml.png + + A systems modelling representation of submitting a process. + .. _topics:processes:concepts:checkpoints: diff --git a/docs/source/topics/processes/functions.rst b/docs/source/topics/processes/functions.rst index c62ee97b6b..deda0601b9 100644 --- a/docs/source/topics/processes/functions.rst +++ b/docs/source/topics/processes/functions.rst @@ -128,6 +128,7 @@ As always, all the values returned by a calculation function have to be storable Because of the calculation/workflow duality in AiiDA, a ``calcfunction``, which is a calculation-like process, can only *create* and not *return* data nodes. This means that if a node is returned from a ``calcfunction`` that *is already stored*, the engine will throw an exception. +.. _topics:processes:functions:exit_codes: Exit codes ========== diff --git a/docs/source/topics/processes/include/images/submit_sysml.png b/docs/source/topics/processes/include/images/submit_sysml.png new file mode 100644 index 0000000000..df2d361dcf Binary files /dev/null and b/docs/source/topics/processes/include/images/submit_sysml.png differ diff --git a/docs/source/topics/processes/include/images/submit_sysml.pptx b/docs/source/topics/processes/include/images/submit_sysml.pptx new file mode 100644 index 0000000000..a66e96d379 Binary files /dev/null and b/docs/source/topics/processes/include/images/submit_sysml.pptx differ diff --git a/docs/source/topics/processes/include/snippets/serialize/workchain_serialize.py b/docs/source/topics/processes/include/snippets/serialize/workchain_serialize.py index c8ce2b81dd..5980207626 100644 --- a/docs/source/topics/processes/include/snippets/serialize/workchain_serialize.py +++ b/docs/source/topics/processes/include/snippets/serialize/workchain_serialize.py @@ -1,9 +1,6 @@ # -*- coding: utf-8 -*- from aiida.engine import WorkChain -from aiida.orm.nodes.data import to_aiida_type -# The basic types need to be loaded such that they are registered with -# the 'to_aiida_type' function. -from aiida.orm.nodes.data.base import * +from aiida.orm import to_aiida_type class SerializeWorkChain(WorkChain): diff --git a/docs/source/topics/processes/usage.rst b/docs/source/topics/processes/usage.rst index a883be92f1..e9653be16c 100644 --- a/docs/source/topics/processes/usage.rst +++ b/docs/source/topics/processes/usage.rst @@ -201,12 +201,14 @@ This function, passed as ``serializer`` parameter to ``spec.input``, is invoked For inputs which are stored in the database (``non_db=False``), the serialization function should return an AiiDA data type. For ``non_db`` inputs, the function must be idempotent because it might be applied more than once. -The following example work chain takes three inputs ``a``, ``b``, ``c``, and simply returns the given inputs. The :func:`aiida.orm.nodes.data.base.to_aiida_type` function is used as serialization function. +The following example work chain takes three inputs ``a``, ``b``, ``c``, and simply returns the given inputs. +The :func:`~aiida.orm.nodes.data.base.to_aiida_type` function is used as serialization function. .. include:: include/snippets/serialize/workchain_serialize.py :code: python -This work chain can now be called with native Python types, which will automatically be converted to AiiDA types by the :func:`aiida.orm.nodes.data.base.to_aiida_type` function. Note that the module which defines the corresponding AiiDA type must be loaded for it to be recognized by :func:`aiida.orm.nodes.data.base.to_aiida_type`. +This work chain can now be called with native Python types, which will automatically be converted to AiiDA types by the :func:`~aiida.orm.nodes.data.base.to_aiida_type` function. +Note that the module which defines the corresponding AiiDA type must be loaded for it to be recognized by :func:`~aiida.orm.nodes.data.base.to_aiida_type`. .. include:: include/snippets/serialize/run_workchain_serialize.py :code: python @@ -223,7 +225,7 @@ To clearly communicate to the caller what went wrong, the ``Process`` supports s This ``exit_status``, a positive integer, is an attribute of the process node and by convention, when it is zero means the process was successful, whereas any other value indicates failure. This concept of an exit code, with a positive integer as the exit status, `is a common concept in programming `_ and a standard way for programs to communicate the result of their execution. -Potential exit codes for the ``Process`` can be defined through the ``ProcessSpec``, just like inputs and ouputs. +Potential exit codes for the ``Process`` can be defined through the ``ProcessSpec``, just like inputs and outputs. Any exit code consists of a positive non-zero integer, a string label to reference it and a more detailed description of the problem that triggers the exit code. Consider the following example: @@ -250,6 +252,16 @@ This is useful, because the caller can now programmatically, based on the ``exit This is an infinitely more robust way of communicating specific errors to a non-human than parsing text-based logs or reports. Additionally, the exit codes make it very easy to query for failed processes with specific error codes. +.. seealso:: + + Additional documentation, specific to certain process types, can be found in the following sections: + + - :ref:`Process functions` + - :ref:`Work functions` + - :ref:`CalcJob parsers` + - :ref:`Workchain exit code specification` + - :ref:`External code plugins` + - :ref:`Restart workchains` .. _topics:processes:usage:exit_code_conventions: diff --git a/docs/source/topics/provenance/caching.rst b/docs/source/topics/provenance/caching.rst index 00d977eded..2cb4e45e1e 100644 --- a/docs/source/topics/provenance/caching.rst +++ b/docs/source/topics/provenance/caching.rst @@ -4,6 +4,11 @@ Caching and hashing =================== +This section covers the more general considerations of the hashing/caching mechanism. +For a more practical guide on how to enable and disable this feature, please visit the corresponding :ref:`how-to section `. +If you want to know more about how the internal design of the mechanism is implemented, you can check the :ref:`internals section ` instead. + + .. _topics:provenance:caching:hashing: How are nodes hashed @@ -23,7 +28,7 @@ The hash of a :class:`~aiida.orm.ProcessNode` includes, on top of this, the hash Once a node is stored in the database, its hash is stored in the ``_aiida_hash`` extra, and this extra is used to find matching nodes. If a node of the same class with the same hash already exists in the database, this is considered a cache match. -Use the :meth:`~aiida.orm.nodes.Node.get_hash` method to check the hash of any node. +You can use the :meth:`~aiida.orm.nodes.Node.get_hash` method to check the hash of any node. In order to figure out why a calculation is *not* being reused, the :meth:`~aiida.orm.nodes.Node._get_objects_to_hash` method may be useful: .. code-block:: ipython @@ -53,10 +58,47 @@ In order to figure out why a calculation is *not* being reused, the :meth:`~aiid ] +.. _topics:provenance:caching:control-hashing: + +Controlling hashing +------------------- + +Data nodes +.......... + +The hashing of *Data nodes* can be customized both when implementing a new data node class and during runtime. + +In the :py:class:`~aiida.orm.nodes.Node` subclass: + +* Use the ``_hash_ignored_attributes`` to exclude a list of node attributes ``['attr1', 'attr2']`` from computing the hash. +* Include extra information in computing the hash by overriding the :meth:`~aiida.orm.nodes.Node._get_objects_to_hash` method. + Use the ``super()`` method, and then append to the list of objects to hash. + +You can also modify hashing behavior during runtime by passing a keyword argument to :meth:`~aiida.orm.nodes.Node.get_hash`, which are forwarded to :meth:`~aiida.common.hashing.make_hash`. + +Process nodes +............. + +The hashing of *Process nodes* is fixed and can only be influenced indirectly via the hashes of their inputs. +For implementation details of the hashing mechanism for process nodes, see :ref:`here `. + +.. _topics:provenance:caching:control-caching: + +Controlling Caching +------------------- + +Caching can be configured at runtime (see :ref:`how-to:run-codes:caching:configure`) and when implementing a new process class: + +* The :meth:`spec.exit_code ` has a keyword argument ``invalidates_cache``. + If this is set to ``True``, that means that a calculation with this exit code will not be used as a cache source for another one, even if their hashes match. +* The :class:`Process ` parent class from which calcjobs inherit has an :meth:`is_valid_cache ` method, which can be overridden in the plugin to implement custom ways of invalidating the cache. + When doing this, make sure to call :meth:`super().is_valid_cache(node)` and respect its output: if it is `False`, your implementation should also return `False`. + If you do not comply with this, the 'invalidates_cache' keyword on exit codes will not work. + .. _topics:provenance:caching:limitations: -Limitations ------------ +Limitations and Guidelines +-------------------------- #. Workflow nodes are not cached. In the current design this follows from the requirement that the provenance graph be independent of whether caching is enabled or not: @@ -72,4 +114,11 @@ Limitations While AiiDA's hashes include the version of the Python package containing the calculation/data classes, it cannot detect cases where the underlying Python code was changed without increasing the version number. Another scenario that can lead to an erroneous cache hit is if the parser and calculation are not implemented as part of the same Python package, because the calculation nodes store only the name, but not the version of the used parser. -#. Finally, while caching saves unnecessary computations, it does not save disk space: the output nodes of the cached calculation are full copies of the original outputs. +#. Note that while caching saves unnecessary computations, it does not save disk space: the output nodes of the cached calculation are full copies of the original outputs. + +#. Finally, When modifying the hashing/caching behaviour of your classes, keep in mind that cache matches can go wrong in two ways: + + * False negatives, where two nodes *should* have the same hash but do not + * False positives, where two different nodes get the same hash by mistake + + False negatives are **highly preferrable** because they only increase the runtime of your calculations, while false positives can lead to wrong results. diff --git a/docs/source/topics/provenance/consistency.rst b/docs/source/topics/provenance/consistency.rst index 2c76733393..a79b57815d 100644 --- a/docs/source/topics/provenance/consistency.rst +++ b/docs/source/topics/provenance/consistency.rst @@ -33,7 +33,7 @@ In the following section we will explain in more detail the criteria for includi Traversal Rules =============== -When you run ``verdi node delete [NODE_IDS]`` or ``verdi export create -N [NODE_IDS]``, AiiDA will look at the links incoming or outgoing from the nodes that you specified and decide if there are other nodes that are critical to keep. +When you run ``verdi node delete [NODE_IDS]`` or ``verdi archive create -N [NODE_IDS]``, AiiDA will look at the links incoming or outgoing from the nodes that you specified and decide if there are other nodes that are critical to keep. For this decision, it is not only important to consider the type of link, but also if we are following it along its direction (we will call this ``forward`` direction) or in the reversed direction (``backward`` direction). To clarify this, in the example above, when deleting data node |D_1|, AiiDA will follow the ``input_calc`` link in the ``forward`` direction (in this case, it will decide that the linked node (|C_1|) must then also be deleted). diff --git a/docs/source/topics/workflows/usage.rst b/docs/source/topics/workflows/usage.rst index c83f863117..4de24a9f7c 100644 --- a/docs/source/topics/workflows/usage.rst +++ b/docs/source/topics/workflows/usage.rst @@ -272,7 +272,7 @@ Converting the resulting flow diagram in a one-to-one fashion into an outline, o Exit codes ---------- There is one more property of a work chain that is specified through its process specification, in addition to its inputs, outputs and outline. -Any work chain may have one to multiple failure modes, which are modeled by :ref:`exit codes`. +Any work chain may have one to multiple failure modes, which are modelled by :ref:`exit codes`. A work chain can be stopped at any time, simply by returning an exit code from an outline method. To retrieve an exit code that is defined on the spec, one can use the :py:meth:`~aiida.engine.processes.process.Process.exit_codes` property. This returns an attribute dictionary where the exit code labels map to their corresponding exit code. @@ -587,6 +587,9 @@ Of course, we then need to explicitly pass the input ``a``. Finally, we use :meth:`~aiida.engine.processes.process.Process.exposed_outputs` and :meth:`~aiida.engine.processes.process.Process.out_many` to forward the outputs of the children to the outputs of the parent. Again, the ``namespace`` and ``agglomerate`` options can be used to select which outputs are returned by the :meth:`~aiida.engine.processes.process.Process.exposed_outputs` method. +.. seealso:: + + For further practical examples of creating workflows, see the :ref:`how to run multi-step workflows` and :ref:`how to write error resistant workflows ` sections. .. rubric:: Footnotes diff --git a/environment.yml b/environment.yml index 04411ca482..095809e39e 100644 --- a/environment.yml +++ b/environment.yml @@ -9,22 +9,23 @@ dependencies: - aldjemy~=0.9.1 - alembic~=1.2 - archive-path~=0.2.1 -- circus~=0.16.1 +- aio-pika~=6.6 +- circus~=0.17.1 - click-completion~=0.5.1 - click-config-file~=0.6.0 - click-spinner~=0.1.8 -- click~=7.0 -- dataclasses~=0.7 +- click~=7.1 - django~=2.2 - ete3~=3.1 - python-graphviz~=0.13 -- ipython~=7.0 +- ipython~=7.20 - jinja2~=2.10 -- kiwipy[rmq]~=0.5.5 +- jsonschema~=3.0 +- kiwipy[rmq]~=0.7.4 - numpy~=1.17 -- paramiko~=2.7 -- pika~=1.1 -- plumpy~=0.15.1 +- pamqp~=2.3 +- paramiko>=2.7.2,~=2.7 +- plumpy~=0.19.0 - pgsu~=0.1.0 - psutil~=5.6 - psycopg2>=2.8.3,~=2.8 @@ -33,11 +34,9 @@ dependencies: - pyyaml~=5.1.2 - reentry~=1.3 - simplejson~=3.16 -- sqlalchemy-utils~=0.34.2 -- sqlalchemy>=1.3.10,~=1.3 +- sqlalchemy-utils~=0.36.0 +- sqlalchemy~=1.3.10 - tabulate~=0.8.5 -- tornado<5.0 -- topika~=0.2.2 - tqdm~=4.45 - tzlocal~=2.0 - upf_to_json~=0.9.2 diff --git a/mypy.ini b/mypy.ini index 22d349aacf..073863bcca 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,7 +1,8 @@ # Global options [mypy] -python_version = 3.6 + +show_error_codes = True check_untyped_defs = True scripts_are_modules = True @@ -38,13 +39,16 @@ follow_imports = skip [mypy-tests.*] check_untyped_defs = False +[mypy-circus.*] +ignore_missing_imports = True + [mypy-django.*] ignore_missing_imports = True -[mypy-numpy.*] +[mypy-kiwipy.*] ignore_missing_imports = True -[mypy-plumpy.*] +[mypy-numpy.*] ignore_missing_imports = True [mypy-scipy.*] @@ -55,3 +59,6 @@ ignore_missing_imports = True [mypy-tqdm.*] ignore_missing_imports = True + +[mypy-wrapt.*] +ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml index 939a30965f..d264c55af5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] -requires = ["setuptools>=40.8.0,<50", "wheel", "reentry~=1.3", "fastentrypoints~=0.12"] -build-backend = "setuptools.build_meta:__legacy__" +requires = ["setuptools>=40.8.0", "wheel", "reentry~=1.3", "fastentrypoints~=0.12"] +build-backend = "setuptools.build_meta" [tool.pylint.master] load-plugins = "pylint_django" @@ -63,8 +63,11 @@ filterwarnings = [ "ignore::DeprecationWarning:jsonbackend:", "ignore::DeprecationWarning:reentry:", "ignore::DeprecationWarning:pkg_resources:", + "ignore::pytest.PytestCollectionWarning", + "default::ResourceWarning", ] markers = [ + "requires_rmq: requires a connection (on port 5672) to RabbitMQ", "sphinx: set parameters for the sphinx `app` fixture" ] @@ -75,27 +78,29 @@ envlist = py37-django [testenv] usedevelop=True - -[testenv:py{36,37,38,39}-{django,sqla}] deps = - py36: -rrequirements/requirements-py-3.6.txt py37: -rrequirements/requirements-py-3.7.txt py38: -rrequirements/requirements-py-3.8.txt py39: -rrequirements/requirements-py-3.9.txt + +[testenv:py{36,37,38,39}-{django,sqla}] +passenv = + PYTHONASYNCIODEBUG setenv = django: AIIDA_TEST_BACKEND = django sqla: AIIDA_TEST_BACKEND = sqlalchemy commands = pytest {posargs} +[testenv:py{36,37,38,39}-verdi] +setenv = + AIIDA_TEST_BACKEND = django + AIIDA_PATH = {toxinidir}/.tox/.aiida +commands = verdi {posargs} + [testenv:py{36,37,38,39}-docs-{clean,update}] description = clean: Build the documentation (remove any existing build) update: Build the documentation (modify any existing build) -deps = - py36: -rrequirements/requirements-py-3.6.txt - py37: -rrequirements/requirements-py-3.7.txt - py38: -rrequirements/requirements-py-3.8.txt - py38: -rrequirements/requirements-py-3.9.txt passenv = RUN_APIDOC setenv = update: RUN_APIDOC = False @@ -109,7 +114,6 @@ commands = # tip: remove apidocs before using this feature (`cd docs; make clean`) description = Build the documentation and launch browser (with live updates) deps = - py36: -rrequirements/requirements-py-3.6.txt py37: -rrequirements/requirements-py-3.7.txt py38: -rrequirements/requirements-py-3.8.txt py39: -rrequirements/requirements-py-3.9.txt @@ -124,6 +128,22 @@ commands = [testenv:py{36,37,38,39}-pre-commit] description = Run the pre-commit checks -extras = all +extras = pre-commit commands = pre-commit run {posargs} + +[testenv:molecule-{django,sqla}] +description = Run the molecule containerised tests +skip_install = true +parallel_show_output = true +deps = + ansible~=2.10.0 + docker~=4.2 + molecule[docker]~=3.1.0 +setenv = + MOLECULE_GLOB = .molecule/*/config_local.yml + django: AIIDA_TEST_BACKEND = django + sqla: AIIDA_TEST_BACKEND = sqlalchemy +passenv = + AIIDA_TEST_WORKERS +commands = molecule {posargs:test} """ diff --git a/requirements/requirements-py-3.6.txt b/requirements/requirements-py-3.6.txt deleted file mode 100644 index f3fb98a8cc..0000000000 --- a/requirements/requirements-py-3.6.txt +++ /dev/null @@ -1,170 +0,0 @@ -aiida-export-migration-tests==0.9.0 -alabaster==0.7.12 -aldjemy==0.9.1 -alembic==1.4.1 -aniso8601==8.0.0 -appdirs==1.4.4 -appnope==0.1.0 -archive-path==0.2.1 -ase==3.19.0 -attrs==19.3.0 -Babel==2.8.0 -backcall==0.1.0 -bcrypt==3.1.7 -bleach==3.1.4 -certifi==2019.11.28 -cffi==1.14.0 -chardet==3.0.4 -circus==0.16.1 -Click==7.0 -click-completion==0.5.2 -click-config-file==0.6.0 -click-spinner==0.1.8 -configobj==5.0.6 -coverage==4.5.4 -cryptography==2.8 -cycler==0.10.0 -dataclasses==0.7 -decorator==4.4.2 -defusedxml==0.6.0 -Django==2.2.11 -docutils==0.15.2 -entrypoints==0.3 -ete3==3.1.1 -flake8==3.8.3 -Flask==1.1.1 -Flask-Cors==3.0.8 -Flask-RESTful==0.3.8 -frozendict==1.2 -furl==2.1.0 -future==0.18.2 -graphviz==0.13.2 -idna==2.9 -imagesize==1.2.0 -importlib-metadata==1.5.0 -iniconfig==1.1.1 -ipykernel==5.1.4 -ipython==7.13.0 -ipython-genutils==0.2.0 -ipywidgets==7.5.1 -itsdangerous==1.1.0 -jedi==0.16.0 -Jinja2==2.11.1 -jsonschema==3.2.0 -jupyter==1.0.0 -jupyter-client==6.0.0 -jupyter-console==6.1.0 -jupyter-core==4.6.3 -kiwipy==0.5.5 -kiwisolver==1.1.0 -Mako==1.1.2 -MarkupSafe==1.1.1 -matplotlib==3.2.0 -mccabe==0.6.1 -mistune==0.8.4 -monty==3.0.2 -more-itertools==8.2.0 -mpmath==1.1.0 -nbconvert==5.6.1 -nbformat==5.0.4 -networkx==2.4 -notebook==5.7.8 -numpy==1.17.5 -orderedmultidict==1.0.1 -packaging==20.3 -palettable==3.3.0 -pandas==0.25.3 -pandocfilters==1.4.2 -paramiko==2.7.1 -parso==0.6.2 -pathspec==0.8.0 -pexpect==4.8.0 -pg8000==1.13.2 -pgsu==0.1.0 -pgtest==1.3.2 -pickleshare==0.7.5 -pika==1.1.0 -pluggy==0.13.1 -plumpy==0.15.1 -prometheus-client==0.7.1 -prompt-toolkit==3.0.4 -psutil==5.7.0 -psycopg2-binary==2.8.4 -ptyprocess==0.6.0 -py==1.9.0 -py-cpuinfo==7.0.0 -PyCifRW==4.4.1 -pycparser==2.20 -pydata-sphinx-theme==0.4.1 -PyDispatcher==2.0.5 -pyflakes==2.2.0 -Pygments==2.6.1 -pymatgen==2020.3.2 -PyMySQL==0.9.3 -PyNaCl==1.3.0 -pyparsing==2.4.6 -pyrsistent==0.15.7 -pytest==6.0.0 -pytest-benchmark==3.2.3 -pytest-cov==2.8.1 -pytest-rerunfailures==9.1.1 -pytest-timeout==1.3.4 -python-dateutil==2.8.1 -python-editor==1.0.4 -python-memcached==1.59 -pytz==2019.3 -PyYAML==5.1.2 -pyzmq==19.0.0 -qtconsole==4.7.1 -QtPy==1.9.0 -reentry==1.3.1 -regex==2020.7.14 -requests==2.23.0 -ruamel.yaml==0.16.10 -ruamel.yaml.clib==0.2.0 -scipy==1.4.1 -scramp==1.1.0 -seekpath==1.9.4 -Send2Trash==1.5.0 -shellingham==1.3.2 -shortuuid==1.0.1 -simplejson==3.17.0 -six==1.14.0 -snowballstemmer==2.0.0 -spglib==1.14.1.post0 -Sphinx==3.2.1 -sphinx-copybutton==0.3.0 -sphinx-notfound-page==0.5 -sphinx-panels==0.5.2 -sphinxcontrib-applehelp==1.0.2 -sphinxcontrib-contentui==0.2.4 -sphinxcontrib-details-directive==0.1.0 -sphinxcontrib-devhelp==1.0.2 -sphinxcontrib-htmlhelp==1.0.3 -sphinxcontrib-jsmath==1.0.1 -sphinxcontrib-qthelp==1.0.3 -sphinxcontrib-serializinghtml==1.1.4 -sphinxext-rediraffe==0.2.4 -SQLAlchemy==1.3.13 -sqlalchemy-diff==0.1.3 -SQLAlchemy-Utils==0.34.2 -sqlparse==0.3.1 -sympy==1.5.1 -tabulate==0.8.6 -terminado==0.8.3 -testpath==0.4.4 -toml==0.10.1 -topika==0.2.2 -tornado==4.5.3 -tqdm==4.45.0 -traitlets==4.3.3 -typed-ast==1.4.1 -tzlocal==2.0.0 -upf-to-json==0.9.2 -urllib3==1.25.8 -wcwidth==0.1.8 -webencodings==0.5.1 -Werkzeug==1.0.0 -widgetsnbextension==3.5.1 -wrapt==1.11.2 -zipp==3.1.0 diff --git a/requirements/requirements-py-3.7.txt b/requirements/requirements-py-3.7.txt index 2cb9e68289..b1decfa255 100644 --- a/requirements/requirements-py-3.7.txt +++ b/requirements/requirements-py-3.7.txt @@ -1,170 +1,174 @@ aiida-export-migration-tests==0.9.0 +aio-pika==6.7.1 +aiormq==3.3.1 alabaster==0.7.12 aldjemy==0.9.1 -alembic==1.4.1 -aniso8601==8.0.0 -appdirs==1.4.4 -appnope==0.1.0 +alembic==1.5.4 +aniso8601==8.1.1 archive-path==0.2.1 -ase==3.19.0 -attrs==19.3.0 -Babel==2.8.0 -backcall==0.1.0 -bcrypt==3.1.7 -bleach==3.1.4 -certifi==2019.11.28 -cffi==1.14.0 -chardet==3.0.4 -circus==0.16.1 -Click==7.0 +argon2-cffi==20.1.0 +ase==3.21.1 +async-generator==1.10 +attrs==20.3.0 +Babel==2.9.0 +backcall==0.2.0 +bcrypt==3.2.0 +bleach==3.3.0 +certifi==2020.12.5 +cffi==1.14.4 +chardet==4.0.0 +circus==0.17.1 +click==7.1.2 click-completion==0.5.2 click-config-file==0.6.0 -click-spinner==0.1.8 +click-spinner==0.1.10 configobj==5.0.6 coverage==4.5.4 -cryptography==2.8 +cryptography==3.4.3 cycler==0.10.0 decorator==4.4.2 defusedxml==0.6.0 -Django==2.2.11 +deprecation==2.1.0 +Django==2.2.18 docutils==0.15.2 entrypoints==0.3 -ete3==3.1.1 -flake8==3.8.3 -Flask==1.1.1 -Flask-Cors==3.0.8 +ete3==3.1.2 +Flask==1.1.2 +Flask-Cors==3.0.10 Flask-RESTful==0.3.8 frozendict==1.2 -furl==2.1.0 future==0.18.2 -graphviz==0.13.2 -idna==2.9 +graphviz==0.16 +idna==2.10 imagesize==1.2.0 -importlib-metadata==1.5.0 +importlib-metadata==3.4.0 iniconfig==1.1.1 -ipykernel==5.1.4 -ipython==7.13.0 +ipykernel==5.4.3 +ipython==7.20.0 ipython-genutils==0.2.0 -ipywidgets==7.5.1 +ipywidgets==7.6.3 itsdangerous==1.1.0 -jedi==0.16.0 -Jinja2==2.11.1 +jedi==0.18.0 +Jinja2==2.11.3 jsonschema==3.2.0 jupyter==1.0.0 -jupyter-client==6.0.0 -jupyter-console==6.1.0 -jupyter-core==4.6.3 -kiwipy==0.5.5 -kiwisolver==1.1.0 -Mako==1.1.2 +jupyter-client==6.1.11 +jupyter-console==6.2.0 +jupyter-core==4.7.1 +jupyterlab-pygments==0.1.2 +jupyterlab-widgets==1.0.0 +kiwipy==0.7.4 +kiwisolver==1.3.1 +Mako==1.1.4 MarkupSafe==1.1.1 -matplotlib==3.2.0 -mccabe==0.6.1 +matplotlib==3.3.4 mistune==0.8.4 -monty==3.0.2 -more-itertools==8.2.0 +monty==4.0.2 mpmath==1.1.0 -nbconvert==5.6.1 -nbformat==5.0.4 -networkx==2.4 -notebook==5.7.8 -numpy==1.17.5 -orderedmultidict==1.0.1 -packaging==20.3 +multidict==5.1.0 +nbclient==0.5.1 +nbconvert==6.0.7 +nbformat==5.1.2 +nest-asyncio==1.4.3 +networkx==2.5 +notebook==6.2.0 +numpy==1.20.1 +packaging==20.9 palettable==3.3.0 -pandas==0.25.3 -pandocfilters==1.4.2 -paramiko==2.7.1 -parso==0.6.2 -pathspec==0.8.0 +pamqp==2.3.0 +pandas==1.2.2 +pandocfilters==1.4.3 +paramiko==2.7.2 +parso==0.8.1 pexpect==4.8.0 -pg8000==1.13.2 +pg8000==1.17.0 pgsu==0.1.0 pgtest==1.3.2 pickleshare==0.7.5 -pika==1.1.0 +Pillow==8.1.0 +plotly==4.14.3 pluggy==0.13.1 -plumpy==0.15.1 -prometheus-client==0.7.1 -prompt-toolkit==3.0.4 -psutil==5.7.0 -psycopg2-binary==2.8.4 -ptyprocess==0.6.0 -py==1.9.0 +plumpy==0.19.0 +prometheus-client==0.9.0 +prompt-toolkit==3.0.14 +psutil==5.8.0 +psycopg2-binary==2.8.6 +ptyprocess==0.7.0 +py==1.10.0 py-cpuinfo==7.0.0 -PyCifRW==4.4.1 -pycodestyle==2.6.0 +PyCifRW==4.4.2 pycparser==2.20 -pydata-sphinx-theme==0.4.1 -PyDispatcher==2.0.5 -pyflakes==2.2.0 -Pygments==2.6.1 -pymatgen==2020.3.2 +pydata-sphinx-theme==0.4.3 +Pygments==2.7.4 +pymatgen==2021.2.8.1 +pympler==0.9 PyMySQL==0.9.3 -PyNaCl==1.3.0 -pyparsing==2.4.6 -pyrsistent==0.15.7 -pytest==6.0.0 +PyNaCl==1.4.0 +pyparsing==2.4.7 +pyrsistent==0.17.3 +pytest==6.2.2 +pytest-asyncio==0.14.0 pytest-benchmark==3.2.3 -pytest-cov==2.8.1 +pytest-cov==2.10.1 pytest-rerunfailures==9.1.1 -pytest-timeout==1.3.4 +pytest-timeout==1.4.2 python-dateutil==2.8.1 python-editor==1.0.4 python-memcached==1.59 +pytray==0.3.1 pytz==2019.3 PyYAML==5.1.2 -pyzmq==19.0.0 -qtconsole==4.7.1 +pyzmq==22.0.2 +qtconsole==5.0.2 QtPy==1.9.0 -reentry==1.3.1 -regex==2020.7.14 -requests==2.23.0 -ruamel.yaml==0.16.10 -ruamel.yaml.clib==0.2.0 -scipy==1.4.1 -scramp==1.1.0 -seekpath==1.9.4 +reentry==1.3.2 +requests==2.25.1 +retrying==1.3.3 +ruamel.yaml==0.16.12 +ruamel.yaml.clib==0.2.2 +scipy==1.6.0 +scramp==1.2.0 +seekpath==1.9.7 Send2Trash==1.5.0 -shellingham==1.3.2 +shellingham==1.4.0 shortuuid==1.0.1 -simplejson==3.17.0 -six==1.14.0 -snowballstemmer==2.0.0 -spglib==1.14.1.post0 +simplejson==3.17.2 +six==1.15.0 +snowballstemmer==2.1.0 +spglib==1.16.1 Sphinx==3.2.1 -sphinx-copybutton==0.3.0 -sphinx-notfound-page==0.5 +sphinx-copybutton==0.3.1 +sphinx-notfound-page==0.6 sphinx-panels==0.5.2 sphinxcontrib-applehelp==1.0.2 -sphinxcontrib-contentui==0.2.4 sphinxcontrib-details-directive==0.1.0 sphinxcontrib-devhelp==1.0.2 sphinxcontrib-htmlhelp==1.0.3 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.4 -sphinxext-rediraffe==0.2.4 -SQLAlchemy==1.3.13 +sphinxext-rediraffe==0.2.5 +SQLAlchemy==1.3.23 sqlalchemy-diff==0.1.3 -SQLAlchemy-Utils==0.34.2 -sqlparse==0.3.1 -sympy==1.5.1 -tabulate==0.8.6 -terminado==0.8.3 +SQLAlchemy-Utils==0.36.8 +sqlparse==0.4.1 +sympy==1.7.1 +tabulate==0.8.7 +terminado==0.9.2 testpath==0.4.4 -toml==0.10.1 -topika==0.2.2 -tornado==4.5.3 -tqdm==4.45.0 -traitlets==4.3.3 -typed-ast==1.4.1 -tzlocal==2.0.0 +toml==0.10.2 +tornado==6.1 +tqdm==4.56.0 +traitlets==5.0.5 +typing-extensions==3.7.4.3 +tzlocal==2.1 +uncertainties==3.1.5 upf-to-json==0.9.2 -urllib3==1.25.8 -wcwidth==0.1.8 +urllib3==1.26.3 +wcwidth==0.2.5 webencodings==0.5.1 -Werkzeug==1.0.0 +Werkzeug==1.0.1 widgetsnbextension==3.5.1 wrapt==1.11.2 -zipp==3.1.0 +yarl==1.6.3 +zipp==3.4.0 diff --git a/requirements/requirements-py-3.8.txt b/requirements/requirements-py-3.8.txt index b83043f2d6..4d2326794d 100644 --- a/requirements/requirements-py-3.8.txt +++ b/requirements/requirements-py-3.8.txt @@ -1,161 +1,171 @@ aiida-export-migration-tests==0.9.0 +aio-pika==6.7.1 +aiormq==3.3.1 alabaster==0.7.12 aldjemy==0.9.1 -alembic==1.4.1 -aniso8601==8.0.0 -appnope==0.1.0 +alembic==1.5.4 +aniso8601==8.1.1 archive-path==0.2.1 -ase==3.19.0 -attrs==19.3.0 -Babel==2.8.0 -backcall==0.1.0 -bcrypt==3.1.7 -bleach==3.1.4 -certifi==2019.11.28 -cffi==1.14.0 -chardet==3.0.4 -circus==0.16.1 -Click==7.0 +argon2-cffi==20.1.0 +ase==3.21.1 +async-generator==1.10 +attrs==20.3.0 +Babel==2.9.0 +backcall==0.2.0 +bcrypt==3.2.0 +bleach==3.3.0 +certifi==2020.12.5 +cffi==1.14.4 +chardet==4.0.0 +circus==0.17.1 +click==7.1.2 click-completion==0.5.2 click-config-file==0.6.0 -click-spinner==0.1.8 +click-spinner==0.1.10 configobj==5.0.6 coverage==4.5.4 -cryptography==2.8 +cryptography==3.4.3 cycler==0.10.0 decorator==4.4.2 defusedxml==0.6.0 -Django==2.2.11 +deprecation==2.1.0 +Django==2.2.18 docutils==0.15.2 entrypoints==0.3 -ete3==3.1.1 -Flask==1.1.1 -Flask-Cors==3.0.8 +ete3==3.1.2 +Flask==1.1.2 +Flask-Cors==3.0.10 Flask-RESTful==0.3.8 frozendict==1.2 -furl==2.1.0 future==0.18.2 -graphviz==0.13.2 -idna==2.9 +graphviz==0.16 +idna==2.10 imagesize==1.2.0 iniconfig==1.1.1 -ipykernel==5.1.4 -ipython==7.13.0 +ipykernel==5.4.3 +ipython==7.20.0 ipython-genutils==0.2.0 -ipywidgets==7.5.1 +ipywidgets==7.6.3 itsdangerous==1.1.0 -jedi==0.16.0 -Jinja2==2.11.1 +jedi==0.18.0 +Jinja2==2.11.3 jsonschema==3.2.0 jupyter==1.0.0 -jupyter-client==6.0.0 -jupyter-console==6.1.0 -jupyter-core==4.6.3 -kiwipy==0.5.5 -kiwisolver==1.1.0 -Mako==1.1.2 +jupyter-client==6.1.11 +jupyter-console==6.2.0 +jupyter-core==4.7.1 +jupyterlab-pygments==0.1.2 +jupyterlab-widgets==1.0.0 +kiwipy==0.7.4 +kiwisolver==1.3.1 +Mako==1.1.4 MarkupSafe==1.1.1 -matplotlib==3.2.0 +matplotlib==3.3.4 mistune==0.8.4 -monty==3.0.2 -more-itertools==8.2.0 +monty==4.0.2 mpmath==1.1.0 -nbconvert==5.6.1 -nbformat==5.0.4 -networkx==2.4 -notebook==5.7.8 -numpy==1.17.5 -orderedmultidict==1.0.1 -packaging==20.3 +multidict==5.1.0 +nbclient==0.5.1 +nbconvert==6.0.7 +nbformat==5.1.2 +nest-asyncio==1.4.3 +networkx==2.5 +notebook==6.2.0 +numpy==1.20.1 +packaging==20.9 palettable==3.3.0 -pandas==0.25.3 -pandocfilters==1.4.2 -paramiko==2.7.1 -parso==0.6.2 +pamqp==2.3.0 +pandas==1.2.2 +pandocfilters==1.4.3 +paramiko==2.7.2 +parso==0.8.1 pexpect==4.8.0 -pg8000==1.13.2 +pg8000==1.17.0 pgsu==0.1.0 pgtest==1.3.2 pickleshare==0.7.5 -pika==1.1.0 +Pillow==8.1.0 +plotly==4.14.3 pluggy==0.13.1 -plumpy==0.15.1 -prometheus-client==0.7.1 -prompt-toolkit==3.0.4 -psutil==5.7.0 -psycopg2-binary==2.8.4 -ptyprocess==0.6.0 -py==1.9.0 +plumpy==0.19.0 +prometheus-client==0.9.0 +prompt-toolkit==3.0.14 +psutil==5.8.0 +psycopg2-binary==2.8.6 +ptyprocess==0.7.0 +py==1.10.0 py-cpuinfo==7.0.0 -PyCifRW==4.4.1 +PyCifRW==4.4.2 pycparser==2.20 -pydata-sphinx-theme==0.4.1 -PyDispatcher==2.0.5 -Pygments==2.6.1 -pymatgen==2020.3.2 +pydata-sphinx-theme==0.4.3 +Pygments==2.7.4 +pymatgen==2021.2.8.1 +pympler==0.9 PyMySQL==0.9.3 -PyNaCl==1.3.0 -pyparsing==2.4.6 -pyrsistent==0.15.7 -pytest==6.0.0 +PyNaCl==1.4.0 +pyparsing==2.4.7 +pyrsistent==0.17.3 +pytest==6.2.2 +pytest-asyncio==0.14.0 pytest-benchmark==3.2.3 -pytest-cov==2.8.1 +pytest-cov==2.10.1 pytest-rerunfailures==9.1.1 -pytest-timeout==1.3.4 +pytest-timeout==1.4.2 python-dateutil==2.8.1 python-editor==1.0.4 python-memcached==1.59 +pytray==0.3.1 pytz==2019.3 PyYAML==5.1.2 -pyzmq==19.0.0 -qtconsole==4.7.1 +pyzmq==22.0.2 +qtconsole==5.0.2 QtPy==1.9.0 -reentry==1.3.1 -requests==2.23.0 -rope==0.17.0 -ruamel.yaml==0.16.10 -ruamel.yaml.clib==0.2.0 -scipy==1.4.1 -scramp==1.1.0 -seekpath==1.9.4 +reentry==1.3.2 +requests==2.25.1 +retrying==1.3.3 +ruamel.yaml==0.16.12 +ruamel.yaml.clib==0.2.2 +scipy==1.6.0 +scramp==1.2.0 +seekpath==1.9.7 Send2Trash==1.5.0 -shellingham==1.3.2 +shellingham==1.4.0 shortuuid==1.0.1 -simplejson==3.17.0 -six==1.14.0 -snowballstemmer==2.0.0 -spglib==1.14.1.post0 +simplejson==3.17.2 +six==1.15.0 +snowballstemmer==2.1.0 +spglib==1.16.1 Sphinx==3.2.1 -sphinx-copybutton==0.3.0 -sphinx-notfound-page==0.5 +sphinx-copybutton==0.3.1 +sphinx-notfound-page==0.6 sphinx-panels==0.5.2 sphinxcontrib-applehelp==1.0.2 -sphinxcontrib-contentui==0.2.4 sphinxcontrib-details-directive==0.1.0 sphinxcontrib-devhelp==1.0.2 sphinxcontrib-htmlhelp==1.0.3 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.4 -sphinxext-rediraffe==0.2.4 -SQLAlchemy==1.3.13 +sphinxext-rediraffe==0.2.5 +SQLAlchemy==1.3.23 sqlalchemy-diff==0.1.3 -SQLAlchemy-Utils==0.34.2 -sqlparse==0.3.1 -sympy==1.5.1 -tabulate==0.8.6 -terminado==0.8.3 +SQLAlchemy-Utils==0.36.8 +sqlparse==0.4.1 +sympy==1.7.1 +tabulate==0.8.7 +terminado==0.9.2 testpath==0.4.4 -toml==0.10.1 -topika==0.2.2 -tornado==4.5.3 -tqdm==4.45.0 -traitlets==4.3.3 -tzlocal==2.0.0 +toml==0.10.2 +tornado==6.1 +tqdm==4.56.0 +traitlets==5.0.5 +tzlocal==2.1 +uncertainties==3.1.5 upf-to-json==0.9.2 -urllib3==1.25.8 -wcwidth==0.1.8 +urllib3==1.26.3 +wcwidth==0.2.5 webencodings==0.5.1 -Werkzeug==1.0.0 +Werkzeug==1.0.1 widgetsnbextension==3.5.1 wrapt==1.11.2 +yarl==1.6.3 diff --git a/requirements/requirements-py-3.9.txt b/requirements/requirements-py-3.9.txt index 5404a60c2f..5bcba80782 100644 --- a/requirements/requirements-py-3.9.txt +++ b/requirements/requirements-py-3.9.txt @@ -1,102 +1,112 @@ aiida-export-migration-tests==0.9.0 +aio-pika==6.7.1 +aiormq==3.3.1 alabaster==0.7.12 aldjemy==0.9.1 -alembic==1.4.3 -aniso8601==8.0.0 +alembic==1.5.4 +aniso8601==8.1.1 archive-path==0.2.1 -ase==3.20.1 -attrs==20.2.0 -Babel==2.8.0 +argon2-cffi==20.1.0 +ase==3.21.1 +async-generator==1.10 +attrs==20.3.0 +Babel==2.9.0 backcall==0.2.0 bcrypt==3.2.0 -bleach==3.2.1 -certifi==2020.6.20 -cffi==1.14.3 -chardet==3.0.4 -circus==0.16.1 +bleach==3.3.0 +certifi==2020.12.5 +cffi==1.14.4 +chardet==4.0.0 +circus==0.17.1 click==7.1.2 click-completion==0.5.2 click-config-file==0.6.0 click-spinner==0.1.10 configobj==5.0.6 coverage==4.5.4 -cryptography==3.2.1 +cryptography==3.4.3 cycler==0.10.0 decorator==4.4.2 defusedxml==0.6.0 -Django==2.2.17 +deprecation==2.1.0 +Django==2.2.18 docutils==0.15.2 entrypoints==0.3 ete3==3.1.2 Flask==1.1.2 -Flask-Cors==3.0.9 +Flask-Cors==3.0.10 Flask-RESTful==0.3.8 frozendict==1.2 -furl==2.1.0 future==0.18.2 -graphviz==0.14.2 +graphviz==0.16 idna==2.10 imagesize==1.2.0 iniconfig==1.1.1 -ipykernel==5.3.4 -ipython==7.19.0 +ipykernel==5.4.3 +ipython==7.20.0 ipython-genutils==0.2.0 -ipywidgets==7.5.1 +ipywidgets==7.6.3 itsdangerous==1.1.0 -jedi==0.17.2 -Jinja2==2.11.2 +jedi==0.18.0 +Jinja2==2.11.3 jsonschema==3.2.0 jupyter==1.0.0 -jupyter-client==6.1.7 +jupyter-client==6.1.11 jupyter-console==6.2.0 -jupyter-core==4.6.3 -kiwipy==0.5.5 +jupyter-core==4.7.1 +jupyterlab-pygments==0.1.2 +jupyterlab-widgets==1.0.0 +kiwipy==0.7.4 kiwisolver==1.3.1 -Mako==1.1.3 +Mako==1.1.4 MarkupSafe==1.1.1 -matplotlib==3.3.2 +matplotlib==3.3.4 mistune==0.8.4 monty==4.0.2 mpmath==1.1.0 -nbconvert==5.6.1 -nbformat==5.0.8 +multidict==5.1.0 +nbclient==0.5.1 +nbconvert==6.0.7 +nbformat==5.1.2 +nest-asyncio==1.4.3 networkx==2.5 -notebook==5.7.10 -numpy==1.19.4 -orderedmultidict==1.0.1 -packaging==20.4 +notebook==6.2.0 +numpy==1.20.1 +packaging==20.9 palettable==3.3.0 -pandas==1.1.4 +pamqp==2.3.0 +pandas==1.2.2 pandocfilters==1.4.3 paramiko==2.7.2 -parso==0.7.1 +parso==0.8.1 pexpect==4.8.0 -pg8000==1.16.6 +pg8000==1.17.0 pgsu==0.1.0 pgtest==1.3.2 pickleshare==0.7.5 -pika==1.1.0 -Pillow==8.0.1 -plotly==4.12.0 +Pillow==8.1.0 +plotly==4.14.3 pluggy==0.13.1 -plumpy==0.15.1 -prometheus-client==0.8.0 -prompt-toolkit==3.0.8 -psutil==5.7.3 +plumpy==0.19.0 +prometheus-client==0.9.0 +prompt-toolkit==3.0.14 +psutil==5.8.0 psycopg2-binary==2.8.6 -ptyprocess==0.6.0 -py==1.9.0 +ptyprocess==0.7.0 +py==1.10.0 py-cpuinfo==7.0.0 -PyCifRW==4.4.1 +PyCifRW==4.4.2 pycparser==2.20 -pydata-sphinx-theme==0.4.1 -Pygments==2.7.2 -pymatgen==2020.10.20 +pydata-sphinx-theme==0.4.3 +Pygments==2.7.4 +pymatgen==2021.2.8.1 +pympler==0.9 PyMySQL==0.9.3 PyNaCl==1.4.0 pyparsing==2.4.7 pyrsistent==0.17.3 -pytest==6.1.2 +pytest==6.2.2 +pytest-asyncio==0.14.0 pytest-benchmark==3.2.3 pytest-cov==2.10.1 pytest-rerunfailures==9.1.1 @@ -104,28 +114,29 @@ pytest-timeout==1.4.2 python-dateutil==2.8.1 python-editor==1.0.4 python-memcached==1.59 +pytray==0.3.1 pytz==2019.3 PyYAML==5.1.2 -pyzmq==19.0.2 -qtconsole==4.7.7 +pyzmq==22.0.2 +qtconsole==5.0.2 QtPy==1.9.0 -reentry==1.3.1 -requests==2.24.0 +requests==2.25.1 +reentry==1.3.2 retrying==1.3.3 ruamel.yaml==0.16.12 -scipy==1.5.3 +scipy==1.6.0 scramp==1.2.0 seekpath==1.9.7 Send2Trash==1.5.0 -shellingham==1.3.2 +shellingham==1.4.0 shortuuid==1.0.1 simplejson==3.17.2 six==1.15.0 -snowballstemmer==2.0.0 -spglib==1.16.0 -Sphinx==3.3.0 +snowballstemmer==2.1.0 +spglib==1.16.1 +Sphinx==3.2.1 sphinx-copybutton==0.3.1 -sphinx-notfound-page==0.5 +sphinx-notfound-page==0.6 sphinx-panels==0.5.2 sphinxcontrib-applehelp==1.0.2 sphinxcontrib-details-directive==0.1.0 @@ -134,26 +145,26 @@ sphinxcontrib-htmlhelp==1.0.3 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.4 -sphinxext-rediraffe==0.2.4 -SQLAlchemy==1.3.20 +sphinxext-rediraffe==0.2.5 +SQLAlchemy==1.3.23 sqlalchemy-diff==0.1.3 -SQLAlchemy-Utils==0.34.2 +SQLAlchemy-Utils==0.36.8 sqlparse==0.4.1 -sympy==1.6.2 +sympy==1.7.1 tabulate==0.8.7 -terminado==0.9.1 +terminado==0.9.2 testpath==0.4.4 toml==0.10.2 -topika==0.2.2 -tornado==4.5.3 -tqdm==4.51.0 +tornado==6.1 +tqdm==4.56.0 traitlets==5.0.5 tzlocal==2.1 -uncertainties==3.1.4 +uncertainties==3.1.5 upf-to-json==0.9.2 -urllib3==1.25.11 +urllib3==1.26.3 wcwidth==0.2.5 webencodings==0.5.1 Werkzeug==1.0.1 widgetsnbextension==3.5.1 wrapt==1.11.2 +yarl==1.6.3 diff --git a/setup.json b/setup.json index 7e8750cc63..b70e9ce465 100644 --- a/setup.json +++ b/setup.json @@ -1,20 +1,19 @@ { "name": "aiida-core", - "version": "1.5.2", + "version": "1.6.0", "url": "http://www.aiida.net/", "license": "MIT License", "author": "The AiiDA team", "author_email": "developers@aiida.net", "description": "AiiDA is a workflow manager for computational science with a strong focus on provenance, performance and extensibility.", "include_package_data": true, - "python_requires": ">=3.6", + "python_requires": ">=3.7", "classifiers": [ "Framework :: AiiDA", "License :: OSI Approved :: MIT License", "Operating System :: POSIX :: Linux", "Operating System :: MacOS :: MacOS X", "Programming Language :: Python", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", @@ -24,22 +23,23 @@ "aldjemy~=0.9.1", "alembic~=1.2", "archive-path~=0.2.1", - "circus~=0.16.1", + "aio-pika~=6.6", + "circus~=0.17.1", "click-completion~=0.5.1", "click-config-file~=0.6.0", "click-spinner~=0.1.8", - "click~=7.0", - "dataclasses~=0.7; python_version < '3.7.0'", + "click~=7.1", "django~=2.2", "ete3~=3.1", "graphviz~=0.13", - "ipython~=7.0", + "ipython~=7.20", "jinja2~=2.10", - "kiwipy[rmq]~=0.5.5", + "jsonschema~=3.0", + "kiwipy[rmq]~=0.7.4", "numpy~=1.17", - "paramiko~=2.7", - "pika~=1.1", - "plumpy~=0.15.1", + "pamqp~=2.3", + "paramiko~=2.7,>=2.7.2", + "plumpy~=0.19.0", "pgsu~=0.1.0", "psutil~=5.6", "psycopg2-binary~=2.8,>=2.8.3", @@ -48,11 +48,9 @@ "pyyaml~=5.1.2", "reentry~=1.3", "simplejson~=3.16", - "sqlalchemy-utils~=0.34.2", - "sqlalchemy~=1.3,>=1.3.10", + "sqlalchemy-utils~=0.36.0", + "sqlalchemy~=1.3.10", "tabulate~=0.8.5", - "tornado<5.0", - "topika~=0.2.2", "tqdm~=4.45", "tzlocal~=2.0", "upf_to_json~=0.9.2", @@ -75,7 +73,7 @@ "docutils==0.15.2", "pygments~=2.5", "pydata-sphinx-theme~=0.4.0", - "sphinx~=3.2", + "sphinx~=3.2.1", "sphinxcontrib-details-directive~=0.1.0", "sphinx-panels~=0.5.0", "sphinx-copybutton~=0.3.0", @@ -85,21 +83,22 @@ "atomic_tools": [ "PyCifRW~=4.4", "ase~=3.18", - "pymatgen>=2019.7.2", + "pymatgen>=2019.7.2,<=2022.02.03,!=2019.9.7", "pymysql~=0.9.3", "seekpath~=1.9,>=1.9.3", "spglib~=1.14" ], "notebook": [ - "jupyter==1.0.0", - "notebook<6" + "jupyter~=1.0", + "notebook~=6.1,>=6.1.5" ], "pre-commit": [ + "astroid<2.5", "mypy==0.790", "packaging==20.3", "pre-commit~=2.2", "pylint~=2.5.0", - "pylint-django~=2.0", + "pylint-django>=2.0,<2.4.0", "tomlkit~=0.7.0" ], "tests": [ @@ -107,10 +106,12 @@ "pg8000~=1.13", "pgtest~=1.3,>=1.3.1", "pytest~=6.0", + "pytest-asyncio~=0.12", "pytest-timeout~=1.3", "pytest-cov~=2.7", "pytest-rerunfailures~=9.1,>=9.1.1", "pytest-benchmark~=3.2", + "pympler~=0.9", "coverage<5.0", "sqlalchemy-diff~=0.1.3" ], @@ -125,6 +126,7 @@ "runaiida=aiida.cmdline.commands.cmd_run:run" ], "aiida.calculations": [ + "core.transfer = aiida.calculations.transfer:TransferCalculation", "arithmetic.add = aiida.calculations.arithmetic.add:ArithmeticAddCalculation", "templatereplacer = aiida.calculations.templatereplacer:TemplatereplacerCalculation" ], @@ -161,7 +163,9 @@ "list = aiida.orm.nodes.data.list:List", "numeric = aiida.orm.nodes.data.numeric:NumericType", "orbital = aiida.orm.nodes.data.orbital:OrbitalData", - "remote = aiida.orm.nodes.data.remote:RemoteData", + "remote = aiida.orm.nodes.data.remote.base:RemoteData", + "remote.stash = aiida.orm.nodes.data.remote.stash.base:RemoteStashData", + "remote.stash.folder = aiida.orm.nodes.data.remote.stash.folder:RemoteStashFolderData", "singlefile = aiida.orm.nodes.data.singlefile:SinglefileData", "str = aiida.orm.nodes.data.str:Str", "structure = aiida.orm.nodes.data.structure:StructureData", diff --git a/tests/backends/aiida_sqlalchemy/test_migrations.py b/tests/backends/aiida_sqlalchemy/test_migrations.py index 6b621b4763..2bbb3e3c2f 100644 --- a/tests/backends/aiida_sqlalchemy/test_migrations.py +++ b/tests/backends/aiida_sqlalchemy/test_migrations.py @@ -120,7 +120,7 @@ def _reset_database_and_schema(self): It is important to also reset the database content to avoid hanging of tests. """ - self.reset_database() + self.clean_db() self.migrate_db_up('head') @property diff --git a/tests/benchmark/test_engine.py b/tests/benchmark/test_engine.py index 77009e5c96..ef9b996e19 100644 --- a/tests/benchmark/test_engine.py +++ b/tests/benchmark/test_engine.py @@ -13,12 +13,11 @@ The purpose of these tests is to benchmark and compare processes, which are executed *via* both a local runner and the daemon. """ -import datetime +import asyncio -from tornado import gen import pytest -from aiida.engine import run_get_node, submit, ToContext, while_, WorkChain +from aiida.engine import run_get_node, submit, while_, WorkChain from aiida.manage.manager import get_manager from aiida.orm import Code, Int from aiida.plugins.factories import CalculationFactory @@ -55,7 +54,7 @@ class WorkchainLoopWcSerial(WorkchainLoop): def run_task(self): future = self.submit(WorkchainLoop, iterations=Int(1)) - return ToContext(**{f'wkchain{str(self.ctx.counter)}': future}) + return self.to_context(**{f'wkchain{str(self.ctx.counter)}': future}) class WorkchainLoopWcThreaded(WorkchainLoop): @@ -71,7 +70,7 @@ def run_task(self): f'wkchain{str(i)}': self.submit(WorkchainLoop, iterations=Int(1)) for i in range(self.inputs.iterations.value) } - return ToContext(**context) + return self.to_context(**context) class WorkchainLoopCalcSerial(WorkchainLoop): @@ -84,7 +83,7 @@ def run_task(self): 'code': self.inputs.code, } future = self.submit(ArithmeticAddCalculation, **inputs) - return ToContext(addition=future) + return self.to_context(addition=future) class WorkchainLoopCalcThreaded(WorkchainLoop): @@ -103,7 +102,7 @@ def run_task(self): 'code': self.inputs.code, } futures[f'addition{str(i)}'] = self.submit(ArithmeticAddCalculation, **inputs) - return ToContext(**futures) + return self.to_context(**futures) WORKCHAINS = { @@ -131,17 +130,15 @@ def _run(): assert len(result.node.get_outgoing().all()) == outgoing -@gen.coroutine -def with_timeout(what, timeout=60): - """Coroutine return with timeout.""" - raise gen.Return((yield gen.with_timeout(datetime.timedelta(seconds=timeout), what))) +async def with_timeout(what, timeout=60): + result = await asyncio.wait_for(what, timeout) + return result -@gen.coroutine -def wait_for_process(runner, calc_node, timeout=60): - """Coroutine block with timeout.""" +async def wait_for_process(runner, calc_node, timeout=60): future = runner.get_process_future(calc_node.pk) - raise gen.Return((yield with_timeout(future, timeout))) + result = await with_timeout(future, timeout) + return result @pytest.fixture() @@ -159,13 +156,12 @@ def submit_get_node(): def _submit(_process, timeout=60, **kwargs): - @gen.coroutine - def _do_submit(): + async def _do_submit(): node = submit(_process, **kwargs) - yield wait_for_process(runner, node) + await wait_for_process(runner, node) return node - result = runner.loop.run_sync(_do_submit, timeout=timeout) + result = runner.loop.run_until_complete(_do_submit()) return result diff --git a/tests/calculations/arithmetic/test_add.py b/tests/calculations/arithmetic/test_add.py index f976945d75..af4e479716 100644 --- a/tests/calculations/arithmetic/test_add.py +++ b/tests/calculations/arithmetic/test_add.py @@ -15,6 +15,7 @@ from aiida.calculations.arithmetic.add import ArithmeticAddCalculation +@pytest.mark.requires_rmq @pytest.mark.usefixtures('clear_database_before_test') def test_add_default(fixture_sandbox, aiida_localhost, generate_calc_job): """Test a default `ArithmeticAddCalculation`.""" @@ -43,6 +44,7 @@ def test_add_default(fixture_sandbox, aiida_localhost, generate_calc_job): assert input_written == f"echo $(({inputs['x'].value} + {inputs['y'].value}))\n" +@pytest.mark.requires_rmq @pytest.mark.usefixtures('clear_database_before_test') def test_add_custom_filenames(fixture_sandbox, aiida_localhost, generate_calc_job): """Test an `ArithmeticAddCalculation` with non-default input and output filenames.""" diff --git a/tests/calculations/test_templatereplacer.py b/tests/calculations/test_templatereplacer.py index 3330a3ae46..61e5d85046 100644 --- a/tests/calculations/test_templatereplacer.py +++ b/tests/calculations/test_templatereplacer.py @@ -15,6 +15,7 @@ from aiida.common import datastructures +@pytest.mark.requires_rmq @pytest.mark.usefixtures('clear_database_before_test') def test_base_template(fixture_sandbox, aiida_localhost, generate_calc_job): """Test a base template that emulates the arithmetic add.""" @@ -70,6 +71,7 @@ def test_base_template(fixture_sandbox, aiida_localhost, generate_calc_job): assert input_written == f"echo $(({inputs['parameters']['x']} + {inputs['parameters']['y']}))" +@pytest.mark.requires_rmq @pytest.mark.usefixtures('clear_database_before_test') def test_file_usage(fixture_sandbox, aiida_localhost, generate_calc_job): """Test a base template that uses two files.""" diff --git a/tests/calculations/test_transfer.py b/tests/calculations/test_transfer.py new file mode 100644 index 0000000000..63650e40c0 --- /dev/null +++ b/tests/calculations/test_transfer.py @@ -0,0 +1,251 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for the `TransferCalculation` plugin.""" +import os +import pytest + +from aiida import orm +from aiida.common import datastructures + + +@pytest.mark.requires_rmq +@pytest.mark.usefixtures('clear_database_before_test') +def test_get_transfer(fixture_sandbox, aiida_localhost, generate_calc_job, tmp_path): + """Test a default `TransferCalculation`.""" + + file1 = tmp_path / 'file1.txt' + file1.write_text('file 1 content') + folder = tmp_path / 'folder' + folder.mkdir() + file2 = folder / 'file2.txt' + file2.write_text('file 2 content') + data_source = orm.RemoteData(computer=aiida_localhost, remote_path=str(tmp_path)) + + entry_point_name = 'core.transfer' + list_of_files = [ + ('data_source', 'file1.txt', 'folder/file1.txt'), + ('data_source', 'folder/file2.txt', 'file2.txt'), + ] + list_of_nodes = {'data_source': data_source} + instructions = orm.Dict(dict={'retrieve_files': True, 'symlink_files': list_of_files}) + inputs = {'instructions': instructions, 'source_nodes': list_of_nodes, 'metadata': {'computer': aiida_localhost}} + + # Generate calc_info and verify basics + calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs) + assert isinstance(calc_info, datastructures.CalcInfo) + assert isinstance(calc_info.codes_info, list) + assert len(calc_info.codes_info) == 0 + assert calc_info.skip_submit + + # Check that the lists were set correctly + copy_list = [ + (aiida_localhost.uuid, os.path.join(data_source.get_remote_path(), 'file1.txt'), 'folder/file1.txt'), + (aiida_localhost.uuid, os.path.join(data_source.get_remote_path(), 'folder/file2.txt'), 'file2.txt'), + ] + retrieve_list = [('folder/file1.txt'), ('file2.txt')] + assert sorted(calc_info.remote_symlink_list) == sorted(copy_list) + assert sorted(calc_info.remote_copy_list) == sorted(list()) + assert sorted(calc_info.local_copy_list) == sorted(list()) + assert sorted(calc_info.retrieve_list) == sorted(retrieve_list) + + # Now without symlinks + instructions = orm.Dict(dict={'retrieve_files': True, 'remote_files': list_of_files}) + inputs = {'instructions': instructions, 'source_nodes': list_of_nodes, 'metadata': {'computer': aiida_localhost}} + calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs) + assert sorted(calc_info.remote_symlink_list) == sorted(list()) + assert sorted(calc_info.remote_copy_list) == sorted(copy_list) + assert sorted(calc_info.local_copy_list) == sorted(list()) + assert sorted(calc_info.retrieve_list) == sorted(retrieve_list) + + +@pytest.mark.requires_rmq +@pytest.mark.usefixtures('clear_database_before_test') +def test_put_transfer(fixture_sandbox, aiida_localhost, generate_calc_job, tmp_path): + """Test a default `TransferCalculation`.""" + + file1 = tmp_path / 'file1.txt' + file1.write_text('file 1 content') + folder = tmp_path / 'folder' + folder.mkdir() + file2 = folder / 'file2.txt' + file2.write_text('file 2 content') + data_source = orm.FolderData(tree=str(tmp_path)) + + entry_point_name = 'core.transfer' + list_of_files = [ + ('data_source', 'file1.txt', 'folder/file1.txt'), + ('data_source', 'folder/file2.txt', 'file2.txt'), + ] + list_of_nodes = {'data_source': data_source} + instructions = orm.Dict(dict={'retrieve_files': False, 'local_files': list_of_files}) + inputs = {'instructions': instructions, 'source_nodes': list_of_nodes, 'metadata': {'computer': aiida_localhost}} + + # Generate calc_info and verify basics + calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs) + assert isinstance(calc_info, datastructures.CalcInfo) + assert isinstance(calc_info.codes_info, list) + assert len(calc_info.codes_info) == 0 + assert calc_info.skip_submit + + # Check that the lists were set correctly + copy_list = [ + (data_source.uuid, 'file1.txt', 'folder/file1.txt'), + (data_source.uuid, 'folder/file2.txt', 'file2.txt'), + ] + assert sorted(calc_info.remote_symlink_list) == sorted(list()) + assert sorted(calc_info.remote_copy_list) == sorted(list()) + assert sorted(calc_info.local_copy_list) == sorted(copy_list) + assert sorted(calc_info.retrieve_list) == sorted(list()) + + +def test_validate_instructions(): + """Test the `TransferCalculation` validators.""" + from aiida.calculations.transfer import validate_instructions + + instructions = orm.Dict(dict={}).store() + result = validate_instructions(instructions, None) + expected = ( + '\n\nno indication of what to do in the instruction node:\n' + f' > {instructions.uuid}\n' + '(to store the files in the repository set retrieve_files=True,\n' + 'to copy them to the specified folder on the remote computer,\n' + 'set it to False)\n' + ) + assert result == expected + + instructions = orm.Dict(dict={'retrieve_files': 12}).store() + result = validate_instructions(instructions, None) + expected = ( + 'entry for retrieve files inside of instruction node:\n' + f' > {instructions.uuid}\n' + 'must be either True or False; instead, it is:\n > 12\n' + ) + assert result == expected + + instructions = orm.Dict(dict={'retrieve_files': True}).store() + result = validate_instructions(instructions, None) + expected = ( + 'no indication of which files to copy were found in the instruction node:\n' + f' > {instructions.uuid}\n' + 'Please include at least one of `local_files`, `remote_files`, or `symlink_files`.\n' + 'These should be lists containing 3-tuples with the following format:\n' + ' (source_node_key, source_relpath, target_relpath)\n' + ) + assert result == expected + + +def test_validate_transfer_inputs(aiida_localhost, tmp_path, temp_dir): + """Test the `TransferCalculation` validators.""" + from aiida.orm import Computer + from aiida.calculations.transfer import check_node_type, validate_transfer_inputs + + fake_localhost = Computer( + label='localhost-fake', + description='extra localhost computer set up by test', + hostname='localhost-fake', + workdir=temp_dir, + transport_type='local', + scheduler_type='direct' + ) + fake_localhost.store() + fake_localhost.set_minimum_job_poll_interval(0.) + fake_localhost.configure() + + inputs = { + 'source_nodes': { + 'unused_node': orm.RemoteData(computer=aiida_localhost, remote_path=str(tmp_path)), + }, + 'instructions': + orm.Dict( + dict={ + 'local_files': [('inexistent_node', None, None)], + 'remote_files': [('inexistent_node', None, None)], + 'symlink_files': [('inexistent_node', None, None)], + } + ), + 'metadata': { + 'computer': fake_localhost + }, + } + expected_list = [] + expected_list.append(( + f' > remote node `unused_node` points to computer `{aiida_localhost}`, ' + f'not the one being used (`{fake_localhost}`)' + )) + expected_list.append(check_node_type('local_files', 'inexistent_node', None, orm.FolderData)) + expected_list.append(check_node_type('remote_files', 'inexistent_node', None, orm.RemoteData)) + expected_list.append(check_node_type('symlink_files', 'inexistent_node', None, orm.RemoteData)) + expected_list.append(' > node `unused_node` provided as inputs is not being used') + + expected = '\n\n' + for addition in expected_list: + expected = expected + addition + '\n' + + result = validate_transfer_inputs(inputs, None) + assert result == expected + + result = check_node_type('list_name', 'node_label', None, orm.RemoteData) + expected = ' > node `node_label` requested on list `list_name` not found among inputs' + assert result == expected + + result = check_node_type('list_name', 'node_label', orm.FolderData(), orm.RemoteData) + expected_type = orm.RemoteData.class_node_type + expected = f' > node `node_label`, requested on list `list_name` should be of type `{expected_type}`' + assert result == expected + + +@pytest.mark.requires_rmq +def test_integration_transfer(aiida_localhost, tmp_path): + """Test a default `TransferCalculation`.""" + from aiida.calculations.transfer import TransferCalculation + from aiida.engine import run + + content_local = 'Content of local file' + srcfile_local = tmp_path / 'file_local.txt' + srcfile_local.write_text(content_local) + srcnode_local = orm.FolderData(tree=str(tmp_path)) + + content_remote = 'Content of remote file' + srcfile_remote = tmp_path / 'file_remote.txt' + srcfile_remote.write_text(content_remote) + srcnode_remote = orm.RemoteData(computer=aiida_localhost, remote_path=str(tmp_path)) + + list_of_nodes = {} + list_of_nodes['source_local'] = srcnode_local + list_for_local = [('source_local', 'file_local.txt', 'file_local.txt')] + list_of_nodes['source_remote'] = srcnode_remote + list_for_remote = [('source_remote', 'file_remote.txt', 'file_remote.txt')] + + instructions = orm.Dict( + dict={ + 'retrieve_files': True, + 'local_files': list_for_local, + 'remote_files': list_for_remote, + } + ) + inputs = {'instructions': instructions, 'source_nodes': list_of_nodes, 'metadata': {'computer': aiida_localhost}} + + output_nodes = run(TransferCalculation, **inputs) + + output_remotedir = output_nodes['remote_folder'] + output_retrieved = output_nodes['retrieved'] + + # Check the retrieved folder + assert sorted(output_retrieved.list_object_names()) == sorted(['file_local.txt', 'file_remote.txt']) + assert output_retrieved.get_object_content('file_local.txt') == content_local + assert output_retrieved.get_object_content('file_remote.txt') == content_remote + + # Check the remote folder + assert 'file_local.txt' in output_remotedir.listdir() + assert 'file_remote.txt' in output_remotedir.listdir() + output_remotedir.getfile(relpath='file_local.txt', destpath=str(tmp_path / 'retrieved_local.txt')) + output_remotedir.getfile(relpath='file_remote.txt', destpath=str(tmp_path / 'retrieved_remote.txt')) + assert (tmp_path / 'retrieved_local.txt').read_text() == content_local + assert (tmp_path / 'retrieved_remote.txt').read_text() == content_remote diff --git a/tests/cmdline/commands/test_export.py b/tests/cmdline/commands/test_archive_export.py similarity index 87% rename from tests/cmdline/commands/test_export.py rename to tests/cmdline/commands/test_archive_export.py index 4b2cfa16bf..3f0e9bdb6a 100644 --- a/tests/cmdline/commands/test_export.py +++ b/tests/cmdline/commands/test_archive_export.py @@ -19,7 +19,7 @@ from click.testing import CliRunner from aiida.backends.testbase import AiidaTestCase -from aiida.cmdline.commands import cmd_export +from aiida.cmdline.commands import cmd_archive from aiida.tools.importexport import EXPORT_VERSION, ReaderJsonZip from tests.utils.archives import get_archive_file @@ -40,12 +40,20 @@ def delete_temporary_file(filepath): pass +def test_cmd_export_deprecation(): + """Test that the deprecated `verdi export` commands can still be called.""" + from aiida.cmdline.commands import cmd_export + for command in [cmd_export.inspect, cmd_export.create, cmd_export.migrate]: + result = CliRunner().invoke(command, '--help') + assert result.exit_code == 0 + + class TestVerdiExport(AiidaTestCase): """Tests for `verdi export`.""" @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) + def setUpClass(cls): + super().setUpClass() from aiida import orm cls.computer = orm.Computer( @@ -68,7 +76,7 @@ def setUpClass(cls, *args, **kwargs): cls.penultimate_archive = 'export_v0.6_simple.aiida' @classmethod - def tearDownClass(cls, *args, **kwargs): + def tearDownClass(cls): os.chdir(cls.old_cwd) shutil.rmtree(cls.cwd, ignore_errors=True) @@ -79,7 +87,7 @@ def test_create_file_already_exists(self): """Test that using a file that already exists, which is the case when using NamedTemporaryFile, will raise.""" with tempfile.NamedTemporaryFile() as handle: options = [handle.name] - result = self.cli_runner.invoke(cmd_export.create, options) + result = self.cli_runner.invoke(cmd_archive.create, options) self.assertIsNotNone(result.exception) def test_create_force(self): @@ -89,11 +97,11 @@ def test_create_force(self): """ with tempfile.NamedTemporaryFile() as handle: options = ['-f', handle.name] - result = self.cli_runner.invoke(cmd_export.create, options) + result = self.cli_runner.invoke(cmd_archive.create, options) self.assertIsNone(result.exception, result.output) options = ['--force', handle.name] - result = self.cli_runner.invoke(cmd_export.create, options) + result = self.cli_runner.invoke(cmd_archive.create, options) self.assertIsNone(result.exception, result.output) def test_create_zip(self): @@ -104,7 +112,7 @@ def test_create_zip(self): '-X', self.code.pk, '-Y', self.computer.pk, '-G', self.group.pk, '-N', self.node.pk, '-F', 'zip', filename ] - result = self.cli_runner.invoke(cmd_export.create, options) + result = self.cli_runner.invoke(cmd_archive.create, options) self.assertIsNone(result.exception, ''.join(traceback.format_exception(*result.exc_info))) self.assertTrue(os.path.isfile(filename)) self.assertFalse(zipfile.ZipFile(filename).testzip(), None) @@ -119,7 +127,7 @@ def test_create_zip_uncompressed(self): '-X', self.code.pk, '-Y', self.computer.pk, '-G', self.group.pk, '-N', self.node.pk, '-F', 'zip-uncompressed', filename ] - result = self.cli_runner.invoke(cmd_export.create, options) + result = self.cli_runner.invoke(cmd_archive.create, options) self.assertIsNone(result.exception, ''.join(traceback.format_exception(*result.exc_info))) self.assertTrue(os.path.isfile(filename)) self.assertFalse(zipfile.ZipFile(filename).testzip(), None) @@ -134,7 +142,7 @@ def test_create_tar_gz(self): '-X', self.code.pk, '-Y', self.computer.pk, '-G', self.group.pk, '-N', self.node.pk, '-F', 'tar.gz', filename ] - result = self.cli_runner.invoke(cmd_export.create, options) + result = self.cli_runner.invoke(cmd_archive.create, options) self.assertIsNone(result.exception, ''.join(traceback.format_exception(*result.exc_info))) self.assertTrue(os.path.isfile(filename)) self.assertTrue(tarfile.is_tarfile(filename)) @@ -154,7 +162,7 @@ def test_migrate_versions_old(self): try: options = ['--verbosity', 'DEBUG', filename_input, filename_output] - result = self.cli_runner.invoke(cmd_export.migrate, options) + result = self.cli_runner.invoke(cmd_archive.migrate, options) self.assertIsNone(result.exception, result.output) self.assertTrue(os.path.isfile(filename_output)) self.assertEqual(zipfile.ZipFile(filename_output).testzip(), None) @@ -171,7 +179,7 @@ def test_migrate_version_specific(self): try: options = [filename_input, filename_output, '--version', target_version] - result = self.cli_runner.invoke(cmd_export.migrate, options) + result = self.cli_runner.invoke(cmd_archive.migrate, options) self.assertIsNone(result.exception, result.output) self.assertTrue(os.path.isfile(filename_output)) self.assertEqual(zipfile.ZipFile(filename_output).testzip(), None) @@ -188,7 +196,7 @@ def test_migrate_force(self): # Using the context manager will create the file and so the command should fail with tempfile.NamedTemporaryFile() as file_output: options = [filename_input, file_output.name] - result = self.cli_runner.invoke(cmd_export.migrate, options) + result = self.cli_runner.invoke(cmd_archive.migrate, options) self.assertIsNotNone(result.exception) for option in ['-f', '--force']: @@ -196,7 +204,7 @@ def test_migrate_force(self): with tempfile.NamedTemporaryFile() as file_output: filename_output = file_output.name options = [option, filename_input, filename_output] - result = self.cli_runner.invoke(cmd_export.migrate, options) + result = self.cli_runner.invoke(cmd_archive.migrate, options) self.assertIsNone(result.exception, result.output) self.assertTrue(os.path.isfile(filename_output)) self.assertEqual(zipfile.ZipFile(filename_output).testzip(), None) @@ -214,17 +222,17 @@ def test_migrate_in_place(self): # specifying both output and in-place should except options = [filename_tmp, '--in-place', '--output-file', 'test.aiida'] - result = self.cli_runner.invoke(cmd_export.migrate, options) + result = self.cli_runner.invoke(cmd_archive.migrate, options) self.assertIsNotNone(result.exception, result.output) # specifying neither output nor in-place should except options = [filename_tmp] - result = self.cli_runner.invoke(cmd_export.migrate, options) + result = self.cli_runner.invoke(cmd_archive.migrate, options) self.assertIsNotNone(result.exception, result.output) # check that in-place migration produces a valid archive in place of the old file options = [filename_tmp, '--in-place', '--version', target_version] - result = self.cli_runner.invoke(cmd_export.migrate, options) + result = self.cli_runner.invoke(cmd_archive.migrate, options) self.assertIsNone(result.exception, result.output) self.assertTrue(os.path.isfile(filename_tmp)) # check that files in zip file are ok @@ -244,7 +252,7 @@ def test_migrate_low_verbosity(self): for option in ['--verbosity']: try: options = [option, 'WARNING', filename_input, filename_output] - result = self.cli_runner.invoke(cmd_export.migrate, options) + result = self.cli_runner.invoke(cmd_archive.migrate, options) self.assertEqual(result.output, '') self.assertIsNone(result.exception, result.output) self.assertTrue(os.path.isfile(filename_output)) @@ -260,7 +268,7 @@ def test_migrate_tar_gz(self): for option in ['-F', '--archive-format']: try: options = [option, 'tar.gz', filename_input, filename_output] - result = self.cli_runner.invoke(cmd_export.migrate, options) + result = self.cli_runner.invoke(cmd_archive.migrate, options) self.assertIsNone(result.exception, result.output) self.assertTrue(os.path.isfile(filename_output)) self.assertTrue(tarfile.is_tarfile(filename_output)) @@ -280,12 +288,12 @@ def test_inspect(self): # Testing the options that will print the meta data and data respectively for option in ['-m', '-d']: options = [option, filename_input] - result = self.cli_runner.invoke(cmd_export.inspect, options) + result = self.cli_runner.invoke(cmd_archive.inspect, options) self.assertIsNone(result.exception, result.output) # Test the --version option which should print the archive format version options = ['--version', filename_input] - result = self.cli_runner.invoke(cmd_export.inspect, options) + result = self.cli_runner.invoke(cmd_archive.inspect, options) self.assertIsNone(result.exception, result.output) self.assertEqual(result.output.strip()[-len(version_number):], version_number) @@ -294,6 +302,6 @@ def test_inspect_empty_archive(self): filename_input = get_archive_file('empty.aiida', filepath=self.fixture_archive) options = [filename_input] - result = self.cli_runner.invoke(cmd_export.inspect, options) + result = self.cli_runner.invoke(cmd_archive.inspect, options) self.assertIsNotNone(result.exception, result.output) self.assertIn('corrupt archive', result.output) diff --git a/tests/cmdline/commands/test_import.py b/tests/cmdline/commands/test_archive_import.py similarity index 87% rename from tests/cmdline/commands/test_import.py rename to tests/cmdline/commands/test_archive_import.py index d14ec96f06..7523a0cacf 100644 --- a/tests/cmdline/commands/test_import.py +++ b/tests/cmdline/commands/test_archive_import.py @@ -14,13 +14,20 @@ import pytest from aiida.backends.testbase import AiidaTestCase -from aiida.cmdline.commands import cmd_import +from aiida.cmdline.commands import cmd_archive from aiida.orm import Group from aiida.tools.importexport import EXPORT_VERSION from tests.utils.archives import get_archive_file +def test_cmd_import_deprecation(): + """Test that the deprecated `verdi import` command can still be called.""" + from aiida.cmdline.commands import cmd_import + result = CliRunner().invoke(cmd_import.cmd_import, '--help') + assert result.exit_code == 0 + + class TestVerdiImport(AiidaTestCase): """Tests for `verdi import`.""" @@ -40,7 +47,7 @@ def setUp(self): def test_import_no_archives(self): """Test that passing no valid archives will lead to command failure.""" options = [] - result = self.cli_runner.invoke(cmd_import.cmd_import, options) + result = self.cli_runner.invoke(cmd_archive.import_archive, options) self.assertIsNotNone(result.exception, result.output) self.assertIn('Critical', result.output) @@ -49,7 +56,7 @@ def test_import_no_archives(self): def test_import_non_existing_archives(self): """Test that passing a non-existing archive will lead to command failure.""" options = ['non-existing-archive.aiida'] - result = self.cli_runner.invoke(cmd_import.cmd_import, options) + result = self.cli_runner.invoke(cmd_archive.import_archive, options) self.assertIsNotNone(result.exception, result.output) self.assertNotEqual(result.exit_code, 0, result.output) @@ -64,7 +71,7 @@ def test_import_archive(self): ] options = [] + archives - result = self.cli_runner.invoke(cmd_import.cmd_import, options) + result = self.cli_runner.invoke(cmd_archive.import_archive, options) self.assertIsNone(result.exception, result.output) self.assertEqual(result.exit_code, 0, result.output) @@ -86,7 +93,7 @@ def test_import_to_group(self): # Invoke `verdi import`, making sure there are no exceptions options = ['-G', group.label] + [archives[0]] - result = self.cli_runner.invoke(cmd_import.cmd_import, options) + result = self.cli_runner.invoke(cmd_archive.import_archive, options) self.assertIsNone(result.exception, msg=result.output) self.assertEqual(result.exit_code, 0, msg=result.output) @@ -96,7 +103,7 @@ def test_import_to_group(self): # Invoke `verdi import` again, making sure Group count doesn't change options = ['-G', group.label] + [archives[0]] - result = self.cli_runner.invoke(cmd_import.cmd_import, options) + result = self.cli_runner.invoke(cmd_archive.import_archive, options) self.assertIsNone(result.exception, msg=result.output) self.assertEqual(result.exit_code, 0, msg=result.output) @@ -108,7 +115,7 @@ def test_import_to_group(self): # Invoke `verdi import` again with new archive, making sure Group count is upped options = ['-G', group.label] + [archives[1]] - result = self.cli_runner.invoke(cmd_import.cmd_import, options) + result = self.cli_runner.invoke(cmd_archive.import_archive, options) self.assertIsNone(result.exception, msg=result.output) self.assertEqual(result.exit_code, 0, msg=result.output) @@ -134,7 +141,7 @@ def test_import_make_new_group(self): # Invoke `verdi import`, making sure there are no exceptions options = ['-G', group_label] + archives - result = self.cli_runner.invoke(cmd_import.cmd_import, options) + result = self.cli_runner.invoke(cmd_archive.import_archive, options) self.assertIsNone(result.exception, msg=result.output) self.assertEqual(result.exit_code, 0, msg=result.output) @@ -143,7 +150,7 @@ def test_import_make_new_group(self): self.assertFalse(new_group, msg='The Group should not have been created now, but instead when it was imported.') self.assertFalse(group.is_empty, msg='The Group should not be empty.') - @pytest.mark.skip('Due to summary being logged, this can not be checked against `results.output`.') + @pytest.mark.skip('Due to summary being logged, this can not be checked against `results.output`.') # pylint: disable=not-callable def test_comment_mode(self): """Test toggling comment mode flag""" import re @@ -151,7 +158,7 @@ def test_comment_mode(self): for mode in ['newest', 'overwrite']: options = ['--comment-mode', mode] + archives - result = self.cli_runner.invoke(cmd_import.cmd_import, options) + result = self.cli_runner.invoke(cmd_archive.import_archive, options) self.assertIsNone(result.exception, result.output) self.assertTrue( any([re.fullmatch(r'Comment rules[\s]*{}'.format(mode), line) for line in result.output.split('\n')]), @@ -169,7 +176,7 @@ def test_import_old_local_archives(self): for archive, version in archives: options = [get_archive_file(archive, filepath=self.archive_path)] - result = self.cli_runner.invoke(cmd_import.cmd_import, options) + result = self.cli_runner.invoke(cmd_archive.import_archive, options) self.assertIsNone(result.exception, msg=result.output) self.assertEqual(result.exit_code, 0, msg=result.output) @@ -184,7 +191,7 @@ def test_import_old_url_archives(self): version = '0.3' options = [self.url_path + archive] - result = self.cli_runner.invoke(cmd_import.cmd_import, options) + result = self.cli_runner.invoke(cmd_archive.import_archive, options) self.assertIsNone(result.exception, msg=result.output) self.assertEqual(result.exit_code, 0, msg=result.output) @@ -200,7 +207,7 @@ def test_import_url_and_local_archives(self): get_archive_file(local_archive, filepath=self.archive_path), self.url_path + url_archive, get_archive_file(local_archive, filepath=self.archive_path) ] - result = self.cli_runner.invoke(cmd_import.cmd_import, options) + result = self.cli_runner.invoke(cmd_archive.import_archive, options) self.assertIsNone(result.exception, result.output) self.assertEqual(result.exit_code, 0, result.output) @@ -222,7 +229,7 @@ def test_raise_malformed_url(self): """Test the correct error is raised when supplying a malformed URL""" malformed_url = 'htp://www.aiida.net' - result = self.cli_runner.invoke(cmd_import.cmd_import, [malformed_url]) + result = self.cli_runner.invoke(cmd_archive.import_archive, [malformed_url]) self.assertIsNotNone(result.exception, result.output) self.assertNotEqual(result.exit_code, 0, result.output) @@ -243,7 +250,7 @@ def test_non_interactive_and_migration(self): # Import "normally", but explicitly specifying `--migration`, make sure confirm message is present # `migration` = True (default), `non_interactive` = False (default), Expected: Query user, migrate options = ['--migration', archive] - result = self.cli_runner.invoke(cmd_import.cmd_import, options) + result = self.cli_runner.invoke(cmd_archive.import_archive, options) self.assertIsNone(result.exception, msg=result.output) self.assertEqual(result.exit_code, 0, msg=result.output) @@ -254,7 +261,7 @@ def test_non_interactive_and_migration(self): # Import using non-interactive, make sure confirm message has gone # `migration` = True (default), `non_interactive` = True, Expected: No query, migrate options = ['--non-interactive', archive] - result = self.cli_runner.invoke(cmd_import.cmd_import, options) + result = self.cli_runner.invoke(cmd_archive.import_archive, options) self.assertIsNone(result.exception, msg=result.output) self.assertEqual(result.exit_code, 0, msg=result.output) @@ -264,7 +271,7 @@ def test_non_interactive_and_migration(self): # Import using `--no-migration`, make sure confirm message has gone # `migration` = False, `non_interactive` = False (default), Expected: No query, no migrate options = ['--no-migration', archive] - result = self.cli_runner.invoke(cmd_import.cmd_import, options) + result = self.cli_runner.invoke(cmd_archive.import_archive, options) self.assertIsNotNone(result.exception, msg=result.output) self.assertNotEqual(result.exit_code, 0, msg=result.output) @@ -275,7 +282,7 @@ def test_non_interactive_and_migration(self): # Import using `--no-migration` and `--non-interactive`, make sure confirm message has gone # `migration` = False, `non_interactive` = True, Expected: No query, no migrate options = ['--no-migration', '--non-interactive', archive] - result = self.cli_runner.invoke(cmd_import.cmd_import, options) + result = self.cli_runner.invoke(cmd_archive.import_archive, options) self.assertIsNotNone(result.exception, msg=result.output) self.assertNotEqual(result.exit_code, 0, msg=result.output) diff --git a/tests/cmdline/commands/test_calcjob.py b/tests/cmdline/commands/test_calcjob.py index b62485a790..c243fe7618 100644 --- a/tests/cmdline/commands/test_calcjob.py +++ b/tests/cmdline/commands/test_calcjob.py @@ -130,10 +130,13 @@ def test_calcjob_inputls(self): options = [self.arithmetic_job.uuid] result = self.cli_runner.invoke(command.calcjob_inputls, options) self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 3) + # There is also an additional fourth file added by hand to test retrieval of binary content + # see comments in test_calcjob_inputcat + self.assertEqual(len(get_result_lines(result)), 4) self.assertIn('.aiida', get_result_lines(result)) self.assertIn('aiida.in', get_result_lines(result)) self.assertIn('_aiidasubmit.sh', get_result_lines(result)) + self.assertIn('in_gzipped_data', get_result_lines(result)) options = [self.arithmetic_job.uuid, '.aiida'] result = self.cli_runner.invoke(command.calcjob_inputls, options) @@ -156,10 +159,13 @@ def test_calcjob_outputls(self): options = [self.arithmetic_job.uuid] result = self.cli_runner.invoke(command.calcjob_outputls, options) self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 3) + # There is also an additional fourth file added by hand to test retrieval of binary content + # see comments in test_calcjob_outputcat + self.assertEqual(len(get_result_lines(result)), 4) self.assertIn('_scheduler-stderr.txt', get_result_lines(result)) self.assertIn('_scheduler-stdout.txt', get_result_lines(result)) self.assertIn('aiida.out', get_result_lines(result)) + self.assertIn('gzipped_data', get_result_lines(result)) options = [self.arithmetic_job.uuid, 'non-existing-folder'] result = self.cli_runner.invoke(command.calcjob_inputls, options) @@ -186,16 +192,13 @@ def test_calcjob_inputcat(self): self.assertEqual(get_result_lines(result)[0], '2 3') # Test cat binary files - with self.arithmetic_job.open('aiida.in', 'wb') as fh_out: - fh_out.write(gzip.compress(b'COMPRESS')) - - options = [self.arithmetic_job.uuid, 'aiida.in'] + # I manually added, in the export file, in the files of the arithmetic_job, + # a file called 'in_gzipped_data' whose content has been generated with + # with open('in_gzipped_data', 'wb') as f: + # f.write(gzip.compress(b'COMPRESS-INPUT')) + options = [self.arithmetic_job.uuid, 'in_gzipped_data'] result = self.cli_runner.invoke(command.calcjob_inputcat, options) - assert gzip.decompress(result.stdout_bytes) == b'COMPRESS' - - # Replace the file - with self.arithmetic_job.open('aiida.in', 'w') as fh_out: - fh_out.write('2 3\n') + assert gzip.decompress(result.stdout_bytes) == b'COMPRESS-INPUT' def test_calcjob_outputcat(self): """Test verdi calcjob outputcat""" @@ -217,18 +220,14 @@ def test_calcjob_outputcat(self): self.assertEqual(get_result_lines(result)[0], '5') # Test cat binary files - retrieved = self.arithmetic_job.outputs.retrieved - with retrieved.open('aiida.out', 'wb') as fh_out: - fh_out.write(gzip.compress(b'COMPRESS')) - - options = [self.arithmetic_job.uuid, 'aiida.out'] + # I manually added, in the export file, in the files of the output retrieved node of the arithmetic_job, + # a file called 'gzipped_data' whose content has been generated with + # with open('gzipped_data', 'wb') as f: + # f.write(gzip.compress(b'COMPRESS')) + options = [self.arithmetic_job.uuid, 'gzipped_data'] result = self.cli_runner.invoke(command.calcjob_outputcat, options) assert gzip.decompress(result.stdout_bytes) == b'COMPRESS' - # Replace the file - with retrieved.open('aiida.out', 'w') as fh_out: - fh_out.write('5\n') - def test_calcjob_cleanworkdir(self): """Test verdi calcjob cleanworkdir""" diff --git a/tests/cmdline/commands/test_config.py b/tests/cmdline/commands/test_config.py index 5aec1ed3c5..93fb03fb46 100644 --- a/tests/cmdline/commands/test_config.py +++ b/tests/cmdline/commands/test_config.py @@ -7,25 +7,22 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for `verdi config`.""" +import pytest -from click.testing import CliRunner - -from aiida.backends.testbase import AiidaTestCase from aiida.cmdline.commands import cmd_verdi from aiida.manage.configuration import get_config -from tests.utils.configuration import with_temporary_config_instance - -class TestVerdiConfig(AiidaTestCase): - """Tests for `verdi config`.""" +class TestVerdiConfigDeprecated: + """Tests for deprecated `verdi config `.""" - def setUp(self): - self.cli_runner = CliRunner() + @pytest.fixture(autouse=True) + def setup_fixture(self, config_with_profile_factory): + config_with_profile_factory() - @with_temporary_config_instance - def test_config_set_option(self): + def test_config_set_option(self, run_cli_command): """Test the `verdi config` command when setting an option.""" config = get_config() @@ -34,67 +31,187 @@ def test_config_set_option(self): for option_value in option_values: options = ['config', option_name, str(option_value)] - result = self.cli_runner.invoke(cmd_verdi.verdi, options) + run_cli_command(cmd_verdi.verdi, options) - self.assertClickSuccess(result) - self.assertEqual(str(config.get_option(option_name, scope=config.current_profile.name)), option_value) + assert str(config.get_option(option_name, scope=config.current_profile.name)) == option_value - @with_temporary_config_instance - def test_config_get_option(self): + def test_config_get_option(self, run_cli_command): """Test the `verdi config` command when getting an option.""" option_name = 'daemon.timeout' option_value = str(30) options = ['config', option_name, option_value] - result = self.cli_runner.invoke(cmd_verdi.verdi, options) - self.assertClickSuccess(result) - self.assertClickResultNoException(result) + result = run_cli_command(cmd_verdi.verdi, options) options = ['config', option_name] - result = self.cli_runner.invoke(cmd_verdi.verdi, options) - self.assertClickSuccess(result) - self.assertIn(option_value, result.output.strip()) + result = run_cli_command(cmd_verdi.verdi, options) + + assert option_value in result.output.strip() - @with_temporary_config_instance - def test_config_unset_option(self): + def test_config_unset_option(self, run_cli_command): """Test the `verdi config` command when unsetting an option.""" option_name = 'daemon.timeout' option_value = str(30) options = ['config', option_name, str(option_value)] - result = self.cli_runner.invoke(cmd_verdi.verdi, options) - self.assertClickSuccess(result) + result = run_cli_command(cmd_verdi.verdi, options) options = ['config', option_name] - result = self.cli_runner.invoke(cmd_verdi.verdi, options) - self.assertClickSuccess(result) - self.assertIn(option_value, result.output.strip()) + result = run_cli_command(cmd_verdi.verdi, options) + + assert option_value in result.output.strip() options = ['config', option_name, '--unset'] - result = self.cli_runner.invoke(cmd_verdi.verdi, options) - self.assertClickSuccess(result) - self.assertIn(f'{option_name} unset', result.output.strip()) + result = run_cli_command(cmd_verdi.verdi, options) + + assert f"'{option_name}' unset" in result.output.strip() options = ['config', option_name] - result = self.cli_runner.invoke(cmd_verdi.verdi, options) - self.assertClickSuccess(result) - self.assertEqual(result.output, '') + result = run_cli_command(cmd_verdi.verdi, options) + + # assert result.output == '' # now has deprecation warning - @with_temporary_config_instance - def test_config_set_option_global_only(self): + def test_config_set_option_global_only(self, run_cli_command): """Test that `global_only` options are only set globally even if the `--global` flag is not set.""" config = get_config() - option_name = 'user.email' + option_name = 'autofill.user.email' option_value = 'some@email.com' options = ['config', option_name, str(option_value)] - result = self.cli_runner.invoke(cmd_verdi.verdi, options) - self.assertClickSuccess(result) + result = run_cli_command(cmd_verdi.verdi, options) options = ['config', option_name] - result = self.cli_runner.invoke(cmd_verdi.verdi, options) + result = run_cli_command(cmd_verdi.verdi, options) + + # Check that the current profile name is not in the output + + assert option_value in result.output.strip() + assert config.current_profile.name not in result.output.strip() + + +class TestVerdiConfig: + """Tests for `verdi config`.""" + + @pytest.fixture(autouse=True) + def setup_fixture(self, config_with_profile_factory): + config_with_profile_factory() + + def test_config_set_option(self, run_cli_command): + """Test the `verdi config set` command when setting an option.""" + config = get_config() + + option_name = 'daemon.timeout' + option_values = [str(10), str(20)] + + for option_value in option_values: + options = ['config', 'set', option_name, str(option_value)] + run_cli_command(cmd_verdi.verdi, options) + assert str(config.get_option(option_name, scope=config.current_profile.name)) == option_value + + def test_config_append_option(self, run_cli_command): + """Test the `verdi config set --append` command when appending an option value.""" + config = get_config() + option_name = 'caching.enabled_for' + for value in ['x', 'y']: + options = ['config', 'set', '--append', option_name, value] + run_cli_command(cmd_verdi.verdi, options) + assert config.get_option(option_name, scope=config.current_profile.name) == ['x', 'y'] + + def test_config_remove_option(self, run_cli_command): + """Test the `verdi config set --remove` command when removing an option value.""" + config = get_config() + + option_name = 'caching.disabled_for' + config.set_option(option_name, ['x', 'y'], scope=config.current_profile.name) + + options = ['config', 'set', '--remove', option_name, 'x'] + run_cli_command(cmd_verdi.verdi, options) + assert config.get_option(option_name, scope=config.current_profile.name) == ['y'] + + def test_config_get_option(self, run_cli_command): + """Test the `verdi config show` command when getting an option.""" + option_name = 'daemon.timeout' + option_value = str(30) + + options = ['config', 'set', option_name, option_value] + result = run_cli_command(cmd_verdi.verdi, options) + + options = ['config', 'get', option_name] + result = run_cli_command(cmd_verdi.verdi, options) + assert option_value in result.output.strip() + + def test_config_unset_option(self, run_cli_command): + """Test the `verdi config` command when unsetting an option.""" + option_name = 'daemon.timeout' + option_value = str(30) + + options = ['config', 'set', option_name, str(option_value)] + result = run_cli_command(cmd_verdi.verdi, options) + + options = ['config', 'get', option_name] + result = run_cli_command(cmd_verdi.verdi, options) + assert option_value in result.output.strip() + + options = ['config', 'unset', option_name] + result = run_cli_command(cmd_verdi.verdi, options) + assert f"'{option_name}' unset" in result.output.strip() + + options = ['config', 'get', option_name] + result = run_cli_command(cmd_verdi.verdi, options) + assert result.output.strip() == str(20) # back to the default + + def test_config_set_option_global_only(self, run_cli_command): + """Test that `global_only` options are only set globally even if the `--global` flag is not set.""" + config = get_config() + option_name = 'autofill.user.email' + option_value = 'some@email.com' + + options = ['config', 'set', option_name, str(option_value)] + result = run_cli_command(cmd_verdi.verdi, options) + + options = ['config', 'get', option_name] + result = run_cli_command(cmd_verdi.verdi, options) # Check that the current profile name is not in the output - self.assertClickSuccess(result) - self.assertIn(option_value, result.output.strip()) - self.assertNotIn(config.current_profile.name, result.output.strip()) + assert option_value in result.output.strip() + assert config.current_profile.name not in result.output.strip() + + def test_config_list(self, run_cli_command): + """Test `verdi config list`""" + options = ['config', 'list'] + result = run_cli_command(cmd_verdi.verdi, options) + + assert 'daemon.timeout' in result.output + assert 'Timeout in seconds' not in result.output + + def test_config_list_description(self, run_cli_command): + """Test `verdi config list --description`""" + for flag in ['-d', '--description']: + options = ['config', 'list', flag] + result = run_cli_command(cmd_verdi.verdi, options) + + assert 'daemon.timeout' in result.output + assert 'Timeout in seconds' in result.output + + def test_config_show(self, run_cli_command): + """Test `verdi config show`""" + options = ['config', 'show', 'daemon.timeout'] + result = run_cli_command(cmd_verdi.verdi, options) + assert 'schema' in result.output + + def test_config_caching(self, run_cli_command): + """Test `verdi config caching`""" + result = run_cli_command(cmd_verdi.verdi, ['config', 'caching']) + assert result.output.strip() == '' + + result = run_cli_command(cmd_verdi.verdi, ['config', 'caching', '--disabled']) + assert 'arithmetic.add' in result.output.strip() + + config = get_config() + config.set_option('caching.default_enabled', True, scope=config.current_profile.name) + + result = run_cli_command(cmd_verdi.verdi, ['config', 'caching']) + assert 'arithmetic.add' in result.output.strip() + + result = run_cli_command(cmd_verdi.verdi, ['config', 'caching', '--disabled']) + assert result.output.strip() == '' diff --git a/tests/cmdline/commands/test_data.py b/tests/cmdline/commands/test_data.py index c280f4bbc5..b27fcbcd93 100644 --- a/tests/cmdline/commands/test_data.py +++ b/tests/cmdline/commands/test_data.py @@ -10,15 +10,17 @@ # pylint: disable=no-member, too-many-lines """Test data-related verdi commands.""" +import asyncio import io import os import shutil import unittest import tempfile import subprocess as sp -import numpy as np from click.testing import CliRunner +import numpy as np +import pytest from aiida import orm from aiida.backends.testbase import AiidaTestCase @@ -228,6 +230,7 @@ def test_arrayshow(self): self.assertEqual(res.exit_code, 0, 'The command did not finish correctly') +@pytest.mark.requires_rmq class TestVerdiDataBands(AiidaTestCase, DummyVerdiDataListable): """Testing verdi data bands.""" @@ -298,8 +301,17 @@ def connect_structure_bands(strct): # pylint: disable=unused-argument @classmethod def setUpClass(cls): # pylint: disable=arguments-differ super().setUpClass() + + # create a new event loop since the privious one is closed by other test case + cls.loop = asyncio.new_event_loop() + asyncio.set_event_loop(cls.loop) cls.ids = cls.create_structure_bands() + @classmethod + def tearDownClass(cls): # pylint: disable=arguments-differ + cls.loop.close() + super().tearDownClass() + def setUp(self): self.cli_runner = CliRunner() diff --git a/tests/cmdline/commands/test_database.py b/tests/cmdline/commands/test_database.py index 4269cf6c7e..019ab40737 100644 --- a/tests/cmdline/commands/test_database.py +++ b/tests/cmdline/commands/test_database.py @@ -12,6 +12,7 @@ import enum from click.testing import CliRunner +import pytest from aiida.backends.testbase import AiidaTestCase from aiida.cmdline.commands import cmd_database @@ -48,11 +49,9 @@ def setUpClass(cls, *args, **kwargs): data_output.add_incoming(workflow_parent, link_label='output', link_type=LinkType.RETURN) def setUp(self): + self.refurbish_db() self.cli_runner = CliRunner() - def tearDown(self): - self.reset_database() - def test_detect_invalid_links_workflow_create(self): """Test `verdi database integrity detect-invalid-links` outgoing `create` from `workflow`.""" result = self.cli_runner.invoke(cmd_database.detect_invalid_links, []) @@ -170,3 +169,22 @@ def test_detect_invalid_nodes_unknown_node_type(self): result = self.cli_runner.invoke(cmd_database.detect_invalid_nodes, []) self.assertNotEqual(result.exit_code, 0) self.assertIsNotNone(result.exception) + + +@pytest.mark.usefixtures('aiida_profile') +def tests_database_version(run_cli_command, manager): + """Test the ``verdi database version`` command.""" + backend_manager = manager.get_backend_manager() + result = run_cli_command(cmd_database.database_version) + assert result.output_lines[0].endswith(backend_manager.get_schema_generation_database()) + assert result.output_lines[1].endswith(backend_manager.get_schema_version_database()) + + +@pytest.mark.usefixtures('clear_database_before_test') +def tests_database_summary(aiida_localhost, run_cli_command): + """Test the ``verdi database summary -v`` command.""" + from aiida import orm + node = orm.Dict().store() + result = run_cli_command(cmd_database.database_summary, ['--verbose']) + assert aiida_localhost.label in result.output + assert node.node_type in result.output diff --git a/tests/cmdline/commands/test_graph.py b/tests/cmdline/commands/test_graph.py index 322a6f5730..f9054943e3 100644 --- a/tests/cmdline/commands/test_graph.py +++ b/tests/cmdline/commands/test_graph.py @@ -39,7 +39,7 @@ class TestVerdiGraph(AiidaTestCase): """Tests for verdi graph""" @classmethod - def setUpClass(cls, *args, **kwargs): + def setUpClass(cls): super().setUpClass() from aiida.orm import Data @@ -52,7 +52,7 @@ def setUpClass(cls, *args, **kwargs): os.chdir(cls.cwd) @classmethod - def tearDownClass(cls, *args, **kwargs): + def tearDownClass(cls): os.chdir(cls.old_cwd) os.rmdir(cls.cwd) diff --git a/tests/cmdline/commands/test_group.py b/tests/cmdline/commands/test_group.py index a8b53370a5..db0cf51949 100644 --- a/tests/cmdline/commands/test_group.py +++ b/tests/cmdline/commands/test_group.py @@ -12,6 +12,7 @@ from aiida.backends.testbase import AiidaTestCase from aiida.common import exceptions from aiida.cmdline.commands import cmd_group +from aiida.cmdline.utils.echo import ExitCode class TestVerdiGroup(AiidaTestCase): @@ -131,6 +132,12 @@ def test_delete(self): """Test `verdi group delete` command.""" orm.Group(label='group_test_delete_01').store() orm.Group(label='group_test_delete_02').store() + orm.Group(label='group_test_delete_03').store() + + # dry run + result = self.cli_runner.invoke(cmd_group.group_delete, ['--dry-run', 'group_test_delete_01']) + self.assertClickResultNoException(result) + orm.load_group(label='group_test_delete_01') result = self.cli_runner.invoke(cmd_group.group_delete, ['--force', 'group_test_delete_01']) self.assertClickResultNoException(result) @@ -142,6 +149,7 @@ def test_delete(self): node_01 = orm.CalculationNode().store() node_02 = orm.CalculationNode().store() + node_pks = {node_01.pk, node_02.pk} # Add some nodes and then use `verdi group delete` to delete a group that contains nodes group = orm.load_group(label='group_test_delete_02') @@ -149,11 +157,27 @@ def test_delete(self): self.assertEqual(group.count(), 2) result = self.cli_runner.invoke(cmd_group.group_delete, ['--force', 'group_test_delete_02']) - self.assertClickResultNoException(result) with self.assertRaises(exceptions.NotExistent): orm.load_group(label='group_test_delete_02') + # check nodes still exist + for pk in node_pks: + orm.load_node(pk) + + # delete the group and the nodes it contains + group = orm.load_group(label='group_test_delete_03') + group.add_nodes([node_01, node_02]) + result = self.cli_runner.invoke(cmd_group.group_delete, ['--force', '--delete-nodes', 'group_test_delete_03']) + self.assertClickResultNoException(result) + + # check group and nodes no longer exist + with self.assertRaises(exceptions.NotExistent): + orm.load_group(label='group_test_delete_03') + for pk in node_pks: + with self.assertRaises(exceptions.NotExistent): + orm.load_node(pk) + def test_show(self): """Test `verdi group show` command.""" result = self.cli_runner.invoke(cmd_group.group_show, ['dummygroup1']) @@ -241,7 +265,7 @@ def test_add_remove_nodes(self): result = self.cli_runner.invoke(cmd_group.group_remove_nodes, ['--force', '--group=dummygroup1', node_01.uuid]) self.assertIsNone(result.exception, result.output) - # Check if node is added in group using group show command + # Check that the node is no longer in the group result = self.cli_runner.invoke(cmd_group.group_show, ['-r', 'dummygroup1']) self.assertClickResultNoException(result) self.assertNotIn('CalculationNode', result.output) @@ -256,6 +280,35 @@ def test_add_remove_nodes(self): self.assertClickResultNoException(result) self.assertEqual(group.count(), 0) + # Try to remove node that isn't in the group + result = self.cli_runner.invoke(cmd_group.group_remove_nodes, ['--group=dummygroup1', node_01.uuid]) + self.assertEqual(result.exit_code, ExitCode.CRITICAL) + + # Try to remove no nodes nor clear the group + result = self.cli_runner.invoke(cmd_group.group_remove_nodes, ['--group=dummygroup1']) + self.assertEqual(result.exit_code, ExitCode.CRITICAL) + + # Try to remove both nodes and clear the group + result = self.cli_runner.invoke(cmd_group.group_remove_nodes, ['--group=dummygroup1', '--clear', node_01.uuid]) + self.assertEqual(result.exit_code, ExitCode.CRITICAL) + + # Add a node with confirmation + result = self.cli_runner.invoke(cmd_group.group_add_nodes, ['--group=dummygroup1', node_01.uuid], input='y') + self.assertEqual(group.count(), 1) + + # Try to remove two nodes, one that isn't in the group, but abort + result = self.cli_runner.invoke( + cmd_group.group_remove_nodes, ['--group=dummygroup1', node_01.uuid, node_02.uuid], input='N' + ) + self.assertIn('Warning', result.output) + self.assertEqual(group.count(), 1) + + # Try to clear all nodes from the group, but abort + result = self.cli_runner.invoke(cmd_group.group_remove_nodes, ['--group=dummygroup1', '--clear'], input='N') + self.assertIn('Are you sure you want to remove ALL', result.output) + self.assertIn('Aborted', result.output) + self.assertEqual(group.count(), 1) + def test_copy_existing_group(self): """Test user is prompted to continue if destination group exists and is not empty""" source_label = 'source_copy_existing_group' diff --git a/tests/cmdline/commands/test_help.py b/tests/cmdline/commands/test_help.py index be2231c2f2..82310bbdb6 100644 --- a/tests/cmdline/commands/test_help.py +++ b/tests/cmdline/commands/test_help.py @@ -8,15 +8,14 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for `verdi help`.""" - from click.testing import CliRunner +import pytest from aiida.backends.testbase import AiidaTestCase from aiida.cmdline.commands import cmd_verdi -from tests.utils.configuration import with_temporary_config_instance - +@pytest.mark.usefixtures('config_with_profile') class TestVerdiHelpCommand(AiidaTestCase): """Tests for `verdi help`.""" @@ -24,7 +23,6 @@ def setUp(self): super().setUp() self.cli_runner = CliRunner() - @with_temporary_config_instance def test_without_arg(self): """ Ensure we get the same help for `verdi` (which gives the same as `verdi --help`) @@ -36,7 +34,6 @@ def test_without_arg(self): result_verdi = self.cli_runner.invoke(cmd_verdi.verdi, [], catch_exceptions=False) self.assertEqual(result_help.output, result_verdi.output) - @with_temporary_config_instance def test_cmd_help(self): """Ensure we get the same help for `verdi user --help` and `verdi help user`""" result_help = self.cli_runner.invoke(cmd_verdi.verdi, ['help', 'user'], catch_exceptions=False) diff --git a/tests/cmdline/commands/test_node.py b/tests/cmdline/commands/test_node.py index a99de1b532..874f18ca7b 100644 --- a/tests/cmdline/commands/test_node.py +++ b/tests/cmdline/commands/test_node.py @@ -15,6 +15,7 @@ import pathlib import tempfile import gzip +import warnings from click.testing import CliRunner @@ -22,6 +23,7 @@ from aiida.backends.testbase import AiidaTestCase from aiida.cmdline.commands import cmd_node from aiida.common.utils import Capturing +from aiida.common.warnings import AiidaDeprecationWarning def get_result_lines(result): @@ -55,7 +57,15 @@ def setUpClass(cls, *args, **kwargs): cls.node = node - # Set up a FolderData for the node repo cp tests. + def setUp(self): + self.cli_runner = CliRunner() + + @classmethod + def get_unstored_folder_node(cls): + """Get a "default" folder node with some data. + + The node is unstored so one can add more content to it before storing it. + """ folder_node = orm.FolderData() cls.content_file1 = 'nobody expects' cls.content_file2 = 'the minister of silly walks' @@ -63,27 +73,31 @@ def setUpClass(cls, *args, **kwargs): cls.key_file2 = 'some_other_file.txt' folder_node.put_object_from_filelike(io.StringIO(cls.content_file1), cls.key_file1) folder_node.put_object_from_filelike(io.StringIO(cls.content_file2), cls.key_file2) - folder_node.store() - cls.folder_node = folder_node - - def setUp(self): - self.cli_runner = CliRunner() + return folder_node def test_node_tree(self): """Test `verdi node tree`""" options = [str(self.node.pk)] - result = self.cli_runner.invoke(cmd_node.tree, options) + + # This command (and so the test as well) will go away in 2.0 + # Note: I cannot use simply pytest.mark.filterwarnings as below, as the warning is issued in an invoked command + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=AiidaDeprecationWarning) + result = self.cli_runner.invoke(cmd_node.tree, options) self.assertClickResultNoException(result) + # This command (and so this test as well) will go away in 2.0 def test_node_tree_printer(self): """Test the `NodeTreePrinter` utility.""" from aiida.cmdline.utils.ascii_vis import NodeTreePrinter + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=AiidaDeprecationWarning) - with Capturing(): - NodeTreePrinter.print_node_tree(self.node, max_depth=1) + with Capturing(): + NodeTreePrinter.print_node_tree(self.node, max_depth=1) - with Capturing(): - NodeTreePrinter.print_node_tree(self.node, max_depth=1, follow_links=()) + with Capturing(): + NodeTreePrinter.print_node_tree(self.node, max_depth=1, follow_links=()) def test_node_show(self): """Test `verdi node show`""" @@ -187,12 +201,14 @@ def test_node_extras(self): def test_node_repo_ls(self): """Test 'verdi node repo ls' command.""" - options = [str(self.folder_node.pk), 'some/nested/folder'] + folder_node = self.get_unstored_folder_node().store() + + options = [str(folder_node.pk), 'some/nested/folder'] result = self.cli_runner.invoke(cmd_node.repo_ls, options, catch_exceptions=False) self.assertClickResultNoException(result) self.assertIn('filename.txt', result.output) - options = [str(self.folder_node.pk), 'some/non-existing-folder'] + options = [str(folder_node.pk), 'some/non-existing-folder'] result = self.cli_runner.invoke(cmd_node.repo_ls, options, catch_exceptions=False) self.assertIsNotNone(result.exception) self.assertIn('does not exist for the given node', result.output) @@ -200,19 +216,21 @@ def test_node_repo_ls(self): def test_node_repo_cat(self): """Test 'verdi node repo cat' command.""" # Test cat binary files - with self.folder_node.open('filename.txt.gz', 'wb') as fh_out: - fh_out.write(gzip.compress(b'COMPRESS')) + folder_node = self.get_unstored_folder_node() + folder_node.put_object_from_filelike(io.BytesIO(gzip.compress(b'COMPRESS')), 'filename.txt.gz', mode='wb') + folder_node.store() - options = [str(self.folder_node.pk), 'filename.txt.gz'] + options = [str(folder_node.pk), 'filename.txt.gz'] result = self.cli_runner.invoke(cmd_node.repo_cat, options) assert gzip.decompress(result.stdout_bytes) == b'COMPRESS' def test_node_repo_dump(self): """Test 'verdi node repo dump' command.""" + folder_node = self.get_unstored_folder_node().store() with tempfile.TemporaryDirectory() as tmp_dir: out_path = pathlib.Path(tmp_dir) / 'out_dir' - options = [str(self.folder_node.uuid), str(out_path)] + options = [str(folder_node.uuid), str(out_path)] res = self.cli_runner.invoke(cmd_node.repo_dump, options, catch_exceptions=False) self.assertFalse(res.stdout) @@ -226,10 +244,11 @@ def test_node_repo_dump(self): def test_node_repo_dump_to_nested_folder(self): """Test 'verdi node repo dump' command, with an output folder whose parent does not exist.""" + folder_node = self.get_unstored_folder_node().store() with tempfile.TemporaryDirectory() as tmp_dir: out_path = pathlib.Path(tmp_dir) / 'out_dir' / 'nested' / 'path' - options = [str(self.folder_node.uuid), str(out_path)] + options = [str(folder_node.uuid), str(out_path)] res = self.cli_runner.invoke(cmd_node.repo_dump, options, catch_exceptions=False) self.assertFalse(res.stdout) @@ -243,6 +262,7 @@ def test_node_repo_dump_to_nested_folder(self): def test_node_repo_existing_out_dir(self): """Test 'verdi node repo dump' command, check that an existing output directory is not overwritten.""" + folder_node = self.get_unstored_folder_node().store() with tempfile.TemporaryDirectory() as tmp_dir: out_path = pathlib.Path(tmp_dir) / 'out_dir' @@ -252,7 +272,7 @@ def test_node_repo_existing_out_dir(self): some_file_content = 'ni!' with some_file.open('w') as file_handle: file_handle.write(some_file_content) - options = [str(self.folder_node.uuid), str(out_path)] + options = [str(folder_node.uuid), str(out_path)] res = self.cli_runner.invoke(cmd_node.repo_dump, options, catch_exceptions=False) self.assertIn('exists', res.stdout) self.assertIn('Critical:', res.stdout) @@ -282,7 +302,7 @@ class TestVerdiGraph(AiidaTestCase): """Tests for the ``verdi node graph`` command.""" @classmethod - def setUpClass(cls, *args, **kwargs): + def setUpClass(cls): super().setUpClass() from aiida.orm import Data @@ -295,7 +315,7 @@ def setUpClass(cls, *args, **kwargs): os.chdir(cls.cwd) @classmethod - def tearDownClass(cls, *args, **kwargs): + def tearDownClass(cls): os.chdir(cls.old_cwd) os.rmdir(cls.cwd) @@ -651,3 +671,8 @@ def test_basics(self): with self.assertRaises(NotExistent): orm.load_node(newnodepk) + + def test_missing_pk(self): + """Check that no exception is raised when a non-existent pk is given (just warns).""" + result = self.cli_runner.invoke(cmd_node.node_delete, ['999']) + self.assertClickResultNoException(result) diff --git a/tests/cmdline/commands/test_process.py b/tests/cmdline/commands/test_process.py index af1ae5384b..bea5ed8069 100644 --- a/tests/cmdline/commands/test_process.py +++ b/tests/cmdline/commands/test_process.py @@ -8,22 +8,19 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for `verdi process`.""" -import datetime -import subprocess -import sys import time +import asyncio from concurrent.futures import Future from click.testing import CliRunner -from tornado import gen import kiwipy import plumpy +import pytest from aiida.backends.testbase import AiidaTestCase from aiida.cmdline.commands import cmd_process from aiida.common.links import LinkType from aiida.common.log import LOG_LEVEL_REPORT -from aiida.manage.manager import get_manager from aiida.orm import CalcJobNode, WorkflowNode, WorkFunctionNode, WorkChainNode from tests.utils import processes as test_processes @@ -33,98 +30,6 @@ def get_result_lines(result): return [e for e in result.output.split('\n') if e] -class TestVerdiProcessDaemon(AiidaTestCase): - """Tests for `verdi process` that require a running daemon.""" - - TEST_TIMEOUT = 5. - - def setUp(self): - super().setUp() - from aiida.cmdline.utils.common import get_env_with_venv_bin - from aiida.engine.daemon.client import DaemonClient - from aiida.manage.configuration import get_config - - # Add the current python path to the environment that will be used for the daemon sub process. This is necessary - # to guarantee the daemon can also import all the classes that are defined in this `tests` module. - env = get_env_with_venv_bin() - env['PYTHONPATH'] = ':'.join(sys.path) - - profile = get_config().current_profile - self.daemon_client = DaemonClient(profile) - self.daemon_pid = subprocess.Popen( - self.daemon_client.cmd_string.split(), stderr=sys.stderr, stdout=sys.stdout, env=env - ).pid - self.runner = get_manager().create_runner(rmq_submit=True) - self.cli_runner = CliRunner() - - def tearDown(self): - import os - import signal - - os.kill(self.daemon_pid, signal.SIGTERM) - super().tearDown() - - def test_pause_play_kill(self): - """ - Test the pause/play/kill commands - """ - # pylint: disable=no-member - from aiida.orm import load_node - - calc = self.runner.submit(test_processes.WaitProcess) - start_time = time.time() - while calc.process_state is not plumpy.ProcessState.WAITING: - if time.time() - start_time >= self.TEST_TIMEOUT: - self.fail('Timed out waiting for process to enter waiting state') - - # Make sure that calling any command on a non-existing process id will not except but print an error - # To simulate a process without a corresponding task, we simply create a node and store it. This node will not - # have an associated task at RabbitMQ, but it will be a valid `ProcessNode` so it will pass the initial - # filtering of the `verdi process` commands - orphaned_node = WorkFunctionNode().store() - non_existing_process_id = str(orphaned_node.pk) - for command in [cmd_process.process_pause, cmd_process.process_play, cmd_process.process_kill]: - result = self.cli_runner.invoke(command, [non_existing_process_id]) - self.assertClickResultNoException(result) - self.assertIn('Error:', result.output) - - self.assertFalse(calc.paused) - result = self.cli_runner.invoke(cmd_process.process_pause, [str(calc.pk)]) - self.assertIsNone(result.exception, result.output) - - # We need to make sure that the process is picked up by the daemon and put in the Waiting state before we start - # running the CLI commands, so we add a broadcast subscriber for the state change, which when hit will set the - # future to True. This will be our signal that we can start testing - waiting_future = Future() - filters = kiwipy.BroadcastFilter( - lambda *args, **kwargs: waiting_future.set_result(True), sender=calc.pk, subject='state_changed.*.waiting' - ) - self.runner.communicator.add_broadcast_subscriber(filters) - - # The process may already have been picked up by the daemon and put in the waiting state, before the subscriber - # got the chance to attach itself, making it have missed the broadcast. That's why check if the state is already - # waiting, and if not, we run the loop of the runner to start waiting for the broadcast message. To make sure - # that we have the latest state of the node as it is in the database, we force refresh it by reloading it. - calc = load_node(calc.pk) - if calc.process_state != plumpy.ProcessState.WAITING: - self.runner.loop.run_sync(lambda: with_timeout(waiting_future)) - - # Here we now that the process is with the daemon runner and in the waiting state so we can starting running - # the `verdi process` commands that we want to test - result = self.cli_runner.invoke(cmd_process.process_pause, ['--wait', str(calc.pk)]) - self.assertIsNone(result.exception, result.output) - self.assertTrue(calc.paused) - - result = self.cli_runner.invoke(cmd_process.process_play, ['--wait', str(calc.pk)]) - self.assertIsNone(result.exception, result.output) - self.assertFalse(calc.paused) - - result = self.cli_runner.invoke(cmd_process.process_kill, ['--wait', str(calc.pk)]) - self.assertIsNone(result.exception, result.output) - self.assertTrue(calc.is_terminated) - self.assertTrue(calc.is_killed) - - class TestVerdiProcess(AiidaTestCase): """Tests for `verdi process`.""" @@ -490,6 +395,74 @@ def test_multiple_processes(self): self.assertIn(str(self.node_root.pk), get_result_lines(result)[2]) -@gen.coroutine -def with_timeout(what, timeout=5.0): - raise gen.Return((yield gen.with_timeout(datetime.timedelta(seconds=timeout), what))) +@pytest.mark.skip(reason='fails to complete randomly (see issue #4731)') +@pytest.mark.requires_rmq +@pytest.mark.usefixtures('with_daemon', 'clear_database_before_test') +@pytest.mark.parametrize('cmd_try_all', (True, False)) +def test_pause_play_kill(cmd_try_all, run_cli_command): + """ + Test the pause/play/kill commands + """ + # pylint: disable=no-member, too-many-locals + from aiida.cmdline.commands.cmd_process import process_pause, process_play, process_kill + from aiida.manage.manager import get_manager + from aiida.engine import ProcessState + from aiida.orm import load_node + + runner = get_manager().create_runner(rmq_submit=True) + calc = runner.submit(test_processes.WaitProcess) + + test_daemon_timeout = 5. + start_time = time.time() + while calc.process_state is not plumpy.ProcessState.WAITING: + if time.time() - start_time >= test_daemon_timeout: + raise RuntimeError('Timed out waiting for process to enter waiting state') + + # Make sure that calling any command on a non-existing process id will not except but print an error + # To simulate a process without a corresponding task, we simply create a node and store it. This node will not + # have an associated task at RabbitMQ, but it will be a valid `ProcessNode` with and active state, so it will + # pass the initial filtering of the `verdi process` commands + orphaned_node = WorkFunctionNode() + orphaned_node.set_process_state(ProcessState.RUNNING) + orphaned_node.store() + non_existing_process_id = str(orphaned_node.pk) + for command in [process_pause, process_play, process_kill]: + result = run_cli_command(command, [non_existing_process_id]) + assert 'Error:' in result.output + + assert not calc.paused + result = run_cli_command(process_pause, [str(calc.pk)]) + + # We need to make sure that the process is picked up by the daemon and put in the Waiting state before we start + # running the CLI commands, so we add a broadcast subscriber for the state change, which when hit will set the + # future to True. This will be our signal that we can start testing + waiting_future = Future() + filters = kiwipy.BroadcastFilter( + lambda *args, **kwargs: waiting_future.set_result(True), sender=calc.pk, subject='state_changed.*.waiting' + ) + runner.communicator.add_broadcast_subscriber(filters) + + # The process may already have been picked up by the daemon and put in the waiting state, before the subscriber + # got the chance to attach itself, making it have missed the broadcast. That's why check if the state is already + # waiting, and if not, we run the loop of the runner to start waiting for the broadcast message. To make sure + # that we have the latest state of the node as it is in the database, we force refresh it by reloading it. + calc = load_node(calc.pk) + if calc.process_state != plumpy.ProcessState.WAITING: + runner.loop.run_until_complete(asyncio.wait_for(waiting_future, timeout=5.0)) + + # Here we now that the process is with the daemon runner and in the waiting state so we can starting running + # the `verdi process` commands that we want to test + result = run_cli_command(process_pause, ['--wait', str(calc.pk)]) + assert calc.paused + + if cmd_try_all: + cmd_option = '--all' + else: + cmd_option = str(calc.pk) + + result = run_cli_command(process_play, ['--wait', cmd_option]) + assert not calc.paused + + result = run_cli_command(process_kill, ['--wait', str(calc.pk)]) + assert calc.is_terminated + assert calc.is_killed diff --git a/tests/cmdline/commands/test_profile.py b/tests/cmdline/commands/test_profile.py index 120b09d762..7eb262eeb3 100644 --- a/tests/cmdline/commands/test_profile.py +++ b/tests/cmdline/commands/test_profile.py @@ -10,14 +10,16 @@ """Tests for `verdi profile`.""" from click.testing import CliRunner +import pytest from aiida.backends.testbase import AiidaPostgresTestCase from aiida.cmdline.commands import cmd_profile, cmd_verdi from aiida.manage import configuration -from tests.utils.configuration import create_mock_profile, with_temporary_config_instance +from tests.utils.configuration import create_mock_profile +@pytest.mark.usefixtures('config_with_profile') class TestVerdiProfileSetup(AiidaPostgresTestCase): """Tests for `verdi profile`.""" @@ -32,8 +34,7 @@ def mock_profiles(self, **kwargs): """Create mock profiles and a runner object to invoke the CLI commands. Note: this cannot be done in the `setUp` or `setUpClass` methods, because the temporary configuration instance - is not generated until the test function is entered, which calls the `with_temporary_config_instance` - decorator. + is not generated until the test function is entered, which calls the `config_with_profile` test fixture. """ self.config = configuration.get_config() self.profile_list = ['mock_profile1', 'mock_profile2', 'mock_profile3', 'mock_profile4'] @@ -44,7 +45,6 @@ def mock_profiles(self, **kwargs): self.config.set_default_profile(self.profile_list[0], overwrite=True).store() - @with_temporary_config_instance def test_help(self): """Tests help text for all `verdi profile` commands.""" self.mock_profiles() @@ -67,7 +67,6 @@ def test_help(self): self.assertClickSuccess(result) self.assertIn('Usage', result.output) - @with_temporary_config_instance def test_list(self): """Test the `verdi profile list` command.""" self.mock_profiles() @@ -78,7 +77,6 @@ def test_list(self): self.assertIn(f'* {self.profile_list[0]}', result.output) self.assertIn(self.profile_list[1], result.output) - @with_temporary_config_instance def test_setdefault(self): """Test the `verdi profile setdefault` command.""" self.mock_profiles() @@ -93,7 +91,6 @@ def test_setdefault(self): self.assertIn(f'* {self.profile_list[1]}', result.output) self.assertClickSuccess(result) - @with_temporary_config_instance def test_show(self): """Test the `verdi profile show` command.""" self.mock_profiles() @@ -109,7 +106,6 @@ def test_show(self): self.assertIn(key.lower(), result.output) self.assertIn(value, result.output) - @with_temporary_config_instance def test_show_with_profile_option(self): """Test the `verdi profile show` command in combination with `-p/--profile.""" self.mock_profiles() @@ -126,12 +122,11 @@ def test_show_with_profile_option(self): self.assertClickSuccess(result) self.assertTrue(profile_name_non_default not in result.output) - @with_temporary_config_instance def test_delete_partial(self): """Test the `verdi profile delete` command. .. note:: we skip deleting the database as this might require sudo rights and this is tested in the CI tests - defined in the file `.ci/test_profile.py` + defined in the file `.github/system_tests/test_profile.py` """ self.mock_profiles() @@ -142,7 +137,6 @@ def test_delete_partial(self): self.assertClickSuccess(result) self.assertNotIn(self.profile_list[1], result.output) - @with_temporary_config_instance def test_delete(self): """Test for verdi profile delete command.""" from aiida.cmdline.commands.cmd_profile import profile_delete, profile_list diff --git a/tests/cmdline/commands/test_restapi.py b/tests/cmdline/commands/test_restapi.py index ab3d54eca0..9b0c2cb46a 100644 --- a/tests/cmdline/commands/test_restapi.py +++ b/tests/cmdline/commands/test_restapi.py @@ -10,6 +10,7 @@ """Tests for `verdi restapi`.""" from click.testing import CliRunner +import pytest from aiida.backends.testbase import AiidaTestCase from aiida.cmdline.commands.cmd_restapi import restapi @@ -22,8 +23,12 @@ def setUp(self): super().setUp() self.cli_runner = CliRunner() + @pytest.mark.filterwarnings('ignore::aiida.common.warnings.AiidaDeprecationWarning') def test_run_restapi(self): - """Test `verdi restapi`.""" + """Test `verdi restapi`. + + Note: This test will need to be changed/removed once the hookup parameter is dropped from the CLI. + """ options = ['--no-hookup', '--hostname', 'localhost', '--port', '6000', '--debug', '--wsgi-profile'] diff --git a/tests/cmdline/commands/test_run.py b/tests/cmdline/commands/test_run.py index b151f49639..5e1c49e867 100644 --- a/tests/cmdline/commands/test_run.py +++ b/tests/cmdline/commands/test_run.py @@ -13,6 +13,7 @@ import warnings from click.testing import CliRunner +import pytest from aiida.backends.testbase import AiidaTestCase from aiida.cmdline.commands import cmd_run @@ -25,6 +26,7 @@ def setUp(self): super().setUp() self.cli_runner = CliRunner() + @pytest.mark.requires_rmq def test_run_workfunction(self): """Regression test for #2165 @@ -181,6 +183,7 @@ def test_no_autogroup(self): all_auto_groups = queryb.all() self.assertEqual(len(all_auto_groups), 0, 'There should be no autogroup generated') + @pytest.mark.requires_rmq def test_autogroup_filter_class(self): # pylint: disable=too-many-locals """Check if the autogroup is properly generated but filtered classes are skipped.""" from aiida.orm import Code, QueryBuilder, Node, AutoGroup, load_node diff --git a/tests/cmdline/commands/test_setup.py b/tests/cmdline/commands/test_setup.py index d48061e3ee..8515be996e 100644 --- a/tests/cmdline/commands/test_setup.py +++ b/tests/cmdline/commands/test_setup.py @@ -20,9 +20,8 @@ from aiida.manage import configuration from aiida.manage.external.postgres import Postgres -from tests.utils.configuration import with_temporary_config_instance - +@pytest.mark.usefixtures('config_with_profile') class TestVerdiSetup(AiidaPostgresTestCase): """Tests for `verdi setup` and `verdi quicksetup`.""" @@ -34,7 +33,6 @@ def setUp(self): self.backend = configuration.PROFILE.database_backend self.cli_runner = CliRunner() - @with_temporary_config_instance def test_help(self): """Check that the `--help` option is eager, is not overruled and will properly display the help message. @@ -44,7 +42,6 @@ def test_help(self): self.cli_runner.invoke(cmd_setup.setup, ['--help'], catch_exceptions=False) self.cli_runner.invoke(cmd_setup.quicksetup, ['--help'], catch_exceptions=False) - @with_temporary_config_instance def test_quicksetup(self): """Test `verdi quicksetup`.""" configuration.reset_profile() @@ -79,7 +76,6 @@ def test_quicksetup(self): self.assertEqual(user.last_name, user_last_name) self.assertEqual(user.institution, user_institution) - @with_temporary_config_instance def test_quicksetup_from_config_file(self): """Test `verdi quicksetup` from configuration file.""" import tempfile @@ -99,7 +95,6 @@ def test_quicksetup_from_config_file(self): result = self.cli_runner.invoke(cmd_setup.quicksetup, ['--config', os.path.realpath(handle.name)]) self.assertClickResultNoException(result) - @with_temporary_config_instance def test_quicksetup_wrong_port(self): """Test `verdi quicksetup` exits if port is wrong.""" configuration.reset_profile() @@ -119,7 +114,6 @@ def test_quicksetup_wrong_port(self): result = self.cli_runner.invoke(cmd_setup.quicksetup, options) self.assertIsNotNone(result.exception, ''.join(traceback.format_exception(*result.exc_info))) - @with_temporary_config_instance def test_setup(self): """Test `verdi setup` (non-interactive).""" postgres = Postgres(interactive=False, quiet=True, dbinfo=self.pg_test.dsn) diff --git a/tests/cmdline/commands/test_status.py b/tests/cmdline/commands/test_status.py index 4818be3d39..4e7f2ca02d 100644 --- a/tests/cmdline/commands/test_status.py +++ b/tests/cmdline/commands/test_status.py @@ -14,6 +14,7 @@ from aiida.cmdline.utils.echo import ExitCode +@pytest.mark.requires_rmq def test_status(run_cli_command): """Test `verdi status`.""" options = [] @@ -27,7 +28,7 @@ def test_status(run_cli_command): assert string in result.output -@pytest.mark.usefixtures('create_empty_config_instance') +@pytest.mark.usefixtures('empty_config') def test_status_no_profile(run_cli_command): """Test `verdi status` when there is no profile.""" options = [] diff --git a/tests/cmdline/commands/test_verdi.py b/tests/cmdline/commands/test_verdi.py index 0791150dca..ed3aa88204 100644 --- a/tests/cmdline/commands/test_verdi.py +++ b/tests/cmdline/commands/test_verdi.py @@ -9,14 +9,14 @@ ########################################################################### """Tests for `verdi`.""" from click.testing import CliRunner +import pytest from aiida import get_version from aiida.backends.testbase import AiidaTestCase from aiida.cmdline.commands import cmd_verdi -from tests.utils.configuration import with_temporary_config_instance - +@pytest.mark.usefixtures('config_with_profile') class TestVerdi(AiidaTestCase): """Tests for `verdi`.""" @@ -30,7 +30,6 @@ def test_verdi_version(self): self.assertIsNone(result.exception, result.output) self.assertIn(get_version(), result.output) - @with_temporary_config_instance def test_verdi_with_empty_profile_list(self): """Regression test for #2424: verify that verdi remains operable even if profile list is empty""" from aiida.manage.configuration import CONFIG @@ -40,7 +39,6 @@ def test_verdi_with_empty_profile_list(self): result = self.cli_runner.invoke(cmd_verdi.verdi, []) self.assertIsNone(result.exception, result.output) - @with_temporary_config_instance def test_invalid_cmd_matches(self): """Test that verdi with an invalid command will return matches if somewhat close""" result = self.cli_runner.invoke(cmd_verdi.verdi, ['usr']) @@ -49,7 +47,6 @@ def test_invalid_cmd_matches(self): self.assertIn('user', result.output) self.assertNotEqual(result.exit_code, 0) - @with_temporary_config_instance def test_invalid_cmd_no_matches(self): """Test that verdi with an invalid command with no matches returns an appropriate message""" result = self.cli_runner.invoke(cmd_verdi.verdi, ['foobar']) diff --git a/tests/common/test_hashing.py b/tests/common/test_hashing.py index 05f6f66260..a3d1db3dbd 100644 --- a/tests/common/test_hashing.py +++ b/tests/common/test_hashing.py @@ -24,6 +24,7 @@ except ImportError: import unittest +from aiida.common.exceptions import HashingError from aiida.common.hashing import make_hash, float_to_text from aiida.common.folders import SandboxFolder from aiida.backends.testbase import AiidaTestCase @@ -36,6 +37,7 @@ class FloatToTextTest(unittest.TestCase): """ def test_subnormal(self): + self.assertEqual(float_to_text(-0.00, sig=2), '0') # 0 is always printed as '0' self.assertEqual(float_to_text(3.555, sig=2), '3.6') self.assertEqual(float_to_text(3.555, sig=3), '3.56') self.assertEqual(float_to_text(3.141592653589793238462643383279502884197, sig=14), '3.1415926535898') @@ -177,7 +179,7 @@ def test_unhashable_type(self): class MadeupClass: pass - with self.assertRaises(ValueError): + with self.assertRaises(HashingError): make_hash(MadeupClass()) def test_folder(self): diff --git a/tests/conftest.py b/tests/conftest.py index 2000f7500c..4a0c5862fc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ import pytest -from aiida.manage.configuration import Config, Profile, get_config +from aiida.manage.configuration import Config, Profile, get_config, load_profile pytest_plugins = ['aiida.manage.tests.pytest_fixtures', 'sphinx.testing.fixtures'] # pylint: disable=invalid-name @@ -152,7 +152,7 @@ def _generate_calculation_node(process_state=ProcessState.FINISHED, exit_status= @pytest.fixture -def create_empty_config_instance(tmp_path) -> Config: +def empty_config(tmp_path) -> Config: """Create a temporary configuration instance. This creates a temporary directory with a clean `.aiida` folder and basic configuration file. The currently loaded @@ -162,7 +162,7 @@ def create_empty_config_instance(tmp_path) -> Config: """ from aiida.common.utils import Capturing from aiida.manage import configuration - from aiida.manage.configuration import settings, load_profile, reset_profile + from aiida.manage.configuration import settings, reset_profile # Store the current configuration instance and config directory path current_config = configuration.CONFIG @@ -193,7 +193,7 @@ def create_empty_config_instance(tmp_path) -> Config: @pytest.fixture -def create_profile() -> Profile: +def profile_factory() -> Profile: """Create a new profile instance. :return: the profile instance. @@ -221,10 +221,85 @@ def _create_profile(name, **kwargs): @pytest.fixture -def backend(): - """Get the ``Backend`` instance of the currently loaded profile.""" +def config_with_profile_factory(empty_config, profile_factory) -> Config: + """Create a temporary configuration instance with one profile. + + This fixture builds on the `empty_config` fixture, to add a single profile. + + The defaults of the profile can be overridden in the callable, as well as whether it should be set as default. + + Example:: + + def test_config_with_profile(config_with_profile_factory): + config = config_with_profile_factory(set_as_default=True, name='default', database_backend='django') + assert config.current_profile.name == 'default' + + As with `empty_config`, the currently loaded configuration and profile are stored in memory, + and are automatically restored at the end of this context manager. + + This fixture should be used by tests that modify aspects of the AiiDA configuration or profile + and require a preconfigured profile, but do not require an actual configured database. + """ + + def _config_with_profile_factory(set_as_default=True, load=True, name='default', **kwargs): + """Create a temporary configuration instance with one profile. + + :param set_as_default: whether to set the one profile as the default. + :param load: whether to load the profile. + :param name: the profile name + :param kwargs: parameters that are forwarded to the `Profile` constructor. + + :return: a config instance with a configured profile. + """ + profile = profile_factory(name=name, **kwargs) + config = empty_config + config.add_profile(profile) + + if set_as_default: + config.set_default_profile(profile.name, overwrite=True) + + config.store() + + if load: + load_profile(profile.name) + + return config + + return _config_with_profile_factory + + +@pytest.fixture +def config_with_profile(config_with_profile_factory): + """Create a temporary configuration instance with one default, loaded profile.""" + yield config_with_profile_factory() + + +@pytest.fixture +def manager(aiida_profile): # pylint: disable=unused-argument + """Get the ``Manager`` instance of the currently loaded profile.""" from aiida.manage.manager import get_manager - return get_manager().get_backend() + return get_manager() + + +@pytest.fixture +def event_loop(manager): + """Get the event loop instance of the currently loaded profile. + + This is automatically called as a fixture for any test marked with ``@pytest.mark.asyncio``. + """ + yield manager.get_runner().loop + + +@pytest.fixture +def backend(manager): + """Get the ``Backend`` instance of the currently loaded profile.""" + return manager.get_backend() + + +@pytest.fixture +def communicator(manager): + """Get the ``Communicator`` instance of the currently loaded profile to communicate with RabbitMQ.""" + return manager.get_communicator() @pytest.fixture @@ -259,3 +334,33 @@ def override_logging(): config.unset_option('logging.aiida_loglevel') config.unset_option('logging.db_loglevel') configure_logging(with_orm=True) + + +@pytest.fixture +def with_daemon(): + """Starts the daemon process and then makes sure to kill it once the test is done.""" + import sys + import signal + import subprocess + + from aiida.engine.daemon.client import DaemonClient + from aiida.cmdline.utils.common import get_env_with_venv_bin + + # Add the current python path to the environment that will be used for the daemon sub process. + # This is necessary to guarantee the daemon can also import all the classes that are defined + # in this `tests` module. + env = get_env_with_venv_bin() + env['PYTHONPATH'] = ':'.join(sys.path) + + profile = get_config().current_profile + daemon = subprocess.Popen( + DaemonClient(profile).cmd_string.split(), + stderr=sys.stderr, + stdout=sys.stdout, + env=env, + ) + + yield + + # Note this will always be executed after the yield no matter what happened in the test that used this fixture. + os.kill(daemon.pid, signal.SIGTERM) diff --git a/tests/engine/daemon/test_execmanager.py b/tests/engine/daemon/test_execmanager.py index 62c496ba97..2cce4eebca 100644 --- a/tests/engine/daemon/test_execmanager.py +++ b/tests/engine/daemon/test_execmanager.py @@ -7,54 +7,127 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=redefined-outer-name """Tests for the :mod:`aiida.engine.daemon.execmanager` module.""" import io import os +import pathlib +import typing + import pytest from aiida.engine.daemon import execmanager from aiida.transports.plugins.local import LocalTransport -@pytest.mark.usefixtures('clear_database_before_test') -def test_retrieve_files_from_list(tmp_path_factory, generate_calculation_node): - """Test the `retrieve_files_from_list` function.""" - node = generate_calculation_node() +def serialize_file_hierarchy(dirpath: pathlib.Path) -> typing.Dict: + """Serialize the file hierarchy at ``dirpath``. - retrieve_list = [ - 'file_a.txt', - ('sub/folder', 'sub/folder', 0), - ] + .. note:: empty directories are ignored. - source = tmp_path_factory.mktemp('source') - target = tmp_path_factory.mktemp('target') + :param dirpath: the base path. + :return: a mapping representing the file hierarchy, where keys are filenames. The leafs correspond to files and the + values are the text contents. + """ + serialized = {} - content_a = b'content_a' - content_b = b'content_b' + for root, _, files in os.walk(dirpath): + for filepath in files: - with open(str(source / 'file_a.txt'), 'wb') as handle: - handle.write(content_a) - handle.flush() + relpath = pathlib.Path(root).relative_to(dirpath) + subdir = serialized + if relpath.parts: + for part in relpath.parts: + subdir = subdir.setdefault(part, {}) + subdir[filepath] = (pathlib.Path(root) / filepath).read_text() - os.makedirs(str(source / 'sub' / 'folder')) + return serialized - with open(str(source / 'sub' / 'folder' / 'file_b.txt'), 'wb') as handle: - handle.write(content_b) - handle.flush() - with LocalTransport() as transport: - transport.chdir(str(source)) - execmanager.retrieve_files_from_list(node, transport, str(target), retrieve_list) +def create_file_hierarchy(hierarchy: typing.Dict, basepath: pathlib.Path) -> None: + """Create the file hierarchy represented by the hierarchy created by ``serialize_file_hierarchy``. + + .. note:: empty directories are ignored and are not created explicitly on disk. + + :param hierarchy: mapping with structure returned by ``serialize_file_hierarchy``. + :param basepath: the basepath where to write the hierarchy to disk. + """ + for filename, value in hierarchy.items(): + if isinstance(value, dict): + create_file_hierarchy(value, basepath / filename) + else: + basepath.mkdir(parents=True, exist_ok=True) + (basepath / filename).write_text(value) + + +@pytest.fixture +def file_hierarchy(): + """Return a sample nested file hierarchy.""" + return { + 'file_a.txt': 'file_a', + 'path': { + 'file_b.txt': 'file_b', + 'sub': { + 'file_c.txt': 'file_c', + 'file_d.txt': 'file_d' + } + } + } - assert sorted(os.listdir(str(target))) == sorted(['file_a.txt', 'sub']) - assert os.listdir(str(target / 'sub')) == ['folder'] - assert os.listdir(str(target / 'sub' / 'folder')) == ['file_b.txt'] - with open(str(target / 'sub' / 'folder' / 'file_b.txt'), 'rb') as handle: - assert handle.read() == content_b +def test_hierarchy_utility(file_hierarchy, tmp_path): + """Test that the ``create_file_hierarchy`` and ``serialize_file_hierarchy`` function as intended. + + This is tested by performing a round-trip. + """ + create_file_hierarchy(file_hierarchy, tmp_path) + assert serialize_file_hierarchy(tmp_path) == file_hierarchy + + +# yapf: disable +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize('retrieve_list, expected_hierarchy', ( + # Single file or folder, either toplevel or nested + (['file_a.txt'], {'file_a.txt': 'file_a'}), + (['path/sub/file_c.txt'], {'file_c.txt': 'file_c'}), + (['path'], {'path': {'file_b.txt': 'file_b', 'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}}), + (['path/sub'], {'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}), + # Single nested file that is retrieved keeping a varying level of depth of original hierarchy + ([('path/sub/file_c.txt', '.', 3)], {'path': {'sub': {'file_c.txt': 'file_c'}}}), + ([('path/sub/file_c.txt', '.', 2)], {'sub': {'file_c.txt': 'file_c'}}), + ([('path/sub/file_c.txt', '.', 1)], {'file_c.txt': 'file_c'}), + ([('path/sub/file_c.txt', '.', 0)], {'file_c.txt': 'file_c'}), + # Single nested folder that is retrieved keeping a varying level of depth of original hierarchy + ([('path/sub', '.', 2)], {'path': {'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}}), + ([('path/sub', '.', 1)], {'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}), + # Using globbing patterns + ([('path/*', '.', 0)], {'file_b.txt': 'file_b', 'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}), + ([('path/sub/*', '.', 0)], {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}), # This is identical to ['path/sub'] + ([('path/sub/*c.txt', '.', 2)], {'sub': {'file_c.txt': 'file_c'}}), + ([('path/sub/*c.txt', '.', 0)], {'file_c.txt': 'file_c'}), + # Different target directory + ([('path/sub/file_c.txt', 'target', 3)], {'target': {'path': {'sub': {'file_c.txt': 'file_c'}}}}), + ([('path/sub', 'target', 1)], {'target': {'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}}), + ([('path/sub/*c.txt', 'target', 2)], {'target': {'sub': {'file_c.txt': 'file_c'}}}), + # Missing files should be ignored and not cause the retrieval to except + (['file_a.txt', 'file_u.txt', 'path/file_u.txt', ('path/sub/file_u.txt', '.', 3)], {'file_a.txt': 'file_a'}), +)) +# yapf: enable +def test_retrieve_files_from_list( + tmp_path_factory, generate_calculation_node, file_hierarchy, retrieve_list, expected_hierarchy +): + """Test the `retrieve_files_from_list` function.""" + source = tmp_path_factory.mktemp('source') + target = tmp_path_factory.mktemp('target') + + create_file_hierarchy(file_hierarchy, source) + + with LocalTransport() as transport: + node = generate_calculation_node() + transport.chdir(source) + execmanager.retrieve_files_from_list(node, transport, target, retrieve_list) - with open(str(target / 'file_a.txt'), 'rb') as handle: - assert handle.read() == content_a + assert serialize_file_hierarchy(target) == expected_hierarchy @pytest.mark.usefixtures('clear_database_before_test') diff --git a/tests/engine/daemon/test_runner.py b/tests/engine/daemon/test_runner.py new file mode 100644 index 0000000000..044ec3349f --- /dev/null +++ b/tests/engine/daemon/test_runner.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Unit tests for the :mod:`aiida.engine.daemon.runner` module.""" +import pytest + +from aiida.engine.daemon.runner import shutdown_runner + + +@pytest.mark.requires_rmq +@pytest.mark.asyncio +async def test_shutdown_runner(manager): + """Test the ``shutdown_runner`` method.""" + runner = manager.get_runner() + await shutdown_runner(runner) + + try: + assert runner.is_closed() + finally: + # Reset the runner of the manager, because once closed it cannot be reused by other tests. + manager._runner = None # pylint: disable=protected-access diff --git a/tests/engine/processes/test_builder.py b/tests/engine/processes/test_builder.py index aa7ad19b0b..239bc4984d 100644 --- a/tests/engine/processes/test_builder.py +++ b/tests/engine/processes/test_builder.py @@ -28,29 +28,29 @@ def test_access_methods(): builder = ProcessBuilder(ArithmeticAddCalculation) builder['x'] = node_numb - assert dict(builder) == {'metadata': {'options': {}}, 'x': node_numb} + assert dict(builder) == {'metadata': {'options': {'stash': {}}}, 'x': node_numb} del builder['x'] - assert dict(builder) == {'metadata': {'options': {}}} + assert dict(builder) == {'metadata': {'options': {'stash': {}}}} with pytest.raises(ValueError): builder['x'] = node_dict builder['x'] = node_numb - assert dict(builder) == {'metadata': {'options': {}}, 'x': node_numb} + assert dict(builder) == {'metadata': {'options': {'stash': {}}}, 'x': node_numb} # AS ATTRIBUTES del builder builder = ProcessBuilder(ArithmeticAddCalculation) builder.x = node_numb - assert dict(builder) == {'metadata': {'options': {}}, 'x': node_numb} + assert dict(builder) == {'metadata': {'options': {'stash': {}}}, 'x': node_numb} del builder.x - assert dict(builder) == {'metadata': {'options': {}}} + assert dict(builder) == {'metadata': {'options': {'stash': {}}}} with pytest.raises(ValueError): builder.x = node_dict builder.x = node_numb - assert dict(builder) == {'metadata': {'options': {}}, 'x': node_numb} + assert dict(builder) == {'metadata': {'options': {'stash': {}}}, 'x': node_numb} diff --git a/tests/engine/processes/workchains/test_restart.py b/tests/engine/processes/workchains/test_restart.py index dbea7970b4..034a244bba 100644 --- a/tests/engine/processes/workchains/test_restart.py +++ b/tests/engine/processes/workchains/test_restart.py @@ -49,6 +49,7 @@ def test_get_process_handler(): assert [handler.__name__ for handler in SomeWorkChain.get_process_handlers()] == ['handler_a', 'handler_b'] +@pytest.mark.requires_rmq @pytest.mark.usefixtures('clear_database_before_test') def test_excepted_process(generate_work_chain, generate_calculation_node): """Test that the workchain aborts if the sub process was excepted.""" @@ -58,6 +59,7 @@ def test_excepted_process(generate_work_chain, generate_calculation_node): assert process.inspect_process() == engine.BaseRestartWorkChain.exit_codes.ERROR_SUB_PROCESS_EXCEPTED +@pytest.mark.requires_rmq @pytest.mark.usefixtures('clear_database_before_test') def test_killed_process(generate_work_chain, generate_calculation_node): """Test that the workchain aborts if the sub process was killed.""" @@ -67,6 +69,7 @@ def test_killed_process(generate_work_chain, generate_calculation_node): assert process.inspect_process() == engine.BaseRestartWorkChain.exit_codes.ERROR_SUB_PROCESS_KILLED +@pytest.mark.requires_rmq @pytest.mark.usefixtures('clear_database_before_test') def test_unhandled_failure(generate_work_chain, generate_calculation_node): """Test the unhandled failure mechanism. @@ -85,6 +88,7 @@ def test_unhandled_failure(generate_work_chain, generate_calculation_node): ) == engine.BaseRestartWorkChain.exit_codes.ERROR_SECOND_CONSECUTIVE_UNHANDLED_FAILURE # pylint: disable=no-member +@pytest.mark.requires_rmq @pytest.mark.usefixtures('clear_database_before_test') def test_unhandled_reset_after_success(generate_work_chain, generate_calculation_node): """Test `ctx.unhandled_failure` is reset to `False` in `inspect_process` after a successful process.""" @@ -99,6 +103,7 @@ def test_unhandled_reset_after_success(generate_work_chain, generate_calculation assert process.ctx.unhandled_failure is False +@pytest.mark.requires_rmq @pytest.mark.usefixtures('clear_database_before_test') def test_unhandled_reset_after_handled(generate_work_chain, generate_calculation_node): """Test `ctx.unhandled_failure` is reset to `False` in `inspect_process` after a handled failed process.""" @@ -120,6 +125,7 @@ def test_unhandled_reset_after_handled(generate_work_chain, generate_calculation assert process.ctx.unhandled_failure is False +@pytest.mark.requires_rmq @pytest.mark.usefixtures('clear_database_before_test') def test_run_process(generate_work_chain, generate_calculation_node, monkeypatch): """Test the `run_process` method.""" diff --git a/tests/engine/processes/workchains/test_utils.py b/tests/engine/processes/workchains/test_utils.py index efcc28537b..5da7f6e156 100644 --- a/tests/engine/processes/workchains/test_utils.py +++ b/tests/engine/processes/workchains/test_utils.py @@ -9,6 +9,8 @@ ########################################################################### # pylint: disable=no-self-use,unused-argument,unused-variable,function-redefined,missing-class-docstring,missing-function-docstring """Tests for `aiida.engine.processes.workchains.utils` module.""" +import pytest + from aiida.backends.testbase import AiidaTestCase from aiida.engine import ExitCode, ProcessState from aiida.engine.processes.workchains.restart import BaseRestartWorkChain @@ -19,6 +21,7 @@ ArithmeticAddCalculation = CalculationFactory('arithmetic.add') +@pytest.mark.requires_rmq class TestRegisterProcessHandler(AiidaTestCase): """Tests for the `process_handler` decorator.""" diff --git a/tests/engine/test_calc_job.py b/tests/engine/test_calc_job.py index 08429c9a56..75a611d8b9 100644 --- a/tests/engine/test_calc_job.py +++ b/tests/engine/test_calc_job.py @@ -11,6 +11,7 @@ """Test for the `CalcJob` process sub class.""" from copy import deepcopy from functools import partial +import io import os from unittest.mock import patch @@ -18,9 +19,10 @@ from aiida import orm from aiida.backends.testbase import AiidaTestCase -from aiida.common import exceptions, LinkType, CalcJobState +from aiida.common import exceptions, LinkType, CalcJobState, StashMode from aiida.engine import launch, CalcJob, Process, ExitCode from aiida.engine.processes.ports import PortNamespace +from aiida.engine.processes.calcjobs.calcjob import validate_stash_options from aiida.plugins import CalculationFactory ArithmeticAddCalculation = CalculationFactory('arithmetic.add') # pylint: disable=invalid-name @@ -34,6 +36,7 @@ def raise_exception(exception): raise exception() +@pytest.mark.requires_rmq class FileCalcJob(CalcJob): """Example `CalcJob` implementation to test the `provenance_exclude_list` functionality. @@ -71,6 +74,7 @@ def prepare_for_submission(self, folder): return calcinfo +@pytest.mark.requires_rmq class TestCalcJob(AiidaTestCase): """Test for the `CalcJob` process sub class.""" @@ -356,36 +360,46 @@ def test_parse_retrieved_folder(self): @pytest.fixture -def process(aiida_local_code_factory): +def generate_process(aiida_local_code_factory): """Instantiate a process with default inputs and return the `Process` instance.""" from aiida.engine.utils import instantiate_process from aiida.manage.manager import get_manager - inputs = { - 'code': aiida_local_code_factory('arithmetic.add', '/bin/bash'), - 'x': orm.Int(1), - 'y': orm.Int(2), - 'metadata': { - 'options': {} + def _generate_process(inputs=None): + + base_inputs = { + 'code': aiida_local_code_factory('arithmetic.add', '/bin/bash'), + 'x': orm.Int(1), + 'y': orm.Int(2), + 'metadata': { + 'options': {} + } } - } - manager = get_manager() - runner = manager.get_runner() + if inputs is not None: + base_inputs = {**base_inputs, **inputs} + + manager = get_manager() + runner = manager.get_runner() + + process_class = CalculationFactory('arithmetic.add') + process = instantiate_process(runner, process_class, **base_inputs) + process.node.set_state(CalcJobState.PARSING) - process_class = CalculationFactory('arithmetic.add') - process = instantiate_process(runner, process_class, **inputs) - process.node.set_state(CalcJobState.PARSING) + return process - return process + return _generate_process +@pytest.mark.requires_rmq @pytest.mark.usefixtures('clear_database_before_test', 'override_logging') -def test_parse_insufficient_data(process): +def test_parse_insufficient_data(generate_process): """Test the scheduler output parsing logic in `CalcJob.parse`. Here we check explicitly that the parsing does not except even if the required information is not available. """ + process = generate_process() + retrieved = orm.FolderData().store() retrieved.add_incoming(process.node, link_label='retrieved', link_type=LinkType.CREATE) process.parse() @@ -408,13 +422,16 @@ def test_parse_insufficient_data(process): assert log in logs +@pytest.mark.requires_rmq @pytest.mark.usefixtures('clear_database_before_test', 'override_logging') -def test_parse_non_zero_retval(process): +def test_parse_non_zero_retval(generate_process): """Test the scheduler output parsing logic in `CalcJob.parse`. This is testing the case where the `detailed_job_info` is incomplete because the call failed. This is checked through the return value that is stored within the attribute dictionary. """ + process = generate_process() + retrieved = orm.FolderData().store() retrieved.add_incoming(process.node, link_label='retrieved', link_type=LinkType.CREATE) @@ -425,25 +442,24 @@ def test_parse_non_zero_retval(process): assert 'could not parse scheduler output: return value of `detailed_job_info` is non-zero' in logs +@pytest.mark.requires_rmq @pytest.mark.usefixtures('clear_database_before_test', 'override_logging') -def test_parse_not_implemented(process): +def test_parse_not_implemented(generate_process): """Test the scheduler output parsing logic in `CalcJob.parse`. Here we check explicitly that the parsing does not except even if the scheduler does not implement the method. """ - retrieved = orm.FolderData().store() - retrieved.add_incoming(process.node, link_label='retrieved', link_type=LinkType.CREATE) - - process.node.set_attribute('detailed_job_info', {}) - + process = generate_process() filename_stderr = process.node.get_option('scheduler_stderr') filename_stdout = process.node.get_option('scheduler_stdout') - with retrieved.open(filename_stderr, 'w') as handle: - handle.write('\n') + retrieved = orm.FolderData() + retrieved.put_object_from_filelike(io.StringIO('\n'), filename_stderr, mode='w') + retrieved.put_object_from_filelike(io.StringIO('\n'), filename_stdout, mode='w') + retrieved.store() + retrieved.add_incoming(process.node, link_label='retrieved', link_type=LinkType.CREATE) - with retrieved.open(filename_stdout, 'w') as handle: - handle.write('\n') + process.node.set_attribute('detailed_job_info', {}) process.parse() @@ -456,27 +472,26 @@ def test_parse_not_implemented(process): assert log in logs +@pytest.mark.requires_rmq @pytest.mark.usefixtures('clear_database_before_test', 'override_logging') -def test_parse_scheduler_excepted(process, monkeypatch): +def test_parse_scheduler_excepted(generate_process, monkeypatch): """Test the scheduler output parsing logic in `CalcJob.parse`. Here we check explicitly the case where the `Scheduler.parse_output` method excepts """ from aiida.schedulers.plugins.direct import DirectScheduler - retrieved = orm.FolderData().store() - retrieved.add_incoming(process.node, link_label='retrieved', link_type=LinkType.CREATE) - - process.node.set_attribute('detailed_job_info', {}) - + process = generate_process() filename_stderr = process.node.get_option('scheduler_stderr') filename_stdout = process.node.get_option('scheduler_stdout') - with retrieved.open(filename_stderr, 'w') as handle: - handle.write('\n') + retrieved = orm.FolderData() + retrieved.put_object_from_filelike(io.StringIO('\n'), filename_stderr, mode='w') + retrieved.put_object_from_filelike(io.StringIO('\n'), filename_stdout, mode='w') + retrieved.store() + retrieved.add_incoming(process.node, link_label='retrieved', link_type=LinkType.CREATE) - with retrieved.open(filename_stdout, 'w') as handle: - handle.write('\n') + process.node.set_attribute('detailed_job_info', {}) msg = 'crash' @@ -493,6 +508,7 @@ def raise_exception(*args, **kwargs): assert log in logs +@pytest.mark.requires_rmq @pytest.mark.parametrize(('exit_status_scheduler', 'exit_status_retrieved', 'final'), ( (None, None, 0), (100, None, 100), @@ -558,3 +574,84 @@ def parse_retrieved_output(_, __): result = process.parse() assert isinstance(result, ExitCode) assert result.status == final + + +@pytest.mark.requires_rmq +@pytest.mark.usefixtures('clear_database_before_test') +def test_additional_retrieve_list(generate_process, fixture_sandbox): + """Test the ``additional_retrieve_list`` option.""" + process = generate_process() + process.presubmit(fixture_sandbox) + retrieve_list = process.node.get_attribute('retrieve_list') + + # Keep reference of the base contents of the retrieve list. + base_retrieve_list = retrieve_list + + # Test that the code works if no explicit additional retrieve list is specified + assert len(retrieve_list) != 0 + assert isinstance(process.node.get_attribute('retrieve_list'), list) + + # Defining explicit additional retrieve list that is disjoint with the base retrieve list + additional_retrieve_list = ['file.txt', 'folder/file.txt'] + process = generate_process({'metadata': {'options': {'additional_retrieve_list': additional_retrieve_list}}}) + process.presubmit(fixture_sandbox) + retrieve_list = process.node.get_attribute('retrieve_list') + + # Check that the `retrieve_list` is a list and contains the union of the base and additional retrieve list + assert isinstance(process.node.get_attribute('retrieve_list'), list) + assert set(retrieve_list) == set(base_retrieve_list).union(set(additional_retrieve_list)) + + # Defining explicit additional retrieve list with elements that overlap with `base_retrieve_list + additional_retrieve_list = ['file.txt', 'folder/file.txt'] + base_retrieve_list + process = generate_process({'metadata': {'options': {'additional_retrieve_list': additional_retrieve_list}}}) + process.presubmit(fixture_sandbox) + retrieve_list = process.node.get_attribute('retrieve_list') + + # Check that the `retrieve_list` is a list and contains the union of the base and additional retrieve list + assert isinstance(process.node.get_attribute('retrieve_list'), list) + assert set(retrieve_list) == set(base_retrieve_list).union(set(additional_retrieve_list)) + + # Test the validator + with pytest.raises(ValueError, match=r'`additional_retrieve_list` should only contain relative filepaths.*'): + process = generate_process({'metadata': {'options': {'additional_retrieve_list': [None]}}}) + + with pytest.raises(ValueError, match=r'`additional_retrieve_list` should only contain relative filepaths.*'): + process = generate_process({'metadata': {'options': {'additional_retrieve_list': ['/abs/path']}}}) + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize(('stash_options', 'expected'), ( + ({ + 'target_base': None + }, '`metadata.options.stash.target_base` should be'), + ({ + 'target_base': 'relative/path' + }, '`metadata.options.stash.target_base` should be'), + ({ + 'target_base': '/path' + }, '`metadata.options.stash.source_list` should be'), + ({ + 'target_base': '/path', + 'source_list': ['/abspath'] + }, '`metadata.options.stash.source_list` should be'), + ({ + 'target_base': '/path', + 'source_list': ['rel/path'], + 'mode': 'test' + }, '`metadata.options.stash.mode` should be'), + ({ + 'target_base': '/path', + 'source_list': ['rel/path'] + }, None), + ({ + 'target_base': '/path', + 'source_list': ['rel/path'], + 'mode': StashMode.COPY.value + }, None), +)) +def test_validate_stash_options(stash_options, expected): + """Test the ``validate_stash_options`` function.""" + if expected is None: + assert validate_stash_options(stash_options, None) is expected + else: + assert expected in validate_stash_options(stash_options, None) diff --git a/tests/engine/test_calcfunctions.py b/tests/engine/test_calcfunctions.py index 5ff6329db2..95e2b55bf3 100644 --- a/tests/engine/test_calcfunctions.py +++ b/tests/engine/test_calcfunctions.py @@ -8,6 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the calcfunction decorator and CalcFunctionNode.""" +import pytest from aiida.backends.testbase import AiidaTestCase from aiida.common import exceptions @@ -37,6 +38,7 @@ def execution_counter_calcfunction(data): return Int(data.value + 1) +@pytest.mark.requires_rmq class TestCalcFunction(AiidaTestCase): """Tests for calcfunctions. diff --git a/tests/engine/test_daemon.py b/tests/engine/test_daemon.py index fd9c64ff7b..53f6b4a20b 100644 --- a/tests/engine/test_daemon.py +++ b/tests/engine/test_daemon.py @@ -8,8 +8,35 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Test daemon module.""" -from aiida.backends.testbase import AiidaTestCase +import asyncio +from plumpy.process_states import ProcessState +import pytest -class TestDaemon(AiidaTestCase): - """Testing the daemon.""" +from aiida.manage.manager import get_manager +from tests.utils import processes as test_processes + + +async def reach_waiting_state(process): + while process.state != ProcessState.WAITING: + await asyncio.sleep(0.1) + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_cancel_process_task(): + """This test is designed to replicate how processes are cancelled in the current `shutdown_runner` callback. + + The `CancelledError` should bubble up to the caller, and not be caught and transition the process to excepted. + """ + runner = get_manager().get_runner() + # create the process and start it running + process = runner.instantiate_process(test_processes.WaitProcess) + task = runner.loop.create_task(process.step_until_terminated()) + # wait for the process to reach a WAITING state + runner.loop.run_until_complete(asyncio.wait_for(reach_waiting_state(process), 5.0)) + # cancel the task and wait for the cancellation + task.cancel() + with pytest.raises(asyncio.CancelledError): + runner.loop.run_until_complete(asyncio.wait_for(task, 5.0)) + # the node should still record a waiting state, not excepted + assert process.node.process_state == ProcessState.WAITING diff --git a/tests/engine/test_futures.py b/tests/engine/test_futures.py index 521693137d..dba89e6c94 100644 --- a/tests/engine/test_futures.py +++ b/tests/engine/test_futures.py @@ -8,9 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module to test process futures.""" -import datetime +import asyncio -from tornado import gen +import pytest from aiida.backends.testbase import AiidaTestCase from aiida.engine import processes, run @@ -19,9 +19,10 @@ from tests.utils import processes as test_processes +@pytest.mark.requires_rmq class TestWf(AiidaTestCase): """Test process futures.""" - TIMEOUT = datetime.timedelta(seconds=5.0) + TIMEOUT = 5.0 # seconds def test_calculation_future_broadcasts(self): """Test calculation future broadcasts.""" @@ -31,11 +32,11 @@ def test_calculation_future_broadcasts(self): # No polling future = processes.futures.ProcessFuture( - pk=process.pid, poll_interval=None, communicator=manager.get_communicator() + pk=process.pid, loop=runner.loop, communicator=manager.get_communicator() ) run(process) - calc_node = runner.run_until_complete(gen.with_timeout(self.TIMEOUT, future)) + calc_node = runner.run_until_complete(asyncio.wait_for(future, self.TIMEOUT)) self.assertEqual(process.node.pk, calc_node.pk) @@ -49,6 +50,6 @@ def test_calculation_future_polling(self): future = processes.futures.ProcessFuture(pk=process.pid, loop=runner.loop, poll_interval=0) runner.run(process) - calc_node = runner.run_until_complete(gen.with_timeout(self.TIMEOUT, future)) + calc_node = runner.run_until_complete(asyncio.wait_for(future, self.TIMEOUT)) self.assertEqual(process.node.pk, calc_node.pk) diff --git a/tests/engine/test_launch.py b/tests/engine/test_launch.py index d259ee5121..7a53712127 100644 --- a/tests/engine/test_launch.py +++ b/tests/engine/test_launch.py @@ -8,6 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module to test processess launch.""" +import pytest + from aiida import orm from aiida.backends.testbase import AiidaTestCase from aiida.common import exceptions @@ -62,6 +64,7 @@ def add(self): self.out('result', orm.Int(self.inputs.term_a + self.inputs.term_b).store()) +@pytest.mark.requires_rmq class TestLaunchers(AiidaTestCase): """Class to test process launchers.""" @@ -142,6 +145,7 @@ def test_submit_store_provenance_false(self): launch.submit(AddWorkChain, term_a=self.term_a, term_b=self.term_b, metadata={'store_provenance': False}) +@pytest.mark.requires_rmq class TestLaunchersDryRun(AiidaTestCase): """Test the launchers when performing a dry-run.""" diff --git a/tests/engine/test_manager.py b/tests/engine/test_manager.py index 4e2748e901..574f30713f 100644 --- a/tests/engine/test_manager.py +++ b/tests/engine/test_manager.py @@ -10,8 +10,7 @@ """Tests for the classes in `aiida.engine.processes.calcjobs.manager`.""" import time - -import tornado +import asyncio from aiida.orm import AuthInfo, User from aiida.backends.testbase import AiidaTestCase @@ -24,7 +23,7 @@ class TestJobManager(AiidaTestCase): def setUp(self): super().setUp() - self.loop = tornado.ioloop.IOLoop() + self.loop = asyncio.get_event_loop() self.transport_queue = TransportQueue(self.loop) self.user = User.objects.get_default() self.auth_info = AuthInfo(self.computer, self.user).store() @@ -45,7 +44,7 @@ def test_get_jobs_list(self): def test_request_job_info_update(self): """Test the `JobManager.request_job_info_update` method.""" with self.manager.request_job_info_update(self.auth_info, job_id=1) as request: - self.assertIsInstance(request, tornado.concurrent.Future) + self.assertIsInstance(request, asyncio.Future) class TestJobsList(AiidaTestCase): @@ -53,7 +52,7 @@ class TestJobsList(AiidaTestCase): def setUp(self): super().setUp() - self.loop = tornado.ioloop.IOLoop() + self.loop = asyncio.get_event_loop() self.transport_queue = TransportQueue(self.loop) self.user = User.objects.get_default() self.auth_info = AuthInfo(self.computer, self.user).store() diff --git a/tests/engine/test_persistence.py b/tests/engine/test_persistence.py index 343bd868b3..7a451c1d0f 100644 --- a/tests/engine/test_persistence.py +++ b/tests/engine/test_persistence.py @@ -9,6 +9,7 @@ ########################################################################### """Test persisting via the AiiDAPersister.""" import plumpy +import pytest from aiida.backends.testbase import AiidaTestCase from aiida.engine.persistence import AiiDAPersister @@ -17,6 +18,7 @@ from tests.utils.processes import DummyProcess +@pytest.mark.requires_rmq class TestProcess(AiidaTestCase): """Test the basic saving and loading of process states.""" @@ -40,6 +42,7 @@ def test_save_load(self): self.assertEqual(loaded_process.state, plumpy.ProcessState.FINISHED) +@pytest.mark.requires_rmq class TestAiiDAPersister(AiidaTestCase): """Test AiiDAPersister.""" maxDiff = 1024 diff --git a/tests/engine/test_ports.py b/tests/engine/test_ports.py index d4cd6d6246..796571d407 100644 --- a/tests/engine/test_ports.py +++ b/tests/engine/test_ports.py @@ -92,7 +92,7 @@ def test_serialize_type_check(self): port_namespace.create_port_namespace(nested_namespace) with self.assertRaisesRegex(TypeError, f'.*{base_namespace}.*{nested_namespace}.*'): - port_namespace.serialize({'some': {'nested': {'namespace': {Dict()}}}}) + port_namespace.serialize({'some': {'nested': {'namespace': Dict()}}}) def test_lambda_default(self): """Test that an input port can specify a lambda as a default.""" diff --git a/tests/engine/test_process.py b/tests/engine/test_process.py index 22f2c0391e..8116fbc6e0 100644 --- a/tests/engine/test_process.py +++ b/tests/engine/test_process.py @@ -13,6 +13,7 @@ import plumpy from plumpy.utils import AttributesFrozendict +import pytest from aiida import orm from aiida.backends.testbase import AiidaTestCase @@ -36,6 +37,7 @@ def define(cls, spec): spec.input('some.name.space.a', valid_type=orm.Int) +@pytest.mark.requires_rmq class TestProcessNamespace(AiidaTestCase): """Test process namespace""" @@ -91,6 +93,7 @@ def on_stop(self): assert self._thread_id is threading.current_thread().ident +@pytest.mark.requires_rmq class TestProcess(AiidaTestCase): """Test AiiDA process.""" diff --git a/tests/engine/test_process_function.py b/tests/engine/test_process_function.py index b07d8e941a..fe911685ea 100644 --- a/tests/engine/test_process_function.py +++ b/tests/engine/test_process_function.py @@ -8,6 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the process_function decorator.""" +import pytest from aiida import orm from aiida.backends.testbase import AiidaTestCase @@ -22,6 +23,7 @@ CUSTOM_DESCRIPTION = 'Custom description' +@pytest.mark.requires_rmq class TestProcessFunction(AiidaTestCase): """ Note that here we use `@workfunctions` and `@calculations`, the concrete versions of the @@ -352,7 +354,7 @@ def test_launchers(self): self.assertTrue(isinstance(node, orm.CalcFunctionNode)) # Process function can be submitted and will be run by a daemon worker as long as the function is importable - # Note that the actual running is not tested here but is done so in `.ci/test_daemon.py`. + # Note that the actual running is not tested here but is done so in `.github/system_tests/test_daemon.py`. node = submit(add_multiply, x=orm.Int(1), y=orm.Int(2), z=orm.Int(3)) assert isinstance(node, orm.WorkFunctionNode) diff --git a/tests/engine/test_rmq.py b/tests/engine/test_rmq.py index 670b971d2e..aa9f863220 100644 --- a/tests/engine/test_rmq.py +++ b/tests/engine/test_rmq.py @@ -8,19 +8,20 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module to test RabbitMQ.""" -import datetime +import asyncio -from tornado import gen import plumpy +import pytest from aiida.backends.testbase import AiidaTestCase -from aiida.engine import ProcessState, submit +from aiida.engine import ProcessState from aiida.manage.manager import get_manager from aiida.orm import Int from tests.utils import processes as test_processes +@pytest.mark.requires_rmq class TestProcessControl(AiidaTestCase): """Test AiiDA's RabbitMQ functionalities.""" @@ -29,141 +30,152 @@ class TestProcessControl(AiidaTestCase): def setUp(self): super().setUp() - # These two need to share a common event loop otherwise the first will never send - # the message while the daemon is running listening to intercept + # The coroutine defined in testcase should run in runner's loop + # and process need submit by runner.submit rather than `submit` import from + # aiida.engine, since the broad one will create its own loop manager = get_manager() self.runner = manager.get_runner() - self.daemon_runner = manager.create_daemon_runner(loop=self.runner.loop) def tearDown(self): - self.daemon_runner.close() + self.runner.close() super().tearDown() def test_submit_simple(self): """"Launch the process.""" - @gen.coroutine - def do_submit(): - calc_node = submit(test_processes.DummyProcess) - yield self.wait_for_process(calc_node) + async def do_submit(): + calc_node = self.runner.submit(test_processes.DummyProcess) + await self.wait_for_process(calc_node) self.assertTrue(calc_node.is_finished_ok) self.assertEqual(calc_node.process_state.value, plumpy.ProcessState.FINISHED.value) - self.runner.loop.run_sync(do_submit) + self.runner.loop.run_until_complete(do_submit()) def test_launch_with_inputs(self): """Test launch with inputs.""" - @gen.coroutine - def do_launch(): + async def do_launch(): term_a = Int(5) term_b = Int(10) - calc_node = submit(test_processes.AddProcess, a=term_a, b=term_b) - yield self.wait_for_process(calc_node) + calc_node = self.runner.submit(test_processes.AddProcess, a=term_a, b=term_b) + await self.wait_for_process(calc_node) self.assertTrue(calc_node.is_finished_ok) self.assertEqual(calc_node.process_state.value, plumpy.ProcessState.FINISHED.value) - self.runner.loop.run_sync(do_launch) + self.runner.loop.run_until_complete(do_launch()) def test_submit_bad_input(self): with self.assertRaises(ValueError): - submit(test_processes.AddProcess, a=Int(5)) + self.runner.submit(test_processes.AddProcess, a=Int(5)) def test_exception_process(self): """Test process excpetion.""" - @gen.coroutine - def do_exception(): - calc_node = submit(test_processes.ExceptionProcess) - yield self.wait_for_process(calc_node) + async def do_exception(): + calc_node = self.runner.submit(test_processes.ExceptionProcess) + await self.wait_for_process(calc_node) self.assertFalse(calc_node.is_finished_ok) self.assertEqual(calc_node.process_state.value, plumpy.ProcessState.EXCEPTED.value) - self.runner.loop.run_sync(do_exception) + self.runner.loop.run_until_complete(do_exception()) def test_pause(self): """Testing sending a pause message to the process.""" controller = get_manager().get_process_controller() - @gen.coroutine - def do_pause(): - calc_node = submit(test_processes.WaitProcess) + async def do_pause(): + calc_node = self.runner.submit(test_processes.WaitProcess) while calc_node.process_state != ProcessState.WAITING: - yield + await asyncio.sleep(0.1) self.assertFalse(calc_node.paused) - future = yield with_timeout(controller.pause_process(calc_node.pk)) - result = yield self.wait_future(future) + pause_future = controller.pause_process(calc_node.pk) + future = await with_timeout(asyncio.wrap_future(pause_future)) + result = await self.wait_future(asyncio.wrap_future(future)) self.assertTrue(result) self.assertTrue(calc_node.paused) - self.runner.loop.run_sync(do_pause) + kill_message = 'Sorry, you have to go mate' + kill_future = controller.kill_process(calc_node.pk, msg=kill_message) + future = await with_timeout(asyncio.wrap_future(kill_future)) + result = await self.wait_future(asyncio.wrap_future(future)) + self.assertTrue(result) + + self.runner.loop.run_until_complete(do_pause()) def test_pause_play(self): """Test sending a pause and then a play message.""" controller = get_manager().get_process_controller() - @gen.coroutine - def do_pause_play(): - calc_node = submit(test_processes.WaitProcess) + async def do_pause_play(): + calc_node = self.runner.submit(test_processes.WaitProcess) self.assertFalse(calc_node.paused) while calc_node.process_state != ProcessState.WAITING: - yield + await asyncio.sleep(0.1) pause_message = 'Take a seat' - future = yield with_timeout(controller.pause_process(calc_node.pk, msg=pause_message)) - result = yield self.wait_future(future) + pause_future = controller.pause_process(calc_node.pk, msg=pause_message) + future = await with_timeout(asyncio.wrap_future(pause_future)) + result = await self.wait_future(asyncio.wrap_future(future)) self.assertTrue(calc_node.paused) self.assertEqual(calc_node.process_status, pause_message) - future = yield with_timeout(controller.play_process(calc_node.pk)) - result = yield self.wait_future(future) + play_future = controller.play_process(calc_node.pk) + future = await with_timeout(asyncio.wrap_future(play_future)) + result = await self.wait_future(asyncio.wrap_future(future)) + self.assertTrue(result) self.assertFalse(calc_node.paused) self.assertEqual(calc_node.process_status, None) - self.runner.loop.run_sync(do_pause_play) + kill_message = 'Sorry, you have to go mate' + kill_future = controller.kill_process(calc_node.pk, msg=kill_message) + future = await with_timeout(asyncio.wrap_future(kill_future)) + result = await self.wait_future(asyncio.wrap_future(future)) + self.assertTrue(result) + + self.runner.loop.run_until_complete(do_pause_play()) def test_kill(self): """Test sending a kill message.""" controller = get_manager().get_process_controller() - @gen.coroutine - def do_kill(): - calc_node = submit(test_processes.WaitProcess) + async def do_kill(): + calc_node = self.runner.submit(test_processes.WaitProcess) self.assertFalse(calc_node.is_killed) while calc_node.process_state != ProcessState.WAITING: - yield + await asyncio.sleep(0.1) kill_message = 'Sorry, you have to go mate' - future = yield with_timeout(controller.kill_process(calc_node.pk, msg=kill_message)) - result = yield self.wait_future(future) + kill_future = controller.kill_process(calc_node.pk, msg=kill_message) + future = await with_timeout(asyncio.wrap_future(kill_future)) + result = await self.wait_future(asyncio.wrap_future(future)) self.assertTrue(result) - self.wait_for_process(calc_node) + await self.wait_for_process(calc_node) self.assertTrue(calc_node.is_killed) self.assertEqual(calc_node.process_status, kill_message) - self.runner.loop.run_sync(do_kill) + self.runner.loop.run_until_complete(do_kill()) - @gen.coroutine - def wait_for_process(self, calc_node, timeout=2.): + async def wait_for_process(self, calc_node, timeout=2.): future = self.runner.get_process_future(calc_node.pk) - raise gen.Return((yield with_timeout(future, timeout))) + result = await with_timeout(future, timeout) + return result @staticmethod - @gen.coroutine - def wait_future(future, timeout=2.): - raise gen.Return((yield with_timeout(future, timeout))) + async def wait_future(future, timeout=2.): + result = await with_timeout(future, timeout) + return result -@gen.coroutine -def with_timeout(what, timeout=5.0): - raise gen.Return((yield gen.with_timeout(datetime.timedelta(seconds=timeout), what))) +async def with_timeout(what, timeout=5.0): + result = await asyncio.wait_for(what, timeout) + return result diff --git a/tests/engine/test_run.py b/tests/engine/test_run.py index 84981a6b7a..36535c64b7 100644 --- a/tests/engine/test_run.py +++ b/tests/engine/test_run.py @@ -8,6 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the `run` functions.""" +import pytest from aiida.backends.testbase import AiidaTestCase from aiida.engine import run, run_get_node @@ -16,6 +17,7 @@ from tests.utils.processes import DummyProcess +@pytest.mark.requires_rmq class TestRun(AiidaTestCase): """Tests for the `run` functions.""" diff --git a/tests/engine/test_runners.py b/tests/engine/test_runners.py index 41b8887c6f..774f3ee9a1 100644 --- a/tests/engine/test_runners.py +++ b/tests/engine/test_runners.py @@ -10,6 +10,7 @@ # pylint: disable=redefined-outer-name """Module to test process runners.""" import threading +import asyncio import plumpy import pytest @@ -24,7 +25,8 @@ def create_runner(): """Construct and return a `Runner`.""" def _create_runner(poll_interval=0.5): - return get_manager().create_runner(poll_interval=poll_interval) + loop = asyncio.new_event_loop() + return get_manager().create_runner(poll_interval=poll_interval, loop=loop) return _create_runner @@ -42,6 +44,7 @@ def the_hans_klok_comeback(loop): loop.stop() +@pytest.mark.requires_rmq @pytest.mark.usefixtures('clear_database_before_test') def test_call_on_process_finish(create_runner): """Test call on calculation finish.""" @@ -53,7 +56,7 @@ def test_call_on_process_finish(create_runner): def calc_done(): if event.is_set(): - future.set_exc_info(AssertionError('the callback was called twice, which should never happen')) + future.set_exception(AssertionError('the callback was called twice, which should never happen')) future.set_result(True) event.set() @@ -62,9 +65,9 @@ def calc_done(): runner.call_on_process_finish(proc.node.pk, calc_done) # Run the calculation - runner.loop.add_callback(proc.step_until_terminated) + runner.loop.create_task(proc.step_until_terminated()) loop.call_later(5, the_hans_klok_comeback, runner.loop) - loop.start() + loop.run_forever() - assert not future.exc_info() + assert not future.exception() assert future.result() diff --git a/tests/engine/test_transport.py b/tests/engine/test_transport.py index 9974fc1d9f..cae5b4e895 100644 --- a/tests/engine/test_transport.py +++ b/tests/engine/test_transport.py @@ -8,7 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module to test transport.""" -from tornado.gen import coroutine, Return +import asyncio from aiida.backends.testbase import AiidaTestCase from aiida.engine.transports import TransportQueue @@ -30,70 +30,65 @@ def tearDown(self, *args, **kwargs): # pylint: disable=arguments-differ def test_simple_request(self): """ Test a simple transport request """ queue = TransportQueue() - loop = queue.loop() + loop = queue.loop - @coroutine - def test(): + async def test(): trans = None with queue.request_transport(self.authinfo) as request: - trans = yield request + trans = await request self.assertTrue(trans.is_open) self.assertFalse(trans.is_open) - loop.run_sync(lambda: test()) # pylint: disable=unnecessary-lambda + loop.run_until_complete(test()) def test_get_transport_nested(self): """Test nesting calls to get the same transport.""" transport_queue = TransportQueue() - loop = transport_queue.loop() + loop = transport_queue.loop - @coroutine - def nested(queue, authinfo): + async def nested(queue, authinfo): with queue.request_transport(authinfo) as request1: - trans1 = yield request1 + trans1 = await request1 self.assertTrue(trans1.is_open) with queue.request_transport(authinfo) as request2: - trans2 = yield request2 + trans2 = await request2 self.assertIs(trans1, trans2) self.assertTrue(trans2.is_open) - loop.run_sync(lambda: nested(transport_queue, self.authinfo)) + loop.run_until_complete(nested(transport_queue, self.authinfo)) def test_get_transport_interleaved(self): """Test interleaved calls to get the same transport.""" transport_queue = TransportQueue() - loop = transport_queue.loop() + loop = transport_queue.loop - @coroutine - def interleaved(authinfo): + async def interleaved(authinfo): with transport_queue.request_transport(authinfo) as trans_future: - yield trans_future + await trans_future - loop.run_sync(lambda: [interleaved(self.authinfo), interleaved(self.authinfo)]) + loop.run_until_complete(asyncio.gather(interleaved(self.authinfo), interleaved(self.authinfo))) def test_return_from_context(self): """Test raising a Return from coroutine context.""" queue = TransportQueue() - loop = queue.loop() + loop = queue.loop - @coroutine - def test(): + async def test(): with queue.request_transport(self.authinfo) as request: - trans = yield request - raise Return(trans.is_open) + trans = await request + return trans.is_open - retval = loop.run_sync(lambda: test()) # pylint: disable=unnecessary-lambda + retval = loop.run_until_complete(test()) self.assertTrue(retval) def test_open_fail(self): """Test that if opening fails.""" queue = TransportQueue() - loop = queue.loop() + loop = queue.loop - @coroutine - def test(): + async def test(): with queue.request_transport(self.authinfo) as request: - yield request + await request def broken_open(trans): raise RuntimeError('Could not open transport') @@ -104,7 +99,7 @@ def broken_open(trans): original = self.authinfo.get_transport().__class__.open self.authinfo.get_transport().__class__.open = broken_open with self.assertRaises(RuntimeError): - loop.run_sync(lambda: test()) # pylint: disable=unnecessary-lambda + loop.run_until_complete(test()) finally: self.authinfo.get_transport().__class__.open = original @@ -120,22 +115,21 @@ def test_safe_interval(self): import time queue = TransportQueue() - loop = queue.loop() + loop = queue.loop time_start = time.time() - @coroutine - def test(iteration): + async def test(iteration): trans = None with queue.request_transport(self.authinfo) as request: - trans = yield request + trans = await request time_current = time.time() time_elapsed = time_current - time_start time_minimum = trans.get_safe_open_interval() * (iteration + 1) self.assertTrue(time_elapsed > time_minimum, 'transport safe interval was violated') - for i in range(5): - loop.run_sync(lambda iteration=i: test(iteration)) + for iteration in range(5): + loop.run_until_complete(test(iteration)) finally: transport_class._DEFAULT_SAFE_OPEN_INTERVAL = original_interval # pylint: disable=protected-access diff --git a/tests/engine/test_utils.py b/tests/engine/test_utils.py index b0121c9325..6b29e6321a 100644 --- a/tests/engine/test_utils.py +++ b/tests/engine/test_utils.py @@ -9,13 +9,15 @@ ########################################################################### # pylint: disable=global-statement """Test engine utilities such as the exponential backoff mechanism.""" -from tornado.ioloop import IOLoop -from tornado.gen import coroutine +import asyncio + +import pytest from aiida import orm from aiida.backends.testbase import AiidaTestCase from aiida.engine import calcfunction, workfunction -from aiida.engine.utils import exponential_backoff_retry, is_process_function +from aiida.engine.utils import exponential_backoff_retry, is_process_function, \ + InterruptableFuture, interruptable_task ITERATION = 0 MAX_ITERATIONS = 3 @@ -36,10 +38,9 @@ def test_exp_backoff_success(): """Test that exponential backoff will successfully catch exceptions as long as max_attempts is not exceeded.""" global ITERATION ITERATION = 0 - loop = IOLoop() + loop = asyncio.get_event_loop() - @coroutine - def coro(): + async def coro(): """A function that will raise RuntimeError as long as ITERATION is smaller than MAX_ITERATIONS.""" global ITERATION ITERATION += 1 @@ -47,15 +48,14 @@ def coro(): raise RuntimeError max_attempts = MAX_ITERATIONS + 1 - loop.run_sync(lambda: exponential_backoff_retry(coro, initial_interval=0.1, max_attempts=max_attempts)) + loop.run_until_complete(exponential_backoff_retry(coro, initial_interval=0.1, max_attempts=max_attempts)) def test_exp_backoff_max_attempts_exceeded(self): """Test that exponential backoff will finally raise if max_attempts is exceeded""" global ITERATION ITERATION = 0 - loop = IOLoop() + loop = asyncio.get_event_loop() - @coroutine def coro(): """A function that will raise RuntimeError as long as ITERATION is smaller than MAX_ITERATIONS.""" global ITERATION @@ -65,7 +65,11 @@ def coro(): max_attempts = MAX_ITERATIONS - 1 with self.assertRaises(RuntimeError): - loop.run_sync(lambda: exponential_backoff_retry(coro, initial_interval=0.1, max_attempts=max_attempts)) + loop.run_until_complete(exponential_backoff_retry(coro, initial_interval=0.1, max_attempts=max_attempts)) + + +class TestUtils(AiidaTestCase): + """ Tests for engine utils.""" def test_is_process_function(self): """Test the `is_process_function` utility.""" @@ -84,3 +88,141 @@ def work_function(): self.assertEqual(is_process_function(normal_function), False) self.assertEqual(is_process_function(calc_function), True) self.assertEqual(is_process_function(work_function), True) + + def test_is_process_scoped(self): + pass + + def test_loop_scope(self): + pass + + +class TestInterruptable(AiidaTestCase): + """ Tests for InterruptableFuture and interruptable_task.""" + + def test_normal_future(self): + """Test interrupt future not being interrupted""" + loop = asyncio.get_event_loop() + + interruptable = InterruptableFuture() + fut = asyncio.Future() + + async def task(): + fut.set_result('I am done') + + loop.run_until_complete(interruptable.with_interrupt(task())) + self.assertFalse(interruptable.done()) + self.assertEqual(fut.result(), 'I am done') + + def test_interrupt(self): + """Test interrupt future being interrupted""" + loop = asyncio.get_event_loop() + + interruptable = InterruptableFuture() + loop.call_soon(interruptable.interrupt, RuntimeError('STOP')) + try: + loop.run_until_complete(interruptable.with_interrupt(asyncio.sleep(10.))) + except RuntimeError as err: + self.assertEqual(str(err), 'STOP') + else: + self.fail('ExpectedException not raised') + + self.assertTrue(interruptable.done()) + + def test_inside_interrupted(self): + """Test interrupt future being interrupted from inside of coroutine""" + loop = asyncio.get_event_loop() + + interruptable = InterruptableFuture() + fut = asyncio.Future() + + async def task(): + await asyncio.sleep(1.) + interruptable.interrupt(RuntimeError('STOP')) + fut.set_result('I got set.') + + try: + loop.run_until_complete(interruptable.with_interrupt(task())) + except RuntimeError as err: + self.assertEqual(str(err), 'STOP') + else: + self.fail('ExpectedException not raised') + + self.assertTrue(interruptable.done()) + self.assertEqual(fut.result(), 'I got set.') + + def test_interruptable_future_set(self): + """Test interrupt future being set before coroutine is done""" + loop = asyncio.get_event_loop() + + interruptable = InterruptableFuture() + + async def task(): + interruptable.set_result('NOT ME!!!') + + loop.create_task(task()) + try: + loop.run_until_complete(interruptable.with_interrupt(asyncio.sleep(20.))) + except RuntimeError as err: + self.assertEqual(str(err), "This interruptible future had it's result set unexpectedly to 'NOT ME!!!'") + else: + self.fail('ExpectedException not raised') + + self.assertTrue(interruptable.done()) + + +@pytest.mark.requires_rmq +class TestInterruptableTask(): + """ Tests for InterruptableFuture and interruptable_task.""" + + @pytest.mark.asyncio + async def test_task(self): + """Test coroutine run and succed""" + + async def task_fn(cancellable): + fut = asyncio.Future() + + async def coro(): + fut.set_result('I am done') + + await cancellable.with_interrupt(coro()) + return fut.result() + + task_fut = interruptable_task(task_fn) + result = await task_fut + assert isinstance(task_fut, InterruptableFuture) + assert task_fut.done() + assert result == 'I am done' + + @pytest.mark.asyncio + async def test_interrupted(self): + """Test interrupt future being interrupted""" + + async def task_fn(cancellable): + cancellable.interrupt(RuntimeError('STOP')) + + task_fut = interruptable_task(task_fn) + try: + await task_fut + except RuntimeError as err: + assert str(err) == 'STOP' + else: + raise AssertionError('ExpectedException not raised') + + @pytest.mark.asyncio + async def test_future_already_set(self): + """Test interrupt future being set before coroutine is done""" + + async def task_fn(cancellable): + fut = asyncio.Future() + + async def coro(): + fut.set_result('I am done') + + await cancellable.with_interrupt(coro()) + cancellable.set_result('NOT ME!!!') + return fut.result() + + task_fut = interruptable_task(task_fn) + + result = await task_fut + assert result == 'NOT ME!!!' diff --git a/tests/engine/test_work_chain.py b/tests/engine/test_work_chain.py index 040176894d..49166c9872 100644 --- a/tests/engine/test_work_chain.py +++ b/tests/engine/test_work_chain.py @@ -11,9 +11,9 @@ """Tests for the `WorkChain` class.""" import inspect import unittest +import asyncio import plumpy -from tornado import gen import pytest from aiida import orm @@ -28,7 +28,7 @@ def run_until_paused(proc): - """ Set up a future that will be resolved on entering the WAITING state """ + """ Set up a future that will be resolved when process is paused""" listener = plumpy.ProcessListener() paused = plumpy.Future() @@ -95,7 +95,7 @@ def define(cls, spec): spec.outputs.dynamic = True spec.outline( cls.step1, - if_(cls.is_a)(cls.step2).elif_(cls.is_b)(cls.step3).else_(cls.step4), + if_(cls.is_a)(cls.step2).elif_(cls.is_b)(cls.step3).else_(cls.step4), # pylint: disable=no-member cls.step5, while_(cls.larger_then_n)(cls.step6,), ) @@ -111,37 +111,37 @@ def on_create(self): } def step1(self): - self._set_finished(inspect.stack()[0][3]) + self._set_finished(inspect.stack()[0].function) def step2(self): - self._set_finished(inspect.stack()[0][3]) + self._set_finished(inspect.stack()[0].function) def step3(self): - self._set_finished(inspect.stack()[0][3]) + self._set_finished(inspect.stack()[0].function) def step4(self): - self._set_finished(inspect.stack()[0][3]) + self._set_finished(inspect.stack()[0].function) def step5(self): self.ctx.counter = 0 - self._set_finished(inspect.stack()[0][3]) + self._set_finished(inspect.stack()[0].function) def step6(self): self.ctx.counter = self.ctx.counter + 1 - self._set_finished(inspect.stack()[0][3]) + self._set_finished(inspect.stack()[0].function) def is_a(self): - self._set_finished(inspect.stack()[0][3]) + self._set_finished(inspect.stack()[0].function) return self.inputs.value.value == 'A' def is_b(self): - self._set_finished(inspect.stack()[0][3]) + self._set_finished(inspect.stack()[0].function) return self.inputs.value.value == 'B' def larger_then_n(self): keep_looping = self.ctx.counter < self.inputs.n.value if not keep_looping: - self._set_finished(inspect.stack()[0][3]) + self._set_finished(inspect.stack()[0].function) return keep_looping def _set_finished(self, function_name): @@ -187,6 +187,7 @@ def success(self): self.out(self.OUTPUT_LABEL, Int(self.OUTPUT_VALUE).store()) +@pytest.mark.requires_rmq class TestExitStatus(AiidaTestCase): """ This class should test the various ways that one can exit from the outline flow of a WorkChain, other than @@ -252,8 +253,8 @@ def define(cls, spec): super().define(spec) spec.outline(if_(cls.condition)(cls.step1, cls.step2)) - def on_create(self, *args, **kwargs): - super().on_create(*args, **kwargs) + def on_create(self): + super().on_create() self.ctx.s1 = False self.ctx.s2 = False @@ -268,6 +269,7 @@ def step2(self): self.ctx.s2 = True +@pytest.mark.requires_rmq class TestContext(AiidaTestCase): def test_attributes(self): @@ -289,6 +291,7 @@ def test_dict(self): wc.ctx['new_attr'] # pylint: disable=pointless-statement +@pytest.mark.requires_rmq class TestWorkchain(AiidaTestCase): # pylint: disable=too-many-public-methods @@ -677,20 +680,22 @@ def do_run(self): run_and_check_success(MainWorkChain) def test_if_block_persistence(self): - """ + """Test a reloaded `If` conditional can be resumed. + This test was created to capture issue #902 """ runner = get_manager().get_runner() wc = IfTest() runner.schedule(wc) - @gen.coroutine - def run_async(workchain): - yield run_until_paused(workchain) + async def run_async(workchain): + + # run the original workchain until paused + await run_until_paused(workchain) self.assertTrue(workchain.ctx.s1) self.assertFalse(workchain.ctx.s2) - # Now bundle the thing + # Now bundle the workchain bundle = plumpy.Bundle(workchain) # Need to close the process before recreating a new instance workchain.close() @@ -700,15 +705,22 @@ def run_async(workchain): self.assertTrue(workchain2.ctx.s1) self.assertFalse(workchain2.ctx.s2) + # check bundling again creates the same saved state bundle2 = plumpy.Bundle(workchain2) self.assertDictEqual(bundle, bundle2) - workchain.play() - yield workchain.future() - self.assertTrue(workchain.ctx.s1) - self.assertTrue(workchain.ctx.s2) + # run the loaded workchain to completion + runner.schedule(workchain2) + workchain2.play() + await workchain2.future() + self.assertTrue(workchain2.ctx.s1) + self.assertTrue(workchain2.ctx.s2) - runner.loop.run_sync(lambda: run_async(wc)) # pylint: disable=unnecessary-lambda + # ensure the original paused workchain future is finalised + # to avoid warnings + workchain.future().set_result(None) + + runner.loop.run_until_complete(run_async(wc)) def test_report_dbloghandler(self): """ @@ -851,6 +863,7 @@ def _run_with_checkpoints(wf_class, inputs=None): return proc.finished_steps +@pytest.mark.requires_rmq class TestWorkChainAbort(AiidaTestCase): """ Test the functionality to abort a workchain @@ -885,18 +898,17 @@ def test_simple_run(self): runner = get_manager().get_runner() process = TestWorkChainAbort.AbortableWorkChain() - @gen.coroutine - def run_async(): - yield run_until_paused(process) + async def run_async(): + await run_until_paused(process) process.play() with Capturing(): with self.assertRaises(RuntimeError): - yield process.future() + await process.future() runner.schedule(process) - runner.loop.run_sync(lambda: run_async()) # pylint: disable=unnecessary-lambda + runner.loop.run_until_complete(run_async()) self.assertEqual(process.node.is_finished_ok, False) self.assertEqual(process.node.is_excepted, True) @@ -911,9 +923,8 @@ def test_simple_kill_through_process(self): runner = get_manager().get_runner() process = TestWorkChainAbort.AbortableWorkChain() - @gen.coroutine - def run_async(): - yield run_until_paused(process) + async def run_async(): + await run_until_paused(process) self.assertTrue(process.paused) process.kill() @@ -922,13 +933,14 @@ def run_async(): launch.run(process) runner.schedule(process) - runner.loop.run_sync(lambda: run_async()) # pylint: disable=unnecessary-lambda + runner.loop.run_until_complete(run_async()) self.assertEqual(process.node.is_finished_ok, False) self.assertEqual(process.node.is_excepted, False) self.assertEqual(process.node.is_killed, True) +@pytest.mark.requires_rmq class TestWorkChainAbortChildren(AiidaTestCase): """ Test the functionality to abort a workchain and verify that children @@ -998,17 +1010,18 @@ def test_simple_kill_through_process(self): runner = get_manager().get_runner() process = TestWorkChainAbortChildren.MainWorkChain(inputs={'kill': Bool(True)}) - @gen.coroutine - def run_async(): - yield run_until_waiting(process) + async def run_async(): + await run_until_waiting(process) - process.kill() + result = process.kill() + if asyncio.isfuture(result): + await result with self.assertRaises(plumpy.KilledError): - yield process.future() + await process.future() runner.schedule(process) - runner.loop.run_sync(lambda: run_async()) # pylint: disable=unnecessary-lambda + runner.loop.run_until_complete(run_async()) child = process.node.get_outgoing(link_type=LinkType.CALL_WORK).first().node self.assertEqual(child.is_finished_ok, False) @@ -1020,6 +1033,7 @@ def run_async(): self.assertEqual(process.node.is_killed, True) +@pytest.mark.requires_rmq class TestImmutableInputWorkchain(AiidaTestCase): """ Test that inputs cannot be modified @@ -1125,6 +1139,7 @@ def do_test(self): assert self.inputs.test == self.inputs.reference +@pytest.mark.requires_rmq class TestSerializeWorkChain(AiidaTestCase): """ Test workchains with serialized input / output. @@ -1251,6 +1266,7 @@ def do_run(self): self.out('c', self.inputs.c) +@pytest.mark.requires_rmq class TestWorkChainExpose(AiidaTestCase): """ Test the expose inputs / outputs functionality @@ -1362,6 +1378,7 @@ def step1(self): launch.run(Child) +@pytest.mark.requires_rmq class TestWorkChainMisc(AiidaTestCase): class PointlessWorkChain(WorkChain): @@ -1398,6 +1415,7 @@ def test_global_submit_raises(self): launch.run(TestWorkChainMisc.IllegalSubmitWorkChain) +@pytest.mark.requires_rmq class TestDefaultUniqueness(AiidaTestCase): """Test that default inputs of exposed nodes will get unique UUIDS.""" diff --git a/tests/engine/test_workfunctions.py b/tests/engine/test_workfunctions.py index 9f62d54a77..16b4d0f83a 100644 --- a/tests/engine/test_workfunctions.py +++ b/tests/engine/test_workfunctions.py @@ -8,6 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the workfunction decorator and WorkFunctionNode.""" +import pytest from aiida.backends.testbase import AiidaTestCase from aiida.common.links import LinkType @@ -16,6 +17,7 @@ from aiida.orm import Int, WorkFunctionNode, CalcFunctionNode +@pytest.mark.requires_rmq class TestWorkFunction(AiidaTestCase): """Tests for workfunctions. diff --git a/tests/manage/configuration/migrations/test_migrations.py b/tests/manage/configuration/migrations/test_migrations.py index 4251b3116a..d4e43203e7 100644 --- a/tests/manage/configuration/migrations/test_migrations.py +++ b/tests/manage/configuration/migrations/test_migrations.py @@ -70,3 +70,10 @@ def test_3_4_migration(self): config_reference = self.load_config_sample('reference/4.json') config_migrated = _MIGRATION_LOOKUP[3].apply(config_initial) self.assertEqual(config_migrated, config_reference) + + def test_4_5_migration(self): + """Test the step between config versions 4 and 5.""" + config_initial = self.load_config_sample('input/4.json') + config_reference = self.load_config_sample('reference/5.json') + config_migrated = _MIGRATION_LOOKUP[4].apply(config_initial) + self.assertEqual(config_migrated, config_reference) diff --git a/tests/manage/configuration/migrations/test_samples/input/4.json b/tests/manage/configuration/migrations/test_samples/input/4.json new file mode 100644 index 0000000000..0f3df0ec40 --- /dev/null +++ b/tests/manage/configuration/migrations/test_samples/input/4.json @@ -0,0 +1 @@ +{"CONFIG_VERSION": {"CURRENT": 4, "OLDEST_COMPATIBLE": 3}, "default_profile": "default", "profiles": {"default": {"PROFILE_UUID": "00000000000000000000000000000000", "AIIDADB_ENGINE": "postgresql_psycopg2", "AIIDADB_PASS": "some_random_password", "AIIDADB_NAME": "aiidadb_qs_some_user", "AIIDADB_HOST": "localhost", "AIIDADB_BACKEND": "django", "AIIDADB_PORT": "5432", "default_user_email": "email@aiida.net", "AIIDADB_REPOSITORY_URI": "file:////home/some_user/.aiida/repository-quicksetup/", "AIIDADB_USER": "aiida_qs_greschd", "broker_protocol": "amqp", "broker_username": "guest", "broker_password": "guest", "broker_host": "127.0.0.1", "broker_port": 5672, "broker_virtual_host": ""}}} diff --git a/tests/manage/configuration/migrations/test_samples/reference/5.json b/tests/manage/configuration/migrations/test_samples/reference/5.json new file mode 100644 index 0000000000..14f0942535 --- /dev/null +++ b/tests/manage/configuration/migrations/test_samples/reference/5.json @@ -0,0 +1 @@ +{"CONFIG_VERSION": {"CURRENT": 5, "OLDEST_COMPATIBLE": 5}, "default_profile": "default", "profiles": {"default": {"PROFILE_UUID": "00000000000000000000000000000000", "AIIDADB_ENGINE": "postgresql_psycopg2", "AIIDADB_PASS": "some_random_password", "AIIDADB_NAME": "aiidadb_qs_some_user", "AIIDADB_HOST": "localhost", "AIIDADB_BACKEND": "django", "AIIDADB_PORT": "5432", "default_user_email": "email@aiida.net", "AIIDADB_REPOSITORY_URI": "file:////home/some_user/.aiida/repository-quicksetup/", "AIIDADB_USER": "aiida_qs_greschd", "broker_protocol": "amqp", "broker_username": "guest", "broker_password": "guest", "broker_host": "127.0.0.1", "broker_port": 5672, "broker_virtual_host": ""}}} diff --git a/tests/manage/configuration/migrations/test_samples/reference/final.json b/tests/manage/configuration/migrations/test_samples/reference/final.json index 0f3df0ec40..14f0942535 100644 --- a/tests/manage/configuration/migrations/test_samples/reference/final.json +++ b/tests/manage/configuration/migrations/test_samples/reference/final.json @@ -1 +1 @@ -{"CONFIG_VERSION": {"CURRENT": 4, "OLDEST_COMPATIBLE": 3}, "default_profile": "default", "profiles": {"default": {"PROFILE_UUID": "00000000000000000000000000000000", "AIIDADB_ENGINE": "postgresql_psycopg2", "AIIDADB_PASS": "some_random_password", "AIIDADB_NAME": "aiidadb_qs_some_user", "AIIDADB_HOST": "localhost", "AIIDADB_BACKEND": "django", "AIIDADB_PORT": "5432", "default_user_email": "email@aiida.net", "AIIDADB_REPOSITORY_URI": "file:////home/some_user/.aiida/repository-quicksetup/", "AIIDADB_USER": "aiida_qs_greschd", "broker_protocol": "amqp", "broker_username": "guest", "broker_password": "guest", "broker_host": "127.0.0.1", "broker_port": 5672, "broker_virtual_host": ""}}} +{"CONFIG_VERSION": {"CURRENT": 5, "OLDEST_COMPATIBLE": 5}, "default_profile": "default", "profiles": {"default": {"PROFILE_UUID": "00000000000000000000000000000000", "AIIDADB_ENGINE": "postgresql_psycopg2", "AIIDADB_PASS": "some_random_password", "AIIDADB_NAME": "aiidadb_qs_some_user", "AIIDADB_HOST": "localhost", "AIIDADB_BACKEND": "django", "AIIDADB_PORT": "5432", "default_user_email": "email@aiida.net", "AIIDADB_REPOSITORY_URI": "file:////home/some_user/.aiida/repository-quicksetup/", "AIIDADB_USER": "aiida_qs_greschd", "broker_protocol": "amqp", "broker_username": "guest", "broker_password": "guest", "broker_host": "127.0.0.1", "broker_port": 5672, "broker_virtual_host": ""}}} diff --git a/tests/manage/configuration/test_config.py b/tests/manage/configuration/test_config.py index 71c02e1528..50a91f1643 100644 --- a/tests/manage/configuration/test_config.py +++ b/tests/manage/configuration/test_config.py @@ -374,7 +374,7 @@ def test_option(self): def test_option_global_only(self): """Test that `global_only` options are only set globally even if a profile specific scope is set.""" - option_name = 'user.email' + option_name = 'autofill.user.email' option_value = 'some@email.com' config = Config(self.config_filepath, self.config_dictionary) @@ -390,7 +390,7 @@ def test_option_global_only(self): def test_set_option_override(self): """Test that `global_only` options are only set globally even if a profile specific scope is set.""" - option_name = 'user.email' + option_name = 'autofill.user.email' option_value_one = 'first@email.com' option_value_two = 'second@email.com' diff --git a/tests/manage/configuration/test_options.py b/tests/manage/configuration/test_options.py index 972b7fb280..b498922e3a 100644 --- a/tests/manage/configuration/test_options.py +++ b/tests/manage/configuration/test_options.py @@ -8,12 +8,12 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the configuration options.""" +import pytest from aiida.backends.testbase import AiidaTestCase -from aiida.manage.configuration.options import get_option, get_option_names, parse_option, Option, CONFIG_OPTIONS -from aiida.manage.configuration import get_config, get_config_option - -from tests.utils.configuration import with_temporary_config_instance +from aiida.common.exceptions import ConfigurationError +from aiida.manage.configuration.options import get_option, get_option_names, parse_option, Option +from aiida.manage.configuration import get_config, get_config_option, ConfigValidationError class TestConfigurationOptions(AiidaTestCase): @@ -21,14 +21,15 @@ class TestConfigurationOptions(AiidaTestCase): def test_get_option_names(self): """Test `get_option_names` function.""" - self.assertEqual(get_option_names(), CONFIG_OPTIONS.keys()) + self.assertIsInstance(get_option_names(), list) + self.assertEqual(len(get_option_names()), 26) def test_get_option(self): """Test `get_option` function.""" - with self.assertRaises(ValueError): + with self.assertRaises(ConfigurationError): get_option('no_existing_option') - option_name = list(CONFIG_OPTIONS)[0] + option_name = get_option_names()[0] option = get_option(option_name) self.assertIsInstance(option, Option) self.assertEqual(option.name, option_name) @@ -36,24 +37,22 @@ def test_get_option(self): def test_parse_option(self): """Test `parse_option` function.""" - with self.assertRaises(ValueError): + with self.assertRaises(ConfigValidationError): parse_option('logging.aiida_loglevel', 1) - with self.assertRaises(ValueError): + with self.assertRaises(ConfigValidationError): parse_option('logging.aiida_loglevel', 'INVALID_LOG_LEVEL') def test_options(self): """Test that all defined options can be converted into Option namedtuples.""" - for option_name, set_optiontings in CONFIG_OPTIONS.items(): + for option_name in get_option_names(): option = get_option(option_name) self.assertEqual(option.name, option_name) - self.assertEqual(option.key, set_optiontings['key']) - self.assertEqual(option.valid_type, set_optiontings['valid_type']) - self.assertEqual(option.valid_values, set_optiontings['valid_values']) - self.assertEqual(option.default, set_optiontings['default']) - self.assertEqual(option.description, set_optiontings['description']) + self.assertIsInstance(option.description, str) + option.valid_type # pylint: disable=pointless-statement + option.default # pylint: disable=pointless-statement - @with_temporary_config_instance + @pytest.mark.usefixtures('config_with_profile') def test_get_config_option_default(self): """Tests that `get_option` return option default if not specified globally or for current profile.""" option_name = 'logging.aiida_loglevel' @@ -63,7 +62,7 @@ def test_get_config_option_default(self): option_value = get_config_option(option_name) self.assertEqual(option_value, option.default) - @with_temporary_config_instance + @pytest.mark.usefixtures('config_with_profile') def test_get_config_option_profile_specific(self): """Tests that `get_option` correctly gets a configuration option if specified for the current profile.""" config = get_config() @@ -77,7 +76,7 @@ def test_get_config_option_profile_specific(self): option_value = get_config_option(option_name) self.assertEqual(option_value, option_value_profile) - @with_temporary_config_instance + @pytest.mark.usefixtures('config_with_profile') def test_get_config_option_global(self): """Tests that `get_option` correctly agglomerates upwards and so retrieves globally set config options.""" config = get_config() diff --git a/tests/manage/configuration/test_profile.py b/tests/manage/configuration/test_profile.py index f60140b2d6..9009ef25cd 100644 --- a/tests/manage/configuration/test_profile.py +++ b/tests/manage/configuration/test_profile.py @@ -68,9 +68,9 @@ def test_is_test_profile(self): def test_set_option(self): """Test the `set_option` method.""" - option_key = 'user_email' - option_value_one = 'first@email.com' - option_value_two = 'second@email.com' + option_key = 'daemon.timeout' + option_value_one = 999 + option_value_two = 666 # Setting an option if it does not exist should work self.profile.set_option(option_key, option_value_one) diff --git a/tests/manage/external/test_rmq.py b/tests/manage/external/test_rmq.py index 7733334fbd..940f4f5681 100644 --- a/tests/manage/external/test_rmq.py +++ b/tests/manage/external/test_rmq.py @@ -8,6 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the `aiida.manage.external.rmq` module.""" +from kiwipy.rmq import RmqThreadCommunicator import pytest from aiida.manage.external import rmq @@ -34,3 +35,25 @@ def test_get_rmq_url(args, kwargs, expected): else: with pytest.raises(expected): rmq.get_rmq_url(*args, **kwargs) + + +@pytest.mark.requires_rmq +@pytest.mark.parametrize('url', ('amqp://guest:guest@127.0.0.1:5672',)) +def test_communicator(url): + """Test the instantiation of a ``kiwipy.rmq.RmqThreadCommunicator``. + + This class is used by all runners to communicate with the RabbitMQ server. + """ + RmqThreadCommunicator.connect(connection_params={'url': url}) + + +@pytest.mark.requires_rmq +def test_add_rpc_subscriber(communicator): + """Test ``add_rpc_subscriber``.""" + communicator.add_rpc_subscriber(None) + + +@pytest.mark.requires_rmq +def test_add_broadcast_subscriber(communicator): + """Test ``add_broadcast_subscriber``.""" + communicator.add_broadcast_subscriber(None) diff --git a/tests/manage/test_caching_config.py b/tests/manage/test_caching_config.py index 1e6881d7e4..265d6a3d5e 100644 --- a/tests/manage/test_caching_config.py +++ b/tests/manage/test_caching_config.py @@ -11,74 +11,87 @@ # pylint: disable=redefined-outer-name -import tempfile import contextlib +import json +from pathlib import Path import yaml import pytest from aiida.common import exceptions -from aiida.manage.configuration import get_profile -from aiida.manage.caching import configure, get_use_cache, enable_caching, disable_caching +from aiida.manage.caching import get_use_cache, enable_caching, disable_caching @pytest.fixture -def configure_caching(): +def configure_caching(config_with_profile_factory): """ Fixture to set the caching configuration in the test profile to a specific dictionary. This is done by creating a temporary caching configuration file. """ + config = config_with_profile_factory() @contextlib.contextmanager def inner(config_dict): - with tempfile.NamedTemporaryFile() as handle: - yaml.dump({get_profile().name: config_dict}, handle, encoding='utf-8') - configure(config_file=handle.name) + for key, value in config_dict.items(): + config.set_option(f'caching.{key}', value) yield # reset the configuration - configure() + for key in config_dict.keys(): + config.unset_option(f'caching.{key}') return inner -@pytest.fixture -def use_default_configuration(configure_caching): # pylint: disable=redefined-outer-name - """ - Fixture to load a default caching configuration. - """ - with configure_caching( - config_dict={ - 'default': True, - 'enabled': ['aiida.calculations:arithmetic.add'], - 'disabled': ['aiida.calculations:templatereplacer'] - } - ): - yield +def test_merge_deprecated_yaml(tmp_path): + """Test that an existing 'cache_config.yml' is correctly merged into the main config. - -def test_empty_enabled_disabled(configure_caching): - """Test that `aiida.manage.caching.configure` does not except when either `enabled` or `disabled` is `None`. - - This will happen when the configuration file specifies either one of the keys but no actual values, e.g.:: - - profile_name: - default: False - enabled: - - In this case, the dictionary parsed by yaml will contain `None` for the `enabled` key. - Now this will be unlikely, but the same holds when all values are commented:: - - profile_name: - default: False - enabled: - # - aiida.calculations:templatereplacer - - which is not unlikely to occurr in the wild. + An AiidaDeprecationWarning should also be raised. """ - with configure_caching(config_dict={'default': True, 'enabled': None, 'disabled': None}): - # Check that `get_use_cache` also does not except, and works as expected - assert get_use_cache(identifier='aiida.calculations:templatereplacer') + from aiida.common.warnings import AiidaDeprecationWarning + from aiida.manage import configuration + from aiida.manage.configuration import settings, load_profile, reset_profile, get_config_option + + # Store the current configuration instance and config directory path + current_config = configuration.CONFIG + current_config_path = current_config.dirpath + current_profile_name = configuration.PROFILE.name + + try: + reset_profile() + configuration.CONFIG = None + + # Create a temporary folder, set it as the current config directory path + settings.AIIDA_CONFIG_FOLDER = str(tmp_path) + config_dictionary = json.loads( + Path(__file__).parent.joinpath('configuration/migrations/test_samples/reference/5.json').read_text() + ) + config_dictionary['profiles']['default']['AIIDADB_REPOSITORY_URI'] = f"file:///{tmp_path/'repo'}" + cache_dictionary = { + 'default': { + 'default': True, + 'enabled': ['aiida.calculations:quantumespresso.pw'], + 'disabled': ['aiida.calculations:templatereplacer'] + } + } + tmp_path.joinpath('config.json').write_text(json.dumps(config_dictionary)) + tmp_path.joinpath('cache_config.yml').write_text(yaml.dump(cache_dictionary)) + with pytest.warns(AiidaDeprecationWarning, match='cache_config.yml'): + configuration.CONFIG = configuration.load_config() + load_profile('default') + + assert get_config_option('caching.default_enabled') is True + assert get_config_option('caching.enabled_for') == ['aiida.calculations:quantumespresso.pw'] + assert get_config_option('caching.disabled_for') == ['aiida.calculations:templatereplacer'] + # should have now been moved to cache_config.yml. + assert not tmp_path.joinpath('cache_config.yml').exists() + finally: + # Reset the config folder path and the config instance. Note this will always be executed after the yield no + # matter what happened in the test that used this fixture. + reset_profile() + settings.AIIDA_CONFIG_FOLDER = current_config_path + configuration.CONFIG = current_config + load_profile(current_profile_name) def test_no_enabled_disabled(configure_caching): @@ -89,7 +102,7 @@ def test_no_enabled_disabled(configure_caching): profile_name: default: False """ - with configure_caching(config_dict={'default': False}): + with configure_caching(config_dict={'default_enabled': False}): # Check that `get_use_cache` also does not except, and works as expected assert not get_use_cache(identifier='aiida.calculations:templatereplacer') @@ -98,24 +111,24 @@ def test_no_enabled_disabled(configure_caching): 'config_dict', [{ 'wrong_key': ['foo'] }, { - 'default': 2 + 'default_enabled': 'x' }, { - 'enabled': 4 + 'enabled_for': 4 }, { - 'default': 'string' + 'default_enabled': 'string' }, { - 'enabled': ['aiida.spam:Ni'] + 'enabled_for': ['aiida.spam:Ni'] }, { - 'default': True, - 'enabled': ['aiida.calculations:With:second_separator'] + 'default_enabled': True, + 'enabled_for': ['aiida.calculations:With:second_separator'] }, { - 'enabled': ['aiida.sp*:Ni'] + 'enabled_for': ['aiida.sp*:Ni'] }, { - 'disabled': ['aiida.sp*!bar'] + 'disabled_for': ['aiida.sp*!bar'] }, { - 'enabled': ['startswith.number.2bad'] + 'enabled_for': ['startswith.number.2bad'] }, { - 'enabled': ['some.thing.in.this.is.a.keyword'] + 'enabled_for': ['some.thing.in.this.is.a.keyword'] }] ) def test_invalid_configuration_dict(configure_caching, config_dict): @@ -126,71 +139,73 @@ def test_invalid_configuration_dict(configure_caching, config_dict): pass -def test_invalid_identifier(use_default_configuration): # pylint: disable=unused-argument +def test_invalid_identifier(configure_caching): """Test `get_use_cache` raises a `TypeError` if identifier is not a string.""" - with pytest.raises(TypeError): - get_use_cache(identifier=int) + with configure_caching({}): + with pytest.raises(TypeError): + get_use_cache(identifier=int) -def test_default(use_default_configuration): # pylint: disable=unused-argument +def test_default(configure_caching): """Verify that when not specifying any specific identifier, the `default` is used, which is set to `True`.""" - assert get_use_cache() + with configure_caching({'default_enabled': True}): + assert get_use_cache() -@pytest.mark.parametrize(['config_dict', 'enabled', 'disabled'], [ +@pytest.mark.parametrize(['config_dict', 'enabled_for', 'disabled_for'], [ ({ - 'default': True, - 'enabled': ['aiida.calculations:arithmetic.add'], - 'disabled': ['aiida.calculations:templatereplacer'] + 'default_enabled': True, + 'enabled_for': ['aiida.calculations:arithmetic.add'], + 'disabled_for': ['aiida.calculations:templatereplacer'] }, ['some_identifier', 'aiida.calculations:arithmetic.add', 'aiida.calculations:TEMPLATEREPLACER' ], ['aiida.calculations:templatereplacer']), ({ - 'default': False, - 'enabled': ['aiida.calculations:arithmetic.add'], - 'disabled': ['aiida.calculations:templatereplacer'] + 'default_enabled': False, + 'enabled_for': ['aiida.calculations:arithmetic.add'], + 'disabled_for': ['aiida.calculations:templatereplacer'] }, ['aiida.calculations:arithmetic.add'], ['aiida.calculations:templatereplacer', 'some_identifier']), ({ - 'default': False, - 'enabled': ['aiida.calculations:*'], + 'default_enabled': False, + 'enabled_for': ['aiida.calculations:*'], }, ['aiida.calculations:templatereplacer', 'aiida.calculations:arithmetic.add'], ['some_identifier']), ({ - 'default': False, - 'enabled': ['aiida.calcul*'], + 'default_enabled': False, + 'enabled_for': ['aiida.calcul*'], }, ['aiida.calculations:templatereplacer', 'aiida.calculations:arithmetic.add'], ['some_identifier']), ({ - 'default': False, - 'enabled': ['aiida.calculations:*'], - 'disabled': ['aiida.calculations:arithmetic.add'] + 'default_enabled': False, + 'enabled_for': ['aiida.calculations:*'], + 'disabled_for': ['aiida.calculations:arithmetic.add'] }, ['aiida.calculations:templatereplacer', 'aiida.calculations:ARIthmetic.add' ], ['some_identifier', 'aiida.calculations:arithmetic.add']), ({ - 'default': False, - 'enabled': ['aiida.calculations:ar*thmetic.add'], - 'disabled': ['aiida.calculations:*'], + 'default_enabled': False, + 'enabled_for': ['aiida.calculations:ar*thmetic.add'], + 'disabled_for': ['aiida.calculations:*'], }, ['aiida.calculations:arithmetic.add', 'aiida.calculations:arblarghthmetic.add' ], ['some_identifier', 'aiida.calculations:templatereplacer']), ]) -def test_configuration(configure_caching, config_dict, enabled, disabled): +def test_configuration(configure_caching, config_dict, enabled_for, disabled_for): """Check that different caching configurations give the expected result. """ with configure_caching(config_dict=config_dict): - for identifier in enabled: + for identifier in enabled_for: assert get_use_cache(identifier=identifier) - for identifier in disabled: + for identifier in disabled_for: assert not get_use_cache(identifier=identifier) @pytest.mark.parametrize( ['config_dict', 'valid_identifiers', 'invalid_identifiers'], [({ - 'default': False, - 'enabled': ['aiida.calculations:*thmetic.add'], - 'disabled': ['aiida.calculations:arith*ic.add'] + 'default_enabled': False, + 'enabled_for': ['aiida.calculations:*thmetic.add'], + 'disabled_for': ['aiida.calculations:arith*ic.add'] }, ['some_identifier', 'aiida.calculations:templatereplacer'], ['aiida.calculations:arithmetic.add']), ({ - 'default': False, - 'enabled': ['aiida.calculations:arithmetic.add'], - 'disabled': ['aiida.calculations:arithmetic.add'] + 'default_enabled': False, + 'enabled_for': ['aiida.calculations:arithmetic.add'], + 'disabled_for': ['aiida.calculations:arithmetic.add'] }, ['some_identifier', 'aiida.calculations:templatereplacer'], ['aiida.calculations:arithmetic.add'])] ) def test_ambiguous_configuration(configure_caching, config_dict, valid_identifiers, invalid_identifiers): @@ -211,7 +226,7 @@ def test_enable_caching_specific(configure_caching): Check that using enable_caching for a specific identifier works. """ identifier = 'some_ident' - with configure_caching({'default': False}): + with configure_caching({'default_enabled': False}): with enable_caching(identifier=identifier): assert get_use_cache(identifier=identifier) @@ -221,7 +236,7 @@ def test_enable_caching_global(configure_caching): Check that using enable_caching for a specific identifier works. """ specific_identifier = 'some_ident' - with configure_caching(config_dict={'default': False, 'disabled': [specific_identifier]}): + with configure_caching(config_dict={'default_enabled': False, 'disabled_for': [specific_identifier]}): with enable_caching(): assert get_use_cache(identifier='some_other_ident') assert get_use_cache(identifier=specific_identifier) @@ -232,7 +247,7 @@ def test_disable_caching_specific(configure_caching): Check that using disable_caching for a specific identifier works. """ identifier = 'some_ident' - with configure_caching({'default': True}): + with configure_caching({'default_enabled': True}): with disable_caching(identifier=identifier): assert not get_use_cache(identifier=identifier) @@ -242,7 +257,7 @@ def test_disable_caching_global(configure_caching): Check that using disable_caching for a specific identifier works. """ specific_identifier = 'some_ident' - with configure_caching(config_dict={'default': True, 'enabled': [specific_identifier]}): + with configure_caching(config_dict={'default_enabled': True, 'enabled_for': [specific_identifier]}): with disable_caching(): assert not get_use_cache(identifier='some_other_ident') assert not get_use_cache(identifier=specific_identifier) diff --git a/tests/orm/data/test_code.py b/tests/orm/data/test_code.py index 6bc5ab1cf3..515f7eaa8d 100644 --- a/tests/orm/data/test_code.py +++ b/tests/orm/data/test_code.py @@ -9,8 +9,11 @@ ########################################################################### """Tests for the `Code` class.""" # pylint: disable=redefined-outer-name +import warnings + import pytest +from aiida.common.warnings import AiidaDeprecationWarning from aiida.orm import Code @@ -32,24 +35,27 @@ def create_codes(tmpdir, aiida_localhost): @pytest.mark.usefixtures('clear_database_before_test') def test_get_full_text_info(create_codes): """Test the `Code.get_full_text_info` method.""" - for code in create_codes: - full_text_info = code.get_full_text_info() - - assert isinstance(full_text_info, list) - assert ['PK', code.pk] in full_text_info - assert ['UUID', code.uuid] in full_text_info - assert ['Label', code.label] in full_text_info - assert ['Description', code.description] in full_text_info - - if code.is_local(): - assert ['Type', 'local'] in full_text_info - assert ['Exec name', code.get_execname()] in full_text_info - assert ['List of files/folders:', ''] in full_text_info - else: - assert ['Type', 'remote'] in full_text_info - assert ['Remote machine', code.computer.label] in full_text_info - assert ['Remote absolute path', code.get_remote_exec_path()] in full_text_info - - for code in create_codes: - full_text_info = code.get_full_text_info(verbose=True) - assert ['Calculations', 0] in full_text_info + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=AiidaDeprecationWarning) + + for code in create_codes: + full_text_info = code.get_full_text_info() + + assert isinstance(full_text_info, list) + assert ['PK', code.pk] in full_text_info + assert ['UUID', code.uuid] in full_text_info + assert ['Label', code.label] in full_text_info + assert ['Description', code.description] in full_text_info + + if code.is_local(): + assert ['Type', 'local'] in full_text_info + assert ['Exec name', code.get_execname()] in full_text_info + assert ['List of files/folders:', ''] in full_text_info + else: + assert ['Type', 'remote'] in full_text_info + assert ['Remote machine', code.computer.label] in full_text_info + assert ['Remote absolute path', code.get_remote_exec_path()] in full_text_info + + for code in create_codes: + full_text_info = code.get_full_text_info(verbose=True) + assert ['Calculations', 0] in full_text_info diff --git a/tests/orm/data/test_folder.py b/tests/orm/data/test_folder.py index 9c4ec972eb..17aade1904 100644 --- a/tests/orm/data/test_folder.py +++ b/tests/orm/data/test_folder.py @@ -34,7 +34,7 @@ def setUpClass(cls, *args, **kwargs): handle.write(content) @classmethod - def tearDownClass(cls, *args, **kwargs): + def tearDownClass(cls): shutil.rmtree(cls.tempdir) def test_constructor_tree(self): diff --git a/tests/orm/data/test_remote_stash.py b/tests/orm/data/test_remote_stash.py new file mode 100644 index 0000000000..45318ca1b3 --- /dev/null +++ b/tests/orm/data/test_remote_stash.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for the :mod:`aiida.orm.nodes.data.remote.stash` module.""" +import pytest + +from aiida.common.datastructures import StashMode +from aiida.common.exceptions import StoringNotAllowed +from aiida.orm import RemoteStashData, RemoteStashFolderData + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_base_class(): + """Verify that base class cannot be stored.""" + node = RemoteStashData(stash_mode=StashMode.COPY) + + with pytest.raises(StoringNotAllowed): + node.store() + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize('store', (False, True)) +def test_constructor(store): + """Test the constructor and storing functionality.""" + stash_mode = StashMode.COPY + target_basepath = '/absolute/path' + source_list = ['relative/folder', 'relative/file'] + + data = RemoteStashFolderData(stash_mode, target_basepath, source_list) + + assert data.stash_mode == stash_mode + assert data.target_basepath == target_basepath + assert data.source_list == source_list + + if store: + data.store() + assert data.is_stored + assert data.stash_mode == stash_mode + assert data.target_basepath == target_basepath + assert data.source_list == source_list + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize( + 'argument, value', ( + ('stash_mode', 'copy'), + ('target_basepath', ['list']), + ('source_list', 'relative/path'), + ('source_list', ('/absolute/path')), + ) +) +def test_constructor_invalid(argument, value): + """Test the constructor for invalid argument types.""" + kwargs = { + 'stash_mode': StashMode.COPY, + 'target_basepath': '/absolute/path', + 'source_list': ('relative/folder', 'relative/file'), + } + + with pytest.raises(TypeError): + kwargs[argument] = value + RemoteStashFolderData(**kwargs) diff --git a/tests/orm/data/test_singlefile.py b/tests/orm/data/test_singlefile.py index 0815749f22..d4cfac3edc 100644 --- a/tests/orm/data/test_singlefile.py +++ b/tests/orm/data/test_singlefile.py @@ -10,151 +10,195 @@ """Tests for the `SinglefileData` class.""" import os -import tempfile import io +import tempfile +import pathlib -from aiida.backends.testbase import AiidaTestCase -from aiida.orm import SinglefileData, load_node - - -class TestSinglefileData(AiidaTestCase): - """Tests for the `SinglefileData` class.""" - - def test_reload_singlefile_data(self): - """Test writing and reloading a `SinglefileData` instance.""" - content_original = 'some text ABCDE' - - with tempfile.NamedTemporaryFile(mode='w+') as handle: - filepath = handle.name - basename = os.path.basename(filepath) - handle.write(content_original) - handle.flush() - node = SinglefileData(file=filepath) - - uuid = node.uuid - - with node.open() as handle: - content_written = handle.read() - - self.assertEqual(node.list_object_names(), [basename]) - self.assertEqual(content_written, content_original) - - node.store() - - with node.open() as handle: - content_stored = handle.read() - - self.assertEqual(content_stored, content_original) - self.assertEqual(node.list_object_names(), [basename]) - - node_loaded = load_node(uuid) - self.assertTrue(isinstance(node_loaded, SinglefileData)) - - with node.open() as handle: - content_loaded = handle.read() - - self.assertEqual(content_loaded, content_original) - self.assertEqual(node_loaded.list_object_names(), [basename]) - - with node_loaded.open() as handle: - self.assertEqual(handle.read(), content_original) - - def test_construct_from_filelike(self): - """Test constructing an instance from filelike instead of filepath.""" - content_original = 'some testing text\nwith a newline' - - with tempfile.NamedTemporaryFile(mode='wb+') as handle: - basename = os.path.basename(handle.name) - handle.write(content_original.encode('utf-8')) - handle.flush() - handle.seek(0) - node = SinglefileData(file=handle) - - with node.open() as handle: - content_stored = handle.read() - - self.assertEqual(content_stored, content_original) - self.assertEqual(node.list_object_names(), [basename]) - - node.store() - - with node.open() as handle: - content_stored = handle.read() - - self.assertEqual(content_stored, content_original) - self.assertEqual(node.list_object_names(), [basename]) - - def test_construct_from_string(self): - """Test constructing an instance from a string.""" - content_original = 'some testing text\nwith a newline' - - with io.BytesIO(content_original.encode('utf-8')) as handle: - node = SinglefileData(file=handle) - - with node.open() as handle: - content_stored = handle.read() - - self.assertEqual(content_stored, content_original) - self.assertEqual(node.list_object_names(), [SinglefileData.DEFAULT_FILENAME]) - - node.store() - - with node.open() as handle: - content_stored = handle.read() - - self.assertEqual(content_stored, content_original) - self.assertEqual(node.list_object_names(), [SinglefileData.DEFAULT_FILENAME]) - - def test_construct_with_filename(self): - """Test constructing an instance, providing a filename.""" - content_original = 'some testing text\nwith a newline' - filename = 'myfile.txt' +import pytest - # test creating from string - with io.BytesIO(content_original.encode('utf-8')) as handle: - node = SinglefileData(file=handle, filename=filename) +from aiida.orm import SinglefileData, load_node - with node.open() as handle: - content_stored = handle.read() - self.assertEqual(content_stored, content_original) - self.assertEqual(node.list_object_names(), [filename]) +@pytest.fixture +def check_singlefile_content(): + """Fixture to check the content of a SinglefileData. - # test creating from file - with tempfile.NamedTemporaryFile(mode='wb+') as handle: - handle.write(content_original.encode('utf-8')) - handle.flush() - handle.seek(0) - node = SinglefileData(file=handle, filename=filename) + Checks the content of a SinglefileData node against the given + reference content and filename. + """ - with node.open() as handle: - content_stored = handle.read() + def inner(node, content_reference, filename, open_mode='r'): + with node.open(mode=open_mode) as handle: + assert handle.read() == content_reference - self.assertEqual(content_stored, content_original) - self.assertEqual(node.list_object_names(), [filename]) + assert node.list_object_names() == [filename] - def test_binary_file(self): - """Test that the constructor accepts binary files.""" - byte_array = [120, 3, 255, 0, 100] - content_binary = bytearray(byte_array) + return inner - with tempfile.NamedTemporaryFile(mode='wb+') as handle: - basename = os.path.basename(handle.name) - handle.write(bytearray(content_binary)) - handle.flush() - handle.seek(0) - node = SinglefileData(handle.name) - with node.open(mode='rb') as handle: - content_stored = handle.read() +@pytest.fixture +def check_singlefile_content_with_store(check_singlefile_content): # pylint: disable=redefined-outer-name + """Fixture to check the content of a SinglefileData before and after .store(). - self.assertEqual(content_stored, content_binary) - self.assertEqual(node.list_object_names(), [basename]) + Checks the content of a SinglefileData node against the given reference + content and filename twice, before and after calling .store(). + """ + def inner(node, content_reference, filename, open_mode='r'): + check_singlefile_content( + node=node, + content_reference=content_reference, + filename=filename, + open_mode=open_mode, + ) node.store() - - with node.open(mode='rb') as handle: - content_stored = handle.read() - - self.assertEqual(content_stored, content_binary) - self.assertEqual(node.list_object_names(), [basename]) + check_singlefile_content( + node=node, + content_reference=content_reference, + filename=filename, + open_mode=open_mode, + ) + + return inner + + +def test_reload_singlefile_data( + clear_database_before_test, # pylint: disable=unused-argument + check_singlefile_content_with_store, # pylint: disable=redefined-outer-name + check_singlefile_content # pylint: disable=redefined-outer-name +): + """Test writing and reloading a `SinglefileData` instance.""" + content_original = 'some text ABCDE' + + with tempfile.NamedTemporaryFile(mode='w+') as handle: + filepath = handle.name + basename = os.path.basename(filepath) + handle.write(content_original) + handle.flush() + node = SinglefileData(file=filepath) + + check_singlefile_content_with_store( + node=node, + content_reference=content_original, + filename=basename, + ) + + node_loaded = load_node(node.uuid) + assert isinstance(node_loaded, SinglefileData) + + check_singlefile_content( + node=node, + content_reference=content_original, + filename=basename, + ) + check_singlefile_content( + node=node_loaded, + content_reference=content_original, + filename=basename, + ) + + +def test_construct_from_filelike( + clear_database_before_test, # pylint: disable=unused-argument + check_singlefile_content_with_store # pylint: disable=redefined-outer-name +): + """Test constructing an instance from filelike instead of filepath.""" + content_original = 'some testing text\nwith a newline' + + with tempfile.NamedTemporaryFile(mode='wb+') as handle: + basename = os.path.basename(handle.name) + handle.write(content_original.encode('utf-8')) + handle.flush() + handle.seek(0) + node = SinglefileData(file=handle) + + check_singlefile_content_with_store( + node=node, + content_reference=content_original, + filename=basename, + ) + + +def test_construct_from_string( + clear_database_before_test, # pylint: disable=unused-argument + check_singlefile_content_with_store # pylint: disable=redefined-outer-name +): + """Test constructing an instance from a string.""" + content_original = 'some testing text\nwith a newline' + + with io.BytesIO(content_original.encode('utf-8')) as handle: + node = SinglefileData(file=handle) + + check_singlefile_content_with_store( + node=node, + content_reference=content_original, + filename=SinglefileData.DEFAULT_FILENAME, + ) + + +def test_construct_with_path( + clear_database_before_test, # pylint: disable=unused-argument + check_singlefile_content_with_store # pylint: disable=redefined-outer-name +): + """Test constructing an instance from a pathlib.Path.""" + content_original = 'please report to the ministry of silly walks' + + with tempfile.NamedTemporaryFile(mode='w+') as handle: + filepath = pathlib.Path(handle.name).resolve() + filename = filepath.name + handle.write(content_original) + handle.flush() + node = SinglefileData(file=filepath) + + check_singlefile_content_with_store( + node=node, + content_reference=content_original, + filename=filename, + ) + + +def test_construct_with_filename( + clear_database_before_test, # pylint: disable=unused-argument + check_singlefile_content # pylint: disable=redefined-outer-name +): + """Test constructing an instance, providing a filename.""" + content_original = 'some testing text\nwith a newline' + filename = 'myfile.txt' + + # test creating from string + with io.BytesIO(content_original.encode('utf-8')) as handle: + node = SinglefileData(file=handle, filename=filename) + + check_singlefile_content(node=node, content_reference=content_original, filename=filename) + + # test creating from file + with tempfile.NamedTemporaryFile(mode='wb+') as handle: + handle.write(content_original.encode('utf-8')) + handle.flush() + handle.seek(0) + node = SinglefileData(file=handle, filename=filename) + + check_singlefile_content(node=node, content_reference=content_original, filename=filename) + + +def test_binary_file( + clear_database_before_test, # pylint: disable=unused-argument + check_singlefile_content_with_store # pylint: disable=redefined-outer-name +): + """Test that the constructor accepts binary files.""" + byte_array = [120, 3, 255, 0, 100] + content_binary = bytearray(byte_array) + + with tempfile.NamedTemporaryFile(mode='wb+') as handle: + basename = os.path.basename(handle.name) + handle.write(bytearray(content_binary)) + handle.flush() + handle.seek(0) + node = SinglefileData(handle.name) + + check_singlefile_content_with_store( + node=node, + content_reference=content_binary, + filename=basename, + open_mode='rb', + ) diff --git a/tests/orm/data/test_to_aiida_type.py b/tests/orm/data/test_to_aiida_type.py index ea60034bc0..dd16d7f2e8 100644 --- a/tests/orm/data/test_to_aiida_type.py +++ b/tests/orm/data/test_to_aiida_type.py @@ -10,7 +10,7 @@ """ This module contains tests for the to_aiida_type serializer """ -from aiida.orm.nodes.data.base import to_aiida_type +from aiida.orm import to_aiida_type from aiida.orm import Dict, Int, Float, Bool, Str from aiida.backends.testbase import AiidaTestCase diff --git a/tests/orm/data/test_upf.py b/tests/orm/data/test_upf.py index 70094f46ed..34d37b1dd0 100644 --- a/tests/orm/data/test_upf.py +++ b/tests/orm/data/test_upf.py @@ -28,7 +28,7 @@ def isnumeric(vector): """Check if elements of iterable `x` are numbers.""" # pylint: disable=invalid-name - numeric_types = (float, int, numpy.float, numpy.float64, numpy.int64, numpy.int) + numeric_types = (float, int, numpy.float64, numpy.int64) for xi in vector: if isinstance(xi, numeric_types): yield True @@ -327,7 +327,8 @@ def test_upf1_to_json_carbon(self): # pylint: disable=protected-access json_string, _ = self.pseudo_carbon._prepare_json() filepath_base = os.path.abspath(os.path.join(STATIC_DIR, 'pseudos')) - reference_dict = json.load(open(os.path.join(filepath_base, 'C.json'), 'r')) + with open(os.path.join(filepath_base, 'C.json'), 'r') as fhandle: + reference_dict = json.load(fhandle) pp_dict = json.loads(json_string.decode('utf-8')) # remove path information pp_dict['pseudo_potential']['header']['original_upf_file'] = '' @@ -339,7 +340,8 @@ def test_upf2_to_json_barium(self): # pylint: disable=protected-access json_string, _ = self.pseudo_barium._prepare_json() filepath_base = os.path.abspath(os.path.join(STATIC_DIR, 'pseudos')) - reference_dict = json.load(open(os.path.join(filepath_base, 'Ba.json'), 'r')) + with open(os.path.join(filepath_base, 'Ba.json'), 'r') as fhandle: + reference_dict = json.load(fhandle) pp_dict = json.loads(json_string.decode('utf-8')) # remove path information pp_dict['pseudo_potential']['header']['original_upf_file'] = '' diff --git a/tests/orm/node/test_calcjob.py b/tests/orm/node/test_calcjob.py index c457c98dc4..2ad0844043 100644 --- a/tests/orm/node/test_calcjob.py +++ b/tests/orm/node/test_calcjob.py @@ -8,8 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the `CalcJobNode` node sub class.""" - -import tempfile +import io from aiida.backends.testbase import AiidaTestCase from aiida.common import LinkType, CalcJobState @@ -40,30 +39,24 @@ def test_get_scheduler_stdout(self): option_value = '_scheduler-output.txt' stdout = 'some\nstandard output' - node = CalcJobNode(computer=self.computer,) - node.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) - retrieved = FolderData() - - # No scheduler output filename option so should return `None` - self.assertEqual(node.get_scheduler_stdout(), None) - - # No retrieved folder so should return `None` - node.set_option(option_key, option_value) - self.assertEqual(node.get_scheduler_stdout(), None) - - # Now it has retrieved folder, but file does not actually exist in it, should not except but return `None - node.store() - retrieved.store() - retrieved.add_incoming(node, link_type=LinkType.CREATE, link_label='retrieved') - self.assertEqual(node.get_scheduler_stdout(), None) - - # Add the file to the retrieved folder - with tempfile.NamedTemporaryFile(mode='w+') as handle: - handle.write(stdout) - handle.flush() - handle.seek(0) - retrieved.put_object_from_filelike(handle, option_value, force=True) - self.assertEqual(node.get_scheduler_stdout(), stdout) + # Note: cannot use pytest.mark.parametrize in unittest classes, so I just do a loop here + for with_file in [True, False]: + for with_option in [True, False]: + node = CalcJobNode(computer=self.computer,) + node.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) + retrieved = FolderData() + + if with_file: + retrieved.put_object_from_filelike(io.StringIO(stdout), option_value) + if with_option: + node.set_option(option_key, option_value) + node.store() + retrieved.store() + retrieved.add_incoming(node, link_type=LinkType.CREATE, link_label='retrieved') + + # It should return `None` if no scheduler output is there (file not there, or option not set), + # while it should return the content if both are set + self.assertEqual(node.get_scheduler_stdout(), stdout if with_file and with_option else None) def test_get_scheduler_stderr(self): """Verify that the repository sandbox folder is cleaned after the node instance is garbage collected.""" @@ -71,27 +64,21 @@ def test_get_scheduler_stderr(self): option_value = '_scheduler-error.txt' stderr = 'some\nstandard error' - node = CalcJobNode(computer=self.computer,) - node.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) - retrieved = FolderData() - - # No scheduler error filename option so should return `None` - self.assertEqual(node.get_scheduler_stderr(), None) - - # No retrieved folder so should return `None` - node.set_option(option_key, option_value) - self.assertEqual(node.get_scheduler_stderr(), None) - - # Now it has retrieved folder, but file does not actually exist in it, should not except but return `None - node.store() - retrieved.store() - retrieved.add_incoming(node, link_type=LinkType.CREATE, link_label='retrieved') - self.assertEqual(node.get_scheduler_stderr(), None) - - # Add the file to the retrieved folder - with tempfile.NamedTemporaryFile(mode='w+') as handle: - handle.write(stderr) - handle.flush() - handle.seek(0) - retrieved.put_object_from_filelike(handle, option_value, force=True) - self.assertEqual(node.get_scheduler_stderr(), stderr) + # Note: cannot use pytest.mark.parametrize in unittest classes, so I just do a loop here + for with_file in [True, False]: + for with_option in [True, False]: + node = CalcJobNode(computer=self.computer,) + node.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) + retrieved = FolderData() + + if with_file: + retrieved.put_object_from_filelike(io.StringIO(stderr), option_value) + if with_option: + node.set_option(option_key, option_value) + node.store() + retrieved.store() + retrieved.add_incoming(node, link_type=LinkType.CREATE, link_label='retrieved') + + # It should return `None` if no scheduler output is there (file not there, or option not set), + # while it should return the content if both are set + self.assertEqual(node.get_scheduler_stderr(), stderr if with_file and with_option else None) diff --git a/tests/orm/node/test_node.py b/tests/orm/node/test_node.py index 34cc60940b..afa318772b 100644 --- a/tests/orm/node/test_node.py +++ b/tests/orm/node/test_node.py @@ -7,9 +7,10 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-public-methods +# pylint: disable=too-many-public-methods,no-self-use """Tests for the Node ORM class.""" import io +import logging import os import tempfile @@ -824,6 +825,51 @@ def test_delete_collection_outgoing_link(self): Node.objects.delete(calculation.pk) +@pytest.mark.usefixtures('clear_database_before_test') +class TestNodeComments: + """Tests for creating comments on nodes.""" + + def test_add_comment(self): + """Test comment addition.""" + data = Data().store() + content = 'whatever Trevor' + comment = data.add_comment(content) + assert comment.content == content + assert comment.node.pk == data.pk + + def test_get_comment(self): + """Test retrieve single comment.""" + data = Data().store() + content = 'something something dark side' + add_comment = data.add_comment(content) + get_comment = data.get_comment(add_comment.pk) + assert get_comment.content == content + assert get_comment.pk == add_comment.pk + + def test_get_comments(self): + """Test retrieve multiple comments.""" + data = Data().store() + data.add_comment('one') + data.add_comment('two') + comments = data.get_comments() + assert {c.content for c in comments} == {'one', 'two'} + + def test_update_comment(self): + """Test update a comment.""" + data = Data().store() + comment = data.add_comment('original') + data.update_comment(comment.pk, 'new') + assert comment.content == 'new' + + def test_remove_comment(self): + """Test remove a comment.""" + data = Data().store() + comment = data.add_comment('original') + assert len(data.get_comments()) == 1 + data.remove_comment(comment.pk) + assert len(data.get_comments()) == 0 + + @pytest.mark.usefixtures('clear_database_before_test') def test_store_from_cache(): """Regression test for storing a Node with (nested) repository content with caching.""" @@ -845,6 +891,23 @@ def test_store_from_cache(): assert data.get_hash() == clone.get_hash() +@pytest.mark.usefixtures('clear_database_before_test') +def test_hashing_errors(aiida_caplog): + """Tests that ``get_hash`` fails in an expected manner.""" + node = Data().store() + node.__module__ = 'unknown' # this will inhibit package version determination + result = node.get_hash(ignore_errors=True) + assert result is None + assert aiida_caplog.record_tuples == [(node.logger.name, logging.ERROR, 'Node hashing failed')] + + with pytest.raises(exceptions.HashingError, match='package version could not be determined'): + result = node.get_hash(ignore_errors=False) + assert result is None + + +# Ignoring the resource errors as we are indeed testing the wrong way of using these (for backward-compatibility) +@pytest.mark.filterwarnings('ignore::ResourceWarning') +@pytest.mark.filterwarnings('ignore::aiida.common.warnings.AiidaDeprecationWarning') @pytest.mark.usefixtures('clear_database_before_test') def test_open_wrapper(): """Test the wrapper around the return value of ``Node.open``. @@ -860,3 +923,21 @@ def test_open_wrapper(): iter(node.open(filename)) node.open(filename).__next__() node.open(filename).__iter__() + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_uuid_equality_fallback(): + """Tests the fallback mechanism of checking equality by comparing uuids and hash.""" + node_0 = Data().store() + + nodepk = Data().store().pk + node_a = load_node(pk=nodepk) + node_b = load_node(pk=nodepk) + + assert node_a == node_b + assert node_a != node_0 + assert node_b != node_0 + + assert hash(node_a) == hash(node_b) + assert hash(node_a) != hash(node_0) + assert hash(node_b) != hash(node_0) diff --git a/tests/orm/test_groups.py b/tests/orm/test_groups.py index e2833967c9..342806331a 100644 --- a/tests/orm/test_groups.py +++ b/tests/orm/test_groups.py @@ -342,7 +342,13 @@ def test_loading_unregistered(): assert isinstance(loaded, orm.Group) + # Removing it as other methods might get a warning instead + group_pk = group.pk + del group + orm.Group.objects.delete(id=group_pk) + @staticmethod + @pytest.mark.filterwarnings('ignore::UserWarning') def test_explicit_type_string(): """Test that passing explicit `type_string` to `Group` constructor is still possible despite being deprecated. @@ -369,6 +375,11 @@ def test_explicit_type_string(): assert queried.pk == group.pk assert queried.type_string == group.type_string + # Removing it as other methods might get a warning instead + group_pk = group.pk + del group + orm.Group.objects.delete(id=group_pk) + @staticmethod def test_querying(): """Test querying for groups with and without subclassing.""" @@ -386,6 +397,11 @@ def test_querying(): assert orm.QueryBuilder().append(orm.Group).count() == 3 assert orm.QueryBuilder().append(orm.Group, filters={'type_string': 'custom.group'}).count() == 1 + # Removing it as other methods might get a warning instead + group_pk = group.pk + del group + orm.Group.objects.delete(id=group_pk) + @staticmethod def test_querying_node_subclasses(): """Test querying for groups with multiple types for nodes it contains.""" @@ -407,7 +423,7 @@ def test_querying_node_subclasses(): @staticmethod def test_query_with_group(): - """Docs.""" + """Test that querying a data node in a group works.""" group = orm.Group(label='group').store() data = orm.Data().store() diff --git a/tests/orm/test_querybuilder.py b/tests/orm/test_querybuilder.py index 8c320bdb6d..62bc8925c4 100644 --- a/tests/orm/test_querybuilder.py +++ b/tests/orm/test_querybuilder.py @@ -22,15 +22,16 @@ class TestQueryBuilder(AiidaTestCase): def setUp(self): super().setUp() - self.clean_db() - self.insert_data() + self.refurbish_db() def test_date_filters_support(self): """Verify that `datetime.date` is supported in filters.""" - from datetime import datetime, date, timedelta + from datetime import date, timedelta + from aiida.common import timezone - orm.Data(ctime=datetime.now() - timedelta(days=3)).store() - orm.Data(ctime=datetime.now() - timedelta(days=1)).store() + # Using timezone.now() rather than datetime.now() to get a timezone-aware object rather than a naive one + orm.Data(ctime=timezone.now() - timedelta(days=3)).store() + orm.Data(ctime=timezone.now() - timedelta(days=1)).store() builder = orm.QueryBuilder().append(orm.Node, filters={'ctime': {'>': date.today() - timedelta(days=1)}}) self.assertEqual(builder.count(), 1) @@ -133,6 +134,7 @@ def test_get_group_type_filter(self): # Tracked in issue #4281 @pytest.mark.flaky(reruns=2) + @pytest.mark.requires_rmq def test_process_query(self): """ Test querying for a process class. @@ -733,6 +735,8 @@ def test_queryhelp(self): qb = orm.QueryBuilder().append((orm.Group,), filters={'label': 'helloworld'}) self.assertEqual(qb.count(), 1) + # populate computer + self.computer # pylint:disable=pointless-statement qb = orm.QueryBuilder().append(orm.Computer,) self.assertEqual(qb.count(), 1) diff --git a/tests/parsers/test_parser.py b/tests/parsers/test_parser.py index 934a7c8230..a9b625d5f6 100644 --- a/tests/parsers/test_parser.py +++ b/tests/parsers/test_parser.py @@ -11,6 +11,8 @@ import io +import pytest + from aiida import orm from aiida.backends.testbase import AiidaTestCase from aiida.common import LinkType @@ -86,6 +88,7 @@ def test_parser_get_outputs_for_parsing(self): self.assertIn('output', outputs_for_parsing) self.assertEqual(outputs_for_parsing['output'].uuid, output.uuid) + @pytest.mark.requires_rmq def test_parse_from_node(self): """Test that the `parse_from_node` returns a tuple of the parsed output nodes and a calculation node. diff --git a/tests/restapi/test_config.py b/tests/restapi/test_config.py new file mode 100644 index 0000000000..9742640730 --- /dev/null +++ b/tests/restapi/test_config.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for the configuration options from `aiida.restapi.common.config` when running the REST API.""" +# pylint: disable=redefined-outer-name +import pytest + + +@pytest.fixture +def create_app(): + """Set up Flask App""" + from aiida.restapi.run_api import configure_api + + def _create_app(**kwargs): + catch_internal_server = kwargs.pop('catch_internal_server', True) + api = configure_api(catch_internal_server=catch_internal_server, **kwargs) + api.app.config['TESTING'] = True + return api.app + + return _create_app + + +def test_posting(create_app): + """Test CLI_DEFAULTS['POSTING'] configuration""" + from aiida.restapi.common.config import API_CONFIG + + app = create_app(posting=False) + + url = f'{API_CONFIG["PREFIX"]}/querybuilder' + for method in ('get', 'post'): + with app.test_client() as client: + response = getattr(client, method)(url) + + assert response.status_code == 404 + assert response.status == '404 NOT FOUND' + + del app + app = create_app(posting=True) + + url = f'{API_CONFIG["PREFIX"]}/querybuilder' + for method in ('get', 'post'): + with app.test_client() as client: + response = getattr(client, method)(url) + + assert response.status_code != 404 + assert response.status != '404 NOT FOUND' diff --git a/tests/restapi/test_routes.py b/tests/restapi/test_routes.py index ca6eada052..f2a3c249ff 100644 --- a/tests/restapi/test_routes.py +++ b/tests/restapi/test_routes.py @@ -30,7 +30,7 @@ class RESTApiTestCase(AiidaTestCase): _LIMIT_DEFAULT = 400 @classmethod - def setUpClass(cls, *args, **kwargs): # pylint: disable=too-many-locals, too-many-statements + def setUpClass(cls): # pylint: disable=too-many-locals, too-many-statements """ Add objects to the database for different requests/filters/orderings etc. """ @@ -82,7 +82,7 @@ def setUpClass(cls, *args, **kwargs): # pylint: disable=too-many-locals, too-ma handle.write(aiida_in) handle.flush() handle.seek(0) - calc.put_object_from_filelike(handle, 'calcjob_inputs/aiida.in', force=True) + calc.put_object_from_filelike(handle, 'calcjob_inputs/aiida.in') calc.store() # create log message for calcjob @@ -110,7 +110,7 @@ def setUpClass(cls, *args, **kwargs): # pylint: disable=too-many-locals, too-ma handle.write(aiida_out) handle.flush() handle.seek(0) - retrieved_outputs.put_object_from_filelike(handle, 'calcjob_outputs/aiida.out', force=True) + retrieved_outputs.put_object_from_filelike(handle, 'calcjob_outputs/aiida.out') retrieved_outputs.store() retrieved_outputs.add_incoming(calc, link_type=LinkType.CREATE, link_label='retrieved') @@ -1147,3 +1147,225 @@ def test_download_formats(self): for key in ['cif', 'xsf', 'xyz']: self.assertIn(key, response['data']['data.structure.StructureData.|']) self.assertIn('cif', response['data']['data.cif.CifData.|']) + + ############### querybuilder ############### + def test_querybuilder(self): + """Test POSTing a queryhelp dictionary as JSON to /querybuilder + + This also checks that `full_type` is _not_ included in the result no matter the entity. + """ + queryhelp = orm.QueryBuilder().append( + orm.CalculationNode, + tag='calc', + project=['id', 'uuid', 'user_id'], + ).order_by({ + 'calc': [{ + 'id': { + 'order': 'desc' + } + }] + }).queryhelp + + expected_node_uuids = [] + # dummy data already ordered 'desc' by 'id' + for calc in self.get_dummy_data()['calculations']: + if calc['node_type'].startswith('process.calculation.'): + expected_node_uuids.append(calc['uuid']) + + with self.app.test_client() as client: + response = client.post(f'{self.get_url_prefix()}/querybuilder', json=queryhelp).json + + self.assertEqual('POST', response.get('method', '')) + self.assertEqual('QueryBuilder', response.get('resource_type', '')) + + self.assertEqual( + len(expected_node_uuids), + len(response.get('data', {}).get('calc', [])), + msg=json.dumps(response, indent=2), + ) + self.assertListEqual( + expected_node_uuids, + [_.get('uuid', '') for _ in response.get('data', {}).get('calc', [])], + ) + for entities in response.get('data', {}).values(): + for entity in entities: + # All are Nodes, but neither `node_type` or `process_type` are requested, + # hence `full_type` should not be present. + self.assertFalse('full_type' in entity) + + def test_get_querybuilder(self): + """Test GETting the /querybuilder endpoint + + This should return with 405 Method Not Allowed. + Otherwise, a "conventional" JSON response should be returned with a helpful message. + """ + with self.app.test_client() as client: + response_value = client.get(f'{self.get_url_prefix()}/querybuilder') + response = response_value.json + + self.assertEqual(response_value.status_code, 405) + self.assertEqual(response_value.status, '405 METHOD NOT ALLOWED') + + self.assertEqual('GET', response.get('method', '')) + self.assertEqual('QueryBuilder', response.get('resource_type', '')) + + message = ( + 'Method Not Allowed. Use HTTP POST requests to use the AiiDA QueryBuilder. ' + 'POST JSON data, which MUST be a valid QueryBuilder.queryhelp dictionary as a JSON object. ' + 'See the documentation at https://aiida.readthedocs.io/projects/aiida-core/en/latest/topics/' + 'database.html?highlight=QueryBuilder#the-queryhelp for more information.' + ) + self.assertEqual(message, response.get('data', {}).get('message', '')) + + def test_querybuilder_user(self): + """Retrieve a User through the use of the /querybuilder endpoint + + This also checks that `full_type` is _not_ included in the result no matter the entity. + """ + queryhelp = orm.QueryBuilder().append( + orm.CalculationNode, + tag='calc', + project=['id', 'user_id'], + ).append( + orm.User, + tag='users', + with_node='calc', + project=['id', 'email'], + ).order_by({ + 'calc': [{ + 'id': { + 'order': 'desc' + } + }] + }).queryhelp + + expected_user_ids = [] + for calc in self.get_dummy_data()['calculations']: + if calc['node_type'].startswith('process.calculation.'): + expected_user_ids.append(calc['user_id']) + + with self.app.test_client() as client: + response = client.post(f'{self.get_url_prefix()}/querybuilder', json=queryhelp).json + + self.assertEqual('POST', response.get('method', '')) + self.assertEqual('QueryBuilder', response.get('resource_type', '')) + + self.assertEqual( + len(expected_user_ids), + len(response.get('data', {}).get('users', [])), + msg=json.dumps(response, indent=2), + ) + self.assertListEqual( + expected_user_ids, + [_.get('id', '') for _ in response.get('data', {}).get('users', [])], + ) + self.assertListEqual( + expected_user_ids, + [_.get('user_id', '') for _ in response.get('data', {}).get('calc', [])], + ) + for entities in response.get('data', {}).values(): + for entity in entities: + # User is not a Node (no full_type) + self.assertFalse('full_type' in entity) + + def test_querybuilder_project_explicit(self): + """Expliticly project everything from the resulting entities + + Here "project" will use the wildcard (*). + This should result in both CalculationNodes and Data to be returned. + """ + queryhelp = orm.QueryBuilder().append( + orm.CalculationNode, + tag='calc', + project='*', + ).append( + orm.Data, + tag='data', + with_incoming='calc', + project='*', + ).order_by({'data': [{ + 'id': { + 'order': 'desc' + } + }]}) + + expected_calc_uuids = [] + expected_data_uuids = [] + for calc, data in queryhelp.all(): + expected_calc_uuids.append(calc.uuid) + expected_data_uuids.append(data.uuid) + + queryhelp = queryhelp.queryhelp + + with self.app.test_client() as client: + response = client.post(f'{self.get_url_prefix()}/querybuilder', json=queryhelp).json + + self.assertEqual('POST', response.get('method', '')) + self.assertEqual('QueryBuilder', response.get('resource_type', '')) + + self.assertEqual( + len(expected_calc_uuids), + len(response.get('data', {}).get('calc', [])), + msg=json.dumps(response, indent=2), + ) + self.assertEqual( + len(expected_data_uuids), + len(response.get('data', {}).get('data', [])), + msg=json.dumps(response, indent=2), + ) + self.assertListEqual( + expected_calc_uuids, + [_.get('uuid', '') for _ in response.get('data', {}).get('calc', [])], + ) + self.assertListEqual( + expected_data_uuids, + [_.get('uuid', '') for _ in response.get('data', {}).get('data', [])], + ) + for entities in response.get('data', {}).values(): + for entity in entities: + # All are Nodes, and all properties are projected, full_type should be present + self.assertTrue('full_type' in entity) + self.assertTrue('attributes' in entity) + + def test_querybuilder_project_implicit(self): + """Implicitly project everything from the resulting entities + + Here "project" will be an empty list, resulting in only the Data node being returned. + """ + queryhelp = orm.QueryBuilder().append(orm.CalculationNode, tag='calc').append( + orm.Data, + tag='data', + with_incoming='calc', + ).order_by({'data': [{ + 'id': { + 'order': 'desc' + } + }]}) + + expected_data_uuids = [] + for data in queryhelp.all(flat=True): + expected_data_uuids.append(data.uuid) + + queryhelp = queryhelp.queryhelp + + with self.app.test_client() as client: + response = client.post(f'{self.get_url_prefix()}/querybuilder', json=queryhelp).json + + self.assertEqual('POST', response.get('method', '')) + self.assertEqual('QueryBuilder', response.get('resource_type', '')) + + self.assertListEqual(['data'], list(response.get('data', {}).keys())) + self.assertEqual( + len(expected_data_uuids), + len(response.get('data', {}).get('data', [])), + msg=json.dumps(response, indent=2), + ) + self.assertListEqual( + expected_data_uuids, + [_.get('uuid', '') for _ in response.get('data', {}).get('data', [])], + ) + for entities in response.get('data', {}).values(): + for entity in entities: + # All are Nodes, and all properties are projected, full_type should be present + self.assertTrue('full_type' in entity) + self.assertTrue('attributes' in entity) diff --git a/tests/restapi/test_threaded_restapi.py b/tests/restapi/test_threaded_restapi.py index 02af3839ef..ff77342367 100644 --- a/tests/restapi/test_threaded_restapi.py +++ b/tests/restapi/test_threaded_restapi.py @@ -62,8 +62,7 @@ def test_run_threaded_server(restapi_server, server_url, aiida_localhost): pytest.fail('Thread did not close/join within 1 min after REST API server was called to shutdown') -# Tracked in issue #4281 -@pytest.mark.flaky(reruns=2) +@pytest.mark.skip('Is often failing on Python 3.8 and 3.9: see https://github.com/aiidateam/aiida-core/issues/4281') @pytest.mark.usefixtures('clear_database_before_test', 'restrict_sqlalchemy_queuepool') def test_run_without_close_session(restapi_server, server_url, aiida_localhost, capfd): """Run AiiDA REST API threaded in a separate thread and perform many sequential requests""" diff --git a/tests/sphinxext/reference_results/workchain.xml b/tests/sphinxext/reference_results/workchain.xml index e7fae40799..8d487bb2f8 100644 --- a/tests/sphinxext/reference_results/workchain.xml +++ b/tests/sphinxext/reference_results/workchain.xml @@ -1,7 +1,7 @@ - +
sphinx-aiida demo This is a demo documentation to show off the features of the sphinx-aiida extension. @@ -74,16 +74,32 @@ finalize This module defines an example workchain for the aiida-workchain documentation directive. - class demo_workchain.DemoWorkChain*args**kwargs + class demo_workchain.DemoWorkChain*args: Any**kwargs: Any A demo workchain to show how the workchain auto-documentation works. + + + classmethod definespec + + Define the specification of the process, including its inputs, outputs and known exit codes. + A metadata input namespace is defined, with optional ports that are not stored in the database. + + - class demo_workchain.EmptyOutlineWorkChain*args**kwargs + class demo_workchain.EmptyOutlineWorkChain*args: Any**kwargs: Any Here we check that the directive works even if the outline is empty. + + + classmethod definespec + + Define the specification of the process, including its inputs, outputs and known exit codes. + A metadata input namespace is defined, with optional ports that are not stored in the database. + + diff --git a/tests/static/calcjob/arithmetic.add.aiida b/tests/static/calcjob/arithmetic.add.aiida index 21319e1b66..7d988a6839 100644 Binary files a/tests/static/calcjob/arithmetic.add.aiida and b/tests/static/calcjob/arithmetic.add.aiida differ diff --git a/tests/test_calculation_node.py b/tests/test_calculation_node.py index 7fd83e89e6..0e80d1eef5 100644 --- a/tests/test_calculation_node.py +++ b/tests/test_calculation_node.py @@ -11,6 +11,7 @@ from aiida.backends.testbase import AiidaTestCase from aiida.common.exceptions import ModificationNotAllowed +from aiida.common.datastructures import CalcJobState from aiida.orm import CalculationNode, CalcJobNode @@ -117,7 +118,9 @@ def test_process_node_updatable_attribute(self): node.delete_attribute(CalculationNode.PROCESS_STATE_KEY) def test_get_description(self): - self.assertEqual(self.calcjob.get_description(), self.calcjob.get_state()) + self.assertEqual(self.calcjob.get_description(), '') + self.calcjob.set_state(CalcJobState.PARSING) + self.assertEqual(self.calcjob.get_description(), CalcJobState.PARSING.value) def test_get_authinfo(self): """Test that we can get the AuthInfo object from the calculation instance.""" diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index 75f73c0d21..cc1b9933dc 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -13,11 +13,16 @@ import tempfile import unittest +import pytest + from aiida.backends.testbase import AiidaTestCase from aiida.common.exceptions import ModificationNotAllowed from aiida.common.utils import Capturing from aiida.orm import load_node from aiida.orm import CifData, StructureData, KpointsData, BandsData, ArrayData, TrajectoryData, Dict +from aiida.orm.nodes.data.structure import ( + ase_refine_cell, get_formula, get_pymatgen_version, has_ase, has_pymatgen, has_spglib +) from aiida.orm.nodes.data.structure import Kind, Site @@ -49,11 +54,15 @@ def simplify(string): return '\n'.join(s.strip() for s in string.split()) +@pytest.mark.skipif(not has_pymatgen(), reason='pymatgen not installed') +def test_get_pymatgen_version(): + assert isinstance(get_pymatgen_version(), str) + + class TestCifData(AiidaTestCase): """Tests for CifData class.""" from distutils.version import StrictVersion from aiida.orm.nodes.data.cif import has_pycifrw - from aiida.orm.nodes.data.structure import has_ase, has_pymatgen, has_spglib, get_pymatgen_version valid_sample_cif_str = ''' data_test @@ -188,6 +197,7 @@ def test_change_cifdata_file(self): @unittest.skipIf(not has_ase(), 'Unable to import ase') @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.requires_rmq def test_get_structure(self): """Test `CifData.get_structure`.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -225,6 +235,7 @@ def test_get_structure(self): @unittest.skipIf(not has_ase(), 'Unable to import ase') @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.requires_rmq def test_ase_primitive_and_conventional_cells_ase(self): """Checking the number of atoms per primitive/conventional cell returned by ASE ase.io.read() method. Test input is @@ -270,6 +281,7 @@ def test_ase_primitive_and_conventional_cells_ase(self): @unittest.skipIf(not has_ase(), 'Unable to import ase') @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @pytest.mark.requires_rmq def test_ase_primitive_and_conventional_cells_pymatgen(self): """Checking the number of atoms per primitive/conventional cell returned by ASE ase.io.read() method. Test input is @@ -530,6 +542,7 @@ def test_attached_hydrogens(self): @unittest.skipIf(not has_ase(), 'Unable to import ase') @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') @unittest.skipIf(not has_spglib(), 'Unable to import spglib') + @pytest.mark.requires_rmq def test_refine(self): """ Test case for refinement (space group determination) for a @@ -1032,7 +1045,6 @@ class TestStructureData(AiidaTestCase): Tests the creation of StructureData objects (cell and pbc). """ # pylint: disable=too-many-public-methods - from aiida.orm.nodes.data.structure import has_ase, has_spglib from aiida.orm.nodes.data.cif import has_pycifrw def test_cell_ok_and_atoms(self): @@ -1506,7 +1518,6 @@ def test_kind_8(self): Test the ase_refine_cell() function """ # pylint: disable=too-many-statements - from aiida.orm.nodes.data.structure import ase_refine_cell import ase import math import numpy @@ -1587,8 +1598,6 @@ def test_get_formula(self): """ Tests the generation of formula """ - from aiida.orm.nodes.data.structure import get_formula - self.assertEqual(get_formula(['Ba', 'Ti'] + ['O'] * 3), 'BaO3Ti') self.assertEqual(get_formula(['Ba', 'Ti', 'C'] + ['O'] * 3, separator=' '), 'C Ba O3 Ti') self.assertEqual(get_formula(['H'] * 6 + ['C'] * 6), 'C6H6') @@ -1616,8 +1625,6 @@ def test_get_formula_unknown(self): """ Tests the generation of formula, including unknown entry. """ - from aiida.orm.nodes.data.structure import get_formula - self.assertEqual(get_formula(['Ba', 'Ti'] + ['X'] * 3), 'BaTiX3') self.assertEqual(get_formula(['Ba', 'Ti', 'C'] + ['X'] * 3, separator=' '), 'C Ba Ti X3') self.assertEqual(get_formula(['X'] * 6 + ['C'] * 6), 'C6X6') @@ -1643,6 +1650,7 @@ def test_get_formula_unknown(self): @unittest.skipIf(not has_ase(), 'Unable to import ase') @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.requires_rmq def test_get_cif(self): """ Tests the conversion to CifData @@ -1909,7 +1917,6 @@ def test_clone(self): class TestStructureDataFromAse(AiidaTestCase): """Tests the creation of Sites from/to a ASE object.""" - from aiida.orm.nodes.data.structure import has_ase @unittest.skipIf(not has_ase(), 'Unable to import ase') def test_ase(self): @@ -2096,7 +2103,6 @@ class TestStructureDataFromPymatgen(AiidaTestCase): Tests the creation of StructureData from a pymatgen Structure and Molecule objects. """ - from aiida.orm.nodes.data.structure import has_pymatgen, get_pymatgen_version @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') def test_1(self): @@ -2220,15 +2226,17 @@ def test_partial_occ_and_spin(self): Tests pymatgen -> StructureData, with partial occupancies and spins. This should raise a ValueError. """ - import pymatgen - - Fe_spin_up = pymatgen.Specie('Fe', 0, properties={'spin': 1}) - Mn_spin_up = pymatgen.Specie('Mn', 0, properties={'spin': 1}) - Fe_spin_down = pymatgen.Specie('Fe', 0, properties={'spin': -1}) - Mn_spin_down = pymatgen.Specie('Mn', 0, properties={'spin': -1}) - FeMn1 = pymatgen.Composition({Fe_spin_up: 0.5, Mn_spin_up: 0.5}) - FeMn2 = pymatgen.Composition({Fe_spin_down: 0.5, Mn_spin_down: 0.5}) - a = pymatgen.Structure( + from pymatgen.core.periodic_table import Specie + from pymatgen.core.composition import Composition + from pymatgen.core.structure import Structure + + Fe_spin_up = Specie('Fe', 0, properties={'spin': 1}) + Mn_spin_up = Specie('Mn', 0, properties={'spin': 1}) + Fe_spin_down = Specie('Fe', 0, properties={'spin': -1}) + Mn_spin_down = Specie('Mn', 0, properties={'spin': -1}) + FeMn1 = Composition({Fe_spin_up: 0.5, Mn_spin_up: 0.5}) + FeMn2 = Composition({Fe_spin_down: 0.5, Mn_spin_down: 0.5}) + a = Structure( lattice=[[4, 0, 0], [0, 4, 0], [0, 0, 4]], species=[FeMn1, FeMn2], coords=[[0, 0, 0], [0.5, 0.5, 0.5]] ) @@ -2236,9 +2244,9 @@ def test_partial_occ_and_spin(self): StructureData(pymatgen=a) # same, with vacancies - Fe1 = pymatgen.Composition({Fe_spin_up: 0.5}) - Fe2 = pymatgen.Composition({Fe_spin_down: 0.5}) - a = pymatgen.Structure( + Fe1 = Composition({Fe_spin_up: 0.5}) + Fe2 = Composition({Fe_spin_down: 0.5}) + a = Structure( lattice=[[4, 0, 0], [0, 4, 0], [0, 0, 4]], species=[Fe1, Fe2], coords=[[0, 0, 0], [0.5, 0.5, 0.5]] ) @@ -2250,12 +2258,13 @@ def test_partial_occ_and_spin(self): def test_multiple_kinds_partial_occupancies(): """Tests that a structure with multiple sites with the same element but different partial occupancies, get their own unique kind name.""" - import pymatgen + from pymatgen.core.composition import Composition + from pymatgen.core.structure import Structure - Mg1 = pymatgen.Composition({'Mg': 0.50}) - Mg2 = pymatgen.Composition({'Mg': 0.25}) + Mg1 = Composition({'Mg': 0.50}) + Mg2 = Composition({'Mg': 0.25}) - a = pymatgen.Structure( + a = Structure( lattice=[[4, 0, 0], [0, 4, 0], [0, 0, 4]], species=[Mg1, Mg2], coords=[[0, 0, 0], [0.5, 0.5, 0.5]] ) @@ -2268,12 +2277,13 @@ def test_multiple_kinds_alloy(): Tests that a structure with multiple sites with the same alloy symbols but different weights, get their own unique kind name """ - import pymatgen + from pymatgen.core.composition import Composition + from pymatgen.core.structure import Structure - alloy_one = pymatgen.Composition({'Mg': 0.25, 'Al': 0.75}) - alloy_two = pymatgen.Composition({'Mg': 0.45, 'Al': 0.55}) + alloy_one = Composition({'Mg': 0.25, 'Al': 0.75}) + alloy_two = Composition({'Mg': 0.45, 'Al': 0.55}) - a = pymatgen.Structure( + a = Structure( lattice=[[4, 0, 0], [0, 4, 0], [0, 0, 4]], species=[alloy_one, alloy_two], coords=[[0, 0, 0], [0.5, 0.5, 0.5]] @@ -2285,7 +2295,6 @@ def test_multiple_kinds_alloy(): class TestPymatgenFromStructureData(AiidaTestCase): """Tests the creation of pymatgen Structure and Molecule objects from StructureData.""" - from aiida.orm.nodes.data.structure import has_ase, has_pymatgen, get_pymatgen_version @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') def test_1(self): @@ -2823,6 +2832,7 @@ def test_creation(self): # Step 66 does not exist n.get_index_from_stepid(66) + @pytest.mark.requires_rmq def test_conversion_to_structure(self): """ Check the methods to export a given time step to a StructureData node. diff --git a/tests/test_nodes.py b/tests/test_nodes.py index f032336afb..530ecad83b 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -18,8 +18,7 @@ from aiida.backends.testbase import AiidaTestCase from aiida.common.exceptions import InvalidOperation, ModificationNotAllowed, StoringNotAllowed, ValidationError from aiida.common.links import LinkType -from aiida.common.utils import Capturing -from aiida.manage.database.delete.nodes import delete_nodes +from aiida.tools import delete_nodes, delete_group_nodes class TestNodeIsStorable(AiidaTestCase): @@ -1544,8 +1543,34 @@ def _check_existence(self, uuids_check_existence, uuids_check_deleted): def test_deletion_non_existing_pk(): """Verify that passing a non-existing pk should not raise.""" non_existing_pk = -1 - with Capturing(): - delete_nodes([non_existing_pk], force=True) + delete_nodes([non_existing_pk], dry_run=False) + + def test_deletion_dry_run_true(self): + """Verify that a dry run should not delete the node.""" + node = orm.Data().store() + node_pk = node.pk + deleted_pks, was_deleted = delete_nodes([node_pk], dry_run=True) + self.assertTrue(not was_deleted) + self.assertSetEqual(deleted_pks, {node_pk}) + orm.load_node(node_pk) + + def test_deletion_dry_run_callback(self): + """Verify that a dry_run callback works.""" + from aiida.common.exceptions import NotExistent + node = orm.Data().store() + node_pk = node.pk + callback_pks = [] + + def _callback(pks): + callback_pks.extend(pks) + return False + + deleted_pks, was_deleted = delete_nodes([node_pk], dry_run=_callback) + self.assertTrue(was_deleted) + self.assertSetEqual(deleted_pks, {node_pk}) + with self.assertRaises(NotExistent): + orm.load_node(node_pk) + self.assertListEqual(callback_pks, [node_pk]) # TEST BASIC CASES @@ -1642,15 +1667,13 @@ def test_delete_cases(self): di, dm, do, c1, c2, w1, w2 = self._create_simple_graph() uuids_check_existence = [n.uuid for n in [di]] uuids_check_deleted = [n.uuid for n in [dm, do, c1, c2, w1, w2]] - with Capturing(): - delete_nodes([w1.pk], force=True) + delete_nodes([w1.pk], dry_run=False) self._check_existence(uuids_check_existence, uuids_check_deleted) di, dm, do, c1, c2, w1, w2 = self._create_simple_graph() uuids_check_existence = [n.uuid for n in [di]] uuids_check_deleted = [n.uuid for n in [dm, do, c1, c2, w1, w2]] - with Capturing(): - delete_nodes([w2.pk], force=True) + delete_nodes([w2.pk], dry_run=False) self._check_existence(uuids_check_existence, uuids_check_deleted) # By default, targetting a calculation will have the same effect because @@ -1659,15 +1682,13 @@ def test_delete_cases(self): di, dm, do, c1, c2, w1, w2 = self._create_simple_graph() uuids_check_existence = [n.uuid for n in [di]] uuids_check_deleted = [n.uuid for n in [dm, do, c1, c2, w1, w2]] - with Capturing(): - delete_nodes([c1.pk], force=True) + delete_nodes([c1.pk], dry_run=False) self._check_existence(uuids_check_existence, uuids_check_deleted) di, dm, do, c1, c2, w1, w2 = self._create_simple_graph() uuids_check_existence = [n.uuid for n in [di]] uuids_check_deleted = [n.uuid for n in [dm, do, c1, c2, w1, w2]] - with Capturing(): - delete_nodes([c2.pk], force=True) + delete_nodes([c2.pk], dry_run=False) self._check_existence(uuids_check_existence, uuids_check_deleted) # By default, targetting a data node will also have the same effect because @@ -1676,22 +1697,19 @@ def test_delete_cases(self): di, dm, do, c1, c2, w1, w2 = self._create_simple_graph() uuids_check_existence = [n.uuid for n in []] uuids_check_deleted = [n.uuid for n in [di, dm, do, c1, c2, w1, w2]] - with Capturing(): - delete_nodes([di.pk], force=True) + delete_nodes([di.pk], dry_run=False) self._check_existence(uuids_check_existence, uuids_check_deleted) di, dm, do, c1, c2, w1, w2 = self._create_simple_graph() uuids_check_existence = [n.uuid for n in [di]] uuids_check_deleted = [n.uuid for n in [dm, do, c1, c2, w1, w2]] - with Capturing(): - delete_nodes([dm.pk], force=True) + delete_nodes([dm.pk], dry_run=False) self._check_existence(uuids_check_existence, uuids_check_deleted) di, dm, do, c1, c2, w1, w2 = self._create_simple_graph() uuids_check_existence = [n.uuid for n in [di]] uuids_check_deleted = [n.uuid for n in [dm, do, c1, c2, w1, w2]] - with Capturing(): - delete_nodes([do.pk], force=True) + delete_nodes([do.pk], dry_run=False) self._check_existence(uuids_check_existence, uuids_check_deleted) # Data deletion within the highest level workflow can be prevented by @@ -1700,15 +1718,13 @@ def test_delete_cases(self): di, dm, do, c1, c2, w1, w2 = self._create_simple_graph() uuids_check_existence = [n.uuid for n in [di, dm, do]] uuids_check_deleted = [n.uuid for n in [c1, c2, w1, w2]] - with Capturing(): - delete_nodes([w2.pk], force=True, create_forward=False) + delete_nodes([w2.pk], dry_run=False, create_forward=False) self._check_existence(uuids_check_existence, uuids_check_deleted) di, dm, do, c1, c2, w1, w2 = self._create_simple_graph() uuids_check_existence = [n.uuid for n in [dm, do]] uuids_check_deleted = [n.uuid for n in [di, c1, c2, w1, w2]] - with Capturing(): - delete_nodes([di.pk], force=True, create_forward=False) + delete_nodes([di.pk], dry_run=False, create_forward=False) self._check_existence(uuids_check_existence, uuids_check_deleted) # On the other hand, the whole data provenance can be protected by @@ -1719,15 +1735,13 @@ def test_delete_cases(self): di, dm, do, c1, c2, w1, w2 = self._create_simple_graph() uuids_check_existence = [n.uuid for n in [di, dm, do, c1, c2]] uuids_check_deleted = [n.uuid for n in [w1, w2]] - with Capturing(): - delete_nodes([w2.pk], force=True, call_calc_forward=False) + delete_nodes([w2.pk], dry_run=False, call_calc_forward=False) self._check_existence(uuids_check_existence, uuids_check_deleted) di, dm, do, c1, c2, w1, w2 = self._create_simple_graph() uuids_check_existence = [n.uuid for n in [di, dm, c1]] uuids_check_deleted = [n.uuid for n in [do, c2, w1, w2]] - with Capturing(): - delete_nodes([c2.pk], force=True, call_calc_forward=False) + delete_nodes([c2.pk], dry_run=False, call_calc_forward=False) self._check_existence(uuids_check_existence, uuids_check_deleted) # Another posibility which also exists, though may have more limited @@ -1739,22 +1753,19 @@ def test_delete_cases(self): di, dm, do, c1, c2, w1, w2 = self._create_simple_graph() uuids_check_existence = [n.uuid for n in [di, dm, c1, w1]] uuids_check_deleted = [n.uuid for n in [do, c2, w2]] - with Capturing(): - delete_nodes([w2.pk], force=True, call_work_forward=False) + delete_nodes([w2.pk], dry_run=False, call_work_forward=False) self._check_existence(uuids_check_existence, uuids_check_deleted) di, dm, do, c1, c2, w1, w2 = self._create_simple_graph() uuids_check_existence = [n.uuid for n in [di, dm, do, c1, w1]] uuids_check_deleted = [n.uuid for n in [c2, w2]] - with Capturing(): - delete_nodes([w2.pk], force=True, call_work_forward=False, create_forward=False) + delete_nodes([w2.pk], dry_run=False, call_work_forward=False, create_forward=False) self._check_existence(uuids_check_existence, uuids_check_deleted) di, dm, do, c1, c2, w1, w2 = self._create_simple_graph() uuids_check_existence = [n.uuid for n in [di, dm, do, c1, c2, w1]] uuids_check_deleted = [n.uuid for n in [w2]] - with Capturing(): - delete_nodes([w2.pk], force=True, call_work_forward=False, call_calc_forward=False) + delete_nodes([w2.pk], dry_run=False, call_work_forward=False, call_calc_forward=False) self._check_existence(uuids_check_existence, uuids_check_deleted) @staticmethod @@ -1850,15 +1861,13 @@ def test_indep2w(self): dia, doa, pca, pwa, dib, dob, pcb, pwb, pw0 = self._create_indep2w_graph() uuids_check_existence = [n.uuid for n in [dia, dib]] uuids_check_deleted = [n.uuid for n in [doa, pca, pwa, dob, pcb, pwb, pw0]] - with Capturing(): - delete_nodes((pca.pk,), force=True, create_forward=True, call_calc_forward=True, call_work_forward=True) + delete_nodes((pca.pk,), dry_run=False, create_forward=True, call_calc_forward=True, call_work_forward=True) self._check_existence(uuids_check_existence, uuids_check_deleted) dia, doa, pca, pwa, dib, dob, pcb, pwb, pw0 = self._create_indep2w_graph() uuids_check_existence = [n.uuid for n in [dia, dib]] uuids_check_deleted = [n.uuid for n in [doa, pca, pwa, dob, pcb, pwb, pw0]] - with Capturing(): - delete_nodes((pwa.pk,), force=True, create_forward=True, call_calc_forward=True, call_work_forward=True) + delete_nodes((pwa.pk,), dry_run=False, create_forward=True, call_calc_forward=True, call_work_forward=True) self._check_existence(uuids_check_existence, uuids_check_deleted) # In this particular case where the workflow (pwa) only calls a single @@ -1872,15 +1881,13 @@ def test_indep2w(self): dia, doa, pca, pwa, dib, dob, pcb, pwb, pw0 = self._create_indep2w_graph() uuids_check_existence = [n.uuid for n in [dia, dib, dob, pcb, pwb]] uuids_check_deleted = [n.uuid for n in [doa, pca, pwa, pw0]] - with Capturing(): - delete_nodes((pca.pk,), force=True, create_forward=True, call_calc_forward=True, call_work_forward=False) + delete_nodes((pca.pk,), dry_run=False, create_forward=True, call_calc_forward=True, call_work_forward=False) self._check_existence(uuids_check_existence, uuids_check_deleted) dia, doa, pca, pwa, dib, dob, pcb, pwb, pw0 = self._create_indep2w_graph() uuids_check_existence = [n.uuid for n in [dia, dib, dob, pcb, pwb]] uuids_check_deleted = [n.uuid for n in [doa, pca, pwa, pw0]] - with Capturing(): - delete_nodes((pwa.pk,), force=True, create_forward=True, call_calc_forward=True, call_work_forward=False) + delete_nodes((pwa.pk,), dry_run=False, create_forward=True, call_calc_forward=True, call_work_forward=False) self._check_existence(uuids_check_existence, uuids_check_deleted) # The best and most controlled way to deal with this situation would be @@ -1897,12 +1904,10 @@ def test_indep2w(self): uuids_check_existence2 = [n.uuid for n in [dia, dib, dob, pcb, pwb]] uuids_check_deleted2 = [n.uuid for n in [doa, pca, pwa, pw0]] - with Capturing(): - delete_nodes((pw0.pk,), force=True, create_forward=False, call_calc_forward=False, call_work_forward=False) + delete_nodes((pw0.pk,), dry_run=False, create_forward=False, call_calc_forward=False, call_work_forward=False) self._check_existence(uuids_check_existence1, uuids_check_deleted1) - with Capturing(): - delete_nodes((pwa.pk,), force=True, create_forward=True, call_calc_forward=True, call_work_forward=True) + delete_nodes((pwa.pk,), dry_run=False, create_forward=True, call_calc_forward=True, call_work_forward=True) self._check_existence(uuids_check_existence2, uuids_check_deleted2) @staticmethod @@ -1972,8 +1977,7 @@ def test_loop_cases(self): di1, di2, di3, do1, pws, pcs, pwm = self._create_looped_graph() uuids_check_existence = [n.uuid for n in [di1, di2]] uuids_check_deleted = [n.uuid for n in [di3, do1, pcs, pws, pwm]] - with Capturing(): - delete_nodes([di3.pk], force=True, create_forward=True, call_calc_forward=True, call_work_forward=True) + delete_nodes([di3.pk], dry_run=False, create_forward=True, call_calc_forward=True, call_work_forward=True) self._check_existence(uuids_check_existence, uuids_check_deleted) # When disabling the call_calc and call_work forward rules, deleting @@ -1982,8 +1986,7 @@ def test_loop_cases(self): di1, di2, di3, do1, pws, pcs, pwm = self._create_looped_graph() uuids_check_existence = [n.uuid for n in [di1, di2, do1, pcs]] uuids_check_deleted = [n.uuid for n in [di3, pws, pwm]] - with Capturing(): - delete_nodes([di3.pk], force=True, create_forward=True, call_calc_forward=False, call_work_forward=False) + delete_nodes([di3.pk], dry_run=False, create_forward=True, call_calc_forward=False, call_work_forward=False) self._check_existence(uuids_check_existence, uuids_check_deleted) # Of course, deleting the selected input will cause all the procedure to @@ -1991,8 +1994,7 @@ def test_loop_cases(self): di1, di2, di3, do1, pws, pcs, pwm = self._create_looped_graph() uuids_check_existence = [n.uuid for n in [di2, di3]] uuids_check_deleted = [n.uuid for n in [di1, do1, pws, pcs, pwm]] - with Capturing(): - delete_nodes([di1.pk], force=True, create_forward=True, call_calc_forward=False, call_work_forward=False) + delete_nodes([di1.pk], dry_run=False, create_forward=True, call_calc_forward=False, call_work_forward=False) self._check_existence(uuids_check_existence, uuids_check_deleted) # Deleting with these settings the workflow that chooses inputs should @@ -2000,8 +2002,7 @@ def test_loop_cases(self): di1, di2, di3, do1, pws, pcs, pwm = self._create_looped_graph() uuids_check_existence = [n.uuid for n in [di1, di2, di3, do1, pcs]] uuids_check_deleted = [n.uuid for n in [pws, pwm]] - with Capturing(): - delete_nodes([pws.pk], force=True, create_forward=True, call_calc_forward=False, call_work_forward=False) + delete_nodes([pws.pk], dry_run=False, create_forward=True, call_calc_forward=False, call_work_forward=False) self._check_existence(uuids_check_existence, uuids_check_deleted) @staticmethod @@ -2042,6 +2043,29 @@ def test_long_case(self): node_list = self._create_long_graph(10) uuids_check_existence = [n.uuid for n in node_list[:3]] uuids_check_deleted = [n.uuid for n in node_list[3:]] - with Capturing(): - delete_nodes((node_list[3].pk,), force=True, create_forward=True) + delete_nodes((node_list[3].pk,), dry_run=False, create_forward=True) self._check_existence(uuids_check_existence, uuids_check_deleted) + + def test_delete_group_nodes(self): + """Test deleting all nodes in a group.""" + group = orm.Group(label='agroup').store() + nodes = [orm.Data().store() for _ in range(2)] + node_pks = {node.pk for node in nodes} + node_uuids = {node.uuid for node in nodes} + group.add_nodes(nodes) + deleted_pks, was_deleted = delete_group_nodes([group.pk], dry_run=False) + self.assertTrue(was_deleted) + self.assertSetEqual(deleted_pks, node_pks) + self._check_existence([], node_uuids) + + def test_delete_group_nodes_dry_run_true(self): + """Verify that a dry run should not delete the node.""" + group = orm.Group(label='agroup2').store() + nodes = [orm.Data().store() for _ in range(2)] + node_pks = {node.pk for node in nodes} + node_uuids = {node.uuid for node in nodes} + group.add_nodes(nodes) + deleted_pks, was_deleted = delete_group_nodes([group.pk], dry_run=True) + self.assertTrue(not was_deleted) + self.assertSetEqual(deleted_pks, node_pks) + self._check_existence(node_uuids, []) diff --git a/tests/tools/graph/test_age.py b/tests/tools/graph/test_age.py index 369040036e..2d531e967a 100644 --- a/tests/tools/graph/test_age.py +++ b/tests/tools/graph/test_age.py @@ -87,7 +87,7 @@ class TestAiidaGraphExplorer(AiidaTestCase): def setUp(self): super().setUp() - self.reset_database() + self.refurbish_db() @staticmethod def _create_basic_graph(): @@ -670,7 +670,7 @@ class TestAiidaEntitySet(AiidaTestCase): def setUp(self): super().setUp() - self.reset_database() + self.refurbish_db() def test_class_mismatch(self): """ diff --git a/tests/tools/groups/test_paths.py b/tests/tools/groups/test_paths.py index ba9061b156..c735ec2e53 100644 --- a/tests/tools/groups/test_paths.py +++ b/tests/tools/groups/test_paths.py @@ -116,6 +116,7 @@ def test_walk(setup_groups): assert [c.path for c in sorted(group_path.walk())] == ['a', 'a/b', 'a/c', 'a/c/d', 'a/c/e', 'a/c/e/g', 'a/f'] +@pytest.mark.filterwarnings('ignore::UserWarning') def test_walk_with_invalid_path(clear_database_before_test): """Test the ``GroupPath.walk`` method with invalid paths.""" for label in ['a', 'a/b', 'a/c/d', 'a/c/e/g', 'a/f', 'bad//group', 'bad/other']: diff --git a/tests/tools/importexport/__init__.py b/tests/tools/importexport/__init__.py index 2776a55f97..acd0d20bf6 100644 --- a/tests/tools/importexport/__init__.py +++ b/tests/tools/importexport/__init__.py @@ -7,3 +7,32 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Tests for AiiDA archive files (import, export).""" +from aiida.backends.testbase import AiidaTestCase +from aiida.tools.importexport import EXPORT_LOGGER, IMPORT_LOGGER + + +class AiidaArchiveTestCase(AiidaTestCase): + """Testcase for tests of archive-related functionality (import, export).""" + + def setUp(self): + super().setUp() + self.refurbish_db() + + @classmethod + def setUpClass(cls): + """Only run to prepare an archive file""" + super().setUpClass() + + # don't want output + EXPORT_LOGGER.setLevel('CRITICAL') + IMPORT_LOGGER.setLevel('CRITICAL') + + @classmethod + def tearDownClass(cls): + """Only run to prepare an archive file""" + super().tearDownClass() + + # don't want output + EXPORT_LOGGER.setLevel('INFO') + IMPORT_LOGGER.setLevel('INFO') diff --git a/tests/tools/importexport/orm/__init__.py b/tests/tools/importexport/orm/__init__.py index 2776a55f97..2763ad3bfc 100644 --- a/tests/tools/importexport/orm/__init__.py +++ b/tests/tools/importexport/orm/__init__.py @@ -7,3 +7,4 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Tests for archive-related operations (import, export) on ORM entities.""" diff --git a/tests/tools/importexport/orm/test_attributes.py b/tests/tools/importexport/orm/test_attributes.py index d0ea7bf19d..bdd13cf811 100644 --- a/tests/tools/importexport/orm/test_attributes.py +++ b/tests/tools/importexport/orm/test_attributes.py @@ -13,13 +13,13 @@ import os from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.tools.importexport import import_data, export from tests.utils.configuration import with_temp_dir +from .. import AiidaArchiveTestCase -class TestAttributes(AiidaTestCase): +class TestAttributes(AiidaArchiveTestCase): """Test ex-/import cases related to Attributes""" def create_data_with_attr(self): @@ -30,7 +30,7 @@ def create_data_with_attr(self): def import_attributes(self): """Import an AiiDA database""" - import_data(self.export_file, silent=True) + import_data(self.export_file) builder = orm.QueryBuilder().append(orm.Data, filters={'label': 'my_test_data_node'}) @@ -48,10 +48,10 @@ def test_import_of_attributes(self, temp_dir): # Export self.export_file = os.path.join(temp_dir, 'export.aiida') - export([self.data], filename=self.export_file, silent=True) + export([self.data], filename=self.export_file) # Clean db - self.reset_database() + self.clean_db() self.import_attributes() diff --git a/tests/tools/importexport/orm/test_calculations.py b/tests/tools/importexport/orm/test_calculations.py index 74f1007bd4..87118b6da6 100644 --- a/tests/tools/importexport/orm/test_calculations.py +++ b/tests/tools/importexport/orm/test_calculations.py @@ -12,22 +12,19 @@ import os +import pytest + from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.tools.importexport import import_data, export from tests.utils.configuration import with_temp_dir +from .. import AiidaArchiveTestCase -class TestCalculations(AiidaTestCase): +class TestCalculations(AiidaArchiveTestCase): """Test ex-/import cases related to Calculations""" - def setUp(self): - self.reset_database() - - def tearDown(self): - self.reset_database() - + @pytest.mark.requires_rmq @with_temp_dir def test_calcfunction(self, temp_dir): """Test @calcfunction""" @@ -54,10 +51,9 @@ def max_(**kwargs): not_wanted_uuids = [v.uuid for v in (b, c, d)] # At this point we export the generated data filename1 = os.path.join(temp_dir, 'export1.aiida') - export([res], filename=filename1, silent=True, return_backward=True) - self.clean_db() - self.insert_data() - import_data(filename1, silent=True) + export([res], filename=filename1, return_backward=True) + self.refurbish_db() + import_data(filename1) # Check that the imported nodes are correctly imported and that the value is preserved for uuid, value in uuids_values: self.assertEqual(orm.load_node(uuid).value, value) @@ -91,10 +87,9 @@ def test_workcalculation(self, temp_dir): uuids_values = [(v.uuid, v.value) for v in (output_1,)] filename1 = os.path.join(temp_dir, 'export1.aiida') - export([output_1], filename=filename1, silent=True) - self.clean_db() - self.insert_data() - import_data(filename1, silent=True) + export([output_1], filename=filename1) + self.refurbish_db() + import_data(filename1) for uuid, value in uuids_values: self.assertEqual(orm.load_node(uuid).value, value) diff --git a/tests/tools/importexport/orm/test_codes.py b/tests/tools/importexport/orm/test_codes.py index 2e26e4f392..7c16d46f53 100644 --- a/tests/tools/importexport/orm/test_codes.py +++ b/tests/tools/importexport/orm/test_codes.py @@ -12,25 +12,17 @@ import os from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.common.links import LinkType from aiida.tools.importexport import import_data, export from tests.utils.configuration import with_temp_dir from tests.tools.importexport.utils import get_all_node_links +from .. import AiidaArchiveTestCase -class TestCode(AiidaTestCase): +class TestCode(AiidaArchiveTestCase): """Test ex-/import cases related to Codes""" - def setUp(self): - super().setUp() - self.reset_database() - - def tearDown(self): - super().tearDown() - self.reset_database() - @with_temp_dir def test_that_solo_code_is_exported_correctly(self, temp_dir): """ @@ -47,11 +39,11 @@ def test_that_solo_code_is_exported_correctly(self, temp_dir): code_uuid = code.uuid export_file = os.path.join(temp_dir, 'export.aiida') - export([code], filename=export_file, silent=True) + export([code], filename=export_file) - self.reset_database() + self.clean_db() - import_data(export_file, silent=True) + import_data(export_file) self.assertEqual(orm.load_node(code_uuid).label, code_label) @@ -83,11 +75,11 @@ def test_input_code(self, temp_dir): export_links = get_all_node_links() export_file = os.path.join(temp_dir, 'export.aiida') - export([calc], filename=export_file, silent=True) + export([calc], filename=export_file) - self.reset_database() + self.clean_db() - import_data(export_file, silent=True) + import_data(export_file) # Check that the code node is there self.assertEqual(orm.load_node(code_uuid).label, code_label) @@ -120,11 +112,10 @@ def test_solo_code(self, temp_dir): code_uuid = code.uuid export_file = os.path.join(temp_dir, 'export.aiida') - export([code], filename=export_file, silent=True) + export([code], filename=export_file) - self.clean_db() - self.insert_data() + self.refurbish_db() - import_data(export_file, silent=True) + import_data(export_file) self.assertEqual(orm.load_node(code_uuid).label, code_label) diff --git a/tests/tools/importexport/orm/test_comments.py b/tests/tools/importexport/orm/test_comments.py index 921b51b1bf..6a5a1a09ca 100644 --- a/tests/tools/importexport/orm/test_comments.py +++ b/tests/tools/importexport/orm/test_comments.py @@ -13,27 +13,22 @@ import os from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.tools.importexport import import_data, export from tests.utils.configuration import with_temp_dir +from .. import AiidaArchiveTestCase -class TestComments(AiidaTestCase): +class TestComments(AiidaArchiveTestCase): """Test ex-/import cases related to Comments""" def setUp(self): super().setUp() - self.reset_database() self.comments = [ "We're no strangers to love", 'You know the rules and so do I', "A full commitment's what I'm thinking of", "You wouldn't get this from any other guy" ] - def tearDown(self): - super().tearDown() - self.reset_database() - @with_temp_dir def test_multiple_imports_for_single_node(self, temp_dir): """Test multiple imports for single node with different comments are imported correctly""" @@ -49,7 +44,7 @@ def test_multiple_imports_for_single_node(self, temp_dir): # Export as "EXISTING" DB export_file_existing = os.path.join(temp_dir, 'export_EXISTING.aiida') - export([node], filename=export_file_existing, silent=True) + export([node], filename=export_file_existing) # Add 2 more Comments and save UUIDs prior to export comment_three = orm.Comment(node, user, self.comments[2]).store() @@ -58,11 +53,11 @@ def test_multiple_imports_for_single_node(self, temp_dir): # Export as "FULL" DB export_file_full = os.path.join(temp_dir, 'export_FULL.aiida') - export([node], filename=export_file_full, silent=True) + export([node], filename=export_file_full) # Clean database and reimport "EXISTING" DB - self.reset_database() - import_data(export_file_existing, silent=True) + self.clean_db() + import_data(export_file_existing) # Check correct import builder = orm.QueryBuilder().append(orm.Node, tag='node', project=['uuid']) @@ -81,7 +76,7 @@ def test_multiple_imports_for_single_node(self, temp_dir): self.assertIn(imported_comment_content, self.comments[0:2]) # Import "FULL" DB - import_data(export_file_full, silent=True) + import_data(export_file_full) # Since the UUID of the node is identical with the node already in the DB, # the Comments should be added to the existing node, avoiding the addition @@ -125,11 +120,11 @@ def test_exclude_comments_flag(self, temp_dir): # Export nodes, excluding comments export_file = os.path.join(temp_dir, 'export.aiida') - export([node], filename=export_file, silent=True, include_comments=False) + export([node], filename=export_file, include_comments=False) # Clean database and reimport exported file - self.reset_database() - import_data(export_file, silent=True) + self.clean_db() + import_data(export_file) # Get node, users, and comments import_nodes = orm.QueryBuilder().append(orm.Node, project=['uuid']).all() @@ -169,11 +164,11 @@ def test_calc_and_data_nodes_with_comments(self, temp_dir): # Export nodes export_file = os.path.join(temp_dir, 'export.aiida') - export([calc_node, data_node], filename=export_file, silent=True) + export([calc_node, data_node], filename=export_file) # Clean database and reimport exported file - self.reset_database() - import_data(export_file, silent=True) + self.clean_db() + import_data(export_file) # Get nodes and comments builder = orm.QueryBuilder() @@ -220,11 +215,11 @@ def test_multiple_user_comments_single_node(self, temp_dir): # Export node, along with comments and users recursively export_file = os.path.join(temp_dir, 'export.aiida') - export([node], filename=export_file, silent=True) + export([node], filename=export_file) # Clean database and reimport exported file - self.reset_database() - import_data(export_file, silent=True) + self.clean_db() + import_data(export_file) # Get node, users, and comments builder = orm.QueryBuilder() @@ -308,9 +303,9 @@ def test_mtime_of_imported_comments(self, temp_dir): # Export, reset database and reimport export_file = os.path.join(temp_dir, 'export.aiida') - export([calc], filename=export_file, silent=True) - self.reset_database() - import_data(export_file, silent=True) + export([calc], filename=export_file) + self.clean_db() + import_data(export_file) # Retrieve node and comment builder = orm.QueryBuilder().append(orm.CalculationNode, tag='calc', project=['uuid', 'mtime']) @@ -358,7 +353,7 @@ def test_import_arg_comment_mode(self, temp_dir): # Export calc and comment export_file = os.path.join(temp_dir, 'export_file.aiida') - export([calc], filename=export_file, silent=True) + export([calc], filename=export_file) # Update comment cmt.set_content(self.comments[1]) @@ -371,10 +366,10 @@ def test_import_arg_comment_mode(self, temp_dir): # Export calc and UPDATED comment export_file_updated = os.path.join(temp_dir, 'export_file_updated.aiida') - export([calc], filename=export_file_updated, silent=True) + export([calc], filename=export_file_updated) # Reimport exported 'old' calc and comment - import_data(export_file, silent=True, comment_mode='newest') + import_data(export_file, comment_mode='newest') # Check there are exactly 1 CalculationNode and 1 Comment import_calcs = orm.QueryBuilder().append(orm.CalculationNode, tag='calc', project=['uuid']) @@ -391,7 +386,7 @@ def test_import_arg_comment_mode(self, temp_dir): ## Test comment_mode='overwrite' # Reimport exported 'old' calc and comment - import_data(export_file, silent=True, comment_mode='overwrite') + import_data(export_file, comment_mode='overwrite') # Check there are exactly 1 CalculationNode and 1 Comment import_calcs = orm.QueryBuilder().append(orm.CalculationNode, tag='calc', project=['uuid']) @@ -408,7 +403,7 @@ def test_import_arg_comment_mode(self, temp_dir): ## Test ValueError is raised when using a wrong comment_mode: with self.assertRaises(ImportValidationError): - import_data(export_file, silent=True, comment_mode='invalid') + import_data(export_file, comment_mode='invalid') @with_temp_dir def test_reimport_of_comments_for_single_node(self, temp_dir): @@ -474,7 +469,7 @@ def test_reimport_of_comments_for_single_node(self, temp_dir): # Export "EXISTING" DB export_file_existing = os.path.join(temp_dir, export_filenames['EXISTING']) - export([calc], filename=export_file_existing, silent=True) + export([calc], filename=export_file_existing) # Add remaining Comments for comment in self.comments[1:]: @@ -494,14 +489,14 @@ def test_reimport_of_comments_for_single_node(self, temp_dir): # Export "FULL" DB export_file_full = os.path.join(temp_dir, export_filenames['FULL']) - export([calc], filename=export_file_full, silent=True) + export([calc], filename=export_file_full) # Clean database - self.reset_database() + self.clean_db() ## Part II # Reimport "EXISTING" DB - import_data(export_file_existing, silent=True) + import_data(export_file_existing) # Check the database is correctly imported. # There should be exactly: 1 CalculationNode, 1 Comment @@ -533,14 +528,14 @@ def test_reimport_of_comments_for_single_node(self, temp_dir): # Export "NEW" DB export_file_new = os.path.join(temp_dir, export_filenames['NEW']) - export([calc], filename=export_file_new, silent=True) + export([calc], filename=export_file_new) # Clean database - self.reset_database() + self.clean_db() ## Part III # Reimport "EXISTING" DB - import_data(export_file_existing, silent=True) + import_data(export_file_existing) # Check the database is correctly imported. # There should be exactly: 1 CalculationNode, 1 Comment @@ -553,7 +548,7 @@ def test_reimport_of_comments_for_single_node(self, temp_dir): self.assertIn(str(import_comments.all()[0][0]), existing_comment_uuids) # Import "FULL" DB - import_data(export_file_full, silent=True) + import_data(export_file_full) # Check the database is correctly imported. # There should be exactly: 1 CalculationNode, 4 Comments (len(self.comments)) @@ -569,7 +564,7 @@ def test_reimport_of_comments_for_single_node(self, temp_dir): ## Part IV # Import "NEW" DB - import_data(export_file_new, silent=True) + import_data(export_file_new) # Check the database is correctly imported. # There should be exactly: 1 CalculationNode, 7 Comments (org. (1) + 2 x added (3) Comments) diff --git a/tests/tools/importexport/orm/test_computers.py b/tests/tools/importexport/orm/test_computers.py index 00065b5a3d..c6b18d0c19 100644 --- a/tests/tools/importexport/orm/test_computers.py +++ b/tests/tools/importexport/orm/test_computers.py @@ -13,21 +13,15 @@ import os from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.tools.importexport import import_data, export from tests.utils.configuration import with_temp_dir +from .. import AiidaArchiveTestCase -class TestComputer(AiidaTestCase): +class TestComputer(AiidaArchiveTestCase): """Test ex-/import cases related to Computers""" - def setUp(self): - self.reset_database() - - def tearDown(self): - self.reset_database() - @with_temp_dir def test_same_computer_import(self, temp_dir): """ @@ -65,15 +59,14 @@ def test_same_computer_import(self, temp_dir): # Export the first job calculation filename1 = os.path.join(temp_dir, 'export1.aiida') - export([calc1], filename=filename1, silent=True) + export([calc1], filename=filename1) # Export the second job calculation filename2 = os.path.join(temp_dir, 'export2.aiida') - export([calc2], filename=filename2, silent=True) + export([calc2], filename=filename2) # Clean the local database - self.clean_db() - self.create_user() + self.refurbish_db() # Check that there are no computers builder = orm.QueryBuilder() @@ -86,7 +79,7 @@ def test_same_computer_import(self, temp_dir): self.assertEqual(builder.count(), 0, 'There should not be any calculations in the database at this point.') # Import the first calculation - import_data(filename1, silent=True) + import_data(filename1) # Check that the calculation computer is imported correctly. builder = orm.QueryBuilder() @@ -105,7 +98,7 @@ def test_same_computer_import(self, temp_dir): comp_id = builder.first()[2] # Import the second calculation - import_data(filename2, silent=True) + import_data(filename2) # Check that the number of computers remains the same and its data # did not change. @@ -150,7 +143,7 @@ def test_same_computer_different_name_import(self, temp_dir): # Export the first job calculation filename1 = os.path.join(temp_dir, 'export1.aiida') - export([calc1], filename=filename1, silent=True) + export([calc1], filename=filename1) # Rename the computer comp1.label = f'{comp1_name}_updated' @@ -166,11 +159,10 @@ def test_same_computer_different_name_import(self, temp_dir): # Export the second job calculation filename2 = os.path.join(temp_dir, 'export2.aiida') - export([calc2], filename=filename2, silent=True) + export([calc2], filename=filename2) # Clean the local database - self.clean_db() - self.create_user() + self.refurbish_db() # Check that there are no computers builder = orm.QueryBuilder() @@ -183,7 +175,7 @@ def test_same_computer_different_name_import(self, temp_dir): self.assertEqual(builder.count(), 0, 'There should not be any calculations in the database at this point.') # Import the first calculation - import_data(filename1, silent=True) + import_data(filename1) # Check that the calculation computer is imported correctly. builder = orm.QueryBuilder() @@ -198,7 +190,7 @@ def test_same_computer_different_name_import(self, temp_dir): self.assertEqual(str(builder.first()[0]), comp1_name, 'The computer name is not correct.') # Import the second calculation - import_data(filename2, silent=True) + import_data(filename2) # Check that the number of computers remains the same and its data # did not change. @@ -230,11 +222,10 @@ def test_different_computer_same_name_import(self, temp_dir): # Export the first job calculation filename1 = os.path.join(temp_dir, 'export1.aiida') - export([calc1], filename=filename1, silent=True) + export([calc1], filename=filename1) # Reset the database - self.clean_db() - self.insert_data() + self.refurbish_db() # Set the computer name to the same name as before self.computer.label = comp1_name @@ -250,11 +241,10 @@ def test_different_computer_same_name_import(self, temp_dir): # Export the second job calculation filename2 = os.path.join(temp_dir, 'export2.aiida') - export([calc2], filename=filename2, silent=True) + export([calc2], filename=filename2) # Reset the database - self.clean_db() - self.insert_data() + self.refurbish_db() # Set the computer name to the same name as before self.computer.label = comp1_name @@ -270,11 +260,10 @@ def test_different_computer_same_name_import(self, temp_dir): # Export the third job calculation filename3 = os.path.join(temp_dir, 'export3.aiida') - export([calc3], filename=filename3, silent=True) + export([calc3], filename=filename3) # Clean the local database - self.clean_db() - self.create_user() + self.refurbish_db() # Check that there are no computers builder = orm.QueryBuilder() @@ -291,9 +280,9 @@ def test_different_computer_same_name_import(self, temp_dir): ) # Import all the calculations - import_data(filename1, silent=True) - import_data(filename2, silent=True) - import_data(filename3, silent=True) + import_data(filename1) + import_data(filename2) + import_data(filename3) # Retrieve the calculation-computer pairs builder = orm.QueryBuilder() @@ -327,14 +316,13 @@ def test_import_of_computer_json_params(self, temp_dir): # Export the first job calculation filename1 = os.path.join(temp_dir, 'export1.aiida') - export([calc1], filename=filename1, silent=True) + export([calc1], filename=filename1) # Clean the local database - self.clean_db() - self.create_user() + self.refurbish_db() # Import the data - import_data(filename1, silent=True) + import_data(filename1) builder = orm.QueryBuilder() builder.append(orm.Computer, project=['metadata'], tag='comp') @@ -349,7 +337,7 @@ def test_import_of_django_sqla_export_file(self): for archive in ['django.aiida', 'sqlalchemy.aiida']: # Clean the database - self.reset_database() + self.refurbish_db() # Import the needed data import_archive(archive, filepath='export/compare') diff --git a/tests/tools/importexport/orm/test_extras.py b/tests/tools/importexport/orm/test_extras.py index 3393557ca1..7e6f93d061 100644 --- a/tests/tools/importexport/orm/test_extras.py +++ b/tests/tools/importexport/orm/test_extras.py @@ -15,15 +15,15 @@ import tempfile from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.tools.importexport import import_data, export +from .. import AiidaArchiveTestCase -class TestExtras(AiidaTestCase): +class TestExtras(AiidaArchiveTestCase): """Test ex-/import cases related to Extras""" @classmethod - def setUpClass(cls, *args, **kwargs): + def setUpClass(cls): """Only run to prepare an archive file""" super().setUpClass() @@ -33,23 +33,18 @@ def setUpClass(cls, *args, **kwargs): data.set_extra_many({'b': 2, 'c': 3}) cls.tmp_folder = tempfile.mkdtemp() cls.export_file = os.path.join(cls.tmp_folder, 'export.aiida') - export([data], filename=cls.export_file, silent=True) + export([data], filename=cls.export_file) @classmethod - def tearDownClass(cls, *args, **kwargs): + def tearDownClass(cls): """Remove tmp_folder""" super().tearDownClass() shutil.rmtree(cls.tmp_folder, ignore_errors=True) - def setUp(self): - """This function runs before every test execution""" - self.clean_db() - self.insert_data() - def import_extras(self, mode_new='import'): """Import an aiida database""" - import_data(self.export_file, silent=True, extras_mode_new=mode_new) + import_data(self.export_file, extras_mode_new=mode_new) builder = orm.QueryBuilder().append(orm.Data, filters={'label': 'my_test_data_node'}) @@ -62,16 +57,13 @@ def modify_extras(self, mode_existing): self.imported_node.set_extra('b', 1000) self.imported_node.delete_extra('c') - import_data(self.export_file, silent=True, extras_mode_existing=mode_existing) + import_data(self.export_file, extras_mode_existing=mode_existing) # Query again the database builder = orm.QueryBuilder().append(orm.Data, filters={'label': 'my_test_data_node'}) self.assertEqual(builder.count(), 1) return builder.all()[0][0] - def tearDown(self): - pass - def test_import_of_extras(self): """Check if extras are properly imported""" self.import_extras() @@ -157,7 +149,7 @@ def test_extras_import_mode_correct(self): for mode2 in ['n', 'c']: # create or not create new extras for mode3 in ['l', 'u', 'd']: # leave old, update or delete collided extras mode = mode1 + mode2 + mode3 - import_data(self.export_file, silent=True, extras_mode_existing=mode) + import_data(self.export_file, extras_mode_existing=mode) def test_extras_import_mode_wrong(self): """Check a mode that is wrong""" @@ -165,14 +157,14 @@ def test_extras_import_mode_wrong(self): self.import_extras() with self.assertRaises(ImportValidationError): - import_data(self.export_file, silent=True, extras_mode_existing='xnd') # first letter is wrong + import_data(self.export_file, extras_mode_existing='xnd') # first letter is wrong with self.assertRaises(ImportValidationError): - import_data(self.export_file, silent=True, extras_mode_existing='nxd') # second letter is wrong + import_data(self.export_file, extras_mode_existing='nxd') # second letter is wrong with self.assertRaises(ImportValidationError): - import_data(self.export_file, silent=True, extras_mode_existing='nnx') # third letter is wrong + import_data(self.export_file, extras_mode_existing='nnx') # third letter is wrong with self.assertRaises(ImportValidationError): - import_data(self.export_file, silent=True, extras_mode_existing='n') # too short + import_data(self.export_file, extras_mode_existing='n') # too short with self.assertRaises(ImportValidationError): - import_data(self.export_file, silent=True, extras_mode_existing='nndnn') # too long + import_data(self.export_file, extras_mode_existing='nndnn') # too long with self.assertRaises(ImportValidationError): - import_data(self.export_file, silent=True, extras_mode_existing=5) # wrong type + import_data(self.export_file, extras_mode_existing=5) # wrong type diff --git a/tests/tools/importexport/orm/test_groups.py b/tests/tools/importexport/orm/test_groups.py index 004aa3182c..9fef09f69d 100644 --- a/tests/tools/importexport/orm/test_groups.py +++ b/tests/tools/importexport/orm/test_groups.py @@ -12,21 +12,16 @@ import os from aiida import orm -from aiida.backends.testbase import AiidaTestCase + from aiida.tools.importexport import import_data, export from tests.utils.configuration import with_temp_dir +from .. import AiidaArchiveTestCase -class TestGroups(AiidaTestCase): +class TestGroups(AiidaArchiveTestCase): """Test ex-/import cases related to Groups""" - def setUp(self): - self.reset_database() - - def tearDown(self): - self.reset_database() - @with_temp_dir def test_nodes_in_group(self, temp_dir): """ @@ -63,11 +58,10 @@ def test_nodes_in_group(self, temp_dir): # At this point we export the generated data filename1 = os.path.join(temp_dir, 'export1.aiida') - export([sd1, jc1, gr1], filename=filename1, silent=True) + export([sd1, jc1, gr1], filename=filename1) n_uuids = [sd1.uuid, jc1.uuid] - self.clean_db() - self.insert_data() - import_data(filename1, silent=True) + self.refurbish_db() + import_data(filename1) # Check that the imported nodes are correctly imported and that # the user assigned to the nodes is the right one @@ -102,11 +96,10 @@ def test_group_export(self, temp_dir): # Export the generated data, clean the database and import it again filename = os.path.join(temp_dir, 'export.aiida') - export([group], filename=filename, silent=True) + export([group], filename=filename) n_uuids = [sd1.uuid] - self.clean_db() - self.insert_data() - import_data(filename, silent=True) + self.refurbish_db() + import_data(filename) # Check that the imported nodes are correctly imported and that # the user assigned to the nodes is the right one @@ -146,14 +139,13 @@ def test_group_import_existing(self, temp_dir): # At this point we export the generated data filename = os.path.join(temp_dir, 'export1.aiida') - export([group], filename=filename, silent=True) - self.clean_db() - self.insert_data() + export([group], filename=filename) + self.refurbish_db() # Creating a group of the same name group = orm.Group(label='node_group_existing') group.store() - import_data(filename, silent=True) + import_data(filename) # The import should have created a new group with a suffix # I check for this: builder = orm.QueryBuilder().append(orm.Group, filters={'label': {'like': f'{grouplabel}%'}}) @@ -166,7 +158,7 @@ def test_group_import_existing(self, temp_dir): # I check that the group name was changed: self.assertTrue(builder.all()[0][0] != grouplabel) # I import another name, the group should not be imported again - import_data(filename, silent=True) + import_data(filename) builder = orm.QueryBuilder() builder.append(orm.Group, filters={'label': {'like': f'{grouplabel}%'}}) self.assertEqual(builder.count(), 2) @@ -187,8 +179,8 @@ def test_import_to_group(self, temp_dir): # Export Nodes filename = os.path.join(temp_dir, 'export.aiida') - export([data1, data2], filename=filename, silent=True) - self.reset_database() + export([data1, data2], filename=filename) + self.refurbish_db() # Create Group, do not store group_label = 'import_madness' @@ -197,11 +189,11 @@ def test_import_to_group(self, temp_dir): # Try to import to this Group, providing only label - this should fail with self.assertRaises(ImportValidationError) as exc: - import_data(filename, group=group_label, silent=True) + import_data(filename, group=group_label) self.assertIn('group must be a Group entity', str(exc.exception)) # Import properly now, providing the Group object - import_data(filename, group=group, silent=True) + import_data(filename, group=group) # Check Group for content builder = orm.QueryBuilder().append(orm.Group, filters={'label': group_label}, project='uuid') @@ -226,7 +218,7 @@ def test_import_to_group(self, temp_dir): group = orm.Group(label=group_label) group_uuid = group.uuid - import_data(filename, group=group, silent=True) + import_data(filename, group=group) imported_group = load_group(label=group_label) self.assertEqual(imported_group.uuid, group_uuid) diff --git a/tests/tools/importexport/orm/test_links.py b/tests/tools/importexport/orm/test_links.py index 8efe66c530..a7fd20f445 100644 --- a/tests/tools/importexport/orm/test_links.py +++ b/tests/tools/importexport/orm/test_links.py @@ -13,7 +13,7 @@ import tarfile from aiida import orm -from aiida.backends.testbase import AiidaTestCase + from aiida.common import json from aiida.common.folders import SandboxFolder from aiida.common.links import LinkType @@ -23,19 +23,12 @@ from tests.utils.configuration import with_temp_dir from tests.tools.importexport.utils import get_all_node_links +from .. import AiidaArchiveTestCase -class TestLinks(AiidaTestCase): +class TestLinks(AiidaArchiveTestCase): """Test ex-/import cases related to Links""" - def setUp(self): - self.reset_database() - super().setUp() - - def tearDown(self): - self.reset_database() - super().tearDown() - @with_temp_dir def test_links_to_unknown_nodes(self, temp_dir): """Test importing of nodes, that have links to unknown nodes.""" @@ -68,7 +61,7 @@ def test_links_to_unknown_nodes(self, temp_dir): with tarfile.open(filename, 'w:gz', format=tarfile.PAX_FORMAT) as tar: tar.add(unpack.abspath, arcname='') - self.reset_database() + self.clean_db() with self.assertRaises(DanglingLinkError): import_data(filename) @@ -96,7 +89,7 @@ def test_input_and_create_links(self, temp_dir): export_file = os.path.join(temp_dir, 'export.aiida') export([node_output], filename=export_file) - self.reset_database() + self.clean_db() import_data(export_file) import_links = get_all_node_links() @@ -267,7 +260,7 @@ def test_complex_workflow_graph_links(self, temp_dir): export_file = os.path.join(temp_dir, 'export.aiida') export(graph_nodes, filename=export_file) - self.reset_database() + self.clean_db() import_data(export_file) import_links = get_all_node_links() @@ -289,7 +282,7 @@ def test_complex_workflow_graph_export_sets(self, temp_dir): export([export_node], filename=export_file, overwrite=True) export_node_str = str(export_node) - self.reset_database() + self.clean_db() import_data(export_file) @@ -321,7 +314,7 @@ def test_high_level_workflow_links(self, temp_dir): for calcs in high_level_calc_nodes: for works in high_level_work_nodes: - self.reset_database() + self.refurbish_db() graph_nodes, _ = self.construct_complex_graph(calc_nodes=calcs, work_nodes=works) @@ -351,7 +344,7 @@ def test_high_level_workflow_links(self, temp_dir): export_file = os.path.join(temp_dir, 'export.aiida') export(graph_nodes, filename=export_file, overwrite=True) - self.reset_database() + self.refurbish_db() import_data(export_file) import_links = get_all_node_links() @@ -386,7 +379,7 @@ def prepare_link_flags_export(nodes_to_export, test_data): def link_flags_import_helper(self, test_data): """Helper function""" for test, (export_file, _, expected_nodes) in test_data.items(): - self.reset_database() + self.clean_db() import_data(export_file) @@ -594,7 +587,7 @@ def test_double_return_links_for_workflows(self, temp_dir): export_file = os.path.join(temp_dir, 'export.aiida') export([data_out, work1, work2, data_in], filename=export_file) - self.reset_database() + self.clean_db() import_data(export_file) @@ -688,7 +681,7 @@ def test_multiple_post_return_links(self, temp_dir): # pylint: disable=too-many export([data], filename=data_provenance, return_backward=False) export([data], filename=all_provenance, return_backward=True) - self.reset_database() + self.clean_db() # import data provenance import_data(data_provenance) diff --git a/tests/tools/importexport/orm/test_logs.py b/tests/tools/importexport/orm/test_logs.py index d61f6f8b41..2b191cacba 100644 --- a/tests/tools/importexport/orm/test_logs.py +++ b/tests/tools/importexport/orm/test_logs.py @@ -13,27 +13,16 @@ import os from aiida import orm -from aiida.backends.testbase import AiidaTestCase + from aiida.tools.importexport import import_data, export from tests.utils.configuration import with_temp_dir +from .. import AiidaArchiveTestCase -class TestLogs(AiidaTestCase): +class TestLogs(AiidaArchiveTestCase): """Test ex-/import cases related to Logs""" - def setUp(self): - """Reset database prior to all tests""" - super().setUp() - self.reset_database() - - def tearDown(self): - """ - Delete all the created log entries - """ - super().tearDown() - orm.Log.objects.delete_all() - @with_temp_dir def test_critical_log_msg_and_metadata(self, temp_dir): """ Testing logging of critical message """ @@ -54,11 +43,11 @@ def test_critical_log_msg_and_metadata(self, temp_dir): log_metadata = orm.Log.objects.get(dbnode_id=calc.id).metadata export_file = os.path.join(temp_dir, 'export.aiida') - export([calc], filename=export_file, silent=True) + export([calc], filename=export_file) - self.reset_database() + self.clean_db() - import_data(export_file, silent=True) + import_data(export_file) # Finding all the log messages logs = orm.Log.objects.all() @@ -85,11 +74,11 @@ def test_exclude_logs_flag(self, temp_dir): # Export, excluding logs export_file = os.path.join(temp_dir, 'export.aiida') - export([calc], filename=export_file, silent=True, include_logs=False) + export([calc], filename=export_file, include_logs=False) # Clean database and reimport exported data - self.reset_database() - import_data(export_file, silent=True) + self.clean_db() + import_data(export_file) # Finding all the log messages import_calcs = orm.QueryBuilder().append(orm.CalculationNode, project=['uuid']).all() @@ -122,11 +111,11 @@ def test_export_of_imported_logs(self, temp_dir): # Export export_file = os.path.join(temp_dir, 'export.aiida') - export([calc], filename=export_file, silent=True) + export([calc], filename=export_file) # Clean database and reimport exported data - self.reset_database() - import_data(export_file, silent=True) + self.clean_db() + import_data(export_file) # Finding all the log messages import_calcs = orm.QueryBuilder().append(orm.CalculationNode, project=['uuid']).all() @@ -143,11 +132,11 @@ def test_export_of_imported_logs(self, temp_dir): # Re-export calc = orm.load_node(import_calcs[0][0]) re_export_file = os.path.join(temp_dir, 're_export.aiida') - export([calc], filename=re_export_file, silent=True) + export([calc], filename=re_export_file) # Clean database and reimport exported data - self.reset_database() - import_data(re_export_file, silent=True) + self.clean_db() + import_data(re_export_file) # Finding all the log messages import_calcs = orm.QueryBuilder().append(orm.CalculationNode, project=['uuid']).all() @@ -176,7 +165,7 @@ def test_multiple_imports_for_single_node(self, temp_dir): # Export as "EXISTING" DB export_file_existing = os.path.join(temp_dir, 'export_EXISTING.aiida') - export([node], filename=export_file_existing, silent=True) + export([node], filename=export_file_existing) # Add 2 more Logs and save UUIDs for all three Logs prior to export node.logger.critical(log_msgs[1]) @@ -186,11 +175,11 @@ def test_multiple_imports_for_single_node(self, temp_dir): # Export as "FULL" DB export_file_full = os.path.join(temp_dir, 'export_FULL.aiida') - export([node], filename=export_file_full, silent=True) + export([node], filename=export_file_full) # Clean database and reimport "EXISTING" DB - self.reset_database() - import_data(export_file_existing, silent=True) + self.clean_db() + import_data(export_file_existing) # Check correct import builder = orm.QueryBuilder().append(orm.Node, tag='node', project=['uuid']) @@ -208,7 +197,7 @@ def test_multiple_imports_for_single_node(self, temp_dir): self.assertEqual(imported_log_message, log_msgs[0]) # Import "FULL" DB - import_data(export_file_full, silent=True) + import_data(export_file_full) # Since the UUID of the node is identical with the node already in the DB, # the Logs should be added to the existing node, avoiding the addition of @@ -290,7 +279,7 @@ def test_reimport_of_logs_for_single_node(self, temp_dir): # Export "EXISTING" DB export_file_existing = os.path.join(temp_dir, export_filenames['EXISTING']) - export([calc], filename=export_file_existing, silent=True) + export([calc], filename=export_file_existing) # Add remaining Log messages for log_msg in log_msgs[1:]: @@ -310,14 +299,14 @@ def test_reimport_of_logs_for_single_node(self, temp_dir): # Export "FULL" DB export_file_full = os.path.join(temp_dir, export_filenames['FULL']) - export([calc], filename=export_file_full, silent=True) + export([calc], filename=export_file_full) # Clean database - self.reset_database() + self.clean_db() ## Part II # Reimport "EXISTING" DB - import_data(export_file_existing, silent=True) + import_data(export_file_existing) # Check the database is correctly imported. # There should be exactly: 1 CalculationNode, 1 Log @@ -348,14 +337,14 @@ def test_reimport_of_logs_for_single_node(self, temp_dir): # Export "NEW" DB export_file_new = os.path.join(temp_dir, export_filenames['NEW']) - export([calc], filename=export_file_new, silent=True) + export([calc], filename=export_file_new) # Clean database - self.reset_database() + self.clean_db() ## Part III # Reimport "EXISTING" DB - import_data(export_file_existing, silent=True) + import_data(export_file_existing) # Check the database is correctly imported. # There should be exactly: 1 CalculationNode, 1 Log @@ -368,7 +357,7 @@ def test_reimport_of_logs_for_single_node(self, temp_dir): self.assertIn(str(import_logs.all()[0][0]), existing_log_uuids) # Import "FULL" DB - import_data(export_file_full, silent=True) + import_data(export_file_full) # Check the database is correctly imported. # There should be exactly: 1 CalculationNode, 3 Logs (len(log_msgs)) @@ -384,7 +373,7 @@ def test_reimport_of_logs_for_single_node(self, temp_dir): ## Part IV # Import "NEW" DB - import_data(export_file_new, silent=True) + import_data(export_file_new) # Check the database is correctly imported. # There should be exactly: 1 CalculationNode, 5 Logs (len(log_msgs)) diff --git a/tests/tools/importexport/orm/test_users.py b/tests/tools/importexport/orm/test_users.py index 47466caf33..e2e28a2d02 100644 --- a/tests/tools/importexport/orm/test_users.py +++ b/tests/tools/importexport/orm/test_users.py @@ -12,21 +12,15 @@ import os from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.tools.importexport import import_data, export from tests.utils.configuration import with_temp_dir +from .. import AiidaArchiveTestCase -class TestUsers(AiidaTestCase): +class TestUsers(AiidaArchiveTestCase): """Test ex-/import cases related to Users""" - def setUp(self): - self.reset_database() - - def tearDown(self): - self.reset_database() - @with_temp_dir def test_nodes_belonging_to_different_users(self, temp_dir): """ @@ -82,10 +76,9 @@ def test_nodes_belonging_to_different_users(self, temp_dir): filename = os.path.join(temp_dir, 'export.aiida') - export([sd3], filename=filename, silent=True) - self.clean_db() - self.create_user() - import_data(filename, silent=True) + export([sd3], filename=filename) + self.refurbish_db() + import_data(filename) # Check that the imported nodes are correctly imported and that # the user assigned to the nodes is the right one @@ -138,11 +131,10 @@ def test_non_default_user_nodes(self, temp_dir): # pylint: disable=too-many-sta # At this point we export the generated data filename1 = os.path.join(temp_dir, 'export1.aiidaz') - export([sd2], filename=filename1, silent=True) + export([sd2], filename=filename1) uuids1 = [sd1.uuid, jc1.uuid, sd2.uuid] - self.clean_db() - self.insert_data() - import_data(filename1, silent=True) + self.refurbish_db() + import_data(filename1) # Check that the imported nodes are correctly imported and that # the user assigned to the nodes is the right one @@ -171,10 +163,9 @@ def test_non_default_user_nodes(self, temp_dir): # pylint: disable=too-many-sta uuids2 = [jc2.uuid, sd3.uuid] filename2 = os.path.join(temp_dir, 'export2.aiida') - export([sd3], filename=filename2, silent=True) - self.clean_db() - self.insert_data() - import_data(filename2, silent=True) + export([sd3], filename=filename2) + self.refurbish_db() + import_data(filename2) # Check that the imported nodes are correctly imported and that # the user assigned to the nodes is the right one diff --git a/tests/tools/importexport/test_complex.py b/tests/tools/importexport/test_complex.py index f041a63d08..84e95f9458 100644 --- a/tests/tools/importexport/test_complex.py +++ b/tests/tools/importexport/test_complex.py @@ -13,22 +13,16 @@ import os from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.common.links import LinkType from aiida.tools.importexport import import_data, export from tests.utils.configuration import with_temp_dir +from . import AiidaArchiveTestCase -class TestComplex(AiidaTestCase): +class TestComplex(AiidaArchiveTestCase): """Test complex ex-/import cases""" - def setUp(self): - self.reset_database() - - def tearDown(self): - self.reset_database() - @with_temp_dir def test_complex_graph_import_export(self, temp_dir): """ @@ -89,12 +83,11 @@ def test_complex_graph_import_export(self, temp_dir): } filename = os.path.join(temp_dir, 'export.aiida') - export([fd1], filename=filename, silent=True) + export([fd1], filename=filename) - self.clean_db() - self.create_user() + self.refurbish_db() - import_data(filename, silent=True, ignore_unknown_nodes=True) + import_data(filename, ignore_unknown_nodes=True) for uuid, label in node_uuids_labels.items(): try: @@ -201,12 +194,11 @@ def get_hash_from_db_content(grouplabel): group = orm.Group.get(label=grouplabel) # exporting based on all members of the group # this also checks if group memberships are preserved! - export([group] + list(group.nodes), filename=filename, silent=True) + export([group] + list(group.nodes), filename=filename) # cleaning the DB! - self.clean_db() - self.create_user() + self.refurbish_db() # reimporting the data from the file - import_data(filename, silent=True, ignore_unknown_nodes=True) + import_data(filename, ignore_unknown_nodes=True) # creating the hash from db content new_hash = get_hash_from_db_content(grouplabel) # I check for equality against the first hash created, which implies that hashes diff --git a/tests/tools/importexport/test_deprecation.py b/tests/tools/importexport/test_deprecation.py index 05f4fc9361..05bbeed68b 100644 --- a/tests/tools/importexport/test_deprecation.py +++ b/tests/tools/importexport/test_deprecation.py @@ -10,6 +10,7 @@ """Test deprecated parts still work and emit deprecations warnings""" # pylint: disable=invalid-name import os +import warnings import pytest @@ -21,63 +22,68 @@ def test_export_functions(temp_dir): """Check `what` and `outfile` in export(), export_tar() and export_zip()""" - what = [] - outfile = os.path.join(temp_dir, 'deprecated.aiida') - - for export_function in (dbexport.export, dbexport.export_tar, dbexport.export_zip): - if os.path.exists(outfile): - os.remove(outfile) - with pytest.warns(AiidaDeprecationWarning, match='`what` is deprecated, please use `entities` instead'): - export_function(what=what, filename=outfile) - - if os.path.exists(outfile): - os.remove(outfile) - with pytest.warns( - AiidaDeprecationWarning, match='`what` is deprecated, the supplied `entities` input will be used' - ): - export_function(entities=what, what=what, filename=outfile) - - if os.path.exists(outfile): - os.remove(outfile) - with pytest.warns( - AiidaDeprecationWarning, - match='`outfile` is deprecated, please use `filename` instead', - ): - export_function(what, outfile=outfile) - - if os.path.exists(outfile): - os.remove(outfile) - with pytest.warns( - AiidaDeprecationWarning, match='`outfile` is deprecated, the supplied `filename` input will be used' - ): - export_function(what, filename=outfile, outfile=outfile) - - if os.path.exists(outfile): - os.remove(outfile) - with pytest.raises(TypeError, match='`entities` must be specified'): - export_function(filename=outfile) + with warnings.catch_warnings(): # To avoid printing them in output (pytest.mark.filterwarnings does not work) + warnings.filterwarnings('ignore', category=AiidaDeprecationWarning) + what = [] + outfile = os.path.join(temp_dir, 'deprecated.aiida') + + for export_function in (dbexport.export, dbexport.export_tar, dbexport.export_zip): + if os.path.exists(outfile): + os.remove(outfile) + with pytest.warns(AiidaDeprecationWarning, match='`what` is deprecated, please use `entities` instead'): + export_function(what=what, filename=outfile) + + if os.path.exists(outfile): + os.remove(outfile) + with pytest.warns( + AiidaDeprecationWarning, match='`what` is deprecated, the supplied `entities` input will be used' + ): + export_function(entities=what, what=what, filename=outfile) + + if os.path.exists(outfile): + os.remove(outfile) + with pytest.warns( + AiidaDeprecationWarning, + match='`outfile` is deprecated, please use `filename` instead', + ): + export_function(what, outfile=outfile) + + if os.path.exists(outfile): + os.remove(outfile) + with pytest.warns( + AiidaDeprecationWarning, match='`outfile` is deprecated, the supplied `filename` input will be used' + ): + export_function(what, filename=outfile, outfile=outfile) + + if os.path.exists(outfile): + os.remove(outfile) + with pytest.raises(TypeError, match='`entities` must be specified'): + export_function(filename=outfile) def test_export_tree(): """Check `what` in export_tree()""" from aiida.common.folders import SandboxFolder - what = [] + with warnings.catch_warnings(): # To avoid printing them in output (pytest.mark.filterwarnings does not work) + warnings.filterwarnings('ignore', category=AiidaDeprecationWarning) - with SandboxFolder() as folder: - with pytest.warns(AiidaDeprecationWarning, match='`what` is deprecated, please use `entities` instead'): - dbexport.export_tree(what=what, folder=folder) + what = [] - folder.erase(create_empty_folder=True) - with pytest.warns( - AiidaDeprecationWarning, match='`what` is deprecated, the supplied `entities` input will be used' - ): - dbexport.export_tree(entities=what, what=what, folder=folder) + with SandboxFolder() as folder: + with pytest.warns(AiidaDeprecationWarning, match='`what` is deprecated, please use `entities` instead'): + dbexport.export_tree(what=what, folder=folder) - folder.erase(create_empty_folder=True) - with pytest.raises(TypeError, match='`entities` must be specified'): - dbexport.export_tree(folder=folder) + folder.erase(create_empty_folder=True) + with pytest.warns( + AiidaDeprecationWarning, match='`what` is deprecated, the supplied `entities` input will be used' + ): + dbexport.export_tree(entities=what, what=what, folder=folder) - folder.erase(create_empty_folder=True) - with pytest.raises(TypeError, match='`folder` must be specified'): - dbexport.export_tree(entities=what) + folder.erase(create_empty_folder=True) + with pytest.raises(TypeError, match='`entities` must be specified'): + dbexport.export_tree(folder=folder) + + folder.erase(create_empty_folder=True) + with pytest.raises(TypeError, match='`folder` must be specified'): + dbexport.export_tree(entities=what) diff --git a/tests/tools/importexport/test_prov_redesign.py b/tests/tools/importexport/test_prov_redesign.py index 801657624f..5a04caa1b6 100644 --- a/tests/tools/importexport/test_prov_redesign.py +++ b/tests/tools/importexport/test_prov_redesign.py @@ -12,14 +12,16 @@ import os +import pytest + from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.tools.importexport import import_data, export from tests.utils.configuration import with_temp_dir +from . import AiidaArchiveTestCase -class TestProvenanceRedesign(AiidaTestCase): +class TestProvenanceRedesign(AiidaArchiveTestCase): """ Check changes in database schema after upgrading to v0.4 (Provenance Redesign) This includes all migrations from "base_data_plugin_type_string" (django: 0008) until "dbgroup_type_string_change_content" (django: 0022), both included. @@ -59,13 +61,13 @@ def test_base_data_type_change(self, temp_dir): # Export nodes filename = os.path.join(temp_dir, 'export.aiida') - export(export_nodes, filename=filename, silent=True) + export(export_nodes, filename=filename) # Clean the database - self.reset_database() + self.clean_db() # Import nodes again - import_data(filename, silent=True) + import_data(filename) # Check whether types are correctly imported nlist = orm.load_node(list_node_uuid) # List @@ -86,6 +88,7 @@ def test_base_data_type_change(self, temp_dir): msg = f"type of node ('{nlist.node_type}') is not updated according to db schema v0.4" self.assertEqual(nlist.node_type, 'data.list.List.', msg=msg) + @pytest.mark.requires_rmq @with_temp_dir def test_node_process_type(self, temp_dir): """ Column `process_type` added to `Node` entity DB table """ @@ -108,11 +111,11 @@ def test_node_process_type(self, temp_dir): # Export nodes filename = os.path.join(temp_dir, 'export.aiida') - export([node], filename=filename, silent=True) + export([node], filename=filename) # Clean the database and reimport data - self.reset_database() - import_data(filename, silent=True) + self.clean_db() + import_data(filename) # Retrieve node and check exactly one node is imported builder = orm.QueryBuilder() @@ -151,11 +154,11 @@ def test_code_type_change(self, temp_dir): # Export node filename = os.path.join(temp_dir, 'export.aiida') - export([code], filename=filename, silent=True) + export([code], filename=filename) # Clean the database and reimport - self.reset_database() - import_data(filename, silent=True) + self.clean_db() + import_data(filename) # Retrieve Code node and make sure exactly 1 is retrieved builder = orm.QueryBuilder() @@ -233,11 +236,11 @@ def test_group_name_and_type_change(self, temp_dir): # Export node filename = os.path.join(temp_dir, 'export.aiida') - export([group_user, group_upf], filename=filename, silent=True) + export([group_user, group_upf], filename=filename) # Clean the database and reimport - self.reset_database() - import_data(filename, silent=True) + self.clean_db() + import_data(filename) # Retrieve Groups and make sure exactly 3 are retrieved (including the "import group") builder = orm.QueryBuilder() diff --git a/tests/tools/importexport/test_specific_import.py b/tests/tools/importexport/test_specific_import.py index 874a296e57..926f3956ac 100644 --- a/tests/tools/importexport/test_specific_import.py +++ b/tests/tools/importexport/test_specific_import.py @@ -12,29 +12,24 @@ import os import shutil import tempfile +import warnings import numpy as np from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.common.folders import RepositoryFolder +from aiida.common.warnings import AiidaDeprecationWarning from aiida.orm.utils._repository import Repository from aiida.tools.importexport import import_data, export from aiida.tools.importexport.common import exceptions from tests.utils.configuration import with_temp_dir +from . import AiidaArchiveTestCase -class TestSpecificImport(AiidaTestCase): +class TestSpecificImport(AiidaArchiveTestCase): """Test specific ex-/import cases""" - def setUp(self): - super().setUp() - self.reset_database() - - def tearDown(self): - self.reset_database() - def test_simple_import(self): """ This is a very simple test which checks that an archive file with nodes @@ -64,18 +59,17 @@ def test_simple_import(self): with tempfile.NamedTemporaryFile() as handle: nodes = [parameters] - export(nodes, filename=handle.name, overwrite=True, silent=True) + export(nodes, filename=handle.name, overwrite=True) # Check that we have the expected number of nodes in the database self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), len(nodes)) # Clean the database and verify there are no nodes left - self.clean_db() - self.create_user() + self.refurbish_db() self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), 0) # After importing we should have the original number of nodes again - import_data(handle.name, silent=True) + import_data(handle.name) self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), len(nodes)) def test_cycle_structure_data(self): @@ -125,18 +119,17 @@ def test_cycle_structure_data(self): with tempfile.NamedTemporaryFile() as handle: nodes = [structure, child_calculation, parent_process, remote_folder] - export(nodes, filename=handle.name, overwrite=True, silent=True) + export(nodes, filename=handle.name, overwrite=True) # Check that we have the expected number of nodes in the database self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), len(nodes)) # Clean the database and verify there are no nodes left - self.clean_db() - self.create_user() + self.refurbish_db() self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), 0) # After importing we should have the original number of nodes again - import_data(handle.name, silent=True) + import_data(handle.name) self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), len(nodes)) # Verify that orm.CalculationNodes have non-empty attribute dictionaries @@ -230,14 +223,16 @@ def test_missing_node_repo_folder_import(self, temp_dir): # Export and reset db filename = os.path.join(temp_dir, 'export.aiida') - export([node], filename=filename, file_format='tar.gz', silent=True) - self.reset_database() + export([node], filename=filename, file_format='tar.gz') + self.clean_db() # Untar archive file, remove repository folder, re-tar node_shard_uuid = export_shard_uuid(node_uuid) node_top_folder = node_shard_uuid.split('/')[0] with SandboxFolder() as folder: - extract_tar(filename, folder, silent=True, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=AiidaDeprecationWarning) + extract_tar(filename, folder, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER) node_folder = folder.get_subfolder(os.path.join(NODES_EXPORT_SUBFOLDER, node_shard_uuid)) self.assertTrue( node_folder.exists(), msg="The Node's repository folder should still exist in the archive file" @@ -292,12 +287,14 @@ def test_empty_repo_folder_export(self, temp_dir): 'zip archive': os.path.join(temp_dir, 'export.zip') } - export_tree([node], folder=Folder(archive_variants['archive folder'])) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=AiidaDeprecationWarning) + export_tree([node], folder=Folder(archive_variants['archive folder'])) export([node], filename=archive_variants['tar archive'], file_format='tar.gz') export([node], filename=archive_variants['zip archive'], file_format='zip') for variant, filename in archive_variants.items(): - self.reset_database() + self.clean_db() node_count = orm.QueryBuilder().append(orm.Dict, project='uuid').count() self.assertEqual(node_count, 0, msg=f'After DB reset {node_count} Dict Nodes was (wrongly) found') @@ -330,7 +327,9 @@ def test_import_folder(self): archive = get_archive_file('arithmetic.add.aiida', filepath='calcjob') with SandboxFolder() as temp_dir: - extract_zip(archive, temp_dir, silent=True) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=AiidaDeprecationWarning) + extract_zip(archive, temp_dir) # Make sure the JSON files and the nodes subfolder was correctly extracted (is present), # then try to import it by passing the extracted folder to the import function. @@ -342,7 +341,7 @@ def test_import_folder(self): for dirpath, dirnames, _ in os.walk(temp_dir.abspath): org_folders += [os.path.join(dirpath, dirname) for dirname in dirnames] - import_data(temp_dir.abspath, silent=True) + import_data(temp_dir.abspath) # Check nothing from the source was deleted src_folders = [] diff --git a/tests/tools/visualization/test_graph.py b/tests/tools/visualization/test_graph.py index a816d9e012..c3352441aa 100644 --- a/tests/tools/visualization/test_graph.py +++ b/tests/tools/visualization/test_graph.py @@ -23,11 +23,7 @@ class TestVisGraph(AiidaTestCase): def setUp(self): super().setUp() - self.reset_database() - - def tearDown(self): - super().tearDown() - self.reset_database() + self.refurbish_db() def create_provenance(self): """create an example provenance graph diff --git a/tests/transports/test_all_plugins.py b/tests/transports/test_all_plugins.py index 2920126cd2..7470dd15c9 100644 --- a/tests/transports/test_all_plugins.py +++ b/tests/transports/test_all_plugins.py @@ -15,7 +15,19 @@ Plugin specific tests will be written in the plugin itself. """ import io +import os +import random +import tempfile +import signal +import shutil +import string +import time import unittest +import uuid + +import psutil + +from aiida.plugins import SchedulerFactory # TODO : test for copy with pattern # TODO : test for copy with/without patterns, overwriting folder @@ -35,7 +47,6 @@ def get_all_custom_transports(): it was found) """ import importlib - import os modulename = __name__.rpartition('.')[0] this_full_fname = __file__ @@ -133,11 +144,6 @@ def test_makedirs(self, custom_transport): """ Verify the functioning of makedirs command """ - # Imports required later - import random - import string - import os - with custom_transport as transport: location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' @@ -176,11 +182,6 @@ def test_rmtree(self, custom_transport): """ Verify the functioning of rmtree command """ - # Imports required later - import random - import string - import os - with custom_transport as transport: location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' @@ -221,12 +222,6 @@ def test_listdir(self, custom_transport): """ create directories, verify listdir, delete a folder with subfolders """ - # Imports required later - import tempfile - import random - import string - import os - with custom_transport as trans: # We cannot use tempfile.mkdtemp because we're on a remote folder location = trans.normalize(os.path.join('/', 'tmp')) @@ -270,11 +265,6 @@ def test_listdir_withattributes(self, custom_transport): """ create directories, verify listdir_withattributes, delete a folder with subfolders """ - # Imports required later - import tempfile - import random - import string - import os def simplify_attributes(data): """ @@ -340,11 +330,6 @@ def simplify_attributes(data): @run_for_all_plugins def test_dir_creation_deletion(self, custom_transport): """Test creating and deleting directories.""" - # Imports required later - import random - import string - import os - with custom_transport as transport: location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' @@ -370,11 +355,6 @@ def test_dir_copy(self, custom_transport): Verify if in the copy of a directory also the protection bits are carried over """ - # Imports required later - import random - import string - import os - with custom_transport as transport: location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' @@ -403,11 +383,6 @@ def test_dir_permissions_creation_modification(self, custom_transport): # pylin verify if chmod raises IOError when trying to change bits on a non-existing folder """ - # Imports required later - import random - import string - import os - with custom_transport as transport: location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' @@ -460,11 +435,6 @@ def test_dir_reading_permissions(self, custom_transport): Try to enter a directory with no read permissions. Verify that the cwd has not changed after failed try. """ - # Imports required later - import random - import string - import os - with custom_transport as transport: location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' @@ -503,8 +473,6 @@ def test_isfile_isdir_to_empty_string(self, custom_transport): I check that isdir or isfile return False when executed on an empty string """ - import os - with custom_transport as transport: location = transport.normalize(os.path.join('/', 'tmp')) transport.chdir(location) @@ -517,8 +485,6 @@ def test_isfile_isdir_to_non_existing_string(self, custom_transport): I check that isdir or isfile return False when executed on an empty string """ - import os - with custom_transport as transport: location = transport.normalize(os.path.join('/', 'tmp')) transport.chdir(location) @@ -535,8 +501,6 @@ def test_chdir_to_empty_string(self, custom_transport): not change (this is a paramiko default behavior), but getcwd() is still correctly defined. """ - import os - with custom_transport as transport: new_dir = transport.normalize(os.path.join('/', 'tmp')) transport.chdir(new_dir) @@ -555,10 +519,6 @@ class TestPutGetFile(unittest.TestCase): @run_for_all_plugins def test_put_and_get(self, custom_transport): """Test putting and getting files.""" - import os - import random - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -605,10 +565,6 @@ def test_put_get_abs_path(self, custom_transport): """ test of exception for non existing files and abs path """ - import os - import random - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -669,10 +625,6 @@ def test_put_get_empty_string(self, custom_transport): test of exception put/get of empty strings """ # TODO : verify the correctness of \n at the end of a file - import os - import random - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -752,10 +704,6 @@ class TestPutGetTree(unittest.TestCase): @run_for_all_plugins def test_put_and_get(self, custom_transport): """Test putting and getting files.""" - import os - import random - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -807,8 +755,6 @@ def test_put_and_get(self, custom_transport): self.assertTrue('file.txt' in list_pushed_file) self.assertTrue('file.txt' in list_retrieved_file) - import shutil - shutil.rmtree(local_subfolder) shutil.rmtree(retrieved_subfolder) transport.rmtree(remote_subfolder) @@ -819,11 +765,6 @@ def test_put_and_get(self, custom_transport): @run_for_all_plugins def test_put_and_get_overwrite(self, custom_transport): """Test putting and getting files with overwrites.""" - import os - import random - import shutil - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -877,10 +818,6 @@ def test_put_and_get_overwrite(self, custom_transport): @run_for_all_plugins def test_copy(self, custom_transport): """Test copying.""" - import os - import random - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -952,10 +889,6 @@ def test_put(self, custom_transport): # pylint: disable=too-many-statements # exactly the same tests of copy, just with the put function # and therefore the local path must be absolute - import os - import random - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -1033,11 +966,6 @@ def test_get(self, custom_transport): # pylint: disable=too-many-statements # exactly the same tests of copy, just with the put function # and therefore the local path must be absolute - import os - import random - import shutil - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -1119,10 +1047,6 @@ def test_put_get_abs_path(self, custom_transport): """ test of exception for non existing files and abs path """ - import os - import random - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -1194,10 +1118,6 @@ def test_put_get_empty_string(self, custom_transport): test of exception put/get of empty strings """ # TODO : verify the correctness of \n at the end of a file - import os - import random - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -1263,9 +1183,6 @@ def test_put_get_empty_string(self, custom_transport): @run_for_all_plugins def test_gettree_nested_directory(self, custom_transport): # pylint: disable=no-self-use """Test `gettree` for a nested directory.""" - import os - import tempfile - with tempfile.TemporaryDirectory() as dir_remote, tempfile.TemporaryDirectory() as dir_local: content = b'dummy\ncontent' filepath = os.path.join(dir_remote, 'sub', 'path', 'filename.txt') @@ -1294,8 +1211,6 @@ def test_exec_pwd(self, custom_transport): creation (which should be done by paramiko) and in the command execution (done in this module, in the _exec_command_internal function). """ - import os - # Start value delete_at_end = False @@ -1365,3 +1280,84 @@ def test_exec_with_wrong_stdin(self, custom_transport): with custom_transport as transport: with self.assertRaises(ValueError): transport.exec_command_wait('cat', stdin=1) + + +class TestDirectScheduler(unittest.TestCase): + """ + Test how the direct scheduler works. + + While this is technically a scheduler test, I put it under the transport tests + because 1) in reality I am testing the interaction of each transport with the + direct scheduler; 2) the direct scheduler is always available; 3) I am reusing + the infrastructure to test on multiple transport plugins. + """ + + @run_for_all_plugins + def test_asynchronous_execution(self, custom_transport): + """Test that the execution of a long(ish) command via the direct scheduler does not block. + + This is a regression test for #3094, where running a long job on the direct scheduler + (via SSH) would lock the interpreter until the job was done. + """ + # Use a unique name, using a UUID, to avoid concurrent tests (or very rapid + # tests that follow each other) to overwrite the same destination + script_fname = f'sleep-submit-{uuid.uuid4().hex}-{custom_transport.__class__.__name__}.sh' + + scheduler = SchedulerFactory('direct')() + scheduler.set_transport(custom_transport) + with custom_transport as transport: + try: + with tempfile.NamedTemporaryFile() as tmpf: + # Put a submission script that sleeps 10 seconds + tmpf.write(b'#!/bin/bash\nsleep 10\n') + tmpf.flush() + + transport.chdir('/tmp') + transport.putfile(tmpf.name, script_fname) + + timestamp_before = time.time() + job_id_string = scheduler.submit_from_script('/tmp', script_fname) + + elapsed_time = time.time() - timestamp_before + # We want to get back control. If it takes < 5 seconds, it means that it is not blocking + # as the job is taking at least 10 seconds. I put 5 as the machine could be slow (including the + # SSH connection etc.) and I don't want to have false failures. + # Actually, if the time is short, it could mean also that the execution failed! + # So I double check later that the execution was successful. + self.assertTrue( + elapsed_time < 5, 'Getting back control after remote execution took more than 5 seconds! ' + 'Probably submission is blocking' + ) + + # Check that the job is still running + # Wait 0.2 more seconds, so that I don't do a super-quick check that might return True + # even if it's not sleeping + time.sleep(0.2) + # Check that the job is still running - IMPORTANT, I'm assuming that all transports actually act + # on the *same* local machine, and that the job_id is actually the process PID. + # This needs to be adapted if: + # - a new transport plugin is tested and this does not test the same machine + # - a new scheduler is used and does not use the process PID, or the job_id of the 'direct' scheduler + # is not anymore simply the job PID + job_id = int(job_id_string) + self.assertTrue( + psutil.pid_exists(job_id), 'The job is not there after a bit more than 1 second! Probably it failed' + ) + finally: + # Clean up by killing the remote job. + # This assumes it's on the same machine; if we add tests on a different machine, + # we need to call 'kill' via the transport instead. + # In reality it's not critical to remove it since it will end after 10 seconds of + # sleeping, but this might avoid warnings (e.g. ResourceWarning) + try: + os.kill(job_id, signal.SIGTERM) + except ProcessLookupError: + # If the process is already dead (or has never run), I just ignore the error + pass + + # Also remove the script + try: + transport.remove(f'/tmp/{script_fname}') + except FileNotFoundError: + # If the file wasn't even created, I just ignore this error + pass diff --git a/tests/utils/archives.py b/tests/utils/archives.py index f683154fa1..1b6933ba46 100644 --- a/tests/utils/archives.py +++ b/tests/utils/archives.py @@ -77,7 +77,7 @@ def import_archive(archive, filepath=None, external_module=None): dirpath_archive = get_archive_file(archive, filepath=filepath, external_module=external_module) - import_data(dirpath_archive, silent=True) + import_data(dirpath_archive) def read_json_files(path, *, names=('metadata.json', 'data.json')) -> List[dict]: diff --git a/tests/utils/configuration.py b/tests/utils/configuration.py index ae922cef12..e3767af946 100644 --- a/tests/utils/configuration.py +++ b/tests/utils/configuration.py @@ -43,72 +43,6 @@ def create_mock_profile(name, repository_dirpath=None, **kwargs): return Profile(name, profile_dictionary) -@contextlib.contextmanager -def temporary_config_instance(): - """Create a temporary AiiDA instance.""" - current_config = None - current_config_path = None - current_profile_name = None - temporary_config_directory = None - - from aiida.common.utils import Capturing - from aiida.manage import configuration - from aiida.manage.configuration import settings, load_profile, reset_profile - - try: - from aiida.manage.configuration.settings import create_instance_directories - - # Store the current configuration instance and config directory path - current_config = configuration.CONFIG - current_config_path = current_config.dirpath - current_profile_name = configuration.PROFILE.name - - reset_profile() - configuration.CONFIG = None - - # Create a temporary folder, set it as the current config directory path and reset the loaded configuration - profile_name = 'test_profile_1234' - temporary_config_directory = tempfile.mkdtemp() - settings.AIIDA_CONFIG_FOLDER = temporary_config_directory - - # Create the instance base directory structure, the config file and a dummy profile - create_instance_directories() - - # The constructor of `Config` called by `load_config` will print warning messages about migrating it - with Capturing(): - configuration.CONFIG = configuration.load_config(create=True) - - profile = create_mock_profile(name=profile_name, repository_dirpath=temporary_config_directory) - - # Add the created profile and set it as the default - configuration.CONFIG.add_profile(profile) - configuration.CONFIG.set_default_profile(profile_name, overwrite=True) - configuration.CONFIG.store() - load_profile() - - yield configuration.CONFIG - finally: - # Reset the config folder path and the config instance - reset_profile() - settings.AIIDA_CONFIG_FOLDER = current_config_path - configuration.CONFIG = current_config - load_profile(current_profile_name) - - # Destroy the temporary instance directory - if temporary_config_directory and os.path.isdir(temporary_config_directory): - shutil.rmtree(temporary_config_directory) - - -def with_temporary_config_instance(function): - """Create a temporary AiiDA instance for the duration of the wrapped function.""" - - def decorated_function(*args, **kwargs): - with temporary_config_instance(): - function(*args, **kwargs) - - return decorated_function - - @contextlib.contextmanager def temporary_directory(): """Create a temporary directory.""" diff --git a/tests/utils/memory.py b/tests/utils/memory.py new file mode 100644 index 0000000000..9689108bfa --- /dev/null +++ b/tests/utils/memory.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Utilities for testing memory leakage.""" +import asyncio +from pympler import muppy + + +def get_instances(classes, delay=0.0): + """Return all instances of provided classes that are in memory. + + Useful for investigating memory leaks. + + :param classes: A class or tuple of classes to check (passed to `isinstance`). + :param delay: How long to sleep (seconds) before collecting the memory dump. + This is a convenience function for tests involving Processes. For example, :py:func:`~aiida.engine.run` returns + before all futures are resolved/cleaned up. Dumping memory too early would catch those and the references they + carry, although they may not actually be leaking memory. + """ + if delay > 0: + loop = asyncio.get_event_loop() + loop.run_until_complete(asyncio.sleep(delay)) + + all_objects = muppy.get_objects() # this also calls gc.collect() + return [o for o in all_objects if hasattr(o, '__class__') and isinstance(o, classes)] diff --git a/tests/workflows/arithmetic/test_add_multiply.py b/tests/workflows/arithmetic/test_add_multiply.py index ddbee359b1..b172930094 100644 --- a/tests/workflows/arithmetic/test_add_multiply.py +++ b/tests/workflows/arithmetic/test_add_multiply.py @@ -21,7 +21,8 @@ def test_factory(): assert loaded.is_process_function -@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.requires_rmq +@pytest.mark.usefixtures('clear_database_before_test', 'temporary_event_loop') def test_run(): """Test running the work function.""" x = Int(1) diff --git a/utils/dependency_management.py b/utils/dependency_management.py index 14f30c5280..4ff62257e4 100755 --- a/utils/dependency_management.py +++ b/utils/dependency_management.py @@ -188,10 +188,10 @@ def update_pyproject_toml(): # update the build-system key pyproject.setdefault('build-system', {}) pyproject['build-system'].update({ - 'requires': ['setuptools>=40.8.0,<50', 'wheel', + 'requires': ['setuptools>=40.8.0', 'wheel', str(reentry_requirement), 'fastentrypoints~=0.12'], 'build-backend': - 'setuptools.build_meta:__legacy__', + 'setuptools.build_meta', }) # write the new file @@ -247,7 +247,8 @@ def validate_environment_yml(): # pylint: disable=too-many-branches # The Python version should be specified as supported in 'setup.json'. if not any(spec.version >= other_spec.version for other_spec in python_requires.specifier): raise DependencySpecificationError( - "Required Python version between 'setup.json' and 'environment.yml' not consistent." + f"Required Python version {spec.version} from 'environment.yaml' is not consistent with " + + "required version in 'setup.json'." ) break