diff --git a/forte/data/ontology/core.py b/forte/data/ontology/core.py index ca5dc4663..58a2446b7 100644 --- a/forte/data/ontology/core.py +++ b/forte/data/ontology/core.py @@ -31,6 +31,7 @@ Union, Dict, Iterator, + cast, overload, List, Any, @@ -297,11 +298,13 @@ class FList(Generic[ParentEntryType], MutableSequence): def __init__( self, parent_entry: ParentEntryType, - data: Optional[List[int]] = None, + data: Optional[List[Union[int, Tuple[int, int]]]] = None, ): super().__init__() self.__parent_entry = parent_entry - self.__data: List[int] = [] if data is None else data + self.__data: List[Union[int, Tuple[int, int]]] = ( + [] if data is None else data + ) def __eq__(self, other): return self.__data == other._FList__data @@ -310,7 +313,15 @@ def _set_parent(self, parent_entry: ParentEntryType): self.__parent_entry = parent_entry def insert(self, index: int, entry: EntryType): - self.__data.insert(index, entry.tid) + # If the pack id of the entry is not equal to the pack id + # of the parent, it indicates that the entries being stored + # are MultiPack entries. Thus, we store the entries as a tuple + # of the entry's pack id and the entry's tid in contrast to + # regular entries which are just stored by their tid + if entry.pack.pack_id != self.__parent_entry.pack.pack_id: + self.__data.insert(index, (entry.pack.pack_id, entry.tid)) + else: + self.__data.insert(index, entry.tid) @overload @abstractmethod @@ -326,22 +337,66 @@ def __getitem__( self, index: Union[int, slice] ) -> Union[EntryType, MutableSequence]: if isinstance(index, slice): - return [ - self.__parent_entry.pack.get_entry(tid) - for tid in self.__data[index] - ] + if all(isinstance(val, int) for val in self.__data): + # If entry data is stored just be an integer, it indicates + # that this is a Single Pack entry (stored just by its tid) + return [ + self.__parent_entry.pack.get_entry(tid) + for tid in self.__data[index] + ] + else: + # else, it indicates that this is a Multi Pack + # entry (stored as a tuple) + return [ + self.__parent_entry.pack.get_subentry(*attr) + for attr in self.__data[index] + ] else: - return self.__parent_entry.pack.get_entry(self.__data[index]) + if all(isinstance(val, int) for val in self.__data): + # If entry data is stored just be an integer, it indicates + # that this is a Single Pack entry (stored just by its tid) + return self.__parent_entry.pack.get_entry(self.__data[index]) + else: + # else, it indicates that this is a Multi Pack + # entry (stored as a tuple) + return self.__parent_entry.pack.get_subentry( + *self.__data[index] + ) def __setitem__( self, index: Union[int, slice], value: Union[EntryType, Iterable[EntryType]], ) -> None: + if isinstance(index, int): - self.__data[index] = value.tid # type: ignore + value = cast(EntryType, value) + if value.pack.pack_id != self.__parent_entry.pack.pack_id: + # If the pack id of the entry is not equal to the pack id + # of the parent, it indicates that the entries being stored + # are MultiPack entries. + self.__data[index] = (value.pack.pack_id, value.tid) + else: + # If the pack id of the entry is equal to the pack id + # of the parent, it indicates that the entries being stored + # are Single Pack entries. + self.__data[index] = value.tid else: - self.__data[index] = [v.tid for v in value] # type: ignore + value = cast(Iterable[EntryType], value) + if all( + val.pack.pack_id != self.__parent_entry.pack.pack_id + for val in value + ): + # If the pack id of the entry is not equal to the pack id + # of the parent for all entries in the FList data, + # it indicates that the entries being stored + # are MultiPack entries. + self.__data[index] = [(v.pack.pack_id, v.tid) for v in value] + else: + # If the pack id of the entry is equal to the pack id + # of the parent for any FList data item, it indicates that + # the entries being stored are Single Pack entries. + self.__data[index] = [v.tid for v in value] def __delitem__(self, index: Union[int, slice]) -> None: del self.__data[index] diff --git a/forte/data/ontology/top.py b/forte/data/ontology/top.py index 2be964f4b..24f569aef 100644 --- a/forte/data/ontology/top.py +++ b/forte/data/ontology/top.py @@ -23,7 +23,6 @@ Union, Iterable, List, - cast, ) import numpy as np @@ -343,11 +342,7 @@ def get_members(self) -> List[Entry]: "attached to any data pack." ) - member_entries = [] - if self.members is not None: - for m in self.members: - member_entries.append(m) - return member_entries + return list(self.members) class MultiPackGeneric(MultiEntry, Entry): @@ -498,7 +493,7 @@ class MultiPackGroup(MultiEntry, BaseGroup[Entry]): of members. """ member_type: str - members: Optional[FList[Entry]] + members: FList[Entry] MemberType = Entry @@ -520,18 +515,11 @@ def add_member(self, member: Entry): f"The members of {type(self)} should be " f"instances of {self.MemberType}, but got {type(member)}" ) - if self.members is None: - self.members = cast(FList, [member]) - else: - self.members.append(member) + + self.members.append(member) def get_members(self) -> List[Entry]: - members = [] - if self.members is not None: - member_data = self.members - for m in member_data: - members.append(m) - return members + return list(self.members) @dataclass