From ab8cf3b011a6b468ce87aa1470183bee5f5f44f6 Mon Sep 17 00:00:00 2001
From: Ryan Ly <rly@lbl.gov>
Date: Wed, 22 Jan 2025 08:13:56 -0800
Subject: [PATCH] Don't override col_cls in DynamicTable.add_column (#1091)

Co-authored-by: Matthew Avaylon <mavaylon1@berkeley.edu>
---
 CHANGELOG.md                             |  3 ++
 src/hdmf/common/io/table.py              |  3 +-
 src/hdmf/common/table.py                 | 51 +++++++++++++++---------
 tests/unit/common/test_generate_table.py | 49 ++++++++++++++++++++++-
 4 files changed, 83 insertions(+), 23 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 8c72205cd..473be3100 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -20,6 +20,9 @@
 ### Added
 - Added script to check Python version support for HDMF dependencies. @rly [#1230](https://github.com/hdmf-dev/hdmf/pull/1230)
 
+### Fixed
+- Fixed issue with `DynamicTable.add_column` not allowing subclasses of `DynamicTableRegion` or `EnumData`. @rly [#1091](https://github.com/hdmf-dev/hdmf/pull/1091)
+
 ## HDMF 3.14.6 (December 20, 2024)
 
 ### Enhancements
diff --git a/src/hdmf/common/io/table.py b/src/hdmf/common/io/table.py
index 50395ba24..379553c07 100644
--- a/src/hdmf/common/io/table.py
+++ b/src/hdmf/common/io/table.py
@@ -78,12 +78,11 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i
             required=field_spec.required
         )
         dtype = cls._get_type(field_spec, type_map)
+        column_conf['class'] = dtype
         if issubclass(dtype, DynamicTableRegion):
             # the spec does not know which table this DTR points to
             # the user must specify the table attribute on the DTR after it is generated
             column_conf['table'] = True
-        else:
-            column_conf['class'] = dtype
 
         index_counter = 0
         index_name = attr_name
diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py
index 84ac4da3b..2f6401672 100644
--- a/src/hdmf/common/table.py
+++ b/src/hdmf/common/table.py
@@ -521,7 +521,7 @@ def _init_class_columns(self):
                                     description=col['description'],
                                     index=col.get('index', False),
                                     table=col.get('table', False),
-                                    col_cls=col.get('class', VectorData),
+                                    col_cls=col.get('class'),
                                     # Pass through extra kwargs for add_column that subclasses may have added
                                     **{k: col[k] for k in col.keys()
                                        if k not in DynamicTable.__reserved_colspec_keys})
@@ -564,10 +564,13 @@ def _set_dtr_targets(self, target_tables: dict):
                 if not column_conf.get('table', False):
                     raise ValueError("Column '%s' must be a DynamicTableRegion to have a target table."
                                         % colname)
-                self.add_column(name=column_conf['name'],
-                                description=column_conf['description'],
-                                index=column_conf.get('index', False),
-                                table=True)
+                self.add_column(
+                    name=column_conf['name'],
+                    description=column_conf['description'],
+                    index=column_conf.get('index', False),
+                    table=True,
+                    col_cls=column_conf.get('class'),
+                )
             if isinstance(self[colname], VectorIndex):
                 col = self[colname].target
             else:
@@ -681,7 +684,7 @@ def add_row(self, **kwargs):
                                         index=col.get('index', False),
                                         table=col.get('table', False),
                                         enum=col.get('enum', False),
-                                        col_cls=col.get('class', VectorData),
+                                        col_cls=col.get('class'),
                                         # Pass through extra keyword arguments for add_column that
                                         # subclasses may have added
                                         **{k: col[k] for k in col.keys()
@@ -753,7 +756,7 @@ def __eq__(self, other):
              'default': False},
             {'name': 'enum', 'type': (bool, 'array_data'), 'default': False,
              'doc': ('whether or not this column contains data from a fixed set of elements')},
