diff --git a/docs/requirements.txt b/docs/requirements.txt index 83bf516f8..01865167c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,12 +1,13 @@ # myst requires sphinx <4 -sphinx<4 -myst-parser>=0.14.0 -sphinx-rtd-theme >= 0.5.0 +sphinx==3.5.4 +myst-parser==0.17.2 +sphinx-rtd-theme == 1.1.1 Pygments >= 2.1.1 funcsigs~=1.0.2 mypy_extensions~=0.4.1 -sphinxcontrib-spelling +sphinxcontrib-spelling==7.7.0 sphinx-comments +docutils==0.16 # Pipeline remote features fastapi==0.65.2 @@ -44,7 +45,7 @@ transformers>=3.1 nltk==3.6.6 nbsphinx==0.8.8 -jinja2<=3.0.3 +jinja2==3.0.3 asyml_utilities sphinx_autodoc_typehints diff --git a/forte/data/data_store.py b/forte/data/data_store.py index 2f0c660d8..f43fdef53 100644 --- a/forte/data/data_store.py +++ b/forte/data/data_store.py @@ -806,7 +806,6 @@ def fetch_entry_type_data( else: attr_fields: Dict = self._get_entry_attributes_by_class(type_name) for attr_name, attr_info in attr_fields.items(): - attr_class = get_origin(attr_info.type) # Since we store the class specified by get_origin, # if the output it None, we store the class for it, @@ -896,7 +895,7 @@ def _is_subclass( if cls_qualified_name in type_name_parent_class: return True else: - entry_class = get_class(type_name) + entry_class = get_class(type_name, cached_lookup=False) if issubclass(entry_class, cls): type_name_parent_class.add(cls_qualified_name) return True @@ -1047,7 +1046,6 @@ def _add_entry_raw( self._is_subclass(type_name, cls) for cls in (list(SinglePackEntries) + list(MultiPackEntries)) ): - try: self.__elements[type_name].append(entry) except KeyError: @@ -1081,10 +1079,7 @@ def _is_annotation_tid(self, tid: int) -> bool: elif tid in self.__tid_idx_dict: return False else: - raise KeyError( - f"Entry with tid {tid} not found." - f" Data store content is only {str(self.__dict__)}" - ) + raise KeyError(f"Entry with tid {tid} not found.") def _create_new_entry( self, type_name: str, tid: Optional[int] = None @@ -1246,7 +1241,6 @@ def add_entry_raw( allow_duplicate: bool = True, attribute_data: Optional[List] = None, ) -> int: - r""" This function provides a general implementation to add all types of entries to the data store. It can add namely @@ -1487,6 +1481,173 @@ def get_attribute(self, tid: int, attr_name: str) -> Any: return entry[attr_id] + def get_attributes_of_tid(self, tid: int, attr_names: List[str]) -> dict: + r"""This function returns the value of attributes listed in + ``attr_names`` for the entry with ``tid``. It locates the entry data + with ``tid`` and finds attributes listed in ``attr_names`` and return + as a dict. + + Args: + tid: Unique id of the entry. + attr_names: List of names of the attribute. + + Returns: + A dict with keys listed in ``attr_names`` for attributes of the + entry with ``tid``. + + Raises: + KeyError: when ``tid`` or ``attr_name`` is not found. + """ + entry, entry_type = self.get_entry(tid) + attrs: dict = {} + for attr_name in attr_names: + try: + attr_id = self._get_type_attribute_dict(entry_type)[attr_name][ + constants.ATTR_INDEX_KEY + ] + except KeyError as e: + raise KeyError( + f"{entry_type} has no {attr_name} attribute." + ) from e + attrs[attr_name] = entry[attr_id] + + return attrs + + def get_attributes_of_tids( + self, list_of_tid: List[int], attr_names: List[str] + ) -> List[Any]: + r"""This function returns the value of attributes listed in + ``attr_names`` for entries in listed in the ``list_of_tid``. + It locates the entries data with ``tid`` and put attributes + listed in ``attr_name`` in a dict for each entry. + + Args: + list_of_tid: List of unique ids of the entry. + attr_names: List of name of the attribute. + + Returns: + A list of dict with ``attr_name`` as key for attributes + of the entries requested. + + Raises: + KeyError: when ``tid`` or ``attr_name`` is not found. + """ + tids_attrs = [] + for tid in list_of_tid: + entry, entry_type = self.get_entry(tid) + attrs: dict = {} + for attr_name in attr_names: + try: + attr_id = self._get_type_attribute_dict(entry_type)[ + attr_name + ][constants.ATTR_INDEX_KEY] + except KeyError as e: + raise KeyError( + f"{entry_type} has no {attr_name} attribute." + ) from e + attrs[attr_name] = entry[attr_id] + + tids_attrs.append(attrs) + + return tids_attrs + + def get_attributes_of_type( + self, + type_name: str, + attributes_names: List[str], + include_sub_type: bool = True, + range_span: Optional[Tuple[int, int]] = None, + ) -> Iterator[dict]: + r"""This function fetches required attributes of entries from the + data store of type ``type_name``. If `include_sub_type` is set to + True and ``type_name`` is in [Annotation], this function also + fetches entries of subtype of ``type_name``. Otherwise, it only + fetches entries of type ``type_name``. + + Args: + type_name: The fully qualified name of the entry. + attributes_names: list of attributes to be fetched for each entry + include_sub_type: A boolean to indicate whether get its subclass. + range_span: A tuple that contains the begin and end indices + of the searching range of entries. + + Returns: + An iterator of the attributes of the entry in dict matching the + provided arguments. + """ + + entry_class = get_class(type_name) + all_types = set() + if include_sub_type: + for type in self.__elements: + if issubclass(get_class(type), entry_class): + all_types.add(type) + else: + all_types.add(type_name) + all_types = list(all_types) + all_types.sort() + + if self._is_annotation(type_name): + if range_span is None: + # yield from self.co_iterator_annotation_like(all_types) + for entry in self.co_iterator_annotation_like(all_types): + attrs: dict = {"tid": entry[0]} + for attr_name in attributes_names: + try: + attr_id = self._get_type_attribute_dict(type_name)[ + attr_name + ][constants.ATTR_INDEX_KEY] + except KeyError as e: + raise KeyError( + f"{type_name} has no {attr_name} attribute." + ) from e + attrs[attr_name] = entry[attr_id] + + yield attrs + else: + for entry in self.co_iterator_annotation_like( + all_types, range_span=range_span + ): + attrs = {"tid": entry[0]} + for attr_name in attributes_names: + try: + attr_id = self._get_type_attribute_dict(type_name)[ + attr_name + ][constants.ATTR_INDEX_KEY] + except KeyError as e: + raise KeyError( + f"{type_name} has no {attr_name} attribute." + ) from e + attrs[attr_name] = entry[attr_id] + + yield attrs # attrs instead of entry + elif issubclass(entry_class, Link): + raise NotImplementedError( + f"{type_name} of Link is not currently supported." + ) + elif issubclass(entry_class, Group): + raise NotImplementedError( + f"{type_name} of Group is not currently supported." + ) + else: + if type_name not in self.__elements: + raise ValueError(f"type {type_name} does not exist") + # yield from self.iter(type_name) + for entry in self.iter(type_name): + attrs = {"tid": entry[0]} + for attr_name in attributes_names: + try: + attr_id = self._get_type_attribute_dict(type_name)[ + attr_name + ][constants.ATTR_INDEX_KEY] + except KeyError as e: + raise KeyError( + f"{type_name} has no {attr_name} attribute." + ) from e + attrs[attr_name] = entry[attr_id] + + yield attrs + def _get_attr(self, tid: int, attr_id: int) -> Any: r"""This function locates the entry data with ``tid`` and gets the value of ``attr_id`` of this entry. Called by `get_attribute()`. @@ -1870,7 +2031,9 @@ def co_iterator_annotation_like( self.get_datastore_attr_idx(tn, constants.BEGIN_ATTR_NAME), self.get_datastore_attr_idx(tn, constants.END_ATTR_NAME), ) - except IndexError as e: # all_entries_range[tn][0] will be caught here. + except ( + IndexError + ) as e: # all_entries_range[tn][0] will be caught here. raise ValueError( f"Entry list of type name, {tn} which is" " one list item of input argument `type_names`," diff --git a/forte/data/ontology/ontology_code_generator.py b/forte/data/ontology/ontology_code_generator.py index c822f103f..28ecd6e0f 100644 --- a/forte/data/ontology/ontology_code_generator.py +++ b/forte/data/ontology/ontology_code_generator.py @@ -41,7 +41,7 @@ except ImportError: # Try backported to PY<39 `importlib_resources`. import importlib_resources as resources # type: ignore - from importlib_resources.abc import Traversable # type: ignore + from importlib_resources.abc import Traversable from forte.data.ontology import top, utils from forte.data.ontology.code_generation_exceptions import ( diff --git a/forte/utils/utils.py b/forte/utils/utils.py index 4035509e9..b17f9a4d3 100644 --- a/forte/utils/utils.py +++ b/forte/utils/utils.py @@ -16,7 +16,7 @@ """ import sys import difflib -from functools import wraps +from functools import wraps, lru_cache from inspect import getfullargspec from pydoc import locate from typing import Dict, List, Optional, get_type_hints, Tuple @@ -78,8 +78,26 @@ def get_class_name(o, lower: bool = False) -> str: return o.__name__ -def get_class(full_class_name: str, module_paths: Optional[List[str]] = None): - r"""Returns the class based on class name. +@lru_cache() +def cached_locate(name_to_locate_class): + r"""Wrapped version of locate for loading class based on + ``name_to_locate_class``, cached by lru_cache. + Args: + name_to_locate_class (str): Name or full path to the class. + + Returns: + The target class. + + """ + return locate(name_to_locate_class) + + +def get_class( + full_class_name: str, + module_paths: Optional[List[str]] = None, + cached_lookup: Optional[bool] = True, +): + r"""Returns the class based on class name, with cached lookup option. Args: full_class_name (str): Name or full path to the class. @@ -87,18 +105,22 @@ def get_class(full_class_name: str, module_paths: Optional[List[str]] = None): class. This is used if the class cannot be located solely based on ``class_name``. The first module in the list that contains the class is used. + cached_lookup (bool): Flag to use "cached_locate" for class loading Returns: The target class. - Raises: - ValueError: If class is not found based on :attr:`class_name` and - :attr:`module_paths`. """ - class_ = locate(full_class_name) + if cached_lookup: + class_ = cached_locate(full_class_name) + else: + class_ = locate(full_class_name) if (class_ is None) and (module_paths is not None): for module_path in module_paths: - class_ = locate(".".join([module_path, full_class_name])) + if cached_lookup: + class_ = cached_locate(".".join([module_path, full_class_name])) + else: + class_ = locate(".".join([module_path, full_class_name])) if class_ is not None: break diff --git a/setup.py b/setup.py index 3e9aa37a0..25e027a97 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ 'dataclasses~=0.7;python_version<"3.7"', 'importlib-resources>=5.1.4;python_version<"3.9"', "asyml-utilities", + "protobuf==3.20.0" ], extras_require={ "data_aug": [ diff --git a/tests/forte/data/data_store_test.py b/tests/forte/data/data_store_test.py index 33fc1fc93..b9e7833c9 100644 --- a/tests/forte/data/data_store_test.py +++ b/tests/forte/data/data_store_test.py @@ -699,7 +699,6 @@ def value_err_fn(): self.assertRaises(ValueError, value_err_fn) def test_add_annotation_raw(self): - # test add Document entry tid_doc: int = self.data_store.add_entry_raw( type_name="ft.onto.base_ontology.Document", @@ -1039,6 +1038,126 @@ def test_get_attribute(self): ): self.data_store.get_attribute(9999, "class") + def test_get_attributes_of_tid(self): + result_dict = self.data_store.get_attributes_of_tid( + 9999, ["begin", "end", "speaker"] + ) + result_dict2 = self.data_store.get_attributes_of_tid( + 3456, ["payload_idx", "classifications"] + ) + + self.assertEqual(result_dict["begin"], 6) + self.assertEqual(result_dict["end"], 9) + self.assertEqual(result_dict["speaker"], "teacher") + self.assertEqual(result_dict2["payload_idx"], 1) + self.assertEqual(result_dict2["classifications"], {}) + + # Entry with such tid does not exist + with self.assertRaisesRegex(KeyError, "Entry with tid 1111 not found."): + self.data_store.get_attributes_of_tid(1111, ["speaker"]) + + # Get attribute field that does not exist + with self.assertRaisesRegex( + KeyError, "ft.onto.base_ontology.Sentence has no class attribute." + ): + self.data_store.get_attributes_of_tid(9999, ["class"]) + + def test_get_attributes_of_tids(self): + tids_attrs: list[dict] + # tids_attrs2: list[dict] + tids_attrs = self.data_store.get_attributes_of_tids( + [9999, 3456], ["begin", "end", "payload_idx"] + ) + tids_attrs2 = self.data_store.get_attributes_of_tids( + [9999], ["begin", "speaker"] + ) + + self.assertEqual(tids_attrs2[0]["begin"], 6) + self.assertEqual(tids_attrs[0]["end"], 9) + self.assertEqual(tids_attrs[1]["payload_idx"], 1) + self.assertEqual(tids_attrs2[0]["speaker"], "teacher") + + # Entry with such tid does not exist + with self.assertRaisesRegex(KeyError, "Entry with tid 1111 not found."): + self.data_store.get_attributes_of_tids([1111], ["speaker"]) + + # Get attribute field that does not exist + with self.assertRaisesRegex( + KeyError, "ft.onto.base_ontology.Sentence has no class attribute." + ): + self.data_store.get_attributes_of_tids([9999], ["class"]) + + def test_get_attributes_of_type(self): + # get document entries + instances = list( + self.data_store.get_attributes_of_type( + "ft.onto.base_ontology.Document", + ["begin", "end", "payload_idx"], + ) + ) + # print(instances) + self.assertEqual(len(instances), 2) + # check tid + self.assertEqual(instances[0]["tid"], 1234) + self.assertEqual(instances[0]["end"], 5) + self.assertEqual(instances[1]["tid"], 3456) + self.assertEqual(instances[1]["begin"], 10) + + # For types other than annotation, group or link, not support include_subtype + instances = list( + self.data_store.get_attributes_of_type( + "forte.data.ontology.core.Entry", ["begin", "end"] + ) + ) + self.assertEqual(len(instances), 0) + + self.assertEqual( + self.data_store.get_length("forte.data.ontology.core.Entry"), 0 + ) + + # get annotations with subclasses and range annotation + instances = list( + self.data_store.get_attributes_of_type( + "forte.data.ontology.top.Annotation", + ["begin", "end"], + range_span=(1, 20), + ) + ) + self.assertEqual(len(instances), 2) + + # get groups with subclasses + # instances = list(self.data_store.get_attributes_of_type( + # "forte.data.ontology.top.Group", ["begin", "end"])) + # self.assertEqual(len(instances), 3) + + # # get groups with subclasses and range annotation + # instances = list( + # self.data_store.get( + # "forte.data.ontology.top.Group", range_span=(1, 20) + # ) + # ) + # self.assertEqual(len(instances), 0) + # + # # get links with subclasses + # instances = list(self.data_store.get("forte.data.ontology.top.Link")) + # self.assertEqual(len(instances), 1) + # + # # get links with subclasses and range annotation + # instances = list( + # self.data_store.get( + # "forte.data.ontology.top.Link", range_span=(0, 9) + # ) + # ) + # self.assertEqual(len(instances), 1) + # + # # get links with subclasses and range annotation + # instances = list( + # self.data_store.get( + # "forte.data.ontology.top.Link", range_span=(4, 11) + # ) + # ) + # self.assertEqual(len(instances), 0) + def test_set_attribute(self): # change attribute self.data_store.set_attribute(9999, "speaker", "student") @@ -1328,7 +1447,6 @@ def test_get_entry_attribute_by_class(self): ) def test_is_subclass(self): - import forte self.assertEqual( @@ -1396,7 +1514,6 @@ def test_is_subclass(self): ) def test_check_onto_file(self): - expected_type_attributes = { "ft.onto.test.Description": { "attributes": {