diff --git a/src/hdmf/build/classgenerator.py b/src/hdmf/build/classgenerator.py index ba19b4e10..1728de821 100644 --- a/src/hdmf/build/classgenerator.py +++ b/src/hdmf/build/classgenerator.py @@ -323,12 +323,17 @@ def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name): elif attr_name not in attrs_not_to_set: attrs_to_set.append(attr_name) + # We want to use the skip_post_init of the current class and not the parent class + for item in docval_args: + if item['name'] == 'skip_post_init': + docval_args.remove(item) + @docval(*docval_args, - {'name': 'post_init_bool', 'type': bool, 'default': True, - 'doc': 'bool to skip post_init for MCIClassGenerator'}, + {'name': 'skip_post_init', 'type': bool, 'default': False, + 'doc': 'bool to skip post_init'}, allow_positional=AllowPositional.WARNING) def __init__(self, **kwargs): - post_init_bool = popargs('post_init_bool', kwargs) + skip_post_init = popargs('skip_post_init', kwargs) original_kwargs = dict(kwargs) if name is not None: # force container name to be the fixed name in the spec @@ -356,8 +361,7 @@ def __init__(self, **kwargs): for f in fixed_value_attrs_to_set: self.fields[f] = getattr(not_inherited_fields[f], 'value') - if self.post_init_method is not None and post_init_bool: - # + if self.post_init_method is not None and not skip_post_init: self.post_init_method(**original_kwargs) classdict['__init__'] = __init__ @@ -430,6 +434,11 @@ def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name): if '__clsconf__' in classdict: previous_init = classdict['__init__'] + # We want to use the skip_post_init of the current class and not the parent class + for item in docval_args: + if item['name'] == 'skip_post_init': + docval_args.remove(item) + @docval(*docval_args, allow_positional=AllowPositional.WARNING) def __init__(self, **kwargs): # store the values passed to init for each MCI attribute so that they can be added @@ -455,7 +464,7 @@ def __init__(self, **kwargs): kwargs[attr_name] = list() # call the parent class init without the MCI attribute - kwargs['post_init_bool'] = False + kwargs['skip_post_init'] = True previous_init(self, **kwargs) # call the add method for each MCI attribute diff --git a/tests/unit/build_tests/test_classgenerator.py b/tests/unit/build_tests/test_classgenerator.py index 2db1cf65e..52fdc4839 100644 --- a/tests/unit/build_tests/test_classgenerator.py +++ b/tests/unit/build_tests/test_classgenerator.py @@ -183,13 +183,15 @@ def test_dynamic_container_creation(self): AttributeSpec('attr4', 'another float attribute', 'float')]) self.spec_catalog.register_spec(baz_spec, 'extension.yaml') cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) - expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4'} + expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4', 'skip_post_init'} received_args = set() + for x in get_docval(cls.__init__): if x['name'] != 'foo': received_args.add(x['name']) with self.subTest(name=x['name']): - self.assertNotIn('default', x) + if x['name'] != 'skip_post_init': + self.assertNotIn('default', x) self.assertSetEqual(expected_args, received_args) self.assertEqual(cls.__name__, 'Baz') self.assertTrue(issubclass(cls, Bar)) @@ -209,7 +211,7 @@ def test_dynamic_container_creation_defaults(self): AttributeSpec('attr4', 'another float attribute', 'float')]) self.spec_catalog.register_spec(baz_spec, 'extension.yaml') cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) - expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4', 'foo'} + expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4', 'foo', 'skip_post_init'} received_args = set(map(lambda x: x['name'], get_docval(cls.__init__))) self.assertSetEqual(expected_args, received_args) self.assertEqual(cls.__name__, 'Baz') @@ -359,13 +361,14 @@ def __init__(self, **kwargs): AttributeSpec('attr4', 'another float attribute', 'float')]) self.spec_catalog.register_spec(baz_spec, 'extension.yaml') cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) - expected_args = {'name', 'data', 'attr2', 'attr3', 'attr4'} + expected_args = {'name', 'data', 'attr2', 'attr3', 'attr4', 'skip_post_init'} received_args = set() for x in get_docval(cls.__init__): if x['name'] != 'foo': received_args.add(x['name']) with self.subTest(name=x['name']): - self.assertNotIn('default', x) + if x['name'] != 'skip_post_init': + self.assertNotIn('default', x) self.assertSetEqual(expected_args, received_args) self.assertTrue(issubclass(cls, FixedAttrBar)) inst = cls(name="My Baz", data=[1, 2, 3, 4], attr2=1000, attr3=98.6, attr4=1.0) @@ -519,7 +522,7 @@ def setUp(self): def test_init_docval(self): cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) # generate the class - expected_args = {'name'} # 'attr1' should not be included + expected_args = {'name', 'skip_post_init'} # 'attr1' should not be included received_args = set() for x in get_docval(cls.__init__): received_args.add(x['name']) @@ -592,6 +595,8 @@ def test_gen_parent_class(self): {'name': 'my_baz1', 'doc': 'A composition inside with a fixed name', 'type': baz1_cls}, {'name': 'my_baz2', 'doc': 'A composition inside with a fixed name', 'type': baz2_cls}, {'name': 'my_baz1_link', 'doc': 'A composition inside without a fixed name', 'type': baz1_cls}, + {'name': 'skip_post_init', 'type': bool, 'default': False, + 'doc': 'bool to skip post_init'} )) def test_init_fields(self):