Skip to content

Commit

Permalink
Don't override col_cls in DynamicTable.add_column (#1091)
Browse files Browse the repository at this point in the history
Co-authored-by: Matthew Avaylon <[email protected]>
  • Loading branch information
rly and mavaylon1 authored Jan 22, 2025
1 parent 775fa3b commit ab8cf3b
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 23 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
### Added
- Added script to check Python version support for HDMF dependencies. @rly [#1230](https://github.com/hdmf-dev/hdmf/pull/1230)

### Fixed
- Fixed issue with `DynamicTable.add_column` not allowing subclasses of `DynamicTableRegion` or `EnumData`. @rly [#1091](https://github.com/hdmf-dev/hdmf/pull/1091)

## HDMF 3.14.6 (December 20, 2024)

### Enhancements
Expand Down
3 changes: 1 addition & 2 deletions src/hdmf/common/io/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,11 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i
required=field_spec.required
)
dtype = cls._get_type(field_spec, type_map)
column_conf['class'] = dtype
if issubclass(dtype, DynamicTableRegion):
# the spec does not know which table this DTR points to
# the user must specify the table attribute on the DTR after it is generated
column_conf['table'] = True
else:
column_conf['class'] = dtype

index_counter = 0
index_name = attr_name
Expand Down
51 changes: 32 additions & 19 deletions src/hdmf/common/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def _init_class_columns(self):
description=col['description'],
index=col.get('index', False),
table=col.get('table', False),
col_cls=col.get('class', VectorData),
col_cls=col.get('class'),
# Pass through extra kwargs for add_column that subclasses may have added
**{k: col[k] for k in col.keys()
if k not in DynamicTable.__reserved_colspec_keys})
Expand Down Expand Up @@ -564,10 +564,13 @@ def _set_dtr_targets(self, target_tables: dict):
if not column_conf.get('table', False):
raise ValueError("Column '%s' must be a DynamicTableRegion to have a target table."
% colname)
self.add_column(name=column_conf['name'],
description=column_conf['description'],
index=column_conf.get('index', False),
table=True)
self.add_column(
name=column_conf['name'],
description=column_conf['description'],
index=column_conf.get('index', False),
table=True,
col_cls=column_conf.get('class'),
)
if isinstance(self[colname], VectorIndex):
col = self[colname].target
else:
Expand Down Expand Up @@ -681,7 +684,7 @@ def add_row(self, **kwargs):
index=col.get('index', False),
table=col.get('table', False),
enum=col.get('enum', False),
col_cls=col.get('class', VectorData),
col_cls=col.get('class'),
# Pass through extra keyword arguments for add_column that
# subclasses may have added
**{k: col[k] for k in col.keys()
Expand Down Expand Up @@ -753,7 +756,7 @@ def __eq__(self, other):
'default': False},
{'name': 'enum', 'type': (bool, 'array_data'), 'default': False,
'doc': ('whether or not this column contains data from a fixed set of elements')},
{'name': 'col_cls', 'type': type, 'default': VectorData,
{'name': 'col_cls', 'type': type, 'default': None,
'doc': ('class to use to represent the column data. If table=True, this field is ignored and a '
'DynamicTableRegion object is used. If enum=True, this field is ignored and a EnumData '
'object is used.')},
Expand Down Expand Up @@ -805,29 +808,39 @@ def add_column(self, **kwargs): # noqa: C901
% (name, self.__class__.__name__, spec_index))
warn(msg, stacklevel=3)

spec_col_cls = self.__uninit_cols[name].get('class', VectorData)
if col_cls != spec_col_cls:
msg = ("Column '%s' is predefined in %s with class=%s which does not match the entered "
"col_cls argument. The predefined class spec will be ignored. "
"Please ensure the new column complies with the spec. "
"This will raise an error in a future version of HDMF."
% (name, self.__class__.__name__, spec_col_cls))
warn(msg, stacklevel=2)

ckwargs = dict(kwargs)

# Add table if it's been specified
if table and enum:
raise ValueError("column '%s' cannot be both a table region "
"and come from an enumerable set of elements" % name)
# Update col_cls if table is specified
if table is not False:
col_cls = DynamicTableRegion
if col_cls is None:
col_cls = DynamicTableRegion
if isinstance(table, DynamicTable):
ckwargs['table'] = table
# Update col_cls if enum is specified
if enum is not False:
col_cls = EnumData
if col_cls is None:
col_cls = EnumData
if isinstance(enum, (list, tuple, np.ndarray, VectorData)):
ckwargs['elements'] = enum
# Update col_cls to the default VectorData if col_cls is None
if col_cls is None:
col_cls = VectorData

if name in self.__uninit_cols: # column is a predefined optional column from the spec
# check the given values against the predefined optional column spec. if they do not match, raise a warning
# and ignore the given arguments. users should not be able to override these values
spec_col_cls = self.__uninit_cols[name].get('class')
if spec_col_cls is not None and col_cls != spec_col_cls:
msg = ("Column '%s' is predefined in %s with class=%s which does not match the entered "
"col_cls argument. The predefined class spec will be ignored. "
"Please ensure the new column complies with the spec. "
"This will raise an error in a future version of HDMF."
% (name, self.__class__.__name__, spec_col_cls))
warn(msg, stacklevel=2)

