From a5f33c2433e1f59ebd1478a577dd8ababae6d5c9 Mon Sep 17 00:00:00 2001 From: Ryan Ly Date: Thu, 4 Apr 2024 03:02:35 -0700 Subject: [PATCH 1/9] Don't override col_cls in DynamicTable.add_column --- src/hdmf/common/table.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index 3b67ff19d..1e8c3af7a 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -821,11 +821,13 @@ def add_column(self, **kwargs): # noqa: C901 raise ValueError("column '%s' cannot be both a table region " "and come from an enumerable set of elements" % name) if table is not False: - col_cls = DynamicTableRegion + if col_cls is None: + col_cls = DynamicTableRegion if isinstance(table, DynamicTable): ckwargs['table'] = table if enum is not False: - col_cls = EnumData + if col_cls is None: + col_cls = EnumData if isinstance(enum, (list, tuple, np.ndarray, VectorData)): ckwargs['elements'] = enum From fcdc3e5f4c70dd7bf1c3a17d877ef5b59d1d6395 Mon Sep 17 00:00:00 2001 From: Ryan Ly Date: Thu, 4 Apr 2024 03:05:21 -0700 Subject: [PATCH 2/9] Update CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e72015601..f69839cfa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,9 @@ - Added `TermSetConfigurator` to automatically wrap fields with `TermSetWrapper` according to a configuration file. @mavaylon1 [#1016](https://github.com/hdmf-dev/hdmf/pull/1016) - Updated `TermSetWrapper` to support validating a single field within a compound array. @mavaylon1 [#1061](https://github.com/hdmf-dev/hdmf/pull/1061) +### Bug fixes +- Fixed issue with `DynamicTable.add_column` not allowing subclasses of `DynamicTableRegion` or `EnumData`. @rly [#1091](https://github.com/hdmf-dev/hdmf/pull/1091) + ## HDMF 3.13.0 (March 20, 2024) ### Enhancements From 683bf37d73254b5f4c70eb46ff6ae3124ca67260 Mon Sep 17 00:00:00 2001 From: rly Date: Thu, 4 Apr 2024 18:49:58 -0700 Subject: [PATCH 3/9] Fix adding elements in enumdata subclass --- src/hdmf/common/table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index 1e8c3af7a..59acbcd96 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -875,7 +875,7 @@ def add_column(self, **kwargs): # noqa: C901 if col in self.__uninit_cols: self.__uninit_cols.pop(col) - if col_cls is EnumData: + if issubclass(col_cls, EnumData): columns.append(col.elements) col.elements.parent = self From 398ab2dbb428170f6901bacdf5def5290edab564 Mon Sep 17 00:00:00 2001 From: rly Date: Wed, 10 Apr 2024 10:16:08 -0700 Subject: [PATCH 4/9] stash changes --- src/hdmf/common/table.py | 8 ++++++++ src/hdmf/container.py | 3 +++ 2 files changed, 11 insertions(+) diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index 59acbcd96..dc1a2569e 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -7,6 +7,7 @@ from collections import OrderedDict from typing import NamedTuple, Union from warnings import warn +# import h5py import numpy as np import pandas as pd @@ -1621,6 +1622,13 @@ def _get_helper(self, idx, index=False, join=False, **kwargs): return idx if not np.isscalar(idx): idx = np.asarray(idx) + # if isinstance(self.elements.data, h5py.Dataset): + # # h5py datasets cannot index with an array of indices that are not in increasing order + # # so unpack them one by one + # ret = np.empty(idx.shape, dtype=self.elements.data.dtype) + # for i, j in enumerate(idx.ravel()): + # ret[i] = self.elements.get(j, **kwargs) + # else: ret = np.asarray(self.elements.get(idx.ravel(), **kwargs)).reshape(idx.shape) if join: ret = ''.join(ret.ravel()) diff --git a/src/hdmf/container.py b/src/hdmf/container.py index f93c06199..4c5d9f3bc 100644 --- a/src/hdmf/container.py +++ b/src/hdmf/container.py @@ -937,7 +937,10 @@ def __getitem__(self, args): def get(self, args): if isinstance(self.data, (tuple, list)) and isinstance(args, (tuple, list, np.ndarray)): + # try: return [self.data[i] for i in args] + # except: + # breakpoint() if isinstance(self.data, h5py.Dataset) and isinstance(args, np.ndarray): # This is needed for h5py 2.9 compatibility args = args.tolist() From 0277a08ca7ef9c3950f10ed700c578b2ac7c5848 Mon Sep 17 00:00:00 2001 From: rly Date: Fri, 17 Jan 2025 23:57:22 -0800 Subject: [PATCH 5/9] Revert "stash changes" This reverts commit 398ab2dbb428170f6901bacdf5def5290edab564. --- src/hdmf/common/table.py | 8 -------- src/hdmf/container.py | 3 --- 2 files changed, 11 deletions(-) diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index dc1a2569e..59acbcd96 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -7,7 +7,6 @@ from collections import OrderedDict from typing import NamedTuple, Union from warnings import warn -# import h5py import numpy as np import pandas as pd @@ -1622,13 +1621,6 @@ def _get_helper(self, idx, index=False, join=False, **kwargs): return idx if not np.isscalar(idx): idx = np.asarray(idx) - # if isinstance(self.elements.data, h5py.Dataset): - # # h5py datasets cannot index with an array of indices that are not in increasing order - # # so unpack them one by one - # ret = np.empty(idx.shape, dtype=self.elements.data.dtype) - # for i, j in enumerate(idx.ravel()): - # ret[i] = self.elements.get(j, **kwargs) - # else: ret = np.asarray(self.elements.get(idx.ravel(), **kwargs)).reshape(idx.shape) if join: ret = ''.join(ret.ravel()) diff --git a/src/hdmf/container.py b/src/hdmf/container.py index 4c5d9f3bc..f93c06199 100644 --- a/src/hdmf/container.py +++ b/src/hdmf/container.py @@ -937,10 +937,7 @@ def __getitem__(self, args): def get(self, args): if isinstance(self.data, (tuple, list)) and isinstance(args, (tuple, list, np.ndarray)): - # try: return [self.data[i] for i in args] - # except: - # breakpoint() if isinstance(self.data, h5py.Dataset) and isinstance(args, np.ndarray): # This is needed for h5py 2.9 compatibility args = args.tolist() From cbc5968dbb193da5c8d94d3ed6cc13f4334c309a Mon Sep 17 00:00:00 2001 From: rly Date: Tue, 21 Jan 2025 13:48:27 -0800 Subject: [PATCH 6/9] Fix how column spec is generated from spec --- src/hdmf/common/io/table.py | 3 +- src/hdmf/common/table.py | 44 +++++++++++++++--------- tests/unit/common/test_generate_table.py | 35 +++++++++++++++++-- 3 files changed, 62 insertions(+), 20 deletions(-) diff --git a/src/hdmf/common/io/table.py b/src/hdmf/common/io/table.py index 50395ba24..379553c07 100644 --- a/src/hdmf/common/io/table.py +++ b/src/hdmf/common/io/table.py @@ -78,12 +78,11 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i required=field_spec.required ) dtype = cls._get_type(field_spec, type_map) + column_conf['class'] = dtype if issubclass(dtype, DynamicTableRegion): # the spec does not know which table this DTR points to # the user must specify the table attribute on the DTR after it is generated column_conf['table'] = True - else: - column_conf['class'] = dtype index_counter = 0 index_name = attr_name diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index 2b03c2c76..a6aa6cb87 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -521,7 +521,7 @@ def _init_class_columns(self): description=col['description'], index=col.get('index', False), table=col.get('table', False), - col_cls=col.get('class', VectorData), + col_cls=col.get('class'), # Pass through extra kwargs for add_column that subclasses may have added **{k: col[k] for k in col.keys() if k not in DynamicTable.__reserved_colspec_keys}) @@ -564,10 +564,13 @@ def _set_dtr_targets(self, target_tables: dict): if not column_conf.get('table', False): raise ValueError("Column '%s' must be a DynamicTableRegion to have a target table." % colname) - self.add_column(name=column_conf['name'], - description=column_conf['description'], - index=column_conf.get('index', False), - table=True) + self.add_column( + name=column_conf['name'], + description=column_conf['description'], + index=column_conf.get('index', False), + table=True, + col_cls=column_conf.get('class'), + ) if isinstance(self[colname], VectorIndex): col = self[colname].target else: @@ -681,7 +684,7 @@ def add_row(self, **kwargs): index=col.get('index', False), table=col.get('table', False), enum=col.get('enum', False), - col_cls=col.get('class', VectorData), + col_cls=col.get('class'), # Pass through extra keyword arguments for add_column that # subclasses may have added **{k: col[k] for k in col.keys() @@ -753,7 +756,7 @@ def __eq__(self, other): 'default': False}, {'name': 'enum', 'type': (bool, 'array_data'), 'default': False, 'doc': ('whether or not this column contains data from a fixed set of elements')}, - {'name': 'col_cls', 'type': type, 'default': VectorData, + {'name': 'col_cls', 'type': type, 'default': None, 'doc': ('class to use to represent the column data. If table=True, this field is ignored and a ' 'DynamicTableRegion object is used. If enum=True, this field is ignored and a EnumData ' 'object is used.')}, @@ -805,31 +808,40 @@ def add_column(self, **kwargs): # noqa: C901 % (name, self.__class__.__name__, spec_index)) warn(msg, stacklevel=3) - spec_col_cls = self.__uninit_cols[name].get('class', VectorData) - if col_cls != spec_col_cls: - msg = ("Column '%s' is predefined in %s with class=%s which does not match the entered " - "col_cls argument. The predefined class spec will be ignored. " - "Please ensure the new column complies with the spec. " - "This will raise an error in a future version of HDMF." - % (name, self.__class__.__name__, spec_col_cls)) - warn(msg, stacklevel=2) - ckwargs = dict(kwargs) # Add table if it's been specified if table and enum: raise ValueError("column '%s' cannot be both a table region " "and come from an enumerable set of elements" % name) + # Update col_cls if table is specified if table is not False: + # breakpoint() if col_cls is None: col_cls = DynamicTableRegion if isinstance(table, DynamicTable): ckwargs['table'] = table + # Update col_cls if enum is specified if enum is not False: if col_cls is None: col_cls = EnumData if isinstance(enum, (list, tuple, np.ndarray, VectorData)): ckwargs['elements'] = enum + # Update col_cls to the default VectorData if col_cls is None + if col_cls is None: + col_cls = VectorData + + if name in self.__uninit_cols: # column is a predefined optional column from the spec + # check the given values against the predefined optional column spec. if they do not match, raise a warning + # and ignore the given arguments. users should not be able to override these values + spec_col_cls = self.__uninit_cols[name].get('class') + if spec_col_cls is not None and col_cls != spec_col_cls: + msg = ("Column '%s' is predefined in %s with class=%s which does not match the entered " + "col_cls argument. The predefined class spec will be ignored. " + "Please ensure the new column complies with the spec. " + "This will raise an error in a future version of HDMF." + % (name, self.__class__.__name__, spec_col_cls)) + warn(msg, stacklevel=2) # If the user provided a list of lists that needs to be indexed, then we now need to flatten the data # We can only create the index actual VectorIndex once we have the VectorData column so we compute diff --git a/tests/unit/common/test_generate_table.py b/tests/unit/common/test_generate_table.py index 7f7d7da40..914882820 100644 --- a/tests/unit/common/test_generate_table.py +++ b/tests/unit/common/test_generate_table.py @@ -16,6 +16,13 @@ class TestDynamicDynamicTable(TestCase): def setUp(self): + + self.dtr_spec = DatasetSpec( + data_type_def='CustomDTR', + data_type_inc='DynamicTableRegion', + doc='a test DynamicTableRegion column', # this is overridden where it is used + ) + self.dt_spec = GroupSpec( 'A test extension that contains a dynamic table', data_type_def='TestTable', @@ -99,7 +106,13 @@ def setUp(self): doc='a test column', dtype='float', quantity='?', - ) + ), + DatasetSpec( + data_type_inc='CustomDTR', + name='optional_custom_dtr_col', + doc='a test DynamicTableRegion column', + quantity='?' + ), ] ) @@ -107,6 +120,7 @@ def setUp(self): writer = YAMLSpecWriter(outdir='.') self.spec_catalog = SpecCatalog() + self.spec_catalog.register_spec(self.dtr_spec, 'test.yaml') self.spec_catalog.register_spec(self.dt_spec, 'test.yaml') self.spec_catalog.register_spec(self.dt_spec2, 'test.yaml') self.namespace = SpecNamespace( @@ -124,7 +138,7 @@ def setUp(self): self.test_dir = tempfile.mkdtemp() spec_fpath = os.path.join(self.test_dir, 'test.yaml') namespace_fpath = os.path.join(self.test_dir, 'test-namespace.yaml') - writer.write_spec(dict(groups=[self.dt_spec, self.dt_spec2]), spec_fpath) + writer.write_spec(dict(datasets=[self.dtr_spec], groups=[self.dt_spec, self.dt_spec2]), spec_fpath) writer.write_namespace(self.namespace, namespace_fpath) self.namespace_catalog = NamespaceCatalog() hdmf_typemap = get_type_map() @@ -133,6 +147,7 @@ def setUp(self): self.type_map.load_namespaces(namespace_fpath) self.manager = BuildManager(self.type_map) + self.CustomDTR = self.type_map.get_dt_container_cls('CustomDTR', CORE_NAMESPACE) self.TestTable = self.type_map.get_dt_container_cls('TestTable', CORE_NAMESPACE) self.TestDTRTable = self.type_map.get_dt_container_cls('TestDTRTable', CORE_NAMESPACE) @@ -228,6 +243,22 @@ def test_dynamic_table_region_non_dtr_target(self): self.TestDTRTable(name='test_dtr_table', description='my table', target_tables={'optional_col3': test_table}) + def test_custom_dtr_class(self): + test_table = self.TestTable(name='test_table', description='my test table') + test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=.5) + test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=.5) + + test_dtr_table = self.TestDTRTable(name='test_dtr_table', description='my table', + target_tables={'optional_custom_dtr_col': test_table}) + + self.assertIsInstance(test_dtr_table.optional_custom_dtr_col, self.CustomDTR) + self.assertEqual(test_dtr_table.optional_custom_dtr_col.description, "a test DynamicTableRegion column") + self.assertIs(test_dtr_table.optional_custom_dtr_col.table, test_table) + + test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], optional_custom_dtr_col=0) + test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], optional_custom_dtr_col=1) + self.assertEqual(test_dtr_table['optional_custom_dtr_col'].data, [0, 1]) + def test_attribute(self): test_table = self.TestTable(name='test_table', description='my test table') assert test_table.my_col is not None From 8f452ba6fcc98e26130f2b99b539376ebce5165a Mon Sep 17 00:00:00 2001 From: Ryan Ly Date: Tue, 21 Jan 2025 14:08:53 -0800 Subject: [PATCH 7/9] Update CHANGELOG.md --- CHANGELOG.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a37ae084..396d0f744 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ ### Added - Added script to check Python version support for HDMF dependencies. @rly [#1230](https://github.com/hdmf-dev/hdmf/pull/1230) +### Fixed +- Fixed issue with `DynamicTable.add_column` not allowing subclasses of `DynamicTableRegion` or `EnumData`. @rly [#1091](https://github.com/hdmf-dev/hdmf/pull/1091) + ## HDMF 3.14.6 (December 20, 2024) ### Enhancements @@ -99,9 +102,6 @@ is available on build (during the write process), but not on read of a dataset f ### Bug Fixes - Fixed `TermSetWrapper` warning raised during the setters. @mavaylon1 [#1116](https://github.com/hdmf-dev/hdmf/pull/1116) -### Bug fixes -- Fixed issue with `DynamicTable.add_column` not allowing subclasses of `DynamicTableRegion` or `EnumData`. @rly [#1091](https://github.com/hdmf-dev/hdmf/pull/1091) - ## HDMF 3.13.0 (March 20, 2024) ### Enhancements From 9604739ef2a38913f1a45cec5bfd7d4e408206fe Mon Sep 17 00:00:00 2001 From: Matthew Avaylon Date: Tue, 21 Jan 2025 18:11:27 -0800 Subject: [PATCH 8/9] Update table.py --- src/hdmf/common/table.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index a6aa6cb87..2f6401672 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -816,7 +816,6 @@ def add_column(self, **kwargs): # noqa: C901 "and come from an enumerable set of elements" % name) # Update col_cls if table is specified if table is not False: - # breakpoint() if col_cls is None: col_cls = DynamicTableRegion if isinstance(table, DynamicTable): From a5b6671cb059ac1ad735cf1d433bd63dba4e5bbc Mon Sep 17 00:00:00 2001 From: rly Date: Tue, 21 Jan 2025 22:21:21 -0800 Subject: [PATCH 9/9] Add test for custom column in add_column --- tests/unit/common/test_generate_table.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/unit/common/test_generate_table.py b/tests/unit/common/test_generate_table.py index 914882820..71c15aad0 100644 --- a/tests/unit/common/test_generate_table.py +++ b/tests/unit/common/test_generate_table.py @@ -251,9 +251,9 @@ def test_custom_dtr_class(self): test_dtr_table = self.TestDTRTable(name='test_dtr_table', description='my table', target_tables={'optional_custom_dtr_col': test_table}) - self.assertIsInstance(test_dtr_table.optional_custom_dtr_col, self.CustomDTR) - self.assertEqual(test_dtr_table.optional_custom_dtr_col.description, "a test DynamicTableRegion column") - self.assertIs(test_dtr_table.optional_custom_dtr_col.table, test_table) + self.assertIsInstance(test_dtr_table['optional_custom_dtr_col'], self.CustomDTR) + self.assertEqual(test_dtr_table['optional_custom_dtr_col'].description, "a test DynamicTableRegion column") + self.assertIs(test_dtr_table['optional_custom_dtr_col'].table, test_table) test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], optional_custom_dtr_col=0) test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], optional_custom_dtr_col=1) @@ -297,3 +297,17 @@ def test_roundtrip(self): for err in errors: raise Exception(err) self.reader.close() + + def test_add_custom_dtr_column(self): + test_table = self.TestTable(name='test_table', description='my test table') + test_table.add_column( + name='custom_dtr_column', + description='this is a custom DynamicTableRegion column', + col_cls=self.CustomDTR, + ) + self.assertIsInstance(test_table['custom_dtr_column'], self.CustomDTR) + self.assertEqual(test_table['custom_dtr_column'].description, 'this is a custom DynamicTableRegion column') + + test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], custom_dtr_column=0) + test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], custom_dtr_column=1) + self.assertEqual(test_table['custom_dtr_column'].data, [0, 1])