Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Post init option for class generator #1089

Merged
merged 28 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/hdmf/build/classgenerator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from copy import deepcopy
from datetime import datetime, date
from collections.abc import Callable
import types as tp

import numpy as np

Expand Down Expand Up @@ -35,15 +37,19 @@ def register_generator(self, **kwargs):
{'name': 'spec', 'type': BaseStorageSpec, 'doc': ''},
{'name': 'parent_cls', 'type': type, 'doc': ''},
{'name': 'attr_names', 'type': dict, 'doc': ''},
{'name': 'post_init_method', 'type': Callable, 'default': None,
'doc': 'The function used as a post_init method to validate the class generation.'},
{'name': 'type_map', 'type': 'hdmf.build.manager.TypeMap', 'doc': ''},
returns='the class for the given namespace and data_type', rtype=type)
def generate_class(self, **kwargs):
"""Get the container class from data type specification.
If no class has been associated with the ``data_type`` from ``namespace``, a class will be dynamically
created and returned.
"""
data_type, spec, parent_cls, attr_names, type_map = getargs('data_type', 'spec', 'parent_cls', 'attr_names',
'type_map', kwargs)
data_type, spec, parent_cls, attr_names, type_map, post_init_method = getargs('data_type', 'spec',
'parent_cls', 'attr_names',
'type_map',
'post_init_method', kwargs)

not_inherited_fields = dict()
for k, field_spec in attr_names.items():
Expand Down Expand Up @@ -82,6 +88,11 @@ def generate_class(self, **kwargs):
+ str(e)
+ " Please define that type before defining '%s'." % name)
cls = ExtenderMeta(data_type, tuple(bases), classdict)

if post_init_method is not None:
cls.post_init_method = tp.MethodType(post_init_method, cls) # set as bounded method
else:
cls.post_init_method = post_init_method # set to None
return cls


Expand Down Expand Up @@ -318,6 +329,7 @@ def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name):

@docval(*docval_args, allow_positional=AllowPositional.WARNING)
def __init__(self, **kwargs):
original_kwargs = dict(kwargs)
if name is not None: # force container name to be the fixed name in the spec
kwargs.update(name=name)

Expand All @@ -343,6 +355,10 @@ def __init__(self, **kwargs):
for f in fixed_value_attrs_to_set:
self.fields[f] = getattr(not_inherited_fields[f], 'value')

if self.post_init_method is not None:
self.post_init_method(**original_kwargs)


classdict['__init__'] = __init__


Expand Down
15 changes: 13 additions & 2 deletions src/hdmf/build/manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from collections import OrderedDict, deque
from copy import copy
from collections.abc import Callable

from .builders import DatasetBuilder, GroupBuilder, LinkBuilder, Builder, BaseBuilder
from .classgenerator import ClassGenerator, CustomClassGenerator, MCIClassGenerator
Expand Down Expand Up @@ -498,11 +499,14 @@ def get_container_cls(self, **kwargs):
created and returned.
"""
# NOTE: this internally used function get_container_cls will be removed in favor of get_dt_container_cls
# Deprecated: Will be removed by HDMF 4.0
namespace, data_type, autogen = getargs('namespace', 'data_type', 'autogen', kwargs)
return self.get_dt_container_cls(data_type, namespace, autogen)

@docval({"name": "data_type", "type": str, "doc": "the data type to create a AbstractContainer class for"},
{"name": "namespace", "type": str, "doc": "the namespace containing the data_type", "default": None},
{'name': 'post_init_method', 'type': Callable, 'default': None,
'doc': 'The function used as a post_init method to validate the class generation.'},
{"name": "autogen", "type": bool, "doc": "autogenerate class if one does not exist", "default": True},
returns='the class for the given namespace and data_type', rtype=type)
def get_dt_container_cls(self, **kwargs):
Expand All @@ -513,7 +517,8 @@ def get_dt_container_cls(self, **kwargs):
Replaces get_container_cls but namespace is optional. If namespace is unknown, it will be looked up from
all namespaces.
"""
namespace, data_type, autogen = getargs('namespace', 'data_type', 'autogen', kwargs)
namespace, data_type, post_init_method, autogen = getargs('namespace', 'data_type',
'post_init_method','autogen', kwargs)

