diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d5a2cc62..909ef5253 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Added `TypeConfigurator` to automatically wrap fields with `TermSetWrapper` according to a configuration file. @mavaylon1 [#1016](https://github.com/hdmf-dev/hdmf/pull/1016) - Updated `TermSetWrapper` to support validating a single field within a compound array. @mavaylon1 [#1061](https://github.com/hdmf-dev/hdmf/pull/1061) - Updated testing to not install in editable mode and not run `coverage` by default. @rly [#1107](https://github.com/hdmf-dev/hdmf/pull/1107) +- Add `post_init_method` parameter when generating classes to perform post-init functionality, i.e., validation. @mavaylon1 [#1089](https://github.com/hdmf-dev/hdmf/pull/1089) ## HDMF 3.13.0 (March 20, 2024) diff --git a/src/hdmf/build/classgenerator.py b/src/hdmf/build/classgenerator.py index d2e7d4fc0..a3336b98e 100644 --- a/src/hdmf/build/classgenerator.py +++ b/src/hdmf/build/classgenerator.py @@ -1,5 +1,6 @@ from copy import deepcopy from datetime import datetime, date +from collections.abc import Callable import numpy as np @@ -35,6 +36,8 @@ def register_generator(self, **kwargs): {'name': 'spec', 'type': BaseStorageSpec, 'doc': ''}, {'name': 'parent_cls', 'type': type, 'doc': ''}, {'name': 'attr_names', 'type': dict, 'doc': ''}, + {'name': 'post_init_method', 'type': Callable, 'default': None, + 'doc': 'The function used as a post_init method to validate the class generation.'}, {'name': 'type_map', 'type': 'hdmf.build.manager.TypeMap', 'doc': ''}, returns='the class for the given namespace and data_type', rtype=type) def generate_class(self, **kwargs): @@ -42,8 +45,10 @@ def generate_class(self, **kwargs): If no class has been associated with the ``data_type`` from ``namespace``, a class will be dynamically created and returned. """ - data_type, spec, parent_cls, attr_names, type_map = getargs('data_type', 'spec', 'parent_cls', 'attr_names', - 'type_map', kwargs) + data_type, spec, parent_cls, attr_names, type_map, post_init_method = getargs('data_type', 'spec', + 'parent_cls', 'attr_names', + 'type_map', + 'post_init_method', kwargs) not_inherited_fields = dict() for k, field_spec in attr_names.items(): @@ -82,6 +87,8 @@ def generate_class(self, **kwargs): + str(e) + " Please define that type before defining '%s'." % name) cls = ExtenderMeta(data_type, tuple(bases), classdict) + cls.post_init_method = post_init_method + return cls @@ -316,8 +323,19 @@ def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name): elif attr_name not in attrs_not_to_set: attrs_to_set.append(attr_name) - @docval(*docval_args, allow_positional=AllowPositional.WARNING) + # We want to use the skip_post_init of the current class and not the parent class + for item in docval_args: + if item['name'] == 'skip_post_init': + docval_args.remove(item) + + @docval(*docval_args, + {'name': 'skip_post_init', 'type': bool, 'default': False, + 'doc': 'bool to skip post_init'}, + allow_positional=AllowPositional.WARNING) def __init__(self, **kwargs): + skip_post_init = popargs('skip_post_init', kwargs) + + original_kwargs = dict(kwargs) if name is not None: # force container name to be the fixed name in the spec kwargs.update(name=name) @@ -343,6 +361,9 @@ def __init__(self, **kwargs): for f in fixed_value_attrs_to_set: self.fields[f] = getattr(not_inherited_fields[f], 'value') + if self.post_init_method is not None and not skip_post_init: + self.post_init_method(**original_kwargs) + classdict['__init__'] = __init__ @@ -417,6 +438,7 @@ def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name): def __init__(self, **kwargs): # store the values passed to init for each MCI attribute so that they can be added # after calling __init__ + original_kwargs = dict(kwargs) new_kwargs = list() for field_clsconf in classdict['__clsconf__']: attr_name = field_clsconf['attr'] @@ -437,6 +459,7 @@ def __init__(self, **kwargs): kwargs[attr_name] = list() # call the parent class init without the MCI attribute + kwargs['skip_post_init'] = True previous_init(self, **kwargs) # call the add method for each MCI attribute @@ -444,5 +467,8 @@ def __init__(self, **kwargs): add_method = getattr(self, new_kwarg['add_method_name']) add_method(new_kwarg['value']) + if self.post_init_method is not None: + self.post_init_method(**original_kwargs) + # override __init__ classdict['__init__'] = __init__ diff --git a/src/hdmf/build/manager.py b/src/hdmf/build/manager.py index a26de3279..25b9b81bd 100644 --- a/src/hdmf/build/manager.py +++ b/src/hdmf/build/manager.py @@ -1,6 +1,7 @@ import logging from collections import OrderedDict, deque from copy import copy +from collections.abc import Callable from .builders import DatasetBuilder, GroupBuilder, LinkBuilder, Builder, BaseBuilder from .classgenerator import ClassGenerator, CustomClassGenerator, MCIClassGenerator @@ -498,11 +499,14 @@ def get_container_cls(self, **kwargs): created and returned. """ # NOTE: this internally used function get_container_cls will be removed in favor of get_dt_container_cls + # Deprecated: Will be removed by HDMF 4.0 namespace, data_type, autogen = getargs('namespace', 'data_type', 'autogen', kwargs) return self.get_dt_container_cls(data_type, namespace, autogen) @docval({"name": "data_type", "type": str, "doc": "the data type to create a AbstractContainer class for"}, {"name": "namespace", "type": str, "doc": "the namespace containing the data_type", "default": None}, + {'name': 'post_init_method', 'type': Callable, 'default': None, + 'doc': 'The function used as a post_init method to validate the class generation.'}, {"name": "autogen", "type": bool, "doc": "autogenerate class if one does not exist", "default": True}, returns='the class for the given namespace and data_type', rtype=type) def get_dt_container_cls(self, **kwargs): @@ -513,7 +517,8 @@ def get_dt_container_cls(self, **kwargs): Replaces get_container_cls but namespace is optional. If namespace is unknown, it will be looked up from all namespaces. """ - namespace, data_type, autogen = getargs('namespace', 'data_type', 'autogen', kwargs) + namespace, data_type, post_init_method, autogen = getargs('namespace', 'data_type', + 'post_init_method','autogen', kwargs) # namespace is unknown, so look it up if namespace is None: @@ -527,12 +532,18 @@ def get_dt_container_cls(self, **kwargs): raise ValueError("Namespace could not be resolved.") cls = self.__get_container_cls(namespace, data_type) + if cls is None and autogen: # dynamically generate a class spec = self.__ns_catalog.get_spec(namespace, data_type) self.__check_dependent_types(spec, namespace) parent_cls = self.__get_parent_cls(namespace, data_type, spec) attr_names = self.__default_mapper_cls.get_attr_names(spec) - cls = self.__class_generator.generate_class(data_type, spec, parent_cls, attr_names, self) + cls = self.__class_generator.generate_class(data_type=data_type, + spec=spec, + parent_cls=parent_cls, + attr_names=attr_names, + post_init_method=post_init_method, + type_map=self) self.register_container_type(namespace, data_type, cls) return cls diff --git a/src/hdmf/common/__init__.py b/src/hdmf/common/__init__.py index 248ca1095..4d724d1d1 100644 --- a/src/hdmf/common/__init__.py +++ b/src/hdmf/common/__init__.py @@ -3,6 +3,7 @@ ''' import os.path from copy import deepcopy +from collections.abc import Callable CORE_NAMESPACE = 'hdmf-common' EXP_NAMESPACE = 'hdmf-experimental' @@ -136,12 +137,14 @@ def available_namespaces(): @docval({'name': 'data_type', 'type': str, 'doc': 'the data_type to get the Container class for'}, {'name': 'namespace', 'type': str, 'doc': 'the namespace the data_type is defined in'}, + {'name': 'post_init_method', 'type': Callable, 'default': None, + 'doc': 'The function used as a post_init method to validate the class generation.'}, is_method=False) def get_class(**kwargs): """Get the class object of the Container subclass corresponding to a given neurdata_type. """ - data_type, namespace = getargs('data_type', 'namespace', kwargs) - return __TYPE_MAP.get_dt_container_cls(data_type, namespace) + data_type, namespace, post_init_method = getargs('data_type', 'namespace', 'post_init_method', kwargs) + return __TYPE_MAP.get_dt_container_cls(data_type, namespace, post_init_method) @docval({'name': 'extensions', 'type': (str, TypeMap, list), diff --git a/tests/unit/build_tests/test_classgenerator.py b/tests/unit/build_tests/test_classgenerator.py index 0c117820b..52fdc4839 100644 --- a/tests/unit/build_tests/test_classgenerator.py +++ b/tests/unit/build_tests/test_classgenerator.py @@ -2,6 +2,7 @@ import os import shutil import tempfile +from warnings import warn from hdmf.build import TypeMap, CustomClassGenerator from hdmf.build.classgenerator import ClassGenerator, MCIClassGenerator @@ -82,6 +83,79 @@ def test_no_generators(self): self.assertTrue(hasattr(cls, '__init__')) +class TestPostInitGetClass(TestCase): + def setUp(self): + def post_init_method(self, **kwargs): + attr1 = kwargs['attr1'] + if attr1<10: + msg = "attr1 should be >=10" + warn(msg) + self.post_init=post_init_method + + def test_post_init(self): + spec = GroupSpec( + doc='A test group specification with a data type', + data_type_def='Baz', + attributes=[ + AttributeSpec(name='attr1', doc='a int attribute', dtype='int') + ] + ) + + spec_catalog = SpecCatalog() + spec_catalog.register_spec(spec, 'test.yaml') + namespace = SpecNamespace( + doc='a test namespace', + name=CORE_NAMESPACE, + schema=[{'source': 'test.yaml'}], + version='0.1.0', + catalog=spec_catalog + ) + namespace_catalog = NamespaceCatalog() + namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) + type_map = TypeMap(namespace_catalog) + + cls = type_map.get_dt_container_cls('Baz', CORE_NAMESPACE, self.post_init) + + with self.assertWarns(Warning): + cls(name='instance', attr1=9) + + def test_multi_container_post_init(self): + bar_spec = GroupSpec( + doc='A test group specification with a data type', + data_type_def='Bar', + datasets=[ + DatasetSpec( + doc='a dataset', + dtype='int', + name='data', + attributes=[AttributeSpec(name='attr2', doc='an integer attribute', dtype='int')] + ) + ], + attributes=[AttributeSpec(name='attr1', doc='a string attribute', dtype='text')]) + + multi_spec = GroupSpec(doc='A test extension that contains a multi', + data_type_def='Multi', + groups=[GroupSpec(data_type_inc=bar_spec, doc='test multi', quantity='*')], + attributes=[AttributeSpec(name='attr1', doc='a float attribute', dtype='float')]) + + spec_catalog = SpecCatalog() + spec_catalog.register_spec(bar_spec, 'test.yaml') + spec_catalog.register_spec(multi_spec, 'test.yaml') + namespace = SpecNamespace( + doc='a test namespace', + name=CORE_NAMESPACE, + schema=[{'source': 'test.yaml'}], + version='0.1.0', + catalog=spec_catalog + ) + namespace_catalog = NamespaceCatalog() + namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) + type_map = TypeMap(namespace_catalog) + Multi = type_map.get_dt_container_cls('Multi', CORE_NAMESPACE, self.post_init) + + with self.assertWarns(Warning): + Multi(name='instance', attr1=9.1) + class TestDynamicContainer(TestCase): def setUp(self): @@ -109,13 +183,15 @@ def test_dynamic_container_creation(self): AttributeSpec('attr4', 'another float attribute', 'float')]) self.spec_catalog.register_spec(baz_spec, 'extension.yaml') cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) - expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4'} + expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4', 'skip_post_init'} received_args = set() + for x in get_docval(cls.__init__): if x['name'] != 'foo': received_args.add(x['name']) with self.subTest(name=x['name']): - self.assertNotIn('default', x) + if x['name'] != 'skip_post_init': + self.assertNotIn('default', x) self.assertSetEqual(expected_args, received_args) self.assertEqual(cls.__name__, 'Baz') self.assertTrue(issubclass(cls, Bar)) @@ -135,7 +211,7 @@ def test_dynamic_container_creation_defaults(self): AttributeSpec('attr4', 'another float attribute', 'float')]) self.spec_catalog.register_spec(baz_spec, 'extension.yaml') cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) - expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4', 'foo'} + expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4', 'foo', 'skip_post_init'} received_args = set(map(lambda x: x['name'], get_docval(cls.__init__))) self.assertSetEqual(expected_args, received_args) self.assertEqual(cls.__name__, 'Baz') @@ -285,13 +361,14 @@ def __init__(self, **kwargs): AttributeSpec('attr4', 'another float attribute', 'float')]) self.spec_catalog.register_spec(baz_spec, 'extension.yaml') cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) - expected_args = {'name', 'data', 'attr2', 'attr3', 'attr4'} + expected_args = {'name', 'data', 'attr2', 'attr3', 'attr4', 'skip_post_init'} received_args = set() for x in get_docval(cls.__init__): if x['name'] != 'foo': received_args.add(x['name']) with self.subTest(name=x['name']): - self.assertNotIn('default', x) + if x['name'] != 'skip_post_init': + self.assertNotIn('default', x) self.assertSetEqual(expected_args, received_args) self.assertTrue(issubclass(cls, FixedAttrBar)) inst = cls(name="My Baz", data=[1, 2, 3, 4], attr2=1000, attr3=98.6, attr4=1.0) @@ -445,7 +522,7 @@ def setUp(self): def test_init_docval(self): cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) # generate the class - expected_args = {'name'} # 'attr1' should not be included + expected_args = {'name', 'skip_post_init'} # 'attr1' should not be included received_args = set() for x in get_docval(cls.__init__): received_args.add(x['name']) @@ -518,6 +595,8 @@ def test_gen_parent_class(self): {'name': 'my_baz1', 'doc': 'A composition inside with a fixed name', 'type': baz1_cls}, {'name': 'my_baz2', 'doc': 'A composition inside with a fixed name', 'type': baz2_cls}, {'name': 'my_baz1_link', 'doc': 'A composition inside without a fixed name', 'type': baz1_cls}, + {'name': 'skip_post_init', 'type': bool, 'default': False, + 'doc': 'bool to skip post_init'} )) def test_init_fields(self):