-            {'name': 'col_cls', 'type': type, 'default': VectorData,
+            {'name': 'col_cls', 'type': type, 'default': None,
              'doc': ('class to use to represent the column data. If table=True, this field is ignored and a '
                      'DynamicTableRegion object is used. If enum=True, this field is ignored and a EnumData '
                      'object is used.')},
@@ -805,29 +808,39 @@ def add_column(self, **kwargs):  # noqa: C901
                        % (name, self.__class__.__name__, spec_index))
                 warn(msg, stacklevel=3)
 
-            spec_col_cls = self.__uninit_cols[name].get('class', VectorData)
-            if col_cls != spec_col_cls:
-                msg = ("Column '%s' is predefined in %s with class=%s which does not match the entered "
-                       "col_cls argument. The predefined class spec will be ignored. "
-                       "Please ensure the new column complies with the spec. "
-                       "This will raise an error in a future version of HDMF."
-                       % (name, self.__class__.__name__, spec_col_cls))
-                warn(msg, stacklevel=2)
-
         ckwargs = dict(kwargs)
 
         # Add table if it's been specified
         if table and enum:
             raise ValueError("column '%s' cannot be both a table region "
                              "and come from an enumerable set of elements" % name)
+        # Update col_cls if table is specified
         if table is not False:
-            col_cls = DynamicTableRegion
+            if col_cls is None:
+                 col_cls = DynamicTableRegion
             if isinstance(table, DynamicTable):
                 ckwargs['table'] = table
+        # Update col_cls if enum is specified
         if enum is not False:
-            col_cls = EnumData
+            if col_cls is None:
+                col_cls = EnumData
             if isinstance(enum, (list, tuple, np.ndarray, VectorData)):
                 ckwargs['elements'] = enum
+        # Update col_cls to the default VectorData if col_cls is None
+        if col_cls is None:
+            col_cls = VectorData
+
+        if name in self.__uninit_cols:  # column is a predefined optional column from the spec
+            # check the given values against the predefined optional column spec. if they do not match, raise a warning
+            # and ignore the given arguments. users should not be able to override these values
+            spec_col_cls = self.__uninit_cols[name].get('class')
+            if spec_col_cls is not None and col_cls != spec_col_cls:
+                msg = ("Column '%s' is predefined in %s with class=%s which does not match the entered "
+                       "col_cls argument. The predefined class spec will be ignored. "
+                       "Please ensure the new column complies with the spec. "
+                       "This will raise an error in a future version of HDMF."
+                       % (name, self.__class__.__name__, spec_col_cls))
+                warn(msg, stacklevel=2)
 
         # If the user provided a list of lists that needs to be indexed, then we now need to flatten the data
         # We can only create the index actual VectorIndex once we have the VectorData column so we compute
@@ -873,7 +886,7 @@ def add_column(self, **kwargs):  # noqa: C901
         if col in self.__uninit_cols:
             self.__uninit_cols.pop(col)
 
-        if col_cls is EnumData:
+        if issubclass(col_cls, EnumData):
             columns.append(col.elements)
             col.elements.parent = self
 
diff --git a/tests/unit/common/test_generate_table.py b/tests/unit/common/test_generate_table.py
index 7f7d7da40..71c15aad0 100644
--- a/tests/unit/common/test_generate_table.py
+++ b/tests/unit/common/test_generate_table.py
@@ -16,6 +16,13 @@
 class TestDynamicDynamicTable(TestCase):
 
     def setUp(self):
+
+        self.dtr_spec = DatasetSpec(
+            data_type_def='CustomDTR',
+            data_type_inc='DynamicTableRegion',
+            doc='a test DynamicTableRegion column',  # this is overridden where it is used
+        )
+
         self.dt_spec = GroupSpec(
             'A test extension that contains a dynamic table',
             data_type_def='TestTable',
@@ -99,7 +106,13 @@ def setUp(self):
                     doc='a test column',
                     dtype='float',
                     quantity='?',
-                )
+                ),
+                DatasetSpec(
+                    data_type_inc='CustomDTR',
+                    name='optional_custom_dtr_col',
+                    doc='a test DynamicTableRegion column',
+                    quantity='?'
+                ),
             ]
         )
 
