diff --git a/tests/unit/spec_tests/test_group_spec.py b/tests/unit/spec_tests/test_group_spec.py index 31c00cfbb..4cf6ad71d 100644 --- a/tests/unit/spec_tests/test_group_spec.py +++ b/tests/unit/spec_tests/test_group_spec.py @@ -401,6 +401,91 @@ def test_is_inherited_attribute(self): with self.assertRaisesWith(ValueError, "Attribute 'attribute4' not found"): self.inc_group_spec.is_inherited_attribute('attribute4') + def test_is_overridden_spec_nested(self): + """Test that is_overridden_spec correctly identifies overridden specs in nested structures.""" + # Create base spec with a dataset containing an attribute + base_dataset = DatasetSpec('Base dataset', + 'int', + name='test_dataset', + attributes=[AttributeSpec('attr1', 'Base attr', 'text')]) + base_group = GroupSpec('Base group', + name='test_group', + attributes=[AttributeSpec('attr1', 'Base attr', 'text')]) + base_spec = GroupSpec('A base group', + data_type_def='BaseType', + datasets=[base_dataset], + groups=[base_group]) + + # Create extending spec that overrides both dataset and group with new attribute values + override_dataset = DatasetSpec('Override dataset', + 'int', + name='test_dataset', + attributes=[AttributeSpec('attr1', 'Override attr', 'text')]) + override_group = GroupSpec('Override group', + name='test_group', + attributes=[AttributeSpec('attr1', 'Override attr', 'text')]) + ext_spec = GroupSpec('An extending group', + data_type_inc='BaseType', + data_type_def='ExtType', + datasets=[override_dataset], + groups=[override_group]) + + # Resolve the extension + ext_spec.resolve_spec(base_spec) + + # Test attribute in overridden dataset is marked as overridden + dataset_attr = ext_spec.get_dataset('test_dataset').get_attribute('attr1') + self.assertTrue(ext_spec.is_overridden_spec(dataset_attr)) + + # Test attribute in overridden group is marked as overridden + group_attr = ext_spec.get_group('test_group').get_attribute('attr1') + self.assertTrue(ext_spec.is_overridden_spec(group_attr)) + + # Test attributes in base spec are not marked as overridden + base_dataset_attr = base_spec.get_dataset('test_dataset').get_attribute('attr1') + base_group_attr = base_spec.get_group('test_group').get_attribute('attr1') + self.assertFalse(base_spec.is_overridden_spec(base_dataset_attr)) + self.assertFalse(base_spec.is_overridden_spec(base_group_attr)) + + def test_is_overridden_group(self): + """Test that is_overridden_group correctly identifies overridden groups.""" + # Create base spec with a group + base_group = GroupSpec('Base group', + name='test_group', + attributes=[]) + base_spec = GroupSpec('A base group', + data_type_def='BaseType', + groups=[base_group]) + + # Create extending spec that overrides the group + override_group = GroupSpec('Override group', + name='test_group', + attributes=[]) + ext_spec = GroupSpec('An extending group', + data_type_inc='BaseType', + data_type_def='ExtType', + groups=[override_group]) + + # Resolve the extension + ext_spec.resolve_spec(base_spec) + + # Test base spec has no overridden groups + self.assertFalse(base_spec.is_overridden_group('test_group')) + + # Test extending spec correctly identifies overridden group + self.assertTrue(ext_spec.is_overridden_group('test_group')) + + # Test non-existent group raises error + with self.assertRaisesWith(ValueError, "Group 'nonexistent_group' not found in spec"): + ext_spec.is_overridden_group('nonexistent_group') + + # Test new group in extending spec is not overridden + new_group = GroupSpec('New group', + name='new_group', + attributes=[]) + ext_spec.set_group(new_group) + self.assertFalse(ext_spec.is_overridden_group('new_group')) + def test_is_overridden_attribute(self): self.assertFalse(self.def_group_spec.is_overridden_attribute('attribute1')) self.assertFalse(self.def_group_spec.is_overridden_attribute('attribute2')) @@ -410,6 +495,84 @@ def test_is_overridden_attribute(self): with self.assertRaisesWith(ValueError, "Attribute 'attribute4' not found"): self.inc_group_spec.is_overridden_attribute('attribute4') + def test_resolve_group_inheritance(self): + """Test resolution of inherited groups in GroupSpec.resolve_spec.""" + # Create base group with named and unnamed groups + unnamed_group = GroupSpec('An unnamed group', + data_type_def='UnnamedType', + attributes=[]) + named_group = GroupSpec('A named group', + name='named_group', + attributes=[]) + base_groups = [unnamed_group, named_group] + + base_spec = GroupSpec('A test group', + data_type_def='BaseType', + groups=base_groups) + + # Create extending group that overrides the named group and adds a new one + override_group = GroupSpec('Override named group', + name='named_group', + attributes=[]) + new_group = GroupSpec('A new group', + name='new_group', + attributes=[]) + ext_groups = [override_group, new_group] + + ext_spec = GroupSpec('An extending group', + data_type_inc='BaseType', + data_type_def='ExtType', + groups=ext_groups) + + # Resolve the extension + ext_spec.resolve_spec(base_spec) + + # Test unnamed group is added to data_types + self.assertEqual(ext_spec.get_data_type('UnnamedType'), unnamed_group) + + # Test named group is overridden + resolved_group = ext_spec.get_group('named_group') + self.assertEqual(resolved_group.doc, 'Override named group') + self.assertTrue(ext_spec.is_overridden_spec(resolved_group)) + + # Test new group is added + new_resolved = ext_spec.get_group('new_group') + self.assertEqual(new_resolved.doc, 'A new group') + self.assertFalse(ext_spec.is_overridden_spec(new_resolved)) + + def test_resolve_group_inheritance_multiple(self): + """Test resolution of multiple levels of group inheritance.""" + # Base spec with a named group + base_group = GroupSpec('Base group', + name='test_group', + attributes=[]) + base_spec = GroupSpec('A base group', + data_type_def='BaseType', + groups=[base_group]) + + # First extension overrides the group + mid_group = GroupSpec('Mid group', + name='test_group', + attributes=[]) + mid_spec = GroupSpec('A middle group', + data_type_inc='BaseType', + data_type_def='MidType', + groups=[mid_group]) + + # Second extension inherits without override + ext_spec = GroupSpec('An extending group', + data_type_inc='MidType', + data_type_def='ExtType') + + # Resolve the extensions + mid_spec.resolve_spec(base_spec) + ext_spec.resolve_spec(mid_spec) + + # Test group inheritance through multiple levels + resolved_group = ext_spec.get_group('test_group') + self.assertEqual(resolved_group.doc, 'Mid group') + self.assertTrue(ext_spec.is_inherited_spec(resolved_group)) + class TestResolveGroupSameAttributeName(TestCase): # https://github.com/hdmf-dev/hdmf/issues/1121 diff --git a/tests/unit/spec_tests/test_spec_write.py b/tests/unit/spec_tests/test_spec_write.py index a9410df2a..8582e80e4 100644 --- a/tests/unit/spec_tests/test_spec_write.py +++ b/tests/unit/spec_tests/test_spec_write.py @@ -95,16 +95,100 @@ def setUp(self): title='My lab extensions') self.ns_builder.export(self.namespace_path) + # Additional paths for export tests + self.output_path = "test_export.namespace.yaml" + self.source_path = "test_source.yaml" + + # Create a test spec for reuse + self.test_spec = GroupSpec('A test group', + data_type_def='TestGroup', + datasets=[], + attributes=[]) + def tearDown(self): if os.path.exists(self.ext_source_path): os.remove(self.ext_source_path) if os.path.exists(self.namespace_path): os.remove(self.namespace_path) + # Additional cleanup for export tests + if os.path.exists(self.output_path): + os.remove(self.output_path) + if os.path.exists(self.source_path): + os.remove(self.source_path) + def test_export_namespace(self): + """Test basic namespace export functionality.""" self._test_namespace_file() self._test_extensions_file() + def test_export_with_included_types(self): + """Test export with included types from source.""" + self.ns_builder.include_type('TestType1', source=self.source_path) + self.ns_builder.include_type('TestType2', source=self.source_path) + + self.ns_builder.export(self.output_path) + + # Verify the exported namespace + with open(self.output_path, 'r') as f: + content = f.read() + # Check that both types are included + self.assertIn('TestType1', content) + self.assertIn('TestType2', content) + # Check they're included from the correct source + self.assertIn(self.source_path, content) + + def test_export_with_included_namespaces(self): + """Test export with included namespaces.""" + namespace = "test_namespace" + self.ns_builder.include_namespace(namespace) + self.ns_builder.include_type('TestType1', namespace=namespace) + + self.ns_builder.export(self.output_path) + + # Verify the exported namespace + with open(self.output_path, 'r') as f: + content = f.read() + self.assertIn(namespace, content) + self.assertIn('TestType1', content) + + def test_export_source_with_specs(self): + """Test export with source containing specs.""" + self.ns_builder.add_spec(self.source_path, self.test_spec) + self.ns_builder.export(self.output_path) + + # Verify the spec was written to source file + self.assertTrue(os.path.exists(self.source_path)) + with open(self.source_path, 'r') as f: + content = f.read() + self.assertIn('TestGroup', content) + self.assertIn('A test group', content) + + def test_export_source_conflict_error(self): + """Test error when trying to both include from and write to same source.""" + # Add both an included type and a spec to the same source + self.ns_builder.include_type('TestType', source=self.source_path) + self.ns_builder.add_spec(self.source_path, self.test_spec) + + # Verify export raises error + with self.assertRaises(ValueError): + self.ns_builder.export(self.output_path) + + def test_export_source_with_doc_title(self): + """Test export with source containing doc and title.""" + self.ns_builder.add_source(self.source_path, + doc='Test documentation', + title='Test Title') + self.ns_builder.add_spec(self.source_path, self.test_spec) + + self.ns_builder.export(self.output_path) + + # Verify doc and title in namespace file + with open(self.output_path, 'r') as f: + content = f.read() + self.assertIn('doc: Test documentation', content) + self.assertIn('title: Test documentation', content) + def test_read_namespace(self): ns_catalog = NamespaceCatalog() ns_catalog.load_namespaces(self.namespace_path, resolve=True) @@ -147,6 +231,37 @@ def test_missing_version(self): namespace_cls=SpecNamespace, date=self.date) + def test_include_type(self): + """Test including types from source files and namespaces.""" + # Test including type from source + source_path = "test_source.yaml" + self.ns_builder.include_type('TestType', source=source_path) + self.assertIn(source_path, self.ns_builder._NamespaceBuilder__sources) + self.assertIn('TestType', self.ns_builder._NamespaceBuilder__sources[source_path].get('data_types', [])) + + # Test including type from namespace + namespace = "test_namespace" + self.ns_builder.include_type('TestType2', namespace=namespace) + self.assertIn(namespace, self.ns_builder._NamespaceBuilder__namespaces) + self.assertIn('TestType2', self.ns_builder._NamespaceBuilder__namespaces[namespace].get('data_types', [])) + + # Test error when neither source nor namespace is provided + msg = "must specify 'source' or 'namespace' when including type" + with self.assertRaisesWith(ValueError, msg): + self.ns_builder.include_type('TestType3') + + # Test including multiple types from same source + self.ns_builder.include_type('TestType4', source=source_path) + types_in_source = self.ns_builder._NamespaceBuilder__sources[source_path].get('data_types', []) + self.assertIn('TestType', types_in_source) + self.assertIn('TestType4', types_in_source) + + # Test including multiple types from same namespace + self.ns_builder.include_type('TestType5', namespace=namespace) + types_in_namespace = self.ns_builder._NamespaceBuilder__namespaces[namespace].get('data_types', []) + self.assertIn('TestType2', types_in_namespace) + self.assertIn('TestType5', types_in_namespace) + class TestYAMLSpecWrite(TestSpec): @@ -158,11 +273,33 @@ def setUp(self): doc='Extensions for my lab', title='My lab extensions') + # Create a temporary YAML file for reorder_yaml testing + self.temp_yaml = 'temp_test.yaml' + with open(self.temp_yaml, 'w') as f: + f.write(""" +doc: test doc +name: test name +dtype: int +attributes: +- name: attr1 + doc: attr1 doc + dtype: float +groups: +- name: group1 + doc: group1 doc + datasets: + - name: dataset1 + doc: dataset1 doc + dtype: int +""") + def tearDown(self): if os.path.exists(self.ext_source_path): os.remove(self.ext_source_path) if os.path.exists(self.namespace_path): os.remove(self.namespace_path) + if os.path.exists(self.temp_yaml): + os.remove(self.temp_yaml) def test_init(self): temp = YAMLSpecWriter('.') @@ -177,6 +314,96 @@ def test_write_namespace(self): def test_get_name(self): self.assertEqual(self.ns_name, self.ns_builder.name) + def test_reorder_yaml(self): + """Test that reorder_yaml correctly loads, reorders, and saves a YAML file.""" + writer = YAMLSpecWriter() + + # Reorder the YAML file + writer.reorder_yaml(self.temp_yaml) + + # Read the reordered content + with open(self.temp_yaml, 'r') as f: + content = f.read() + + # Verify the order of keys in the reordered content + # The name should come before dtype and doc + name_pos = content.find('name: test name') + dtype_pos = content.find('dtype: int') + doc_pos = content.find('doc: test doc') + self.assertLess(name_pos, dtype_pos) + self.assertLess(dtype_pos, doc_pos) + + # Verify nested structures are also reordered + attr_block = content[content.find('- name: attr1'):content.find('groups:')] + self.assertLess(attr_block.find('name: attr1'), attr_block.find('dtype: float')) + self.assertLess(attr_block.find('dtype: float'), attr_block.find('doc: attr1 doc')) + + def test_sort_keys(self): + """Test that sort_keys correctly orders dictionary keys according to the predefined order.""" + writer = YAMLSpecWriter() + + # Test basic ordering with predefined keys + input_dict = { + 'doc': 'documentation', + 'dtype': 'int', + 'name': 'test_name', + 'attributes': [1], + 'datasets': [2], + 'groups': [3] + } + result = writer.sort_keys(input_dict) + + # Check that the keys are in the correct order + expected_order = ['name', 'dtype', 'doc', 'attributes', 'datasets', 'groups'] + self.assertEqual(list(result.keys()), expected_order) + + # Test neurodata_type_def positioning + input_dict = { + 'doc': 'documentation', + 'name': 'test_name', + 'neurodata_type_def': 'MyType', + 'attributes': [1] + } + result = writer.sort_keys(input_dict) + self.assertEqual(list(result.keys())[0], 'neurodata_type_def') + + # Test nested dictionary ordering + input_dict = { + 'doc': 'documentation', + 'nested': { + 'groups': [1], + 'name': 'nested_name', + 'dtype': 'int', + 'attributes': [2] + } + } + result = writer.sort_keys(input_dict) + self.assertEqual(list(result['nested'].keys()), ['name', 'dtype', 'attributes', 'groups']) + + # Test list handling + input_dict = { + 'attributes': [ + {'doc': 'attr1', 'name': 'attr1_name', 'dtype': 'int'}, + {'doc': 'attr2', 'name': 'attr2_name', 'dtype': 'float'} + ] + } + result = writer.sort_keys(input_dict) + for attr in result['attributes']: + self.assertEqual(list(attr.keys()), ['name', 'dtype', 'doc']) + + # Test tuple handling + input_tuple = ( + {'doc': 'item1', 'name': 'name1', 'dtype': 'int'}, + {'doc': 'item2', 'name': 'name2', 'dtype': 'float'} + ) + result = writer.sort_keys(input_tuple) + # Convert generator to list for testing + result_list = list(result) + for item in result_list: + self.assertEqual(list(item.keys()), ['name', 'dtype', 'doc']) + # Verify the original order is maintained + self.assertEqual(result_list[0]['name'], 'name1') + self.assertEqual(result_list[1]['name'], 'name2') class TestExportSpec(TestSpec):