diff --git a/forte/data/__init__.py b/forte/data/__init__.py index 01858ebca..12b458e39 100644 --- a/forte/data/__init__.py +++ b/forte/data/__init__.py @@ -20,3 +20,4 @@ from forte.data.data_store import * from forte.data.selector import * from forte.data.index import * +from forte.data.modality import * diff --git a/forte/data/data_pack.py b/forte/data/data_pack.py index d18fc2d52..fd548832b 100644 --- a/forte/data/data_pack.py +++ b/forte/data/data_pack.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import IntEnum import logging from pathlib import Path from typing import ( @@ -26,7 +27,9 @@ Set, Callable, Tuple, + cast, ) + import numpy as np from sortedcontainers import SortedList @@ -51,7 +54,10 @@ AudioAnnotation, ImageAnnotation, Grids, + Payload, ) + +from forte.data.modality import Modality from forte.data.span import Span from forte.data.types import ReplaceOperationsType, DataRequest from forte.utils import get_class, get_full_module_name @@ -161,19 +167,16 @@ class DataPack(BasePack[Entry, Link, Group]): def __init__(self, pack_name: Optional[str] = None): super().__init__(pack_name) - self._text = "" self._audio: Optional[np.ndarray] = None self._data_store: DataStore = DataStore() self._entry_converter: EntryConverter = EntryConverter() self.image_annotations: List[ImageAnnotation] = [] self.grids: List[Grids] = [] - self.payloads: List[np.ndarray] = [] - - self.__replace_back_operations: ReplaceOperationsType = [] - self.__processed_original_spans: List[Tuple[Span, Span]] = [] - self.__orig_text_len: int = 0 + self.text_payloads: List[Payload] = [] + self.audio_payloads: List[Payload] = [] + self.image_payloads: List[Payload] = [] self._index: DataIndex = DataIndex() @@ -196,18 +199,10 @@ def __setstate__(self, state): """ self._entry_converter = EntryConverter() super().__setstate__(state) - - # For backward compatibility. - if "replace_back_operations" in self.__dict__: - self.__replace_back_operations = self.__dict__.pop( - "replace_back_operations" - ) - if "processed_original_spans" in self.__dict__: - self.__processed_original_spans = self.__dict__.pop( - "processed_original_spans" - ) - if "orig_text_len" in self.__dict__: - self.__orig_text_len = self.__dict__.pop("orig_text_len") + for payload in ( + self.text_payloads + self.audio_payloads + self.image_payloads + ): + payload.set_pack(self) self._index = DataIndex() self._index.update_basic_index(list(iter(self))) @@ -227,18 +222,25 @@ def _validate(self, entry: EntryType) -> bool: @property def text(self) -> str: - r"""Return the text of the data pack""" - return self._text + """ + Get the first text data stored in the DataPack. + If there is no text payload in the DataPack, it will return empty + string. - @property - def audio(self) -> Optional[np.ndarray]: - r"""Return the audio of the data pack""" - return self._audio + Args: + text_payload_index: the index of the text payload. Defaults to 0. - @property - def sample_rate(self) -> Optional[int]: - r"""Return the sample rate of the audio data""" - return getattr(self._meta, "sample_rate") + Raises: + ValueError: raised when the index is out of bound of the text + payload list. + + Returns: + text data in the text payload. + """ + if len(self.text_payloads) > 0: + return str(self.get_payload_data_at(Modality.Text, 0)) + else: + return "" @property def all_annotations(self) -> Iterator[Annotation]: @@ -421,19 +423,99 @@ def groups(self): def groups(self, val): self._groups = val - def get_span_text(self, begin: int, end: int) -> str: + def get_payload_at( + self, modality: IntEnum, payload_index: int + ): # -> Union[TextPayload, AudioPayload, ImagePayload]: + """ + Get Payload of requested modality at the requested payload index. + + Args: + modality: data modality among "text", "audio", "image" + payload_index: the zero-based index of the Payload + in this DataPack's Payload entries of the requested modality. + + Raises: + ValueError: raised when the requested modality is not supported. + + Returns: + Payload entry containing text data, image or audio data. + + """ + supported_modality = [enum.name for enum in Modality] + + try: + # if modality.name == "text": + if modality == Modality.Text: + payloads_length = len(self.text_payloads) + payload = self.text_payloads[payload_index] + # elif modality.name == "audio": + elif modality == Modality.Audio: + payloads_length = len(self.audio_payloads) + payload = self.audio_payloads[payload_index] + # elif modality.name == "image": + elif modality == Modality.Image: + payloads_length = len(self.image_payloads) + payload = self.image_payloads[payload_index] + else: + raise ValueError( + f"Provided modality {modality.name} is not supported." + "Please provide one of modality among" + f" {supported_modality}." + ) + except IndexError as e: + raise ProcessExecutionException( + f"payload index ({payload_index}) " + f"is larger or equal to {modality.name} payload list" + f" length ({payloads_length}). " + f"Please input a {modality.name} payload index less than it." + ) from e + return payload + + def get_payload_data_at( + self, modality: IntEnum, payload_index: int + ) -> Union[str, np.ndarray]: + """ + Get Payload of requested modality at the requested payload index. + + Args: + modality: data modality among "text", "audio", "image" + payload_index: the zero-based index of the Payload + in this DataPack's Payload entries of the requested modality. + + Raises: + ValueError: raised when the requested modality is not supported. + + Returns: + different data types for different data modalities. + + 1. str data for text data. + + 2. Numpy array for image and audio data. + + """ + return self.get_payload_at(modality, payload_index).cache + + def get_span_text( + self, begin: int, end: int, text_payload_index: int = 0 + ) -> str: r"""Get the text in the data pack contained in the span. Args: begin: begin index to query. end: end index to query. + text_payload_index: the zero-based index of the TextPayload + in this DataPack's TextPayload entries. Defaults to 0. Returns: The text within this span. """ - return self._text[begin:end] + return cast( + str, self.get_payload_data_at(Modality.Text, text_payload_index) + )[begin:end] - def get_span_audio(self, begin: int, end: int) -> np.ndarray: + def get_span_audio( + self, begin: int, end: int, audio_payload_index=0 + ) -> np.ndarray: r"""Get the audio in the data pack contained in the span. `begin` and `end` represent the starting and ending indices of the span in audio payload respectively. Each index corresponds to one sample in @@ -442,73 +524,104 @@ def get_span_audio(self, begin: int, end: int) -> np.ndarray: Args: begin: begin index to query. end: end index to query. + audio_payload_index: the zero-based index of the AudioPayload + in this DataPack's AudioPayload entries. Defaults to 0. Returns: The audio within this span. """ - if self._audio is None: - raise ProcessExecutionException( - "The audio payload of this DataPack is not set. Please call" - " method `set_audio` before running `get_span_audio`." - ) - return self._audio[begin:end] - - def get_image_array(self, image_payload_idx: int): - if image_payload_idx >= len(self.payloads): - raise ValueError( - f"The input image payload index{(image_payload_idx)}" - f" out of range. It should be less than {len(self.payloads)}" - ) - return self.payloads[image_payload_idx] + return cast( + np.ndarray, + self.get_payload_data_at(Modality.Audio, audio_payload_index)[ + begin:end + ], + ) def set_text( self, text: str, replace_func: Optional[Callable[[str], ReplaceOperationsType]] = None, + text_payload_index: int = 0, ): + """ + Set text for TextPayload at a specified index. - if len(text) < len(self._text): - raise ProcessExecutionException( - "The new text is overwriting the original one with shorter " - "length, which might cause unexpected behavior." - ) - - if len(self._text): - logging.warning( - "Need to be cautious when changing the text of a " - "data pack, existing entries may get affected. " - ) + Args: + text: a str text. + replace_func: function that replace text. Defaults to None. + text_payload_index: the zero-based index of the TextPayload + in this DataPack's TextPayload entries. Defaults to 0. + """ + # Temporary imports span_ops = [] if replace_func is None else replace_func(text) - # The spans should be mutually exclusive ( - self._text, - self.__replace_back_operations, - self.__processed_original_spans, - self.__orig_text_len, + text, + replace_back_operations, + processed_original_spans, + orig_text_len, ) = data_utils_io.modify_text_and_track_ops(text, span_ops) + # temporary solution for backward compatibility + # past API use this method to add a single text in the datapack + if len(self.text_payloads) == 0 and text_payload_index == 0: + from ft.onto.base_ontology import ( # pylint: disable=import-outside-toplevel + TextPayload, + ) - def set_audio(self, audio: np.ndarray, sample_rate: int): + tp = TextPayload(self, text_payload_index) + else: + tp = self.get_payload_at(Modality.Text, text_payload_index) + + tp.set_cache(text) + + tp.replace_back_operations = replace_back_operations + tp.processed_original_spans = processed_original_spans + tp.orig_text_len = orig_text_len + + def set_audio( + self, + audio: np.ndarray, + sample_rate: int, + audio_payload_index: int = 0, + ): r"""Set the audio payload and sample rate of the :class:`~forte.data.data_pack.DataPack` object. Args: audio: A numpy array storing the audio waveform. sample_rate: An integer specifying the sample rate. + audio_payload_index: the zero-based index of the AudioPayload + in this DataPack's AudioPayload entries. Defaults to 0. """ - self._audio = audio - self.set_meta(sample_rate=sample_rate) + # temporary solution for backward compatibility + # past API use this method to add a single audio in the datapack + if len(self.audio_payloads) == 0 and audio_payload_index == 0: + from ft.onto.base_ontology import ( # pylint: disable=import-outside-toplevel + AudioPayload, + ) + + ap = AudioPayload(self) + else: + ap = self.get_payload_at(Modality.Audio, audio_payload_index) + + ap.set_cache(audio) + ap.sample_rate = sample_rate - def get_original_text(self): + def get_original_text(self, text_payload_index: int = 0): r"""Get original unmodified text from the :class:`~forte.data.data_pack.DataPack` object. + Args: + text_payload_index: the zero-based index of the TextPayload + in this DataPack's entries. Defaults to 0. + Returns: Original text after applying the `replace_back_operations` of :class:`~forte.data.data_pack.DataPack` object to the modified text """ + tp = self.get_payload_at(Modality.Text, text_payload_index) original_text, _, _, _ = data_utils_io.modify_text_and_track_ops( - self._text, self.__replace_back_operations + tp.cache, tp.replace_back_operations ) return original_text @@ -588,16 +701,19 @@ def get_original_index( Returns: Original index that aligns with input_index """ - if len(self.__processed_original_spans) == 0: + processed_original_spans = self.get_payload_at( + Modality.Text, 0 + ).processed_original_spans + if len(processed_original_spans) == 0: return input_index - len_processed_text = len(self._text) + len_processed_text = len(self.text) orig_index = None prev_end = 0 for ( inverse_span, original_span, - ) in self.__processed_original_spans: + ) in processed_original_spans: # check if the input_index lies between one of the unprocessed # spans if prev_end <= input_index < inverse_span.begin: @@ -626,9 +742,7 @@ def get_original_index( if orig_index is None: # check if the input_index lies between the last unprocessed # span - inverse_span, original_span = self.__processed_original_spans[ - -1 - ] + inverse_span, original_span = processed_original_spans[-1] if inverse_span.end <= input_index < len_processed_text: increment = original_span.end - inverse_span.end orig_index = input_index + increment @@ -770,6 +884,7 @@ def get_data( context_type: Union[str, Type[Annotation], Type[AudioAnnotation]], request: Optional[DataRequest] = None, skip_k: int = 0, + payload_index: int = 0, ) -> Iterator[Dict[str, Any]]: r"""Fetch data from entries in the data_pack of type `context_type`. Data includes `"span"`, annotation-specific @@ -844,6 +959,9 @@ def get_data( returned by default. skip_k: Will skip the first `skip_k` instances and generate data from the (`offset` + 1)th instance. + payload_index: the zero-based index of the Payload + in this DataPack's Payload entries of a particular modality. + The modality is dependent on ``context_type``. Defaults to 0. Returns: A data generator, which generates one piece of data (a dict @@ -932,9 +1050,13 @@ def get_annotation_list( " [Annotation, AudioAnnotation]." ) - def get_context_data(c_type, context): - r"""Get context-specific data of a given context type and - context. + def get_context_data( + c_type: Union[Type[Annotation], Type[AudioAnnotation]], + context: Union[Annotation, AudioAnnotation], + payload_index: int, + ): + r""" + Get context-specific data of a given context type and context. Args: c_type: @@ -942,6 +1064,9 @@ def get_context_data(c_type, context): could be any :class:`~forte.data.ontology.top.Annotation` type. context: context that contains data to be extracted. + payload_index: the zero-based index of the Payload + in this DataPack's Payload entries of a particular modality. + The modality is dependent on ``c_type``. Raises: NotImplementedError: raised when the given context type is @@ -951,9 +1076,13 @@ def get_context_data(c_type, context): str: context data. """ if issubclass(c_type, Annotation): - return self.text[context.begin : context.end] + return self.get_payload_data_at(Modality.Text, payload_index)[ + context.begin : context.end + ] elif issubclass(c_type, AudioAnnotation): - return self.audio[context.begin : context.end] + return self.get_payload_data_at(Modality.Audio, payload_index)[ + context.begin : context.end + ] else: raise NotImplementedError( f"Context type is set to {context_type}" @@ -971,7 +1100,9 @@ def get_context_data(c_type, context): skipped += 1 continue data: Dict[str, Any] = {} - data["context"] = get_context_data(context_type_, context) + data["context"] = get_context_data( + context_type_, context, payload_index + ) data["offset"] = context.begin for field in context_fields: @@ -1417,6 +1548,17 @@ def _save_entry_to_data_store(self, entry: Entry): r"""Save an existing entry object into DataStore""" self._entry_converter.save_entry_object(entry=entry, pack=self) + if isinstance(entry, Payload): + if entry.modality == Modality.Text: + entry.set_payload_index(len(self.text_payloads)) + self.text_payloads.append(entry) + elif entry.modality == Modality.Audio: + entry.set_payload_index(len(self.audio_payloads)) + self.audio_payloads.append(entry) + elif entry.modality == Modality.Image: + entry.set_payload_index(len(self.image_payloads)) + self.image_payloads.append(entry) + def _get_entry_from_data_store(self, tid: int) -> EntryType: r"""Generate a class object from entry data in DataStore""" return self._entry_converter.get_entry_object(tid=tid, pack=self) diff --git a/forte/data/data_store.py b/forte/data/data_store.py index acb3450fb..8671b68e8 100644 --- a/forte/data/data_store.py +++ b/forte/data/data_store.py @@ -33,6 +33,7 @@ ImageAnnotation, Link, Generics, + Payload, MultiPackGeneric, MultiPackGroup, MultiPackLink, @@ -775,6 +776,7 @@ def _add_entry_raw( Generics, ImageAnnotation, Grids, + Payload, MultiPackLink, MultiPackGroup, MultiPackGeneric, diff --git a/forte/data/entry_converter.py b/forte/data/entry_converter.py index ec3abc4f2..343d6693e 100644 --- a/forte/data/entry_converter.py +++ b/forte/data/entry_converter.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, cast from forte.data.base_pack import PackType from forte.data.ontology.core import Entry, FList, FDict from forte.data.ontology.core import EntryType @@ -28,6 +28,7 @@ MultiPackGeneric, MultiPackGroup, MultiPackLink, + Payload, SinglePackEntries, MultiPackEntries, ) @@ -114,6 +115,15 @@ def save_entry_object( tid=entry.tid, allow_duplicate=allow_duplicate, ) + elif data_store_ref._is_subclass(entry.entry_type(), Payload): + entry = cast(Payload, entry) + data_store_ref.add_entry_raw( + type_name=entry.entry_type(), + attribute_data=[entry.payload_index, entry.modality_name], + base_class=Payload, + tid=entry.tid, + allow_duplicate=allow_duplicate, + ) elif data_store_ref._is_subclass(entry.entry_type(), Grids): # Will be deprecated in future data_store_ref.add_entry_raw( diff --git a/forte/data/modality.py b/forte/data/modality.py new file mode 100644 index 000000000..e2a682066 --- /dev/null +++ b/forte/data/modality.py @@ -0,0 +1,16 @@ +# Copyright 2022 The Forte Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import IntEnum + +Modality = IntEnum("modality", "Text Audio Image") diff --git a/forte/data/ontology/top.py b/forte/data/ontology/top.py index e621fe028..39688b906 100644 --- a/forte/data/ontology/top.py +++ b/forte/data/ontology/top.py @@ -12,11 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass +from enum import IntEnum from functools import total_ordering -from typing import Optional, Tuple, Type, Any, Dict, Union, Iterable, List - +from typing import ( + Optional, + Sequence, + Tuple, + Type, + Any, + Dict, + Union, + Iterable, + List, +) import numpy as np +from forte.data.modality import Modality from forte.data.base_pack import PackType from forte.data.ontology.core import ( Entry, @@ -51,6 +62,7 @@ "Region", "Box", "BoundingBox", + "Payload", ] QueryType = Union[Dict[str, Any], np.ndarray] @@ -848,8 +860,10 @@ def __init__(self, pack: PackType, image_payload_idx: int = 0): Args: pack: The container that this image annotation will be added to. - image_payload_idx: the index of the image payload. If it's not set, - it defaults to 0 which means it will load the first image payload. + image_payload_idx: the index of the image payload in the DataPack's + image payload list. + If it's not set, it defaults to 0 which means it will load the + first image payload. """ self._image_payload_idx = image_payload_idx super().__init__(pack) @@ -889,7 +903,9 @@ class Grids(Entry): pack: The container that this grids will be added to. height: the number of grid cell per column, the unit is one grid cell. width: the number of grid cell per row, the unit is one grid cell. - image_payload_idx: the index of the image payload. If it's not set, + image_payload_idx: the index of the image payload in the DataPack's + image payload list. + If it's not set, it defaults to 0 which meaning it will load the first image payload. """ @@ -909,7 +925,9 @@ def __init__( self._width = width self._image_payload_idx = image_payload_idx super().__init__(pack) - self.img_arr = self.pack.get_image_array(self._image_payload_idx) + self.img_arr = self.pack.get_payload_data_at( + Modality.Image, self._image_payload_idx + ) self.c_h, self.c_w = ( self.img_arr.shape[0] // self._height, self.img_arr.shape[1] // self._width, @@ -1021,7 +1039,9 @@ class Region(ImageAnnotation): Args: pack: the container that this ``Region`` will be added to. - image_payload_idx: the index of the image payload. If it's not set, + image_payload_idx: the index of the image payload in the DataPack's + image payload list. + If it's not set, it defaults to 0 which meaning it will load the first image payload. """ @@ -1047,7 +1067,8 @@ class Box(Region): Args: pack: the container that this ``Box`` will be added to. - image_payload_idx: the index of the image payload. If it's not set, + image_payload_idx: the index of the image payload in the DataPack's + image payload list. If it's not set, it defaults to 0 which meaning it will load the first image payload. cy: the row index of the box center in the image array, the unit is one image array entry. @@ -1170,7 +1191,8 @@ class BoundingBox(Box): Args: pack: The container that this BoundingBox will be added to. - image_payload_idx: the index of the image payload. If it's not set, + image_payload_idx: the index of the image payload in the DataPack's + image payload list. If it's not set, it defaults to 0 which means it will load the first image payload. height: the height of the bounding box, the unit is one image array entry. @@ -1207,6 +1229,142 @@ def __init__( ) +class Payload(Entry): + """ + A payload class that holds data cache of one modality and its data source uri. + + Args: + pack: The container that this `Payload` will + be added to. + modality: modality of the payload such as text, audio and image. + payload_idx: the index of the payload in the DataPack's + image payload list of the same modality. For example, if we + instantiate a ``TextPayload`` inherited from ``Payload``, we assign + the payload index in DataPack's text payload list. + uri: universal resource identifier of the data source. Defaults to None. + + Raises: + ValueError: raised when the modality is not supported. + """ + + def __init__( + self, + pack: PackType, + payload_idx: int = 0, + uri: Optional[str] = None, + ): + from ft.onto.base_ontology import ( # pylint: disable=import-outside-toplevel + TextPayload, + AudioPayload, + ImagePayload, + ) + + # since we cannot pass different modality from generated ontology, and + # we don't want to import base ontology in the header of the file + # we import it here. + if isinstance(self, TextPayload): + self._modality = Modality.Text + elif isinstance(self, AudioPayload): + self._modality = Modality.Audio + elif isinstance(self, ImagePayload): + self._modality = Modality.Image + else: + supported_modality = [enum.name for enum in Modality] + raise ValueError( + f"The given modality {self._modality.name} is not supported. " + f"Currently we only support {supported_modality}" + ) + self._payload_idx: int = payload_idx + self._uri: Optional[str] = uri + + super().__init__(pack) + self._cache: Union[str, np.ndarray] = "" + self.replace_back_operations: Sequence[Tuple] = [] + self.processed_original_spans: Sequence[Tuple] = [] + self.orig_text_len: int = 0 + + def get_type(self) -> type: + """ + Get the class type of the payload class. For example, suppose a + ``TextPayload`` inherits this ``Payload`` class, ``TextPayload`` will be + returned. + + Returns: + the type of the payload class. + """ + return type(self) + + @property + def cache(self) -> Union[str, np.ndarray]: + return self._cache + + @property + def modality(self) -> IntEnum: + """ + Get the modality of the payload class. + + Returns: + the modality of the payload class in ``IntEnum`` format. + """ + return self._modality + + @property + def modality_name(self) -> str: + """ + Get the modality of the payload class in str format. + + Returns: + the modality of the payload class in str format. + """ + return self._modality.name + + @property + def payload_index(self) -> int: + return self._payload_idx + + @property + def uri(self) -> Optional[str]: + return self._uri + + def set_cache(self, data: Union[str, np.ndarray]): + """ + Load cache data into the payload. + + Args: + data: data to be set in the payload. It can be str for text data or + numpy array for audio or image data. + """ + self._cache = data + + def set_payload_index(self, payload_index: int): + """ + Set payload index for the DataPack. + + Args: + payload_index: a new payload index to be set. + """ + self._payload_idx = payload_index + + def __getstate__(self): + r""" + Convert ``_modality`` ``Enum`` object to str format for serialization. + """ + # TODO: this function will be removed since + # Entry store is being integrated into DataStore + state = self.__dict__.copy() + state["_modality"] = self._modality.name + return state + + def __setstate__(self, state): + r""" + Convert ``_modality`` string to ``Enum`` object for deserialization. + """ + # TODO: this function will be removed since + # Entry store is being integrated into DataStore + self.__dict__.update(state) + self._modality = getattr(Modality, state["_modality"]) + + SinglePackEntries = ( Link, Group, @@ -1214,5 +1372,6 @@ def __init__( Generics, AudioAnnotation, ImageAnnotation, + Payload, ) MultiPackEntries = (MultiPackLink, MultiPackGroup, MultiPackGeneric) diff --git a/forte/ontology_specs/base_ontology.json b/forte/ontology_specs/base_ontology.json index 8b78809f7..f3d854301 100644 --- a/forte/ontology_specs/base_ontology.json +++ b/forte/ontology_specs/base_ontology.json @@ -444,6 +444,29 @@ "type": "str" } ] + }, + { + "entry_name": "ft.onto.base_ontology.AudioPayload", + "parent_entry": "forte.data.ontology.top.Payload", + "description": "A payload that caches audio data", + "attributes":[ + { + "name": "sample_rate", + "type": "int" + } + ] + }, + { + "entry_name": "ft.onto.base_ontology.TextPayload", + "parent_entry": "forte.data.ontology.top.Payload", + "description": "A payload that caches text data", + "attributes": [] + }, + { + "entry_name": "ft.onto.base_ontology.ImagePayload", + "parent_entry": "forte.data.ontology.top.Payload", + "description": "A payload that caches image data", + "attributes":[] } ] } diff --git a/ft/onto/base_ontology.py b/ft/onto/base_ontology.py index 0be5956d7..678b4c63e 100644 --- a/ft/onto/base_ontology.py +++ b/ft/onto/base_ontology.py @@ -8,7 +8,6 @@ """ from dataclasses import dataclass -from forte.data.base_pack import PackType from forte.data.data_pack import DataPack from forte.data.multi_pack import MultiPack from forte.data.ontology.core import Entry @@ -20,6 +19,7 @@ from forte.data.ontology.top import Group from forte.data.ontology.top import Link from forte.data.ontology.top import MultiPackLink +from forte.data.ontology.top import Payload from typing import Dict from typing import Iterable from typing import List @@ -54,6 +54,9 @@ "MRCQuestion", "Recording", "AudioUtterance", + "AudioPayload", + "TextPayload", + "ImagePayload", ] @@ -337,9 +340,8 @@ def __init__(self, pack: DataPack, parent: Optional[Entry] = None, child: Option @dataclass class EnhancedDependency(Link): """ - A `Link` type entry which represent a `enhanced dependency - `_. - + A `Link` type entry which represent a enhanced dependency: + https://universaldependencies.org/u/overview/enhanced-syntax.html Attributes: dep_label (Optional[str]): The enhanced dependency label in Universal Dependency. """ @@ -540,7 +542,7 @@ class Recording(AudioAnnotation): recording_class: List[str] - def __init__(self, pack: PackType, begin: int, end: int): + def __init__(self, pack: DataPack, begin: int, end: int): super().__init__(pack, begin, end) self.recording_class: List[str] = [] @@ -555,6 +557,41 @@ class AudioUtterance(AudioAnnotation): speaker: Optional[str] - def __init__(self, pack: PackType, begin: int, end: int): + def __init__(self, pack: DataPack, begin: int, end: int): super().__init__(pack, begin, end) self.speaker: Optional[str] = None + + +@dataclass +class AudioPayload(Payload): + """ + A payload that caches audio data + Attributes: + sample_rate (Optional[int]): + """ + + sample_rate: Optional[int] + + def __init__(self, pack: DataPack, payload_idx: int = 0, uri: Optional[str] = None): + super().__init__(pack, payload_idx, uri) + self.sample_rate: Optional[int] = None + + +@dataclass +class TextPayload(Payload): + """ + A payload that caches text data + """ + + def __init__(self, pack: DataPack, payload_idx: int = 0, uri: Optional[str] = None): + super().__init__(pack, payload_idx, uri) + + +@dataclass +class ImagePayload(Payload): + """ + A payload that caches image data + """ + + def __init__(self, pack: DataPack, payload_idx: int = 0, uri: Optional[str] = None): + super().__init__(pack, payload_idx, uri) diff --git a/requirements.txt b/requirements.txt index 34dda65d1..c1a9c73e7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ dataclasses~=0.7; python_version <'3.7' importlib-resources==5.1.4;python_version<'3.7' packaging~=21.2 asyml-utilities +enum34 diff --git a/tests/forte/data/audio_annotation_test.py b/tests/forte/data/audio_annotation_test.py index 357201f33..b06963e04 100644 --- a/tests/forte/data/audio_annotation_test.py +++ b/tests/forte/data/audio_annotation_test.py @@ -16,6 +16,7 @@ """ import os import unittest +from forte.data.modality import Modality import numpy as np from typing import Dict, List @@ -27,18 +28,29 @@ from forte.data.data_pack import DataPack from forte.data.readers import AudioReader from forte.data.ontology.top import ( - Annotation, AudioAnnotation, Generics, Group, Link + Annotation, + AudioAnnotation, + Generics, + Group, + Link, +) +from ft.onto.base_ontology import ( + Recording, + AudioUtterance, + Utterance, ) -from ft.onto.base_ontology import Recording, AudioUtterance, Utterance class RecordingProcessor(PackProcessor): """ A processor to add a Recording ontology to the whole audio data. """ + def _process(self, input_pack: DataPack): Recording( - pack=input_pack, begin=0, end=len(input_pack.audio) + pack=input_pack, + begin=0, + end=len(input_pack.get_payload_data_at(Modality.Audio, 0)), ) @@ -77,11 +89,10 @@ class AudioUtteranceProcessor(PackProcessor): A processor to add an AudioUtterance annotation to the specified span of audio payload. """ + def _process(self, input_pack: DataPack): audio_utter: AudioUtterance = AudioUtterance( - pack=input_pack, - begin=self.configs.begin, - end=self.configs.end + pack=input_pack, begin=self.configs.begin, end=self.configs.end ) audio_utter.speaker = self.configs.speaker @@ -109,20 +120,20 @@ def setUp(self): os.pardir, os.pardir, os.pardir, - "data_samples/audio_reader_test" + "data_samples/audio_reader_test", ) ) self._test_configs = { "Alice": {"begin": 200, "end": 35000}, - "Bob": {"begin": 35200, "end": 72000} + "Bob": {"begin": 35200, "end": 72000}, } # Define and config the Pipeline self._pipeline = Pipeline[DataPack]() - self._pipeline.set_reader(AudioReader(), config={ - "read_kwargs": {"always_2d": "True"} - }) + self._pipeline.set_reader( + AudioReader(), config={"read_kwargs": {"always_2d": "True"}} + ) self._pipeline.add(RecordingProcessor()) for speaker, span in self._test_configs.items(): self._pipeline.add( @@ -131,9 +142,7 @@ def setUp(self): self._pipeline.add(TextUtteranceProcessor()) self._pipeline.initialize() - def test_audio_annotation(self): - # Test `DataPack.get_span_audio()` with None audio payload with self.assertRaises(ProcessExecutionException): pack: DataPack = DataPack() @@ -143,39 +152,65 @@ def test_audio_annotation(self): for pack in self._pipeline.process_dataset(self._test_audio_path): # test get all audio annotation # test get selective fields data from subclass of AudioAnnotation - raw_data_generator = pack.get_data(AudioAnnotation, - {Recording: - {"fields": ["recording_class"]}, - AudioUtterance: - {"fields": ["speaker"]}} - ) + raw_data_generator = pack.get_data( + AudioAnnotation, + { + Recording: {"fields": ["recording_class"]}, + AudioUtterance: {"fields": ["speaker"]}, + }, + ) for data_instance in pack.get(AudioAnnotation): raw_data = next(raw_data_generator) - - self.assertTrue('Recording' in raw_data.keys() and - "recording_class" in raw_data['Recording']) - self.assertTrue('AudioUtterance' in raw_data.keys() and - "speaker" in raw_data['AudioUtterance']) + + self.assertTrue( + "Recording" in raw_data.keys() + and "recording_class" in raw_data["Recording"] + ) + self.assertTrue( + "AudioUtterance" in raw_data.keys() + and "speaker" in raw_data["AudioUtterance"] + ) # test grouped data if isinstance(data_instance, Recording): - self.assertTrue(array_equal(np.array([data_instance.audio]), raw_data['Recording']['audio'])) - self.assertTrue(data_instance.recording_class ==np.squeeze(raw_data['Recording']['recording_class']).tolist()) + self.assertTrue( + array_equal( + np.array([data_instance.audio]), + raw_data["Recording"]["audio"], + ) + ) + self.assertTrue( + data_instance.recording_class + == np.squeeze( + raw_data["Recording"]["recording_class"] + ).tolist() + ) elif isinstance(data_instance, AudioUtterance): - self.assertTrue(array_equal(np.array([data_instance.audio]), raw_data['AudioUtterance']['audio'])) - self.assertTrue(data_instance.speaker - ==raw_data['AudioUtterance']['speaker'][0]) + self.assertTrue( + array_equal( + np.array([data_instance.audio]), + raw_data["AudioUtterance"]["audio"], + ) + ) + self.assertTrue( + data_instance.speaker + == raw_data["AudioUtterance"]["speaker"][0] + ) # check non-existence of non-requested data fields raw_data_generator = pack.get_data(AudioAnnotation) for raw_data in raw_data_generator: self.assertFalse("Recording" in raw_data) self.assertFalse("AudioUtterance" in raw_data) - + # Check Recording recordings = list(pack.get(Recording)) self.assertEqual(len(recordings), 1) - self.assertTrue(array_equal(recordings[0].audio, pack.audio)) - + self.assertTrue( + array_equal( + recordings[0].audio, + pack.get_payload_data_at(Modality.Audio, 0), + ) + ) # Check serialization/deserialization of AudioAnnotation new_pack = DataPack.from_string(pack.to_string()) self.assertEqual(new_pack.audio_annotations, pack.audio_annotations) @@ -192,10 +227,14 @@ def test_audio_annotation(self): for audio_utter in audio_utters: configs: Dict = self._test_configs[audio_utter.speaker] - self.assertTrue(array_equal( - audio_utter.audio, - pack.audio[configs["begin"]:configs["end"]] - )) + self.assertTrue( + array_equal( + audio_utter.audio, + pack.get_payload_data_at(Modality.Audio, 0)[ + configs["begin"] : configs["end"] + ], + ) + ) # Check `AudioAnnotation.get(Group/Link/Generics)`. Note that only # `DummyGroup` and `DummyLink` entries can be retrieved because @@ -203,8 +242,9 @@ def test_audio_annotation(self): for entry_type in (Group, Link): self.assertEqual( len(list(recordings[0].get(entry_type))), - len(self._test_configs) + len(self._test_configs), ) + # we have one generics meta data self.assertEqual(len(list(recordings[0].get(Generics))), 0) # Check operations with mixing types of entries. @@ -216,12 +256,27 @@ def test_audio_annotation(self): # Verify the new conditional branches in DataPack.get() when dealing # with empty annotation/audio_annotation list. empty_pack: DataPack = DataPack() - self.assertEqual(len(list(empty_pack.get( - entry_type=Annotation, range_annotation=utter - ))), 0) - self.assertEqual(len(list(empty_pack.get( - entry_type=AudioAnnotation, range_annotation=recordings[0] - ))), 0) + self.assertEqual( + len( + list( + empty_pack.get( + entry_type=Annotation, range_annotation=utter + ) + ) + ), + 0, + ) + self.assertEqual( + len( + list( + empty_pack.get( + entry_type=AudioAnnotation, + range_annotation=recordings[0], + ) + ) + ), + 0, + ) # Check `DataPack.delete_entry(AudioAnnotation)` for audio_annotation in list(pack.get(AudioAnnotation)): @@ -241,8 +296,7 @@ def test_build_coverage_for(self): # Add coverage index for (Recording, AudioUtterance) pack.build_coverage_for( - context_type=Recording, - covered_type=AudioUtterance + context_type=Recording, covered_type=AudioUtterance ) self.assertTrue(pack._index.coverage_index_is_valid) self.assertEqual( @@ -250,24 +304,32 @@ def test_build_coverage_for(self): ) # Check DataIndex.get_covered() - self.assertTrue(pack.covers( - context_entry=recording, covered_entry=audio_utters[0] - )) - self.assertFalse(pack.covers( - context_entry=audio_utters[0], covered_entry=recording - )) + self.assertTrue( + pack.covers( + context_entry=recording, covered_entry=audio_utters[0] + ) + ) + self.assertFalse( + pack.covers( + context_entry=audio_utters[0], covered_entry=recording + ) + ) # Check DataIndex.coverage_index_is_valid flag pack._index.deactivate_coverage_index() - self.assertTrue(pack._index.coverage_index( - outer_type=Recording, - inner_type=AudioUtterance - ) is None) + self.assertTrue( + pack._index.coverage_index( + outer_type=Recording, inner_type=AudioUtterance + ) + is None + ) pack._index.activate_coverage_index() - self.assertFalse(pack._index.coverage_index( - outer_type=Recording, - inner_type=AudioUtterance - ) is None) + self.assertFalse( + pack._index.coverage_index( + outer_type=Recording, inner_type=AudioUtterance + ) + is None + ) # Check DataIndex.have_overlap() with self.assertRaises(TypeError): @@ -286,8 +348,7 @@ def test_build_coverage_for(self): # Check coverage index when inner and outer entries are the same pack._index.deactivate_coverage_index() pack.build_coverage_for( - context_type=Utterance, - covered_type=Utterance + context_type=Utterance, covered_type=Utterance ) self.assertEqual(len(pack._index._coverage_index), 1) utter = pack.get_single(Utterance) diff --git a/tests/forte/data/readers/audio_reader_test.py b/tests/forte/data/readers/audio_reader_test.py index 23a8c91ce..419976370 100644 --- a/tests/forte/data/readers/audio_reader_test.py +++ b/tests/forte/data/readers/audio_reader_test.py @@ -17,6 +17,7 @@ import os import unittest from typing import Dict +from forte.data import Modality from torch import argmax from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC @@ -27,12 +28,14 @@ from forte.data.readers import AudioReader from forte.pipeline import Pipeline from forte.processors.base.pack_processor import PackProcessor +from ft.onto.base_ontology import TextPayload class TestASRProcessor(PackProcessor): """ An audio processor for automatic speech recognition. """ + def initialize(self, resources: Resources, configs: Config): super().initialize(resources, configs) @@ -42,8 +45,11 @@ def initialize(self, resources: Resources, configs: Config): self._model = Wav2Vec2ForCTC.from_pretrained(pretrained_model) def _process(self, input_pack: DataPack): + ap = input_pack.get_payload_at(Modality.Audio, 0) + sample_rate = ap.sample_rate + audio_data = ap.cache required_sample_rate: int = 16000 - if input_pack.sample_rate != required_sample_rate: + if sample_rate != required_sample_rate: raise ProcessFlowException( f"A sample rate of {required_sample_rate} Hz is requied by the" " pretrained model." @@ -51,7 +57,7 @@ def _process(self, input_pack: DataPack): # tokenize input_values = self._tokenizer( - input_pack.audio, return_tensors="pt", padding="longest" + audio_data, return_tensors="pt", padding="longest" ).input_values # Batch size 1 # take argmax and decode @@ -75,10 +81,9 @@ def setUp(self): os.pardir, os.pardir, os.pardir, - "data_samples/audio_reader_test" + "data_samples/audio_reader_test", ) ) - # Define and config the Pipeline self._pipeline = Pipeline[DataPack]() self._pipeline.set_reader(AudioReader()) @@ -87,12 +92,13 @@ def setUp(self): def test_asr_pipeline(self): target_transcription: Dict[str, str] = { - self._test_audio_path + "/test_audio_0.flac": - "A MAN SAID TO THE UNIVERSE SIR I EXIST", - self._test_audio_path + "/test_audio_1.flac": ( + self._test_audio_path + + "/test_audio_0.flac": "A MAN SAID TO THE UNIVERSE SIR I EXIST", + self._test_audio_path + + "/test_audio_1.flac": ( "NOR IS MISTER QUILTER'S MANNER LESS INTERESTING " "THAN HIS MATTER" - ) + ), } # Verify the ASR result of each datapack diff --git a/tests/forte/grids_test.py b/tests/forte/grids_test.py index 24fd9c3c0..df54a28a8 100644 --- a/tests/forte/grids_test.py +++ b/tests/forte/grids_test.py @@ -15,6 +15,8 @@ Unit tests for Grids. """ import unittest +from forte.data.modality import Modality +from ft.onto.base_ontology import ImagePayload import numpy as np from numpy import array_equal @@ -34,7 +36,8 @@ def setUp(self): line[2, 2] = 1 line[3, 3] = 1 line[4, 4] = 1 - self.datapack.payloads.append(line) + ip = ImagePayload(self.datapack) + ip.set_cache(line) self.datapack.image_annotations.append( ImageAnnotation(self.datapack, 0) ) @@ -45,7 +48,11 @@ def setUp(self): self.zeros = np.zeros((6, 12)) self.ref_arr = np.zeros((6, 12)) self.ref_arr[2, 2] = 1 - self.datapack.payloads.append(self.ref_arr) + ip = ImagePayload(self.datapack) + ip.set_cache(self.ref_arr) + self.datapack.image_annotations.append( + ImageAnnotation(self.datapack, 0) + ) def test_grids(self): diff --git a/tests/forte/image_annotation_test.py b/tests/forte/image_annotation_test.py index 7d2233498..35f523c80 100644 --- a/tests/forte/image_annotation_test.py +++ b/tests/forte/image_annotation_test.py @@ -14,14 +14,17 @@ """ Unit tests for ImageAnnotation. """ -import os import unittest +from forte.data.modality import Modality import numpy as np -from typing import Dict from numpy import array_equal from forte.data.ontology.top import ImageAnnotation + +from ft.onto.base_ontology import ImagePayload + from forte.data.data_pack import DataPack +import unittest class ImageAnnotationTest(unittest.TestCase): @@ -35,16 +38,21 @@ def setUp(self): self.line[2, 2] = 1 self.line[3, 3] = 1 self.line[4, 4] = 1 - self.datapack.payloads.append(self.line) - self.datapack.image_annotations.append( - ImageAnnotation(self.datapack, 0) - ) + ip = ImagePayload(self.datapack, 0) + ip.set_cache(self.line) + ImageAnnotation(self.datapack) def test_image_annotation(self): self.assertEqual( - self.datapack.image_annotations[0].image_payload_idx, 0 + self.datapack.get_single(ImageAnnotation).image_payload_idx, 0 ) self.assertTrue( - array_equal(self.datapack.image_annotations[0].image, self.line) + array_equal( + self.datapack.get_payload_at(Modality.Image, 0).cache, self.line + ) + ) + new_pack = DataPack.from_string(self.datapack.to_string()) + self.assertEqual( + new_pack.audio_annotations, self.datapack.audio_annotations )