Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't override col_cls in DynamicTable.add_column #1091

Merged
merged 14 commits into from
Jan 22, 2025
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
### Added
- Added script to check Python version support for HDMF dependencies. @rly [#1230](https://github.com/hdmf-dev/hdmf/pull/1230)

### Fixed
- Fixed issue with `DynamicTable.add_column` not allowing subclasses of `DynamicTableRegion` or `EnumData`. @rly [#1091](https://github.com/hdmf-dev/hdmf/pull/1091)

## HDMF 3.14.6 (December 20, 2024)

### Enhancements
Expand Down
3 changes: 1 addition & 2 deletions src/hdmf/common/io/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,11 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i
required=field_spec.required
)
dtype = cls._get_type(field_spec, type_map)
column_conf['class'] = dtype
if issubclass(dtype, DynamicTableRegion):
# the spec does not know which table this DTR points to
# the user must specify the table attribute on the DTR after it is generated
column_conf['table'] = True
else:
column_conf['class'] = dtype

index_counter = 0
index_name = attr_name
Expand Down
51 changes: 32 additions & 19 deletions src/hdmf/common/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def _init_class_columns(self):
description=col['description'],
index=col.get('index', False),
table=col.get('table', False),
col_cls=col.get('class', VectorData),
col_cls=col.get('class'),
# Pass through extra kwargs for add_column that subclasses may have added
**{k: col[k] for k in col.keys()
if k not in DynamicTable.__reserved_colspec_keys})
Expand Down Expand Up @@ -564,10 +564,13 @@ def _set_dtr_targets(self, target_tables: dict):
if not column_conf.get('table', False):
raise ValueError("Column '%s' must be a DynamicTableRegion to have a target table."
% colname)
self.add_column(name=column_conf['name'],
description=column_conf['description'],
index=column_conf.get('index', False),
table=True)
self.add_column(
name=column_conf['name'],
description=column_conf['description'],
index=column_conf.get('index', False),
table=True,
col_cls=column_conf.get('class'),
)
if isinstance(self[colname], VectorIndex):
col = self[colname].target
else:
Expand Down Expand Up @@ -681,7 +684,7 @@ def add_row(self, **kwargs):
index=col.get('index', False),
table=col.get('table', False),
enum=col.get('enum', False),
col_cls=col.get('class', VectorData),
col_cls=col.get('class'),
# Pass through extra keyword arguments for add_column that
# subclasses may have added
**{k: col[k] for k in col.keys()
Expand Down Expand Up @@ -753,7 +756,7 @@ def __eq__(self, other):
'default': False},
{'name': 'enum', 'type': (bool, 'array_data'), 'default': False,
'doc': ('whether or not this column contains data from a fixed set of elements')},
{'name': 'col_cls', 'type': type, 'default': VectorData,
{'name': 'col_cls', 'type': type, 'default': None,
'doc': ('class to use to represent the column data. If table=True, this field is ignored and a '
'DynamicTableRegion object is used. If enum=True, this field is ignored and a EnumData '
'object is used.')},
Expand Down Expand Up @@ -805,29 +808,39 @@ def add_column(self, **kwargs): # noqa: C901
% (name, self.__class__.__name__, spec_index))
warn(msg, stacklevel=3)

spec_col_cls = self.__uninit_cols[name].get('class', VectorData)
if col_cls != spec_col_cls:
msg = ("Column '%s' is predefined in %s with class=%s which does not match the entered "
"col_cls argument. The predefined class spec will be ignored. "
"Please ensure the new column complies with the spec. "
"This will raise an error in a future version of HDMF."
% (name, self.__class__.__name__, spec_col_cls))
warn(msg, stacklevel=2)

ckwargs = dict(kwargs)

# Add table if it's been specified
if table and enum:
raise ValueError("column '%s' cannot be both a table region "
"and come from an enumerable set of elements" % name)
# Update col_cls if table is specified
if table is not False:
col_cls = DynamicTableRegion
if col_cls is None:
col_cls = DynamicTableRegion
if isinstance(table, DynamicTable):
ckwargs['table'] = table
# Update col_cls if enum is specified
if enum is not False:
col_cls = EnumData
if col_cls is None:
col_cls = EnumData
if isinstance(enum, (list, tuple, np.ndarray, VectorData)):
ckwargs['elements'] = enum
# Update col_cls to the default VectorData if col_cls is None
if col_cls is None:
col_cls = VectorData

if name in self.__uninit_cols: # column is a predefined optional column from the spec
# check the given values against the predefined optional column spec. if they do not match, raise a warning
# and ignore the given arguments. users should not be able to override these values
spec_col_cls = self.__uninit_cols[name].get('class')
if spec_col_cls is not None and col_cls != spec_col_cls:
msg = ("Column '%s' is predefined in %s with class=%s which does not match the entered "
"col_cls argument. The predefined class spec will be ignored. "
"Please ensure the new column complies with the spec. "
"This will raise an error in a future version of HDMF."
% (name, self.__class__.__name__, spec_col_cls))
warn(msg, stacklevel=2)

