From 5c1eb3fadf257faa978ed6dee961f7d01ee21c48 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Sat, 21 May 2022 08:57:30 +0200 Subject: [PATCH] Engine: fix bug that allowed non-storable inputs to be passed to process (#5532) The basic assumption for a `Process` in `aiida-core` is that all of its inputs should be storable in the database as nodes. Under the current link model, this means that they should be instances of the `Data` class or subclasses thereof. There is a noticeable exception for ports that are explicitly marked as `non_db=True`, in which case the value is not linked as a node, but is stored as an attribute directly on the process node itself, or not stored whatsoever. This basic rule was never explicitly enforced, which made it possible to define processes that would happily take non-storable inputs. The input would not get stored in the database, but would be available within the processes lifetimes from the `inputs` property allowing it to be used. This will most likely result into unintentional loss of provenance. The reason is that the default `valid_type` of the top-level inputs namespace of the `Process` class was never being set to `Data`. This meant that any type would be excepted for a `Process` and all its subclasses unless the valid type of a port was explicitly overridden. This meant that for normal dynamic namespaces, even non-storable types would be accepted just fine. Setting `valid_type=Data` for the input namespace of the `Process` class fixes the problem therefore. --- aiida/engine/processes/process.py | 2 + tests/engine/processes/test_builder.py | 26 +- tests/engine/test_process.py | 21 + tests/engine/test_process_function.py | 809 +++++++++++++------------ 4 files changed, 460 insertions(+), 398 deletions(-) diff --git a/aiida/engine/processes/process.py b/aiida/engine/processes/process.py index cbe20755b7..e50ff6a342 100644 --- a/aiida/engine/processes/process.py +++ b/aiida/engine/processes/process.py @@ -112,6 +112,8 @@ def define(cls, spec: ProcessSpec) -> None: # type: ignore[override] default='CALL', help='The label to use for the `CALL` link if the process is called by another process.' ) + spec.inputs.valid_type = orm.Data + spec.inputs.dynamic = False # Settings a ``valid_type`` automatically makes it dynamic, so we reset it again spec.exit_code(1, 'ERROR_UNSPECIFIED', message='The process has failed with an unspecified error.') spec.exit_code(2, 'ERROR_LEGACY_FAILURE', message='The process failed with legacy failure mode.') spec.exit_code(10, 'ERROR_INVALID_OUTPUT', message='The process returned an invalid output.') diff --git a/tests/engine/processes/test_builder.py b/tests/engine/processes/test_builder.py index 817e255dba..1ac24c1d0a 100644 --- a/tests/engine/processes/test_builder.py +++ b/tests/engine/processes/test_builder.py @@ -45,11 +45,11 @@ class LazyProcessNamespace(Process): @classmethod def define(cls, spec): super().define(spec) - spec.input_namespace('namespace') - spec.input_namespace('namespace.nested') - spec.input('namespace.nested.bird') - spec.input('namespace.a') - spec.input('namespace.c') + spec.input_namespace('namespace', non_db=True) + spec.input_namespace('namespace.nested', non_db=True) + spec.input('namespace.nested.bird', non_db=True) + spec.input('namespace.a', non_db=True) + spec.input('namespace.c', non_db=True) class SimpleProcessNamespace(Process): @@ -58,9 +58,9 @@ class SimpleProcessNamespace(Process): @classmethod def define(cls, spec): super().define(spec) - spec.input_namespace('namespace.nested', dynamic=True) - spec.input('namespace.a', valid_type=int) - spec.input('namespace.c', valid_type=dict) + spec.input_namespace('namespace.nested', dynamic=True, non_db=True) + spec.input('namespace.a', valid_type=int, non_db=True) + spec.input('namespace.c', valid_type=dict, non_db=True) class NestedNamespaceProcess(Process): @@ -69,9 +69,9 @@ class NestedNamespaceProcess(Process): @classmethod def define(cls, spec): super().define(spec) - spec.input('nested.namespace.int', valid_type=int, required=True) - spec.input('nested.namespace.float', valid_type=float, required=True) - spec.input('nested.namespace.str', valid_type=str, required=False) + spec.input('nested.namespace.int', valid_type=int, required=True, non_db=True) + spec.input('nested.namespace.float', valid_type=float, required=True, non_db=True) + spec.input('nested.namespace.str', valid_type=str, required=False, non_db=True) class MappingData(Mapping, orm.Data): @@ -398,7 +398,7 @@ class ProcessOne(Process): @classmethod def define(cls, spec): super().define(spec) - spec.input('port', valid_type=int, default=1) + spec.input('port', valid_type=int, default=1, non_db=True) class ProcessTwo(Process): """Process with nested required ports to check the update functionality.""" @@ -406,7 +406,7 @@ class ProcessTwo(Process): @classmethod def define(cls, spec): super().define(spec) - spec.input('port', valid_type=int, default=2) + spec.input('port', valid_type=int, default=2, non_db=True) builder_one = ProcessOne.get_builder() assert builder_one.port == 1 diff --git a/tests/engine/test_process.py b/tests/engine/test_process.py index 708456976f..9d45f39cbb 100644 --- a/tests/engine/test_process.py +++ b/tests/engine/test_process.py @@ -433,3 +433,24 @@ def define(cls, spec): # If the ``namespace`` does not exist, for example because it is slightly misspelled, a ``KeyError`` is raised with pytest.raises(KeyError): process.exposed_outputs(node_child, ChildProcess, namespace='cildh') + + +class TestValidateDynamicNamespaceProcess(Process): + """Simple process with dynamic input namespace.""" + + _node_class = orm.WorkflowNode + + @classmethod + def define(cls, spec): + super().define(spec) + spec.inputs.dynamic = True + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_input_validation_storable_nodes(): + """Test that validation catches non-storable inputs even if nested in dictionary for dynamic namespace. + + Regression test for #5128. + """ + with pytest.raises(ValueError): + run(TestValidateDynamicNamespaceProcess, **{'namespace': {'a': 1}}) diff --git a/tests/engine/test_process_function.py b/tests/engine/test_process_function.py index 7dc40c54b6..519aedfe73 100644 --- a/tests/engine/test_process_function.py +++ b/tests/engine/test_process_function.py @@ -7,12 +7,19 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=no-self-use -"""Tests for the process_function decorator.""" +"""Tests for the process_function decorator. + +Note that here we use ``workfunction`` and ``calcfunction``, the concrete versions of the ``process_function`` decorator +even though we are testing only the shared functionality that is captured in the ``process_function`` decorator, +relating to the transformation of the wrapped function into a ``FunctionProcess``. The reason we do not use the +``process_function`` decorator itself, is because it does not have a node class by default. We could create one on the +fly, but then anytime inputs or outputs would be attached to it in the tests, the ``validate_link`` function would +complain as the dummy node class is not recognized as a valid process node. +""" import pytest from aiida import orm -from aiida.engine import ExitCode, Process, calcfunction, run, run_get_node, submit, workfunction +from aiida.engine import ExitCode, calcfunction, run, run_get_node, submit, workfunction from aiida.orm.nodes.data.bool import get_true_node from aiida.workflows.arithmetic.add_multiply import add_multiply @@ -22,404 +29,436 @@ CUSTOM_LABEL = 'Custom label' CUSTOM_DESCRIPTION = 'Custom description' +pytest.mark.requires_rmq # pylint: disable=pointless-statement -@pytest.mark.requires_rmq -class TestProcessFunction: + +@workfunction +def function_return_input(data): + return data + + +@calcfunction +def function_return_true(): + return get_true_node() + + +@workfunction +def function_args(data_a): + return data_a + + +@workfunction +def function_args_with_default(data_a=lambda: orm.Int(DEFAULT_INT)): + return data_a + + +@calcfunction +def function_with_none_default(int_a, int_b, int_c=None): + if int_c is not None: + return orm.Int(int_a + int_b + int_c) + return orm.Int(int_a + int_b) + + +@workfunction +def function_kwargs(**kwargs): + return kwargs + + +@workfunction +def function_args_and_kwargs(data_a, **kwargs): + result = {'data_a': data_a} + result.update(kwargs) + return result + + +@workfunction +def function_args_and_default(data_a, data_b=lambda: orm.Int(DEFAULT_INT)): + return {'data_a': data_a, 'data_b': data_b} + + +@workfunction +def function_defaults( + data_a=lambda: orm.Int(DEFAULT_INT), metadata={ + 'label': DEFAULT_LABEL, + 'description': DEFAULT_DESCRIPTION + } +): # pylint: disable=unused-argument,dangerous-default-value,missing-docstring + return data_a + + +@workfunction +def function_default_label(): + return + + +@workfunction +def function_exit_code(exit_status, exit_message): + return ExitCode(exit_status.value, exit_message.value) + + +@workfunction +def function_excepts(exception): + raise RuntimeError(exception.value) + + +@workfunction +def function_out_unstored(): + return orm.Int(DEFAULT_INT) + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_properties(): + """Test that the `is_process_function` and `node_class` attributes are set.""" + assert function_return_input.is_process_function + assert function_return_input.node_class == orm.WorkFunctionNode + assert function_return_true.is_process_function + assert function_return_true.node_class == orm.CalcFunctionNode + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_plugin_version(): + """Test the version attributes of a process function.""" + from aiida import __version__ as version_core + + _, node = function_args_with_default.run_get_node() + + # Since the "plugin" i.e. the process function is defined in `aiida-core` the `version.plugin` is the same as + # the version of `aiida-core` itself + version_info = node.get_attribute('version') + assert version_info['core'] == version_core + assert version_info['plugin'] == version_core + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_process_state(): + """Test the process state for a process function.""" + _, node = function_args_with_default.run_get_node() + + assert node.is_terminated + assert not node.is_excepted + assert not node.is_killed + assert node.is_finished + assert node.is_finished_ok + assert not node.is_failed + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_process_type(): + """Test that the process type correctly contains the module and name of original decorated function.""" + _, node = function_defaults.run_get_node() + process_type = f'{function_defaults.__module__}.{function_defaults.__name__}' + assert node.process_type == process_type + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_exit_status(): + """A FINISHED process function has to have an exit status of 0""" + _, node = function_args_with_default.run_get_node() + assert node.exit_status == 0 + assert node.is_finished_ok + assert not node.is_failed + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_source_code_attributes(): + """Verify function properties are properly introspected and stored in the nodes attributes and repository.""" + function_name = 'test_process_function' + + @calcfunction + def test_process_function(data): + return {'result': orm.Int(data.value + 1)} + + _, node = test_process_function.run_get_node(data=orm.Int(5)) + + # Read the source file of the calculation function that should be stored in the repository + function_source_code = node.get_function_source_code().split('\n') + + # Verify that the function name is correct and the first source code linenumber is stored + assert node.function_name == function_name + assert isinstance(node.function_starting_line_number, int) + + # Check that first line number is correct. Note that the first line should correspond + # to the `@workfunction` directive, but since the list is zero-indexed we actually get the + # following line, which should correspond to the function name i.e. `def test_process_function(data)` + function_name_from_source = function_source_code[node.function_starting_line_number] + assert node.function_name in function_name_from_source + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_function_varargs(): + """Variadic arguments are not supported and should raise.""" + with pytest.raises(ValueError): + + @workfunction + def function_varargs(*args): # pylint: disable=unused-variable + return args + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_function_args(): + """Simple process function that defines a single positional argument.""" + arg = 1 + + with pytest.raises(ValueError): + result = function_args() # pylint: disable=no-value-for-parameter + + result = function_args(data_a=orm.Int(arg)) + assert isinstance(result, orm.Int) + assert result == arg + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_function_args_with_default(): + """Simple process function that defines a single argument with a default.""" + arg = 1 + + result = function_args_with_default() + assert isinstance(result, orm.Int) + assert result == orm.Int(DEFAULT_INT) + + result = function_args_with_default(data_a=orm.Int(arg)) + assert isinstance(result, orm.Int) + assert result == arg + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_function_with_none_default(): + """Simple process function that defines a keyword with `None` as default value.""" + int_a = orm.Int(1) + int_b = orm.Int(2) + int_c = orm.Int(3) + + result = function_with_none_default(int_a, int_b) + assert isinstance(result, orm.Int) + assert result == orm.Int(3) + + result = function_with_none_default(int_a, int_b, int_c) + assert isinstance(result, orm.Int) + assert result == orm.Int(6) + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_function_kwargs(): + """Simple process function that defines keyword arguments.""" + kwargs = {'data_a': orm.Int(DEFAULT_INT)} + + result, node = function_kwargs.run_get_node() + assert isinstance(result, dict) + assert len(node.get_incoming().all()) == 0 + assert result == {} + + result, node = function_kwargs.run_get_node(**kwargs) + assert isinstance(result, dict) + assert len(node.get_incoming().all()) == 1 + assert result == kwargs + + # Calling with any number of positional arguments should raise + with pytest.raises(TypeError): + function_kwargs.run_get_node(orm.Int(1)) + + with pytest.raises(TypeError): + function_kwargs.run_get_node(orm.Int(1), b=orm.Int(2)) + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_function_args_and_kwargs(): + """Simple process function that defines a positional argument and keyword arguments.""" + arg = 1 + args = (orm.Int(DEFAULT_INT),) + kwargs = {'data_b': orm.Int(arg)} + + result = function_args_and_kwargs(*args) + assert isinstance(result, dict) + assert result == {'data_a': args[0]} + + result = function_args_and_kwargs(*args, **kwargs) + assert isinstance(result, dict) + assert result == {'data_a': args[0], 'data_b': kwargs['data_b']} + + # Calling with more positional arguments than defined in the signature should raise + with pytest.raises(TypeError): + function_kwargs.run_get_node(orm.Int(1), orm.Int(2)) + + with pytest.raises(TypeError): + function_kwargs.run_get_node(orm.Int(1), orm.Int(2), b=orm.Int(2)) + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_function_args_and_kwargs_default(): + """Simple process function that defines a positional argument and an argument with a default.""" + arg = 1 + args_input_default = (orm.Int(DEFAULT_INT),) + args_input_explicit = (orm.Int(DEFAULT_INT), orm.Int(arg)) + + result = function_args_and_default(*args_input_default) + assert isinstance(result, dict) + assert result == {'data_a': args_input_default[0], 'data_b': orm.Int(DEFAULT_INT)} + + result = function_args_and_default(*args_input_explicit) + assert isinstance(result, dict) + assert result == {'data_a': args_input_explicit[0], 'data_b': args_input_explicit[1]} + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_function_args_passing_kwargs(): + """Cannot pass kwargs if the function does not explicitly define it accepts kwargs.""" + with pytest.raises(ValueError): + function_args(data_a=orm.Int(1), data_b=orm.Int(1)) # pylint: disable=unexpected-keyword-arg + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_function_set_label_description(): + """Verify that the label and description can be set for all process function variants.""" + metadata = {'label': CUSTOM_LABEL, 'description': CUSTOM_DESCRIPTION} + + _, node = function_args.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata) + assert node.label == CUSTOM_LABEL + assert node.description == CUSTOM_DESCRIPTION + + _, node = function_args_with_default.run_get_node(metadata=metadata) + assert node.label == CUSTOM_LABEL + assert node.description == CUSTOM_DESCRIPTION + + _, node = function_kwargs.run_get_node(metadata=metadata) + assert node.label == CUSTOM_LABEL + assert node.description == CUSTOM_DESCRIPTION + + _, node = function_args_and_kwargs.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata) + assert node.label == CUSTOM_LABEL + assert node.description == CUSTOM_DESCRIPTION + + _, node = function_args_and_default.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata) + assert node.label == CUSTOM_LABEL + assert node.description == CUSTOM_DESCRIPTION + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_function_defaults(): + """Verify that a process function can define a default label and description but can be overriden.""" + metadata = {'label': CUSTOM_LABEL, 'description': CUSTOM_DESCRIPTION} + + _, node = function_defaults.run_get_node(data_a=orm.Int(DEFAULT_INT)) + assert node.label == DEFAULT_LABEL + assert node.description == DEFAULT_DESCRIPTION + + _, node = function_defaults.run_get_node(metadata=metadata) + assert node.label == CUSTOM_LABEL + assert node.description == CUSTOM_DESCRIPTION + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_function_default_label(): + """Verify unless specified label is taken from function name.""" + metadata = {'label': CUSTOM_LABEL, 'description': CUSTOM_DESCRIPTION} + + _, node = function_default_label.run_get_node() + assert node.label == 'function_default_label' + assert node.description == '' + + _, node = function_default_label.run_get_node(metadata=metadata) + assert node.label == CUSTOM_LABEL + assert node.description == CUSTOM_DESCRIPTION + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_launchers(): + """Verify that the various launchers are working.""" + result = run(function_return_true) + assert result + + result, node = run_get_node(function_return_true) + assert result + assert result == get_true_node() + assert isinstance(node, orm.CalcFunctionNode) + + # Process function can be submitted and will be run by a daemon worker as long as the function is importable + # Note that the actual running is not tested here but is done so in `.github/system_tests/test_daemon.py`. + node = submit(add_multiply, x=orm.Int(1), y=orm.Int(2), z=orm.Int(3)) + assert isinstance(node, orm.WorkFunctionNode) + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_return_exit_code(): """ - Note that here we use `@workfunctions` and `@calculations`, the concrete versions of the - `@process_function` decorator, even though we are testing only the shared functionality - that is captured in the `@process_function` decorator, relating to the transformation - of the wrapped function into a `FunctionProcess`. - The reason we do not use the `@process_function` decorator itself, is because it - does not have a node class by default. We could create one on the fly, but then - anytime inputs or outputs would be attached to it in the tests, the `validate_link` - function would complain as the dummy node class is not recognized as a valid process node. + A process function that returns an ExitCode namedtuple should have its exit status and message set FINISHED """ + exit_status = 418 + exit_message = 'I am a teapot' - # pylint: disable=too-many-public-methods,too-many-instance-attributes + message = orm.Str(exit_message) + _, node = function_exit_code.run_get_node(exit_status=orm.Int(exit_status), exit_message=message) - @pytest.fixture(autouse=True) - def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument - """Initialize the profile.""" - # pylint: disable=attribute-defined-outside-init - assert Process.current() is None + assert node.is_finished + assert not node.is_finished_ok + assert node.exit_status == exit_status + assert node.exit_message == exit_message - @workfunction - def function_return_input(data): - return data - @calcfunction - def function_return_true(): - return get_true_node() +@pytest.mark.usefixtures('clear_database_before_test') +def test_normal_exception(): + """If a process, for example a FunctionProcess, excepts, the exception should be stored in the node.""" + exception = 'This process function excepted' - @workfunction - def function_args(data_a): - return data_a + with pytest.raises(RuntimeError): + _, node = function_excepts.run_get_node(exception=orm.Str(exception)) + assert node.is_excepted + assert node.exception == exception - @workfunction - def function_args_with_default(data_a=lambda: orm.Int(DEFAULT_INT)): - return data_a - @calcfunction - def function_with_none_default(int_a, int_b, int_c=None): - if int_c is not None: - return orm.Int(int_a + int_b + int_c) - return orm.Int(int_a + int_b) +@pytest.mark.usefixtures('clear_database_before_test') +def test_function_out_unstored(): + """A workfunction that returns an unstored node should raise as it indicates users tried to create data.""" + with pytest.raises(ValueError): + function_out_unstored() - @workfunction - def function_kwargs(**kwargs): - return kwargs - @workfunction - def function_args_and_kwargs(data_a, **kwargs): - result = {'data_a': data_a} - result.update(kwargs) - return result +@pytest.mark.usefixtures('clear_database_before_test') +def test_simple_workflow(): + """Test construction of simple workflow by chaining process functions.""" - @workfunction - def function_args_and_default(data_a, data_b=lambda: orm.Int(DEFAULT_INT)): - return {'data_a': data_a, 'data_b': data_b} + @calcfunction + def add(data_a, data_b): + return data_a + data_b - @workfunction - def function_defaults( - data_a=lambda: orm.Int(DEFAULT_INT), metadata={ - 'label': DEFAULT_LABEL, - 'description': DEFAULT_DESCRIPTION - } - ): # pylint: disable=unused-argument,dangerous-default-value,missing-docstring - return data_a + @calcfunction + def mul(data_a, data_b): + return data_a * data_b - @workfunction - def function_default_label(): - return + @workfunction + def add_mul_wf(data_a, data_b, data_c): + return mul(add(data_a, data_b), data_c) - @workfunction - def function_exit_code(exit_status, exit_message): - return ExitCode(exit_status.value, exit_message.value) + result, node = add_mul_wf.run_get_node(orm.Int(3), orm.Int(4), orm.Int(5)) - @workfunction - def function_excepts(exception): - raise RuntimeError(exception.value) + assert result == (3 + 4) * 5 + assert isinstance(node, orm.WorkFunctionNode) - @workfunction - def function_out_unstored(): - return orm.Int(DEFAULT_INT) - - self.function_return_input = function_return_input - self.function_return_true = function_return_true - self.function_args = function_args - self.function_args_with_default = function_args_with_default - self.function_with_none_default = function_with_none_default - self.function_kwargs = function_kwargs - self.function_args_and_kwargs = function_args_and_kwargs - self.function_args_and_default = function_args_and_default - self.function_defaults = function_defaults - self.function_default_label = function_default_label - self.function_exit_code = function_exit_code - self.function_excepts = function_excepts - self.function_out_unstored = function_out_unstored - - yield - assert Process.current() is None - - def test_properties(self): - """Test that the `is_process_function` and `node_class` attributes are set.""" - assert self.function_return_input.is_process_function is True - assert self.function_return_input.node_class == orm.WorkFunctionNode - assert self.function_return_true.is_process_function is True - assert self.function_return_true.node_class == orm.CalcFunctionNode - - def test_plugin_version(self): - """Test the version attributes of a process function.""" - from aiida import __version__ as version_core - - _, node = self.function_args_with_default.run_get_node() - - # Since the "plugin" i.e. the process function is defined in `aiida-core` the `version.plugin` is the same as - # the version of `aiida-core` itself - version_info = node.base.attributes.get('version') - assert version_info['core'] == version_core - assert version_info['plugin'] == version_core - - def test_process_state(self): - """Test the process state for a process function.""" - _, node = self.function_args_with_default.run_get_node() - - assert node.is_terminated is True - assert node.is_excepted is False - assert node.is_killed is False - assert node.is_finished is True - assert node.is_finished_ok is True - assert node.is_failed is False - - def test_process_type(self): - """Test that the process type correctly contains the module and name of original decorated function.""" - _, node = self.function_defaults.run_get_node() - process_type = f'{self.function_defaults.__module__}.{self.function_defaults.__name__}' - assert node.process_type == process_type - - def test_exit_status(self): - """A FINISHED process function has to have an exit status of 0""" - _, node = self.function_args_with_default.run_get_node() - assert node.exit_status == 0 - assert node.is_finished_ok is True - assert node.is_failed is False - - def test_source_code_attributes(self): - """Verify function properties are properly introspected and stored in the nodes attributes and repository.""" - function_name = 'test_process_function' - - @calcfunction - def test_process_function(data): - return {'result': orm.Int(data.value + 1)} - - _, node = test_process_function.run_get_node(data=orm.Int(5)) - - # Read the source file of the calculation function that should be stored in the repository - function_source_code = node.get_function_source_code().split('\n') - - # Verify that the function name is correct and the first source code linenumber is stored - assert node.function_name == function_name - assert isinstance(node.function_starting_line_number, int) - - # Check that first line number is correct. Note that the first line should correspond - # to the `@workfunction` directive, but since the list is zero-indexed we actually get the - # following line, which should correspond to the function name i.e. `def test_process_function(data)` - function_name_from_source = function_source_code[node.function_starting_line_number] - assert node.function_name in function_name_from_source - - def test_function_varargs(self): - """Variadic arguments are not supported and should raise.""" - with pytest.raises(ValueError): - - @workfunction - def function_varargs(*args): # pylint: disable=unused-variable - return args - - def test_function_args(self): - """Simple process function that defines a single positional argument.""" - arg = 1 - - with pytest.raises(ValueError): - result = self.function_args() # pylint: disable=no-value-for-parameter - - result = self.function_args(data_a=orm.Int(arg)) - assert isinstance(result, orm.Int) - assert result == arg - - def test_function_args_with_default(self): - """Simple process function that defines a single argument with a default.""" - arg = 1 - - result = self.function_args_with_default() - assert isinstance(result, orm.Int) - assert result == orm.Int(DEFAULT_INT) - - result = self.function_args_with_default(data_a=orm.Int(arg)) - assert isinstance(result, orm.Int) - assert result == arg - - def test_function_with_none_default(self): - """Simple process function that defines a keyword with `None` as default value.""" - int_a = orm.Int(1) - int_b = orm.Int(2) - int_c = orm.Int(3) - - result = self.function_with_none_default(int_a, int_b) - assert isinstance(result, orm.Int) - assert result == orm.Int(3) - - result = self.function_with_none_default(int_a, int_b, int_c) - assert isinstance(result, orm.Int) - assert result == orm.Int(6) - - def test_function_kwargs(self): - """Simple process function that defines keyword arguments.""" - kwargs = {'data_a': orm.Int(DEFAULT_INT)} - - result, node = self.function_kwargs.run_get_node() - assert isinstance(result, dict) - assert len(node.base.links.get_incoming().all()) == 0 - assert result == {} - - result, node = self.function_kwargs.run_get_node(**kwargs) - assert isinstance(result, dict) - assert len(node.base.links.get_incoming().all()) == 1 - assert result == kwargs - - # Calling with any number of positional arguments should raise - with pytest.raises(TypeError): - self.function_kwargs.run_get_node(orm.Int(1)) - - with pytest.raises(TypeError): - self.function_kwargs.run_get_node(orm.Int(1), b=orm.Int(2)) - - def test_function_args_and_kwargs(self): - """Simple process function that defines a positional argument and keyword arguments.""" - arg = 1 - args = (orm.Int(DEFAULT_INT),) - kwargs = {'data_b': orm.Int(arg)} - - result = self.function_args_and_kwargs(*args) - assert isinstance(result, dict) - assert result == {'data_a': args[0]} - - result = self.function_args_and_kwargs(*args, **kwargs) - assert isinstance(result, dict) - assert result == {'data_a': args[0], 'data_b': kwargs['data_b']} - - # Calling with more positional arguments than defined in the signature should raise - with pytest.raises(TypeError): - self.function_kwargs.run_get_node(orm.Int(1), orm.Int(2)) - - with pytest.raises(TypeError): - self.function_kwargs.run_get_node(orm.Int(1), orm.Int(2), b=orm.Int(2)) - - def test_function_args_and_kwargs_default(self): - """Simple process function that defines a positional argument and an argument with a default.""" - arg = 1 - args_input_default = (orm.Int(DEFAULT_INT),) - args_input_explicit = (orm.Int(DEFAULT_INT), orm.Int(arg)) - - result = self.function_args_and_default(*args_input_default) - assert isinstance(result, dict) - assert result == {'data_a': args_input_default[0], 'data_b': orm.Int(DEFAULT_INT)} - - result = self.function_args_and_default(*args_input_explicit) - assert isinstance(result, dict) - assert result == {'data_a': args_input_explicit[0], 'data_b': args_input_explicit[1]} - - def test_function_args_passing_kwargs(self): - """Cannot pass kwargs if the function does not explicitly define it accepts kwargs.""" - arg = 1 - - with pytest.raises(ValueError): - self.function_args(data_a=orm.Int(arg), data_b=orm.Int(arg)) # pylint: disable=unexpected-keyword-arg - - def test_function_set_label_description(self): - """Verify that the label and description can be set for all process function variants.""" - metadata = {'label': CUSTOM_LABEL, 'description': CUSTOM_DESCRIPTION} - - _, node = self.function_args.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata) - assert node.label == CUSTOM_LABEL - assert node.description == CUSTOM_DESCRIPTION - - _, node = self.function_args_with_default.run_get_node(metadata=metadata) - assert node.label == CUSTOM_LABEL - assert node.description == CUSTOM_DESCRIPTION - - _, node = self.function_kwargs.run_get_node(metadata=metadata) - assert node.label == CUSTOM_LABEL - assert node.description == CUSTOM_DESCRIPTION - - _, node = self.function_args_and_kwargs.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata) - assert node.label == CUSTOM_LABEL - assert node.description == CUSTOM_DESCRIPTION - - _, node = self.function_args_and_default.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata) - assert node.label == CUSTOM_LABEL - assert node.description == CUSTOM_DESCRIPTION - - def test_function_defaults(self): - """Verify that a process function can define a default label and description but can be overriden.""" - metadata = {'label': CUSTOM_LABEL, 'description': CUSTOM_DESCRIPTION} - - _, node = self.function_defaults.run_get_node(data_a=orm.Int(DEFAULT_INT)) - assert node.label == DEFAULT_LABEL - assert node.description == DEFAULT_DESCRIPTION - - _, node = self.function_defaults.run_get_node(metadata=metadata) - assert node.label == CUSTOM_LABEL - assert node.description == CUSTOM_DESCRIPTION - - def test_function_default_label(self): - """Verify unless specified label is taken from function name.""" - metadata = {'label': CUSTOM_LABEL, 'description': CUSTOM_DESCRIPTION} - - _, node = self.function_default_label.run_get_node() - assert node.label == 'function_default_label' - assert node.description == '' - - _, node = self.function_default_label.run_get_node(metadata=metadata) - assert node.label == CUSTOM_LABEL - assert node.description == CUSTOM_DESCRIPTION - - def test_launchers(self): - """Verify that the various launchers are working.""" - result = run(self.function_return_true) - assert result - - result, node = run_get_node(self.function_return_true) - assert result - assert result == get_true_node() - assert isinstance(node, orm.CalcFunctionNode) - - # Process function can be submitted and will be run by a daemon worker as long as the function is importable - # Note that the actual running is not tested here but is done so in `.github/system_tests/test_daemon.py`. - node = submit(add_multiply, x=orm.Int(1), y=orm.Int(2), z=orm.Int(3)) - assert isinstance(node, orm.WorkFunctionNode) - - def test_return_exit_code(self): - """ - A process function that returns an ExitCode namedtuple should have its exit status and message set FINISHED - """ - exit_status = 418 - exit_message = 'I am a teapot' - - message = orm.Str(exit_message) - _, node = self.function_exit_code.run_get_node(exit_status=orm.Int(exit_status), exit_message=message) - - assert node.is_finished - assert not node.is_finished_ok - assert node.exit_status == exit_status - assert node.exit_message == exit_message - - def test_normal_exception(self): - """If a process, for example a FunctionProcess, excepts, the exception should be stored in the node.""" - exception = 'This process function excepted' - - with pytest.raises(RuntimeError): - _, node = self.function_excepts.run_get_node(exception=orm.Str(exception)) - assert node.is_excepted - assert node.exception == exception - - def test_function_out_unstored(self): - """A workfunction that returns an unstored node should raise as it indicates users tried to create data.""" - with pytest.raises(ValueError): - self.function_out_unstored() - - def test_simple_workflow(self): - """Test construction of simple workflow by chaining process functions.""" - - @calcfunction - def add(data_a, data_b): - return data_a + data_b - - @calcfunction - def mul(data_a, data_b): - return data_a * data_b - @workfunction - def add_mul_wf(data_a, data_b, data_c): - return mul(add(data_a, data_b), data_c) - - result, node = add_mul_wf.run_get_node(orm.Int(3), orm.Int(4), orm.Int(5)) - - assert result == (3 + 4) * 5 - assert isinstance(node, orm.WorkFunctionNode) - - def test_hashes(self): - """Test that the hashes generated for identical process functions with identical inputs are the same.""" - _, node1 = self.function_return_input.run_get_node(data=orm.Int(2)) - _, node2 = self.function_return_input.run_get_node(data=orm.Int(2)) - assert node1.base.caching.get_hash() == node1.base.extras.get('_aiida_hash') - assert node2.base.caching.get_hash() == node2.base.extras.get('_aiida_hash') - assert node1.base.caching.get_hash() == node2.base.caching.get_hash() - - def test_hashes_different(self): - """Test that the hashes generated for identical process functions with different inputs are the different.""" - _, node1 = self.function_return_input.run_get_node(data=orm.Int(2)) - _, node2 = self.function_return_input.run_get_node(data=orm.Int(3)) - assert node1.base.caching.get_hash() == node1.base.extras.get('_aiida_hash') - assert node2.base.caching.get_hash() == node2.base.extras.get('_aiida_hash') - assert node1.base.caching.get_hash() != node2.base.caching.get_hash() +@pytest.mark.usefixtures('clear_database_before_test') +def test_hashes(): + """Test that the hashes generated for identical process functions with identical inputs are the same.""" + _, node1 = function_return_input.run_get_node(data=orm.Int(2)) + _, node2 = function_return_input.run_get_node(data=orm.Int(2)) + assert node1.get_hash() == node1.get_extra('_aiida_hash') + assert node2.get_hash() == node2.get_extra('_aiida_hash') + assert node1.get_hash() == node2.get_hash() + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_hashes_different(): + """Test that the hashes generated for identical process functions with different inputs are the different.""" + _, node1 = function_return_input.run_get_node(data=orm.Int(2)) + _, node2 = function_return_input.run_get_node(data=orm.Int(3)) + assert node1.get_hash() == node1.get_extra('_aiida_hash') + assert node2.get_hash() == node2.get_extra('_aiida_hash') + assert node1.get_hash() != node2.get_hash() + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_input_validation(): + """Test that process functions do not allow non-storable inputs, even when hidden in nested namespaces. + + Regression test for #5128. + """ + with pytest.raises(ValueError): + function_kwargs.run_get_node(**{'namespace': {'valid': orm.Int(1), 'invalid': 1}})