# namespace is unknown, so look it up
if namespace is None:
Expand All @@ -527,12 +532,18 @@ def get_dt_container_cls(self, **kwargs):
raise ValueError("Namespace could not be resolved.")

cls = self.__get_container_cls(namespace, data_type)

if cls is None and autogen: # dynamically generate a class
spec = self.__ns_catalog.get_spec(namespace, data_type)
self.__check_dependent_types(spec, namespace)
parent_cls = self.__get_parent_cls(namespace, data_type, spec)
attr_names = self.__default_mapper_cls.get_attr_names(spec)
cls = self.__class_generator.generate_class(data_type, spec, parent_cls, attr_names, self)
cls = self.__class_generator.generate_class(data_type=data_type,
spec=spec,
parent_cls=parent_cls,
attr_names=attr_names,
post_init_method=post_init_method,
type_map=self)
self.register_container_type(namespace, data_type, cls)
return cls

Expand Down
8 changes: 6 additions & 2 deletions src/hdmf/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
'''
import os.path
from copy import deepcopy
from collections.abc import Callable

CORE_NAMESPACE = 'hdmf-common'
EXP_NAMESPACE = 'hdmf-experimental'
Expand Down Expand Up @@ -136,12 +137,15 @@ def available_namespaces():
@docval({'name': 'data_type', 'type': str,
'doc': 'the data_type to get the Container class for'},
{'name': 'namespace', 'type': str, 'doc': 'the namespace the data_type is defined in'},
{'name': 'post_init_method', 'type': Callable, 'default': None,
'doc': 'The function used as a post_init method to validate the class generation.'},
{"name": "autogen", "type": bool, "doc": "autogenerate class if one does not exist", "default": True},
is_method=False)
def get_class(**kwargs):
"""Get the class object of the Container subclass corresponding to a given neurdata_type.
"""
data_type, namespace = getargs('data_type', 'namespace', kwargs)
return __TYPE_MAP.get_dt_container_cls(data_type, namespace)
data_type, namespace, post_init_method = getargs('data_type', 'namespace', 'post_init_method', kwargs)
return __TYPE_MAP.get_dt_container_cls(data_type, namespace, post_init_method)


@docval({'name': 'extensions', 'type': (str, TypeMap, list),
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/build_tests/test_classgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,58 @@ def test_no_generators(self):
self.assertTrue(hasattr(cls, '__init__'))


class TestPostInitGetClass(TestCase):
def setUp(self):
# self.bar_spec = GroupSpec(
# doc='A test group specification with a data type',
# data_type_def='Bar',
# datasets=[
# DatasetSpec(
# doc='a dataset',
# dtype='int',
# name='data',
# attributes=[AttributeSpec(name='attr1', doc='an integer attribute', dtype='int')]
# )
# ])
# specs = [self.bar_spec]
# containers = {'Bar': Bar}
# from hdmf.common import get_type_map
# self.type_map = get_type_map()
# self.spec_catalog = self.type_map.namespace_catalog.get_namespace(CORE_NAMESPACE).catalog
spec = GroupSpec(
doc='A test group specification with a data type',
data_type_def='Baz',
attributes=[
AttributeSpec(name='attr1', doc='a int attribute', dtype='int')
]
)

spec_catalog = SpecCatalog()
spec_catalog.register_spec(spec, 'test.yaml')
namespace = SpecNamespace(
doc='a test namespace',
name=CORE_NAMESPACE,
schema=[{'source': 'test.yaml'}],
version='0.1.0',
catalog=spec_catalog
)
namespace_catalog = NamespaceCatalog()
namespace_catalog.add_namespace(CORE_NAMESPACE, namespace)
self.type_map = TypeMap(namespace_catalog)


def test_post_init(self):
def post_init_method(self, **kwargs):
attr1 = kwargs['attr1']
if attr1<10:
msg = "attr1 should be >=10"
raise ValueError(msg)

cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE, post_init_method)

with self.assertRaises(ValueError):
instance = cls(name='instance', attr1=9)

class TestDynamicContainer(TestCase):

def setUp(self):
Expand Down
Loading