# If the user provided a list of lists that needs to be indexed, then we now need to flatten the data
# We can only create the index actual VectorIndex once we have the VectorData column so we compute
Expand Down Expand Up @@ -873,7 +886,7 @@ def add_column(self, **kwargs): # noqa: C901
if col in self.__uninit_cols:
self.__uninit_cols.pop(col)

if col_cls is EnumData:
if issubclass(col_cls, EnumData):
columns.append(col.elements)
col.elements.parent = self

Expand Down
49 changes: 47 additions & 2 deletions tests/unit/common/test_generate_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
class TestDynamicDynamicTable(TestCase):

def setUp(self):

self.dtr_spec = DatasetSpec(
data_type_def='CustomDTR',
data_type_inc='DynamicTableRegion',
doc='a test DynamicTableRegion column', # this is overridden where it is used
)

self.dt_spec = GroupSpec(
'A test extension that contains a dynamic table',
data_type_def='TestTable',
Expand Down Expand Up @@ -99,14 +106,21 @@ def setUp(self):
doc='a test column',
dtype='float',
quantity='?',
)
),
DatasetSpec(
data_type_inc='CustomDTR',
name='optional_custom_dtr_col',
doc='a test DynamicTableRegion column',
quantity='?'
),
]
)

from hdmf.spec.write import YAMLSpecWriter
writer = YAMLSpecWriter(outdir='.')

self.spec_catalog = SpecCatalog()
self.spec_catalog.register_spec(self.dtr_spec, 'test.yaml')
self.spec_catalog.register_spec(self.dt_spec, 'test.yaml')
self.spec_catalog.register_spec(self.dt_spec2, 'test.yaml')
self.namespace = SpecNamespace(
Expand All @@ -124,7 +138,7 @@ def setUp(self):
self.test_dir = tempfile.mkdtemp()
spec_fpath = os.path.join(self.test_dir, 'test.yaml')
namespace_fpath = os.path.join(self.test_dir, 'test-namespace.yaml')
writer.write_spec(dict(groups=[self.dt_spec, self.dt_spec2]), spec_fpath)
writer.write_spec(dict(datasets=[self.dtr_spec], groups=[self.dt_spec, self.dt_spec2]), spec_fpath)
writer.write_namespace(self.namespace, namespace_fpath)
self.namespace_catalog = NamespaceCatalog()
hdmf_typemap = get_type_map()
Expand All @@ -133,6 +147,7 @@ def setUp(self):
self.type_map.load_namespaces(namespace_fpath)
self.manager = BuildManager(self.type_map)

self.CustomDTR = self.type_map.get_dt_container_cls('CustomDTR', CORE_NAMESPACE)
self.TestTable = self.type_map.get_dt_container_cls('TestTable', CORE_NAMESPACE)
self.TestDTRTable = self.type_map.get_dt_container_cls('TestDTRTable', CORE_NAMESPACE)

Expand Down Expand Up @@ -228,6 +243,22 @@ def test_dynamic_table_region_non_dtr_target(self):
self.TestDTRTable(name='test_dtr_table', description='my table',
target_tables={'optional_col3': test_table})

def test_custom_dtr_class(self):
test_table = self.TestTable(name='test_table', description='my test table')
test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=.5)
test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=.5)

test_dtr_table = self.TestDTRTable(name='test_dtr_table', description='my table',
target_tables={'optional_custom_dtr_col': test_table})

self.assertIsInstance(test_dtr_table['optional_custom_dtr_col'], self.CustomDTR)
self.assertEqual(test_dtr_table['optional_custom_dtr_col'].description, "a test DynamicTableRegion column")
self.assertIs(test_dtr_table['optional_custom_dtr_col'].table, test_table)

test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], optional_custom_dtr_col=0)
test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], optional_custom_dtr_col=1)
self.assertEqual(test_dtr_table['optional_custom_dtr_col'].data, [0, 1])

def test_attribute(self):
test_table = self.TestTable(name='test_table', description='my test table')
assert test_table.my_col is not None
Expand Down Expand Up @@ -266,3 +297,17 @@ def test_roundtrip(self):
for err in errors:
raise Exception(err)
self.reader.close()

def test_add_custom_dtr_column(self):
test_table = self.TestTable(name='test_table', description='my test table')
test_table.add_column(
name='custom_dtr_column',
description='this is a custom DynamicTableRegion column',
col_cls=self.CustomDTR,
)
self.assertIsInstance(test_table['custom_dtr_column'], self.CustomDTR)
self.assertEqual(test_table['custom_dtr_column'].description, 'this is a custom DynamicTableRegion column')

test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], custom_dtr_column=0)
test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], custom_dtr_column=1)
self.assertEqual(test_table['custom_dtr_column'].data, [0, 1])
Loading