Skip to content

Commit

Permalink
concept
Browse files Browse the repository at this point in the history
  • Loading branch information
mavaylon1 committed Apr 4, 2024
1 parent 5df6bc1 commit 4799bb9
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/hdmf/build/classgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def generate_class(self, **kwargs):
cls.post_init_method = tp.MethodType(post_init_method, cls) # set as bounded method
else:
cls.post_init_method = post_init_method # set to None

return cls


Expand Down Expand Up @@ -330,6 +329,7 @@ def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name):

@docval(*docval_args, allow_positional=AllowPositional.WARNING)
def __init__(self, **kwargs):
original_kwargs = dict(kwargs)
if name is not None: # force container name to be the fixed name in the spec
kwargs.update(name=name)

Expand All @@ -356,7 +356,7 @@ def __init__(self, **kwargs):
self.fields[f] = getattr(not_inherited_fields[f], 'value')

if self.post_init_method is not None:
self.post_init_method(kwargs)
self.post_init_method(**original_kwargs)


classdict['__init__'] = __init__
Expand Down
1 change: 1 addition & 0 deletions src/hdmf/build/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ def get_dt_container_cls(self, **kwargs):
raise ValueError("Namespace could not be resolved.")

cls = self.__get_container_cls(namespace, data_type)

if cls is None and autogen: # dynamically generate a class
spec = self.__ns_catalog.get_spec(namespace, data_type)
self.__check_dependent_types(spec, namespace)
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/build_tests/test_classgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,58 @@ def test_no_generators(self):
self.assertTrue(hasattr(cls, '__init__'))


class TestPostInitGetClass(TestCase):
def setUp(self):
# self.bar_spec = GroupSpec(
# doc='A test group specification with a data type',
# data_type_def='Bar',
# datasets=[
# DatasetSpec(
# doc='a dataset',
# dtype='int',
# name='data',
# attributes=[AttributeSpec(name='attr1', doc='an integer attribute', dtype='int')]
# )
# ])
# specs = [self.bar_spec]
# containers = {'Bar': Bar}
# from hdmf.common import get_type_map
# self.type_map = get_type_map()
# self.spec_catalog = self.type_map.namespace_catalog.get_namespace(CORE_NAMESPACE).catalog
spec = GroupSpec(
doc='A test group specification with a data type',
data_type_def='Baz',
attributes=[
AttributeSpec(name='attr1', doc='a int attribute', dtype='int')
]
)

spec_catalog = SpecCatalog()
spec_catalog.register_spec(spec, 'test.yaml')
namespace = SpecNamespace(
doc='a test namespace',
name=CORE_NAMESPACE,
schema=[{'source': 'test.yaml'}],
version='0.1.0',
catalog=spec_catalog
)
namespace_catalog = NamespaceCatalog()
namespace_catalog.add_namespace(CORE_NAMESPACE, namespace)
self.type_map = TypeMap(namespace_catalog)


def test_post_init(self):
def post_init_method(self, **kwargs):
attr1 = kwargs['attr1']
if attr1<10:
msg = "attr1 should be >=10"
raise ValueError(msg)

cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE, post_init_method)

with self.assertRaises(ValueError):
instance = cls(name='instance', attr1=9)

class TestDynamicContainer(TestCase):

def setUp(self):
Expand Down

0 comments on commit 4799bb9

Please sign in to comment.