Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create_interface_for_attributes_920 #1

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# myst requires sphinx <4
sphinx<4
myst-parser>=0.14.0
sphinx-rtd-theme >= 0.5.0
sphinx==3.5.4
myst-parser==0.17.2
sphinx-rtd-theme == 1.1.1
Pygments >= 2.1.1
funcsigs~=1.0.2
mypy_extensions~=0.4.1
sphinxcontrib-spelling
sphinxcontrib-spelling==7.7.0
sphinx-comments
docutils==0.16

# Pipeline remote features
fastapi==0.65.2
Expand Down Expand Up @@ -44,7 +45,7 @@ transformers>=3.1
nltk==3.6.6

nbsphinx==0.8.8
jinja2<=3.0.3
jinja2==3.0.3
asyml_utilities
sphinx_autodoc_typehints

181 changes: 172 additions & 9 deletions forte/data/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,6 @@ def fetch_entry_type_data(
else:
attr_fields: Dict = self._get_entry_attributes_by_class(type_name)
for attr_name, attr_info in attr_fields.items():

attr_class = get_origin(attr_info.type)
# Since we store the class specified by get_origin,
# if the output it None, we store the class for it,
Expand Down Expand Up @@ -896,7 +895,7 @@ def _is_subclass(
if cls_qualified_name in type_name_parent_class:
return True
else:
entry_class = get_class(type_name)
entry_class = get_class(type_name, cached_lookup=False)
if issubclass(entry_class, cls):
type_name_parent_class.add(cls_qualified_name)
return True
Expand Down Expand Up @@ -1047,7 +1046,6 @@ def _add_entry_raw(
self._is_subclass(type_name, cls)
for cls in (list(SinglePackEntries) + list(MultiPackEntries))
):

try:
self.__elements[type_name].append(entry)
except KeyError:
Expand Down Expand Up @@ -1081,10 +1079,7 @@ def _is_annotation_tid(self, tid: int) -> bool:
elif tid in self.__tid_idx_dict:
return False
else:
raise KeyError(
f"Entry with tid {tid} not found."
f" Data store content is only {str(self.__dict__)}"
)
raise KeyError(f"Entry with tid {tid} not found.")

def _create_new_entry(
self, type_name: str, tid: Optional[int] = None
Expand Down Expand Up @@ -1246,7 +1241,6 @@ def add_entry_raw(
allow_duplicate: bool = True,
attribute_data: Optional[List] = None,
) -> int:

r"""
This function provides a general implementation to add all
types of entries to the data store. It can add namely
Expand Down Expand Up @@ -1487,6 +1481,173 @@ def get_attribute(self, tid: int, attr_name: str) -> Any:

return entry[attr_id]

def get_attributes_of_tid(self, tid: int, attr_names: List[str]) -> dict:
r"""This function returns the value of attributes listed in
``attr_names`` for the entry with ``tid``. It locates the entry data
with ``tid`` and finds attributes listed in ``attr_names`` and return
as a dict.

Args:
tid: Unique id of the entry.
attr_names: List of names of the attribute.

Returns:
A dict with keys listed in ``attr_names`` for attributes of the
entry with ``tid``.

Raises:
KeyError: when ``tid`` or ``attr_name`` is not found.
"""
entry, entry_type = self.get_entry(tid)
attrs: dict = {}
for attr_name in attr_names:
try:
attr_id = self._get_type_attribute_dict(entry_type)[attr_name][
constants.ATTR_INDEX_KEY
]
except KeyError as e:
raise KeyError(
f"{entry_type} has no {attr_name} attribute."
) from e
attrs[attr_name] = entry[attr_id]

return attrs

def get_attributes_of_tids(
self, list_of_tid: List[int], attr_names: List[str]
) -> List[Any]:
r"""This function returns the value of attributes listed in
``attr_names`` for entries in listed in the ``list_of_tid``.
It locates the entries data with ``tid`` and put attributes
listed in ``attr_name`` in a dict for each entry.

Args:
list_of_tid: List of unique ids of the entry.
attr_names: List of name of the attribute.

Returns:
A list of dict with ``attr_name`` as key for attributes
of the entries requested.

Raises:
KeyError: when ``tid`` or ``attr_name`` is not found.
"""
tids_attrs = []
for tid in list_of_tid:
entry, entry_type = self.get_entry(tid)
attrs: dict = {}
for attr_name in attr_names:
try:
attr_id = self._get_type_attribute_dict(entry_type)[
attr_name
][constants.ATTR_INDEX_KEY]
except KeyError as e:
raise KeyError(
f"{entry_type} has no {attr_name} attribute."
) from e
attrs[attr_name] = entry[attr_id]

tids_attrs.append(attrs)

return tids_attrs

def get_attributes_of_type(
self,
type_name: str,
attributes_names: List[str],
include_sub_type: bool = True,
range_span: Optional[Tuple[int, int]] = None,
) -> Iterator[dict]:
r"""This function fetches required attributes of entries from the
data store of type ``type_name``. If `include_sub_type` is set to
True and ``type_name`` is in [Annotation], this function also
fetches entries of subtype of ``type_name``. Otherwise, it only
fetches entries of type ``type_name``.

Args:
type_name: The fully qualified name of the entry.
attributes_names: list of attributes to be fetched for each entry
include_sub_type: A boolean to indicate whether get its subclass.
range_span: A tuple that contains the begin and end indices
of the searching range of entries.

Returns:
An iterator of the attributes of the entry in dict matching the
provided arguments.
"""

entry_class = get_class(type_name)
all_types = set()
if include_sub_type:
for type in self.__elements:
if issubclass(get_class(type), entry_class):
all_types.add(type)
else:
all_types.add(type_name)
all_types = list(all_types)
all_types.sort()

if self._is_annotation(type_name):
if range_span is None:
# yield from self.co_iterator_annotation_like(all_types)
for entry in self.co_iterator_annotation_like(all_types):
attrs: dict = {"tid": entry[0]}
for attr_name in attributes_names:
try:
attr_id = self._get_type_attribute_dict(type_name)[
attr_name
][constants.ATTR_INDEX_KEY]
except KeyError as e:
raise KeyError(
f"{type_name} has no {attr_name} attribute."
) from e
attrs[attr_name] = entry[attr_id]

yield attrs
else:
for entry in self.co_iterator_annotation_like(
all_types, range_span=range_span
):
attrs = {"tid": entry[0]}
for attr_name in attributes_names:
try:
attr_id = self._get_type_attribute_dict(type_name)[
attr_name
][constants.ATTR_INDEX_KEY]
except KeyError as e:
raise KeyError(
f"{type_name} has no {attr_name} attribute."
) from e
attrs[attr_name] = entry[attr_id]

yield attrs # attrs instead of entry
elif issubclass(entry_class, Link):
raise NotImplementedError(
f"{type_name} of Link is not currently supported."
)
elif issubclass(entry_class, Group):
raise NotImplementedError(
f"{type_name} of Group is not currently supported."
)
else:
if type_name not in self.__elements:
raise ValueError(f"type {type_name} does not exist")
# yield from self.iter(type_name)
for entry in self.iter(type_name):
attrs = {"tid": entry[0]}
for attr_name in attributes_names:
try:
attr_id = self._get_type_attribute_dict(type_name)[
attr_name
][constants.ATTR_INDEX_KEY]
except KeyError as e:
raise KeyError(
f"{type_name} has no {attr_name} attribute."
) from e
attrs[attr_name] = entry[attr_id]

yield attrs

def _get_attr(self, tid: int, attr_id: int) -> Any:
r"""This function locates the entry data with ``tid`` and gets the value
of ``attr_id`` of this entry. Called by `get_attribute()`.
Expand Down Expand Up @@ -1870,7 +2031,9 @@ def co_iterator_annotation_like(
self.get_datastore_attr_idx(tn, constants.BEGIN_ATTR_NAME),
self.get_datastore_attr_idx(tn, constants.END_ATTR_NAME),
)
except IndexError as e: # all_entries_range[tn][0] will be caught here.
except (
IndexError
) as e: # all_entries_range[tn][0] will be caught here.
raise ValueError(
f"Entry list of type name, {tn} which is"
" one list item of input argument `type_names`,"
Expand Down
2 changes: 1 addition & 1 deletion forte/data/ontology/ontology_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
except ImportError:
# Try backported to PY<39 `importlib_resources`.
import importlib_resources as resources # type: ignore
from importlib_resources.abc import Traversable # type: ignore
from importlib_resources.abc import Traversable

from forte.data.ontology import top, utils
from forte.data.ontology.code_generation_exceptions import (
Expand Down
38 changes: 30 additions & 8 deletions forte/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""
import sys
import difflib
from functools import wraps
from functools import wraps, lru_cache
from inspect import getfullargspec
from pydoc import locate
from typing import Dict, List, Optional, get_type_hints, Tuple
Expand Down Expand Up @@ -78,27 +78,49 @@ def get_class_name(o, lower: bool = False) -> str:
return o.__name__


def get_class(full_class_name: str, module_paths: Optional[List[str]] = None):
r"""Returns the class based on class name.
@lru_cache()
def cached_locate(name_to_locate_class):
r"""Wrapped version of locate for loading class based on
``name_to_locate_class``, cached by lru_cache.
Args:
name_to_locate_class (str): Name or full path to the class.

Returns:
The target class.

"""
return locate(name_to_locate_class)


def get_class(
full_class_name: str,
module_paths: Optional[List[str]] = None,
cached_lookup: Optional[bool] = True,
):
r"""Returns the class based on class name, with cached lookup option.

Args:
full_class_name (str): Name or full path to the class.
module_paths (list): Paths to candidate modules to search for the
class. This is used if the class cannot be located solely based on
``class_name``. The first module in the list that contains the class
is used.
cached_lookup (bool): Flag to use "cached_locate" for class loading

Returns:
The target class.

Raises:
ValueError: If class is not found based on :attr:`class_name` and
:attr:`module_paths`.
"""
class_ = locate(full_class_name)
if cached_lookup:
class_ = cached_locate(full_class_name)
else:
class_ = locate(full_class_name)
if (class_ is None) and (module_paths is not None):
for module_path in module_paths:
class_ = locate(".".join([module_path, full_class_name]))
if cached_lookup:
class_ = cached_locate(".".join([module_path, full_class_name]))
else:
class_ = locate(".".join([module_path, full_class_name]))
if class_ is not None:
break

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
'dataclasses~=0.7;python_version<"3.7"',
'importlib-resources>=5.1.4;python_version<"3.9"',
"asyml-utilities",
"protobuf==3.20.0"
],
extras_require={
"data_aug": [
Expand Down
Loading