# If the user provided a list of lists that needs to be indexed, then we now need to flatten the data
# We can only create the index actual VectorIndex once we have the VectorData column so we compute
Expand Down Expand Up @@ -873,7 +886,7 @@ def add_column(self, **kwargs): # noqa: C901
if col in self.__uninit_cols:
self.__uninit_cols.pop(col)

if col_cls is EnumData:
if issubclass(col_cls, EnumData):
columns.append(col.elements)
col.elements.parent = self

Expand Down
49 changes: 47 additions & 2 deletions tests/unit/common/test_generate_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
class TestDynamicDynamicTable(TestCase):

def setUp(self):

self.dtr_spec = DatasetSpec(
data_type_def='CustomDTR',
data_type_inc='DynamicTableRegion',
doc='a test DynamicTableRegion column', # this is overridden where it is used
)

self.dt_spec = GroupSpec(
'A test extension that contains a dynamic table',
data_type_def='TestTable',
Expand Down Expand Up @@ -99,14 +106,21 @@ def setUp(self):
doc='a test column',
dtype='float',
quantity='?',
)
),
DatasetSpec(
data_type_inc='CustomDTR',
name='optional_custom_dtr_col',
doc='a test DynamicTableRegion column',
quantity='?'
),
]
)

from hdmf.spec.write import YAMLSpecWriter
writer = YAMLSpecWriter(outdir='.')

self.spec_catalog = SpecCatalog()
self.spec_catalog.register_spec(self.dtr_spec, 'test.yaml')
self.spec_catalog.register_spec(self.dt_spec, 'test.yaml')
self.spec_catalog.register_spec(self.dt_spec2, 'test.yaml')
self.namespace = SpecNamespace(
Expand All @@ -124,7 +138,7 @@ def setUp(self):
self.test_dir = tempfile.mkdtemp()
spec_fpath = os.path.join(self.test_dir, 'test.yaml')
namespace_fpath = os.path.join(self.test_dir, 'test-namespace.yaml')
writer.write_spec(dict(groups=[self.dt_spec, self.dt_spec2]), spec_fpath)
writer.write_spec(dict(datasets=[self.dtr_spec], groups=[self.dt_spec, self.dt_spec2]), spec_fpath)
writer.write_namespace(self.namespace, namespace_fpath)
self.namespace_catalog = NamespaceCatalog()
hdmf_typemap = get_type_map()
Expand All @@ -133,6 +147,7 @@ def setUp(self):
self.type_map.load_namespaces(namespace_fpath)
self.manager = BuildManager(self.type_map)

self.CustomDTR = self.type_map.get_dt_container_cls('CustomDTR', CORE_NAMESPACE)
self.TestTable = self.type_map.get_dt_container_cls('TestTable', CORE_NAMESPACE)
self.TestDTRTable = self.type_map.get_dt_container_cls('TestDTRTable', CORE_NAMESPACE)

Expand Down Expand Up @@ -228,6 +243,22 @@ def test_dynamic_table_region_non_dtr_target(self):
self.TestDTRTable(name='test_dtr_table', description='my table',
target_tables={'optional_col3': test_table})

def test_custom_dtr_class(self):
test_table = self.TestTable(name='test_table', description='my test table')
test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=.5)
test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=.5)

test_dtr_table = self.TestDTRTable(name='test_dtr_table', description='my table',
target_tables={'optional_custom_dtr_col': test_table})

self.assertIsInstance(test_dtr_table['optional_custom_dtr_col'], self.CustomDTR)
self.assertEqual(test_dtr_table['optional_custom_dtr_col'].description, "a test DynamicTableRegion column")
self.assertIs(test_dtr_table['optional_custom_dtr_col'].table, test_table)

test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], optional_custom_dtr_col=0)
test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], optional_custom_dtr_col=1)
self.assertEqual(test_dtr_table['optional_custom_dtr_col'].data, [0, 1])

def test_attribute(self):
test_table = self.TestTable(name='test_table', description='my test table')
assert test_table.my_col is not None
Expand Down Expand Up @@ -266,3 +297,17 @@ def test_roundtrip(self):
for err in errors:
raise Exception(err)
self.reader.close()

def test_add_custom_dtr_column(self):
test_table = self.TestTable(name='test_table', description='my test table')
test_table.add_column(
name='custom_dtr_column',
description='this is a custom DynamicTableRegion column',
col_cls=self.CustomDTR,
)
self.assertIsInstance(test_table['custom_dtr_column'], self.CustomDTR)
self.assertEqual(test_table['custom_dtr_column'].description, 'this is a custom DynamicTableRegion column')

test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], custom_dtr_column=0)
test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], custom_dtr_column=1)
self.assertEqual(test_table['custom_dtr_column'].data, [0, 1])

0 comments on commit ab8cf3b

Please sign in to comment.