@@ -107,6 +120,7 @@ def setUp(self):
         writer = YAMLSpecWriter(outdir='.')
 
         self.spec_catalog = SpecCatalog()
+        self.spec_catalog.register_spec(self.dtr_spec, 'test.yaml')
         self.spec_catalog.register_spec(self.dt_spec, 'test.yaml')
         self.spec_catalog.register_spec(self.dt_spec2, 'test.yaml')
         self.namespace = SpecNamespace(
@@ -124,7 +138,7 @@ def setUp(self):
         self.test_dir = tempfile.mkdtemp()
         spec_fpath = os.path.join(self.test_dir, 'test.yaml')
         namespace_fpath = os.path.join(self.test_dir, 'test-namespace.yaml')
-        writer.write_spec(dict(groups=[self.dt_spec, self.dt_spec2]), spec_fpath)
+        writer.write_spec(dict(datasets=[self.dtr_spec], groups=[self.dt_spec, self.dt_spec2]), spec_fpath)
         writer.write_namespace(self.namespace, namespace_fpath)
         self.namespace_catalog = NamespaceCatalog()
         hdmf_typemap = get_type_map()
@@ -133,6 +147,7 @@ def setUp(self):
         self.type_map.load_namespaces(namespace_fpath)
         self.manager = BuildManager(self.type_map)
 
+        self.CustomDTR = self.type_map.get_dt_container_cls('CustomDTR', CORE_NAMESPACE)
         self.TestTable = self.type_map.get_dt_container_cls('TestTable', CORE_NAMESPACE)
         self.TestDTRTable = self.type_map.get_dt_container_cls('TestDTRTable', CORE_NAMESPACE)
 
@@ -228,6 +243,22 @@ def test_dynamic_table_region_non_dtr_target(self):
             self.TestDTRTable(name='test_dtr_table', description='my table',
                               target_tables={'optional_col3': test_table})
 
+    def test_custom_dtr_class(self):
+        test_table = self.TestTable(name='test_table', description='my test table')
+        test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=.5)
+        test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=.5)
+
+        test_dtr_table = self.TestDTRTable(name='test_dtr_table', description='my table',
+                                           target_tables={'optional_custom_dtr_col': test_table})
+
+        self.assertIsInstance(test_dtr_table['optional_custom_dtr_col'], self.CustomDTR)
+        self.assertEqual(test_dtr_table['optional_custom_dtr_col'].description, "a test DynamicTableRegion column")
+        self.assertIs(test_dtr_table['optional_custom_dtr_col'].table, test_table)
+
+        test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], optional_custom_dtr_col=0)
+        test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], optional_custom_dtr_col=1)
+        self.assertEqual(test_dtr_table['optional_custom_dtr_col'].data, [0, 1])
+
     def test_attribute(self):
         test_table = self.TestTable(name='test_table', description='my test table')
         assert test_table.my_col is not None
@@ -266,3 +297,17 @@ def test_roundtrip(self):
             for err in errors:
                 raise Exception(err)
         self.reader.close()
+
+    def test_add_custom_dtr_column(self):
+        test_table = self.TestTable(name='test_table', description='my test table')
+        test_table.add_column(
+            name='custom_dtr_column',
+            description='this is a custom DynamicTableRegion column',
+            col_cls=self.CustomDTR,
+        )
+        self.assertIsInstance(test_table['custom_dtr_column'], self.CustomDTR)
+        self.assertEqual(test_table['custom_dtr_column'].description, 'this is a custom DynamicTableRegion column')
+
+        test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], custom_dtr_column=0)
+        test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], custom_dtr_column=1)
+        self.assertEqual(test_table['custom_dtr_column'].data, [0, 1])