From f1f55e208bafd313754c2460ec61aa7c935adb72 Mon Sep 17 00:00:00 2001 From: Vinnam Kim Date: Fri, 29 Mar 2024 17:08:19 +0900 Subject: [PATCH] Refactor XAI data entities (#3230) * Refactor XAI data entities * Fix tests * Fix test errors Signed-off-by: Kim, Vinnam --------- Signed-off-by: Kim, Vinnam --- .../algo/classification/torchvision_model.py | 10 +- src/otx/algo/utils/xai_utils.py | 20 ++-- .../zero_shot_segment_anything.py | 13 +-- src/otx/core/data/dataset/visual_prompting.py | 7 +- .../core/data/entity/action_classification.py | 20 +--- src/otx/core/data/entity/action_detection.py | 24 +++-- .../data/entity/anomaly/classification.py | 10 +- src/otx/core/data/entity/anomaly/detection.py | 20 ++-- .../core/data/entity/anomaly/segmentation.py | 18 ++-- src/otx/core/data/entity/base.py | 93 ++++++++++++------- src/otx/core/data/entity/classification.py | 52 ++--------- src/otx/core/data/entity/detection.py | 28 ++---- .../core/data/entity/instance_segmentation.py | 30 +++--- src/otx/core/data/entity/segmentation.py | 20 +--- src/otx/core/data/entity/visual_prompting.py | 66 ++++++------- .../core/data/transform_libs/torchvision.py | 8 +- src/otx/core/model/action_classification.py | 11 +-- src/otx/core/model/action_detection.py | 9 +- src/otx/core/model/base.py | 30 ++---- src/otx/core/model/classification.py | 60 +++++------- src/otx/core/model/detection.py | 26 +++--- src/otx/core/model/instance_segmentation.py | 29 +++--- src/otx/core/model/segmentation.py | 20 ++-- src/otx/core/model/visual_prompting.py | 25 ++--- src/otx/core/utils/tile_merge.py | 12 +-- tests/conftest.py | 24 +++-- tests/integration/api/test_xai.py | 11 ++- .../algo/hooks/test_saliency_map_dumping.py | 4 +- .../hooks/test_saliency_map_processing.py | 10 +- 29 files changed, 306 insertions(+), 404 deletions(-) diff --git a/src/otx/algo/classification/torchvision_model.py b/src/otx/algo/classification/torchvision_model.py index 64e015724e3..2331cdf106a 100644 --- a/src/otx/algo/classification/torchvision_model.py +++ b/src/otx/algo/classification/torchvision_model.py @@ -13,11 +13,7 @@ from torchvision.models import get_model, get_model_weights from otx.core.data.entity.base import OTXBatchLossEntity -from otx.core.data.entity.classification import ( - MulticlassClsBatchDataEntity, - MulticlassClsBatchPredEntity, - MulticlassClsBatchPredEntityWithXAI, -) +from otx.core.data.entity.classification import MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity from otx.core.metrics.accuracy import MultiClassClsMetricCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.classification import OTXMulticlassClsModel @@ -225,7 +221,7 @@ def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: MulticlassClsBatchDataEntity, - ) -> MulticlassClsBatchPredEntity | MulticlassClsBatchPredEntityWithXAI | OTXBatchLossEntity: + ) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity: if self.training: return OTXBatchLossEntity(loss=outputs) @@ -241,7 +237,7 @@ def _customize_outputs( saliency_maps = outputs["saliency_map"].detach().cpu().numpy() - return MulticlassClsBatchPredEntityWithXAI( + return MulticlassClsBatchPredEntity( batch_size=len(preds), images=inputs.images, imgs_info=inputs.imgs_info, diff --git a/src/otx/algo/utils/xai_utils.py b/src/otx/algo/utils/xai_utils.py index c37ff209f51..47bfe16b6b4 100644 --- a/src/otx/algo/utils/xai_utils.py +++ b/src/otx/algo/utils/xai_utils.py @@ -12,8 +12,7 @@ from datumaro import Image from otx.core.config.explain import ExplainConfig -from otx.core.data.entity.base import OTXBatchPredEntityWithXAI -from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntityWithXAI +from otx.core.data.entity.base import OTXBatchPredEntity from otx.core.types.explain import TargetExplainGroup if TYPE_CHECKING: @@ -23,22 +22,23 @@ def process_saliency_maps_in_pred_entity( - predict_result: list[OTXBatchPredEntityWithXAI | InstanceSegBatchPredEntityWithXAI | Any], + predict_result: list[OTXBatchPredEntity], explain_config: ExplainConfig, -) -> list[Any] | list[OTXBatchPredEntityWithXAI | InstanceSegBatchPredEntityWithXAI]: +) -> list[OTXBatchPredEntity]: """Process saliency maps in PredEntity.""" - for predict_result_per_batch in predict_result: + + def _process(predict_result_per_batch: OTXBatchPredEntity) -> OTXBatchPredEntity: saliency_maps = predict_result_per_batch.saliency_maps imgs_info = predict_result_per_batch.imgs_info ori_img_shapes = [img_info.ori_shape for img_info in imgs_info] - pred_labels = predict_result_per_batch.labels # type: ignore[union-attr] - if pred_labels: + if pred_labels := getattr(predict_result_per_batch, "labels", None): pred_labels = [pred.tolist() for pred in pred_labels] processed_saliency_maps = process_saliency_maps(saliency_maps, explain_config, pred_labels, ori_img_shapes) - predict_result_per_batch.saliency_maps = processed_saliency_maps - return predict_result + return predict_result_per_batch.wrap(saliency_maps=processed_saliency_maps) + + return [_process(predict_result_per_batch) for predict_result_per_batch in predict_result] def process_saliency_maps( @@ -116,7 +116,7 @@ def postprocess(saliency_map: np.ndarray, output_size: tuple[int, int] | None) - def dump_saliency_maps( - predict_result: list[OTXBatchPredEntityWithXAI | InstanceSegBatchPredEntityWithXAI | Any], + predict_result: list[OTXBatchPredEntity], explain_config: ExplainConfig, datamodule: EVAL_DATALOADERS | OTXDataModule, output_dir: Path, diff --git a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py index b3456c668a3..740d366f3f8 100644 --- a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py +++ b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py @@ -863,12 +863,13 @@ def preprocess(self, x: Image) -> Image: def transforms(self, entity: ZeroShotVisualPromptingBatchDataEntity) -> ZeroShotVisualPromptingBatchDataEntity: """Transforms for ZeroShotVisualPromptingBatchDataEntity.""" - entity.images = [self.preprocess(self.apply_image(image)) for image in entity.images] - entity.prompts = [ - self.apply_prompts(prompt, info.ori_shape, self.model.image_size) - for prompt, info in zip(entity.prompts, entity.imgs_info) - ] - return entity + return entity.wrap( + images=[self.preprocess(self.apply_image(image)) for image in entity.images], + prompts=[ + self.apply_prompts(prompt, info.ori_shape, self.model.image_size) + for prompt, info in zip(entity.prompts, entity.imgs_info) + ], + ) def initialize_reference_info(self) -> None: """Initialize reference information.""" diff --git a/src/otx/core/data/dataset/visual_prompting.py b/src/otx/core/data/dataset/visual_prompting.py index c7d8d5aa0e9..4ece94158e4 100644 --- a/src/otx/core/data/dataset/visual_prompting.py +++ b/src/otx/core/data/dataset/visual_prompting.py @@ -146,9 +146,12 @@ def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None: ) transformed_entity = self._apply_transforms(entity) + if transformed_entity is None: + msg = "This is not allowed." + raise RuntimeError(msg) + # insert masks to transformed_entity - transformed_entity.masks = masks # type: ignore[union-attr] - return transformed_entity + return transformed_entity.wrap(masks=masks) @property def collate_fn(self) -> Callable: diff --git a/src/otx/core/data/entity/action_classification.py b/src/otx/core/data/entity/action_classification.py index a10ac00c48c..41d9a0e0899 100644 --- a/src/otx/core/data/entity/action_classification.py +++ b/src/otx/core/data/entity/action_classification.py @@ -11,10 +11,8 @@ from otx.core.data.entity.base import ( OTXBatchDataEntity, OTXBatchPredEntity, - OTXBatchPredEntityWithXAI, OTXDataEntity, OTXPredEntity, - OTXPredEntityWithXAI, ) from otx.core.data.entity.utils import register_pytree_node from otx.core.types.task import OTXTaskType @@ -51,15 +49,10 @@ def task(self) -> OTXTaskType: @dataclass -class ActionClsPredEntity(ActionClsDataEntity, OTXPredEntity): +class ActionClsPredEntity(OTXPredEntity, ActionClsDataEntity): """Data entity to represent the action classification model's output prediction.""" -@dataclass -class ActionClsPredEntityWithXAI(ActionClsDataEntity, OTXPredEntityWithXAI): - """Data entity to represent the detection model output prediction with explanations.""" - - @dataclass class ActionClsBatchDataEntity(OTXBatchDataEntity[ActionClsDataEntity]): """Batch data entity for action classification. @@ -92,16 +85,9 @@ def collate_fn( def pin_memory(self) -> ActionClsBatchDataEntity: """Pin memory for member tensor variables.""" - super().pin_memory() - self.labels = [label.pin_memory() for label in self.labels] - return self + return super().pin_memory().wrap(labels=[label.pin_memory() for label in self.labels]) @dataclass -class ActionClsBatchPredEntity(ActionClsBatchDataEntity, OTXBatchPredEntity): +class ActionClsBatchPredEntity(OTXBatchPredEntity, ActionClsBatchDataEntity): """Data entity to represent model output predictions for action classification task.""" - - -@dataclass -class ActionClsBatchPredEntityWithXAI(ActionClsBatchDataEntity, OTXBatchPredEntityWithXAI): - """Data entity to represent model output predictions for multi-class classification task with explanations.""" diff --git a/src/otx/core/data/entity/action_detection.py b/src/otx/core/data/entity/action_detection.py index 37c7cde05fc..b4c2f0b7d9e 100644 --- a/src/otx/core/data/entity/action_detection.py +++ b/src/otx/core/data/entity/action_detection.py @@ -13,7 +13,6 @@ from otx.core.data.entity.base import ( OTXBatchDataEntity, OTXBatchPredEntity, - OTXBatchPredEntityWithXAI, OTXDataEntity, OTXPredEntity, ) @@ -48,7 +47,7 @@ def task(self) -> OTXTaskType: @dataclass -class ActionDetPredEntity(ActionDetDataEntity, OTXPredEntity): +class ActionDetPredEntity(OTXPredEntity, ActionDetDataEntity): """Data entity to represent the action classification model's output prediction.""" @@ -89,18 +88,17 @@ def collate_fn( def pin_memory(self) -> ActionDetBatchDataEntity: """Pin memory for member tensor variables.""" - super().pin_memory() - self.bboxes = [tv_tensors.wrap(bbox.pin_memory(), like=bbox) for bbox in self.bboxes] - self.labels = [label.pin_memory() for label in self.labels] - self.proposals = [tv_tensors.wrap(proposal.pin_memory(), like=proposal) for proposal in self.proposals] - return self + return ( + super() + .pin_memory() + .wrap( + bboxes=[tv_tensors.wrap(bbox.pin_memory(), like=bbox) for bbox in self.bboxes], + labels=[label.pin_memory() for label in self.labels], + proposals=[tv_tensors.wrap(proposal.pin_memory(), like=proposal) for proposal in self.proposals], + ) + ) @dataclass -class ActionDetBatchPredEntity(ActionDetBatchDataEntity, OTXBatchPredEntity): +class ActionDetBatchPredEntity(OTXBatchPredEntity, ActionDetBatchDataEntity): """Data entity to represent model output predictions for action classification task.""" - - -@dataclass -class ActionDetBatchPredEntityWithXAI(ActionDetBatchDataEntity, OTXBatchPredEntityWithXAI): - """Data entity to represent model output predictions for multi-class classification task with explanations.""" diff --git a/src/otx/core/data/entity/anomaly/classification.py b/src/otx/core/data/entity/anomaly/classification.py index cacdeb792c5..4f9d9826518 100644 --- a/src/otx/core/data/entity/anomaly/classification.py +++ b/src/otx/core/data/entity/anomaly/classification.py @@ -63,18 +63,16 @@ def collate_fn( def pin_memory(self) -> AnomalyClassificationDataBatch: """Pin memory for member tensor variables.""" - super().pin_memory() - self.labels = [label.pin_memory() for label in self.labels] - return self + return super().pin_memory().wrap(labels=[label.pin_memory() for label in self.labels]) @dataclass -class AnomalyClassificationPrediction(AnomalyClassificationDataItem, OTXPredEntity): +class AnomalyClassificationPrediction(OTXPredEntity, AnomalyClassificationDataItem): """Anomaly classification Prediction item.""" -@dataclass -class AnomalyClassificationBatchPrediction(AnomalyClassificationDataBatch, OTXBatchPredEntity): +@dataclass(kw_only=True) +class AnomalyClassificationBatchPrediction(OTXBatchPredEntity, AnomalyClassificationDataBatch): """Anomaly classification batch prediction.""" anomaly_maps: torch.Tensor diff --git a/src/otx/core/data/entity/anomaly/detection.py b/src/otx/core/data/entity/anomaly/detection.py index 08a5e63805a..0bb5c6153a0 100644 --- a/src/otx/core/data/entity/anomaly/detection.py +++ b/src/otx/core/data/entity/anomaly/detection.py @@ -68,20 +68,24 @@ def collate_fn( def pin_memory(self) -> AnomalyDetectionDataBatch: """Pin memory for member tensor variables.""" - super().pin_memory() - self.labels = [label.pin_memory() for label in self.labels] - self.masks = self.masks.pin_memory() - self.boxes = [box.pin_memory() for box in self.boxes] - return self + return ( + super() + .pin_memory() + .wrap( + labels=[label.pin_memory() for label in self.labels], + masks=self.masks.pin_memory(), + boxes=[box.pin_memory() for box in self.boxes], + ) + ) @dataclass -class AnomalyDetectionPrediction(AnomalyDetectionDataItem, OTXPredEntity): +class AnomalyDetectionPrediction(OTXPredEntity, AnomalyDetectionDataItem): """Anomaly Detection Prediction item.""" -@dataclass -class AnomalyDetectionBatchPrediction(AnomalyDetectionDataBatch, OTXBatchPredEntity): +@dataclass(kw_only=True) +class AnomalyDetectionBatchPrediction(OTXBatchPredEntity, AnomalyDetectionDataBatch): """Anomaly classification batch prediction.""" anomaly_maps: torch.Tensor diff --git a/src/otx/core/data/entity/anomaly/segmentation.py b/src/otx/core/data/entity/anomaly/segmentation.py index 0ec64845cd8..6766e19dfcb 100644 --- a/src/otx/core/data/entity/anomaly/segmentation.py +++ b/src/otx/core/data/entity/anomaly/segmentation.py @@ -65,19 +65,23 @@ def collate_fn( def pin_memory(self) -> AnomalySegmentationDataBatch: """Pin memory for member tensor variables.""" - super().pin_memory() - self.labels = [label.pin_memory() for label in self.labels] - self.masks = self.masks.pin_memory() - return self + return ( + super() + .pin_memory() + .wrap( + labels=[label.pin_memory() for label in self.labels], + masks=self.masks.pin_memory(), + ) + ) @dataclass -class AnomalySegmentationPrediction(AnomalySegmentationDataItem, OTXPredEntity): +class AnomalySegmentationPrediction(OTXPredEntity, AnomalySegmentationDataItem): """Anomaly Segmentation Prediction item.""" -@dataclass -class AnomalySegmentationBatchPrediction(AnomalySegmentationDataBatch, OTXBatchPredEntity): +@dataclass(kw_only=True) +class AnomalySegmentationBatchPrediction(OTXBatchPredEntity, AnomalySegmentationDataBatch): """Anomaly classification batch prediction.""" anomaly_maps: torch.Tensor diff --git a/src/otx/core/data/entity/base.py b/src/otx/core/data/entity/base.py index 78b83d7fd9f..5373d249d64 100644 --- a/src/otx/core/data/entity/base.py +++ b/src/otx/core/data/entity/base.py @@ -8,7 +8,7 @@ import warnings from collections.abc import Mapping -from dataclasses import dataclass, fields +from dataclasses import asdict, dataclass, field, fields from typing import TYPE_CHECKING, Any, Dict, Generic, Iterator, TypeVar import torch @@ -501,16 +501,20 @@ def image_type(self) -> ImageType: return ImageType.get_image_type(self.image) def to_tv_image(self: T_OTXDataEntity) -> T_OTXDataEntity: - """Convert `self.image` to TorchVision Image if it is a Numpy array (inplace operation).""" + """Return a new instance with the `image` attribute converted to a TorchVision Image if it is a NumPy array. + + Returns: + A new instance with the `image` attribute converted to a TorchVision Image, if applicable. + Otherwise, return this instance as is. + """ if isinstance(self.image, tv_tensors.Image): return self - self.image = F.to_image(self.image) - return self + return self.wrap(image=F.to_image(self.image)) def __iter__(self) -> Iterator[str]: - for field in fields(self): - yield field.name + for field_ in fields(self): + yield field_.name def __getitem__(self, key: str) -> Any: # noqa: ANN401 return getattr(self, key) @@ -519,6 +523,18 @@ def __len__(self) -> int: """Get the number of fields in this data entity.""" return len(fields(self)) + def wrap(self: T_OTXDataEntity, **kwargs) -> T_OTXDataEntity: + """Wrap this dataclass with the given keyword arguments. + + Args: + **kwargs: Keyword arguments to be overwritten on top of this dataclass + Returns: + Updated dataclass + """ + updated_kwargs = asdict(self) + updated_kwargs.update(**kwargs) + return self.__class__(**updated_kwargs) + @dataclass class OTXPredEntity(OTXDataEntity): @@ -526,13 +542,8 @@ class OTXPredEntity(OTXDataEntity): score: np.ndarray | Tensor - -@dataclass -class OTXPredEntityWithXAI(OTXPredEntity): - """Data entity to represent model output prediction with explanations.""" - - saliency_map: np.ndarray | Tensor - feature_vector: np.ndarray | list + saliency_map: np.ndarray | Tensor | None = None + feature_vector: np.ndarray | list | None = None T_OTXBatchDataEntity = TypeVar( @@ -631,27 +642,52 @@ def pin_memory(self: T_OTXBatchDataEntity) -> T_OTXBatchDataEntity: """Pin memory for member tensor variables.""" # TODO(vinnamki): Keep track this issue # https://github.com/pytorch/pytorch/issues/116403 - self.images = ( - [tv_tensors.wrap(image.pin_memory(), like=image) for image in self.images] - if isinstance(self.images, list) - else tv_tensors.wrap(self.images.pin_memory(), like=self.images) + return self.wrap( + images=( + [tv_tensors.wrap(image.pin_memory(), like=image) for image in self.images] + if isinstance(self.images, list) + else tv_tensors.wrap(self.images.pin_memory(), like=self.images) + ), ) - return self + + def wrap(self: T_OTXBatchDataEntity, **kwargs) -> T_OTXBatchDataEntity: + """Wrap this dataclass with the given keyword arguments. + + Args: + **kwargs: Keyword arguments to be overwritten on top of this dataclass + Returns: + Updated dataclass + """ + updated_kwargs = asdict(self) + updated_kwargs.update(**kwargs) + return self.__class__(**updated_kwargs) @dataclass class OTXBatchPredEntity(OTXBatchDataEntity): - """Data entity to represent model output predictions.""" + """Data entity to represent model output predictions. + + Attributes: + scores: List of probability scores representing model predictions. + saliency_maps: List of saliency maps used to explain model predictions. + This field is optional and will be an empty list for non-XAI pipelines. + feature_vectors: List of intermediate feature vectors used for model predictions. + This field is optional and will be an empty list for non-XAI pipelines. + """ scores: list[np.ndarray] | list[Tensor] + # (Optional) XAI-related outputs + saliency_maps: list[np.ndarray] | list[Tensor] = field(default_factory=list) + feature_vectors: list[np.ndarray] | list[Tensor] = field(default_factory=list) -@dataclass -class OTXBatchPredEntityWithXAI(OTXBatchPredEntity): - """Data entity to represent model output predictions with explanations.""" - - saliency_maps: list[np.ndarray] | list[Tensor] - feature_vectors: list[np.ndarray] | list[Tensor] + @property + def has_xai_outputs(self) -> bool: + """If the XAI related fields are fulfilled, return True.""" + # NOTE: Don't know why but some of test cases in tests/integration/api/test_xai.py + # produce `len(self.saliency_maps) > 0` and `len(self.feature_vectors) == 0` + # return len(self.saliency_maps) > 0 and len(self.feature_vectors) > 0 + return len(self.saliency_maps) > 0 class OTXBatchLossEntity(Dict[str, Tensor]): @@ -663,13 +699,6 @@ class OTXBatchLossEntity(Dict[str, Tensor]): bound=OTXBatchPredEntity, ) - -T_OTXBatchPredEntityWithXAI = TypeVar( - "T_OTXBatchPredEntityWithXAI", - bound=OTXBatchPredEntityWithXAI, -) - - T_OTXBatchLossEntity = TypeVar( "T_OTXBatchLossEntity", bound=OTXBatchLossEntity, diff --git a/src/otx/core/data/entity/classification.py b/src/otx/core/data/entity/classification.py index 878fc44dc0a..49dd341862c 100644 --- a/src/otx/core/data/entity/classification.py +++ b/src/otx/core/data/entity/classification.py @@ -11,10 +11,8 @@ from otx.core.data.entity.base import ( OTXBatchDataEntity, OTXBatchPredEntity, - OTXBatchPredEntityWithXAI, OTXDataEntity, OTXPredEntity, - OTXPredEntityWithXAI, ) from otx.core.data.entity.utils import register_pytree_node from otx.core.types.task import OTXTaskType @@ -40,15 +38,10 @@ def task(self) -> OTXTaskType: @dataclass -class MulticlassClsPredEntity(MulticlassClsDataEntity, OTXPredEntity): +class MulticlassClsPredEntity(OTXPredEntity, MulticlassClsDataEntity): """Data entity to represent the multi-class classification model output prediction.""" -@dataclass -class MulticlassClsPredEntityWithXAI(MulticlassClsDataEntity, OTXPredEntityWithXAI): - """Data entity to represent the multi-class classification model output prediction with explanations.""" - - @dataclass class MulticlassClsBatchDataEntity(OTXBatchDataEntity[MulticlassClsDataEntity]): """Data entity for multi-class classification task. @@ -80,21 +73,14 @@ def collate_fn( def pin_memory(self) -> MulticlassClsBatchDataEntity: """Pin memory for member tensor variables.""" - super().pin_memory() - self.labels = [label.pin_memory() for label in self.labels] - return self + return super().pin_memory().wrap(labels=[label.pin_memory() for label in self.labels]) @dataclass -class MulticlassClsBatchPredEntity(MulticlassClsBatchDataEntity, OTXBatchPredEntity): +class MulticlassClsBatchPredEntity(OTXBatchPredEntity, MulticlassClsBatchDataEntity): """Data entity to represent model output predictions for multi-class classification task.""" -@dataclass -class MulticlassClsBatchPredEntityWithXAI(MulticlassClsBatchDataEntity, OTXBatchPredEntityWithXAI): - """Data entity to represent model output predictions for multi-class classification task with explanations.""" - - @register_pytree_node @dataclass class MultilabelClsDataEntity(OTXDataEntity): @@ -112,15 +98,10 @@ def task(self) -> OTXTaskType: @dataclass -class MultilabelClsPredEntity(MultilabelClsDataEntity, OTXPredEntity): +class MultilabelClsPredEntity(OTXPredEntity, MultilabelClsDataEntity): """Data entity to represent the multi-label classification model output prediction.""" -@dataclass -class MultilabelClsPredEntityWithXAI(MultilabelClsDataEntity, OTXPredEntityWithXAI): - """Data entity to represent the multi-label classification model output prediction with explanations.""" - - @dataclass class MultilabelClsBatchDataEntity(OTXBatchDataEntity[MultilabelClsDataEntity]): """Data entity for multi-label classification task. @@ -152,21 +133,14 @@ def collate_fn( def pin_memory(self) -> MultilabelClsBatchDataEntity: """Pin memory for member tensor variables.""" - super().pin_memory() - self.labels = [label.pin_memory() for label in self.labels] - return self + return super().pin_memory().wrap(labels=[label.pin_memory() for label in self.labels]) @dataclass -class MultilabelClsBatchPredEntity(MultilabelClsBatchDataEntity, OTXBatchPredEntity): +class MultilabelClsBatchPredEntity(OTXBatchPredEntity, MultilabelClsBatchDataEntity): """Data entity to represent model output predictions for multi-label classification task.""" -@dataclass -class MultilabelClsBatchPredEntityWithXAI(MultilabelClsBatchDataEntity, OTXBatchPredEntityWithXAI): - """Data entity to represent model output predictions for multi-label classification task with explanations.""" - - @register_pytree_node @dataclass class HlabelClsDataEntity(OTXDataEntity): @@ -185,15 +159,10 @@ def task(self) -> OTXTaskType: @dataclass -class HlabelClsPredEntity(HlabelClsDataEntity, OTXPredEntity): +class HlabelClsPredEntity(OTXPredEntity, HlabelClsDataEntity): """Data entity to represent the H-label classification model output prediction.""" -@dataclass -class HlabelClsPredEntityWithXAI(HlabelClsDataEntity, OTXPredEntityWithXAI): - """Data entity to represent the H-label classification model output prediction with explanation.""" - - @dataclass class HlabelClsBatchDataEntity(OTXBatchDataEntity[HlabelClsDataEntity]): """Data entity for H-label classification task. @@ -226,10 +195,5 @@ def collate_fn( @dataclass -class HlabelClsBatchPredEntity(HlabelClsBatchDataEntity, OTXBatchPredEntity): +class HlabelClsBatchPredEntity(OTXBatchPredEntity, HlabelClsBatchDataEntity): """Data entity to represent model output predictions for H-label classification task.""" - - -@dataclass -class HlabelClsBatchPredEntityWithXAI(HlabelClsBatchDataEntity, OTXBatchPredEntityWithXAI): - """Data entity to represent model output predictions for H-label classification task with explanations.""" diff --git a/src/otx/core/data/entity/detection.py b/src/otx/core/data/entity/detection.py index c72dd91f9d5..3f1bc45f114 100644 --- a/src/otx/core/data/entity/detection.py +++ b/src/otx/core/data/entity/detection.py @@ -13,10 +13,8 @@ from otx.core.data.entity.base import ( OTXBatchDataEntity, OTXBatchPredEntity, - OTXBatchPredEntityWithXAI, OTXDataEntity, OTXPredEntity, - OTXPredEntityWithXAI, ) from otx.core.data.entity.utils import register_pytree_node from otx.core.types.task import OTXTaskType @@ -45,15 +43,10 @@ def task(self) -> OTXTaskType: @dataclass -class DetPredEntity(DetDataEntity, OTXPredEntity): +class DetPredEntity(OTXPredEntity, DetDataEntity): """Data entity to represent the detection model output prediction.""" -@dataclass -class DetPredEntityWithXAI(DetDataEntity, OTXPredEntityWithXAI): - """Data entity to represent the detection model output prediction with explanations.""" - - @dataclass class DetBatchDataEntity(OTXBatchDataEntity[DetDataEntity]): """Data entity for detection task. @@ -98,17 +91,16 @@ def collate_fn( def pin_memory(self) -> DetBatchDataEntity: """Pin memory for member tensor variables.""" - super().pin_memory() - self.bboxes = [tv_tensors.wrap(bbox.pin_memory(), like=bbox) for bbox in self.bboxes] - self.labels = [label.pin_memory() for label in self.labels] - return self + return ( + super() + .pin_memory() + .wrap( + bboxes=[tv_tensors.wrap(bbox.pin_memory(), like=bbox) for bbox in self.bboxes], + labels=[label.pin_memory() for label in self.labels], + ) + ) @dataclass -class DetBatchPredEntity(DetBatchDataEntity, OTXBatchPredEntity): +class DetBatchPredEntity(OTXBatchPredEntity, DetBatchDataEntity): """Data entity to represent model output predictions for detection task.""" - - -@dataclass -class DetBatchPredEntityWithXAI(DetBatchDataEntity, OTXBatchPredEntityWithXAI): - """Data entity to represent model output predictions for detection task with explanations.""" diff --git a/src/otx/core/data/entity/instance_segmentation.py b/src/otx/core/data/entity/instance_segmentation.py index 89729639f41..62046513608 100644 --- a/src/otx/core/data/entity/instance_segmentation.py +++ b/src/otx/core/data/entity/instance_segmentation.py @@ -12,7 +12,7 @@ from otx.core.types.task import OTXTaskType -from .base import OTXBatchDataEntity, OTXBatchPredEntity, OTXBatchPredEntityWithXAI, OTXDataEntity, OTXPredEntity +from .base import OTXBatchDataEntity, OTXBatchPredEntity, OTXDataEntity, OTXPredEntity if TYPE_CHECKING: from datumaro import Polygon @@ -42,15 +42,10 @@ def task(self) -> OTXTaskType: @dataclass -class InstanceSegPredEntity(InstanceSegDataEntity, OTXPredEntity): +class InstanceSegPredEntity(OTXPredEntity, InstanceSegDataEntity): """Data entity to represent the detection model output prediction.""" -@dataclass -class InstanceSegPredEntityWithXAI(InstanceSegDataEntity, OTXBatchPredEntityWithXAI): - """Data entity to represent the detection model output prediction with explanation.""" - - @dataclass class InstanceSegBatchDataEntity(OTXBatchDataEntity[InstanceSegDataEntity]): """Batch entity for InstanceSegDataEntity. @@ -101,18 +96,17 @@ def collate_fn( def pin_memory(self) -> InstanceSegBatchDataEntity: """Pin memory for member tensor variables.""" - super().pin_memory() - self.bboxes = [tv_tensors.wrap(bbox.pin_memory(), like=bbox) for bbox in self.bboxes] - self.masks = [tv_tensors.wrap(mask.pin_memory(), like=mask) for mask in self.masks] - self.labels = [label.pin_memory() for label in self.labels] - return self + return ( + super() + .pin_memory() + .wrap( + bboxes=[tv_tensors.wrap(bbox.pin_memory(), like=bbox) for bbox in self.bboxes], + masks=[tv_tensors.wrap(mask.pin_memory(), like=mask) for mask in self.masks], + labels=[label.pin_memory() for label in self.labels], + ) + ) @dataclass -class InstanceSegBatchPredEntity(InstanceSegBatchDataEntity, OTXBatchPredEntity): +class InstanceSegBatchPredEntity(OTXBatchPredEntity, InstanceSegBatchDataEntity): """Data entity to represent model output predictions for instance segmentation task.""" - - -@dataclass -class InstanceSegBatchPredEntityWithXAI(InstanceSegBatchDataEntity, OTXBatchPredEntityWithXAI): - """Data entity to represent model output predictions for instance segmentation task with explanations.""" diff --git a/src/otx/core/data/entity/segmentation.py b/src/otx/core/data/entity/segmentation.py index 457bc26989a..638461fb209 100644 --- a/src/otx/core/data/entity/segmentation.py +++ b/src/otx/core/data/entity/segmentation.py @@ -12,10 +12,8 @@ from otx.core.data.entity.base import ( OTXBatchDataEntity, OTXBatchPredEntity, - OTXBatchPredEntityWithXAI, OTXDataEntity, OTXPredEntity, - OTXPredEntityWithXAI, ) from otx.core.data.entity.utils import register_pytree_node from otx.core.types.task import OTXTaskType @@ -38,15 +36,10 @@ def task(self) -> OTXTaskType: @dataclass -class SegPredEntity(SegDataEntity, OTXPredEntity): +class SegPredEntity(OTXPredEntity, SegDataEntity): """Data entity to represent the segmentation model output prediction.""" -@dataclass -class SegPredEntityWithXAI(SegDataEntity, OTXPredEntityWithXAI): - """Data entity to represent the segmentation model output prediction with explanation.""" - - @dataclass class SegBatchDataEntity(OTXBatchDataEntity[SegDataEntity]): """Data entity for segmentation task. @@ -78,16 +71,9 @@ def collate_fn( def pin_memory(self) -> SegBatchDataEntity: """Pin memory for member tensor variables.""" - super().pin_memory() - self.masks = [tv_tensors.wrap(mask.pin_memory(), like=mask) for mask in self.masks] - return self + return super().pin_memory().wrap(masks=[tv_tensors.wrap(mask.pin_memory(), like=mask) for mask in self.masks]) @dataclass -class SegBatchPredEntity(SegBatchDataEntity, OTXBatchPredEntity): +class SegBatchPredEntity(OTXBatchPredEntity, SegBatchDataEntity): """Data entity to represent model output predictions for segmentation task.""" - - -@dataclass -class SegBatchPredEntityWithXAI(SegBatchDataEntity, OTXBatchPredEntityWithXAI): - """Data entity to represent model output predictions for segmentation task with explanations.""" diff --git a/src/otx/core/data/entity/visual_prompting.py b/src/otx/core/data/entity/visual_prompting.py index 0d31106585a..2924a65ab64 100644 --- a/src/otx/core/data/entity/visual_prompting.py +++ b/src/otx/core/data/entity/visual_prompting.py @@ -13,7 +13,6 @@ from otx.core.data.entity.base import ( OTXBatchDataEntity, OTXBatchPredEntity, - OTXBatchPredEntityWithXAI, OTXDataEntity, OTXPredEntity, Points, @@ -52,7 +51,7 @@ def task(self) -> OTXTaskType: @dataclass -class VisualPromptingPredEntity(VisualPromptingDataEntity, OTXPredEntity): +class VisualPromptingPredEntity(OTXPredEntity, VisualPromptingDataEntity): """Data entity to represent the visual prompting model output prediction.""" @@ -107,27 +106,28 @@ def collate_fn( def pin_memory(self) -> VisualPromptingBatchDataEntity: """Pin memory for member tensor variables.""" - super().pin_memory() - self.points = [ - tv_tensors.wrap(point.pin_memory(), like=point) if point is not None else point for point in self.points - ] - self.bboxes = [ - tv_tensors.wrap(bbox.pin_memory(), like=bbox) if bbox is not None else bbox for bbox in self.bboxes - ] - self.masks = [tv_tensors.wrap(mask.pin_memory(), like=mask) for mask in self.masks] - self.labels = [ - {prompt_type: values.pin_memory() for prompt_type, values in labels.items()} for labels in self.labels - ] - return self - - -@dataclass -class VisualPromptingBatchPredEntity(VisualPromptingBatchDataEntity, OTXBatchPredEntity): - """Data entity to represent model output predictions for visual prompting task.""" + return ( + super() + .pin_memory() + .wrap( + points=[ + tv_tensors.wrap(point.pin_memory(), like=point) if point is not None else point + for point in self.points + ], + bboxes=[ + tv_tensors.wrap(bbox.pin_memory(), like=bbox) if bbox is not None else bbox for bbox in self.bboxes + ], + masks=[tv_tensors.wrap(mask.pin_memory(), like=mask) for mask in self.masks], + labels=[ + {prompt_type: values.pin_memory() for prompt_type, values in labels.items()} + for labels in self.labels + ], + ) + ) @dataclass -class VisualPromptingBatchPredEntityWithXAI(VisualPromptingBatchPredEntity, OTXBatchPredEntityWithXAI): +class VisualPromptingBatchPredEntity(OTXBatchPredEntity, VisualPromptingBatchDataEntity): """Data entity to represent model output predictions for visual prompting task.""" @@ -202,22 +202,22 @@ def collate_fn( def pin_memory(self) -> ZeroShotVisualPromptingBatchDataEntity: """Pin memory for member tensor variables.""" - super().pin_memory() - self.prompts = [ - [tv_tensors.wrap(prompt.pin_memory(), like=prompt) for prompt in prompts] for prompts in self.prompts - ] - self.masks = [tv_tensors.wrap(mask.pin_memory(), like=mask) for mask in self.masks] - self.labels = [label.pin_memory() for label in self.labels] - return self + return ( + super() + .pin_memory() + .wrap( + prompts=[ + [tv_tensors.wrap(prompt.pin_memory(), like=prompt) for prompt in prompts] + for prompts in self.prompts + ], + masks=[tv_tensors.wrap(mask.pin_memory(), like=mask) for mask in self.masks], + labels=[label.pin_memory() for label in self.labels], + ) + ) @dataclass -class ZeroShotVisualPromptingBatchPredEntity(ZeroShotVisualPromptingBatchDataEntity, OTXBatchPredEntity): +class ZeroShotVisualPromptingBatchPredEntity(OTXBatchPredEntity, ZeroShotVisualPromptingBatchDataEntity): """Data entity to represent model output predictions for zero-shot visual prompting task.""" prompts: list[Points] # type: ignore[assignment] - - -@dataclass -class ZeroShotVisualPromptingBatchPredEntityWithXAI(ZeroShotVisualPromptingBatchPredEntity, OTXBatchPredEntityWithXAI): - """Data entity to represent model output predictions for visual prompting task.""" diff --git a/src/otx/core/data/transform_libs/torchvision.py b/src/otx/core/data/transform_libs/torchvision.py index a99cff29cdb..e90bac09c71 100644 --- a/src/otx/core/data/transform_libs/torchvision.py +++ b/src/otx/core/data/transform_libs/torchvision.py @@ -48,7 +48,8 @@ def custom_query_size(flat_inputs: list[Any]) -> tuple[int, int]: # noqa: D103 if not sizes: raise TypeError("No image, video, mask, bounding box, or point was found in the sample") # noqa: EM101, TRY003 elif len(sizes) > 1: # noqa: RET506 - raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}") # noqa: EM102, TRY003 + msg = f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}" + raise ValueError(msg) h, w = sizes.pop() return h, w @@ -275,10 +276,7 @@ class PackVideo(tvt_v2.Transform): def forward(self, *inputs: ActionClsDataEntity) -> ActionClsDataEntity: """Replace ActionClsDataEntity's image to ActionClsDataEntity's video.""" - inputs[0].image = inputs[0].video - inputs[0].video = [] - - return inputs[0] + return inputs[0].wrap(image=inputs[0].video, video=[]) tvt_v2.PerturbBoundingBoxes = PerturbBoundingBoxes diff --git a/src/otx/core/model/action_classification.py b/src/otx/core/model/action_classification.py index 20dfa94b3c1..77f164780e3 100644 --- a/src/otx/core/model/action_classification.py +++ b/src/otx/core/model/action_classification.py @@ -10,11 +10,7 @@ import numpy as np import torch -from otx.core.data.entity.action_classification import ( - ActionClsBatchDataEntity, - ActionClsBatchPredEntity, - ActionClsBatchPredEntityWithXAI, -) +from otx.core.data.entity.action_classification import ActionClsBatchDataEntity, ActionClsBatchPredEntity from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.exporter.native import OTXNativeModelExporter @@ -39,7 +35,6 @@ class OTXActionClsModel( OTXModel[ ActionClsBatchDataEntity, ActionClsBatchPredEntity, - ActionClsBatchPredEntityWithXAI, T_OTXTileBatchDataEntity, ], ): @@ -75,7 +70,7 @@ def _export_parameters(self) -> dict[str, Any]: def _convert_pred_entity_to_compute_metric( self, - preds: ActionClsBatchPredEntity | ActionClsBatchPredEntityWithXAI, + preds: ActionClsBatchPredEntity, inputs: ActionClsBatchDataEntity, ) -> MetricInput: pred = torch.tensor(preds.labels) @@ -201,7 +196,7 @@ def _exporter(self) -> OTXModelExporter: class OVActionClsModel( - OVModel[ActionClsBatchDataEntity, ActionClsBatchPredEntity, ActionClsBatchPredEntityWithXAI], + OVModel[ActionClsBatchDataEntity, ActionClsBatchPredEntity], ): """Action Classification model compatible for OpenVINO IR inference. diff --git a/src/otx/core/model/action_detection.py b/src/otx/core/model/action_detection.py index b74a9c11131..7529f93602c 100644 --- a/src/otx/core/model/action_detection.py +++ b/src/otx/core/model/action_detection.py @@ -9,11 +9,7 @@ from torchvision import tv_tensors -from otx.core.data.entity.action_detection import ( - ActionDetBatchDataEntity, - ActionDetBatchPredEntity, - ActionDetBatchPredEntityWithXAI, -) +from otx.core.data.entity.action_detection import ActionDetBatchDataEntity, ActionDetBatchPredEntity from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.metrics import MetricInput @@ -34,7 +30,6 @@ class OTXActionDetModel( OTXModel[ ActionDetBatchDataEntity, ActionDetBatchPredEntity, - ActionDetBatchPredEntityWithXAI, T_OTXTileBatchDataEntity, ], ): @@ -58,7 +53,7 @@ def __init__( def _convert_pred_entity_to_compute_metric( self, - preds: ActionDetBatchPredEntity | ActionDetBatchPredEntityWithXAI, + preds: ActionDetBatchPredEntity, inputs: ActionDetBatchDataEntity, ) -> MetricInput: return { diff --git a/src/otx/core/model/base.py b/src/otx/core/model/base.py index 7df705a0f3d..00d4474410a 100644 --- a/src/otx/core/model/base.py +++ b/src/otx/core/model/base.py @@ -28,7 +28,6 @@ OTXBatchLossEntity, T_OTXBatchDataEntity, T_OTXBatchPredEntity, - T_OTXBatchPredEntityWithXAI, ) from otx.core.data.entity.tile import OTXTileBatchDataEntity, T_OTXTileBatchDataEntity from otx.core.exporter.base import OTXModelExporter @@ -75,10 +74,7 @@ def _default_scheduler_callable( DefaultSchedulerCallable = _default_scheduler_callable -class OTXModel( - LightningModule, - Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity, T_OTXBatchPredEntityWithXAI, T_OTXTileBatchDataEntity], -): +class OTXModel(LightningModule, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity, T_OTXTileBatchDataEntity]): """Base class for the models used in OTX. Args: @@ -203,7 +199,7 @@ def predict_step( batch: T_OTXBatchDataEntity, batch_idx: int, dataloader_idx: int = 0, - ) -> T_OTXBatchPredEntity | T_OTXBatchPredEntityWithXAI: + ) -> T_OTXBatchPredEntity: """Step function called during PyTorch Lightning Trainer's predict.""" if self.explain_mode: return self.forward_explain(inputs=batch) @@ -300,7 +296,7 @@ def metric(self) -> Metric | MetricCollection: @abstractmethod def _convert_pred_entity_to_compute_metric( self, - preds: T_OTXBatchPredEntity | T_OTXBatchPredEntityWithXAI, + preds: T_OTXBatchPredEntity, inputs: T_OTXBatchDataEntity, ) -> MetricInput: """Convert given inputs to a Python dictionary for the metric computation.""" @@ -443,14 +439,14 @@ def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: T_OTXBatchDataEntity, - ) -> T_OTXBatchPredEntity | T_OTXBatchPredEntityWithXAI | OTXBatchLossEntity: + ) -> T_OTXBatchPredEntity | OTXBatchLossEntity: """Customize OTX output batch data entity if needed for model.""" raise NotImplementedError def forward( self, inputs: T_OTXBatchDataEntity, - ) -> T_OTXBatchPredEntity | T_OTXBatchPredEntityWithXAI | OTXBatchLossEntity: + ) -> T_OTXBatchPredEntity | OTXBatchLossEntity: """Model forward function.""" # If customize_inputs is overridden if isinstance(inputs, OTXTileBatchDataEntity): @@ -468,10 +464,7 @@ def forward( else outputs ) - def forward_explain( - self, - inputs: T_OTXBatchDataEntity, - ) -> T_OTXBatchPredEntityWithXAI: + def forward_explain(self, inputs: T_OTXBatchDataEntity) -> T_OTXBatchPredEntity: """Model forward explain function.""" raise NotImplementedError @@ -488,7 +481,7 @@ def _restore_model_forward(self) -> None: def forward_tiles( self, inputs: T_OTXTileBatchDataEntity, - ) -> T_OTXBatchPredEntity | T_OTXBatchPredEntityWithXAI | OTXBatchLossEntity: + ) -> T_OTXBatchPredEntity | OTXBatchLossEntity: """Model forward function for tile task.""" raise NotImplementedError @@ -674,7 +667,7 @@ def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Tensor) -> return super().lr_scheduler_step(scheduler=scheduler, metric=metric) -class OVModel(OTXModel, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity, T_OTXBatchPredEntityWithXAI]): +class OVModel(OTXModel, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]): """Base class for the OpenVINO model. This is a base class representing interface for interacting with OpenVINO @@ -742,10 +735,7 @@ def _customize_inputs(self, entity: T_OTXBatchDataEntity) -> dict[str, Any]: images = [np.transpose(im.cpu().numpy(), (1, 2, 0)) for im in entity.images] return {"inputs": images} - def _forward( - self, - inputs: T_OTXBatchDataEntity, - ) -> T_OTXBatchPredEntity | T_OTXBatchPredEntityWithXAI: + def _forward(self, inputs: T_OTXBatchDataEntity) -> T_OTXBatchPredEntity: """Model forward function.""" def _callback(result: NamedTuple, idx: int) -> None: @@ -775,7 +765,7 @@ def forward(self, inputs: T_OTXBatchDataEntity) -> T_OTXBatchPredEntity: """Model forward function.""" return self._forward(inputs=inputs) # type: ignore[return-value] - def forward_explain(self, inputs: T_OTXBatchDataEntity) -> T_OTXBatchPredEntityWithXAI: + def forward_explain(self, inputs: T_OTXBatchDataEntity) -> T_OTXBatchPredEntity: """Model forward explain function.""" return self._forward(inputs=inputs) # type: ignore[return-value] diff --git a/src/otx/core/model/classification.py b/src/otx/core/model/classification.py index 39684b11073..0e6bcd1b5c1 100644 --- a/src/otx/core/model/classification.py +++ b/src/otx/core/model/classification.py @@ -18,18 +18,14 @@ OTXBatchLossEntity, T_OTXBatchDataEntity, T_OTXBatchPredEntity, - T_OTXBatchPredEntityWithXAI, ) from otx.core.data.entity.classification import ( HlabelClsBatchDataEntity, HlabelClsBatchPredEntity, - HlabelClsBatchPredEntityWithXAI, MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity, - MulticlassClsBatchPredEntityWithXAI, MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity, - MultilabelClsBatchPredEntityWithXAI, ) from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.exporter.base import OTXModelExporter @@ -59,7 +55,7 @@ class ExplainableOTXClsModel( - OTXModel[T_OTXBatchDataEntity, T_OTXBatchPredEntity, T_OTXBatchPredEntityWithXAI, T_OTXTileBatchDataEntity], + OTXModel[T_OTXBatchDataEntity, T_OTXBatchPredEntity, T_OTXTileBatchDataEntity], ): """OTX classification model which can attach a XAI hook.""" @@ -86,10 +82,7 @@ def head_forward_fn(self, x: torch.Tensor) -> torch.Tensor: output = neck(x) return head([output]) - def forward_explain( - self, - inputs: T_OTXBatchDataEntity, - ) -> T_OTXBatchPredEntityWithXAI: + def forward_explain(self, inputs: T_OTXBatchDataEntity) -> T_OTXBatchPredEntity: """Model forward function.""" self.model.feature_vector_fn = feature_vector_fn self.model.explain_fn = self.get_explain_fn() @@ -184,7 +177,6 @@ class OTXMulticlassClsModel( ExplainableOTXClsModel[ MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity, - MulticlassClsBatchPredEntityWithXAI, T_OTXTileBatchDataEntity, ], ): @@ -222,7 +214,7 @@ def _export_parameters(self) -> dict[str, Any]: def _convert_pred_entity_to_compute_metric( self, - preds: MulticlassClsBatchPredEntity | MulticlassClsBatchPredEntityWithXAI, + preds: MulticlassClsBatchPredEntity, inputs: MulticlassClsBatchDataEntity, ) -> MetricInput: pred = torch.tensor(preds.labels) @@ -301,7 +293,7 @@ def _customize_outputs( self, outputs: dict[str, Any], inputs: MulticlassClsBatchDataEntity, - ) -> MulticlassClsBatchPredEntity | MulticlassClsBatchPredEntityWithXAI | OTXBatchLossEntity: + ) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity: from mmpretrain.structures import DataSample if self.training: @@ -340,7 +332,7 @@ def _customize_outputs( feature_vectors = outputs["feature_vector"].detach().cpu().numpy() saliency_maps = outputs["saliency_map"].detach().cpu().numpy() - return MulticlassClsBatchPredEntityWithXAI( + return MulticlassClsBatchPredEntity( batch_size=len(predictions), images=inputs.images, imgs_info=inputs.imgs_info, @@ -381,7 +373,6 @@ class OTXMultilabelClsModel( ExplainableOTXClsModel[ MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity, - MultilabelClsBatchPredEntityWithXAI, T_OTXTileBatchDataEntity, ], ): @@ -420,7 +411,7 @@ def _export_parameters(self) -> dict[str, Any]: def _convert_pred_entity_to_compute_metric( self, - preds: MultilabelClsBatchPredEntity | MultilabelClsBatchPredEntityWithXAI, + preds: MultilabelClsBatchPredEntity, inputs: MultilabelClsBatchDataEntity, ) -> MetricInput: return { @@ -499,7 +490,7 @@ def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: MultilabelClsBatchDataEntity, - ) -> MultilabelClsBatchPredEntity | MultilabelClsBatchPredEntityWithXAI | OTXBatchLossEntity: + ) -> MultilabelClsBatchPredEntity | OTXBatchLossEntity: from mmpretrain.structures import DataSample if self.training: @@ -538,7 +529,7 @@ def _customize_outputs( feature_vectors = outputs["feature_vector"].detach().cpu().numpy() saliency_maps = outputs["saliency_map"].detach().cpu().numpy() - return MultilabelClsBatchPredEntityWithXAI( + return MultilabelClsBatchPredEntity( batch_size=len(predictions), images=inputs.images, imgs_info=inputs.imgs_info, @@ -575,7 +566,6 @@ class OTXHlabelClsModel( ExplainableOTXClsModel[ HlabelClsBatchDataEntity, HlabelClsBatchPredEntity, - HlabelClsBatchPredEntityWithXAI, T_OTXTileBatchDataEntity, ], ): @@ -626,7 +616,7 @@ def _export_parameters(self) -> dict[str, Any]: def _convert_pred_entity_to_compute_metric( self, - preds: HlabelClsBatchPredEntity | HlabelClsBatchPredEntityWithXAI, + preds: HlabelClsBatchPredEntity, inputs: HlabelClsBatchDataEntity, ) -> MetricInput: hlabel_info: HLabelInfo = self.label_info # type: ignore[assignment] @@ -720,7 +710,7 @@ def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: HlabelClsBatchDataEntity, - ) -> HlabelClsBatchPredEntity | HlabelClsBatchPredEntityWithXAI | OTXBatchLossEntity: + ) -> HlabelClsBatchPredEntity | OTXBatchLossEntity: from mmpretrain.structures import DataSample if self.training: @@ -759,7 +749,7 @@ def _customize_outputs( feature_vectors = outputs["feature_vector"].detach().cpu().numpy() saliency_maps = outputs["saliency_map"].detach().cpu().numpy() - return HlabelClsBatchPredEntityWithXAI( + return HlabelClsBatchPredEntity( batch_size=len(outputs), images=inputs.images, imgs_info=inputs.imgs_info, @@ -793,7 +783,7 @@ def _export_parameters(self) -> dict[str, Any]: class OVMulticlassClassificationModel( - OVModel[MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity, MulticlassClsBatchPredEntityWithXAI], + OVModel[MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity], ): """Classification model compatible for OpenVINO IR inference. @@ -826,7 +816,7 @@ def _customize_outputs( self, outputs: list[ClassificationResult], inputs: MulticlassClsBatchDataEntity, - ) -> MulticlassClsBatchPredEntity | MulticlassClsBatchPredEntityWithXAI: + ) -> MulticlassClsBatchPredEntity: pred_labels = [torch.tensor(out.top_labels[0][0], dtype=torch.long) for out in outputs] pred_scores = [torch.tensor(out.top_labels[0][2]) for out in outputs] @@ -836,7 +826,7 @@ def _customize_outputs( # Squeeze dim 2D => 1D, (1, internal_dim) => (internal_dim) predicted_f_vectors = [out.feature_vector[0] for out in outputs] - return MulticlassClsBatchPredEntityWithXAI( + return MulticlassClsBatchPredEntity( batch_size=len(outputs), images=inputs.images, imgs_info=inputs.imgs_info, @@ -856,7 +846,7 @@ def _customize_outputs( def _convert_pred_entity_to_compute_metric( self, - preds: MulticlassClsBatchPredEntity | MulticlassClsBatchPredEntityWithXAI, + preds: MulticlassClsBatchPredEntity, inputs: MulticlassClsBatchDataEntity, ) -> MetricInput: pred = torch.tensor(preds.labels) @@ -867,9 +857,7 @@ def _convert_pred_entity_to_compute_metric( } -class OVMultilabelClassificationModel( - OVModel[MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity, MultilabelClsBatchPredEntityWithXAI], -): +class OVMultilabelClassificationModel(OVModel[MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity]): """Multilabel classification model compatible for OpenVINO IR inference. It can consume OpenVINO IR model path or model name from Intel OMZ repository @@ -903,7 +891,7 @@ def _customize_outputs( self, outputs: list[ClassificationResult], inputs: MultilabelClsBatchDataEntity, - ) -> MultilabelClsBatchPredEntity | MultilabelClsBatchPredEntityWithXAI: + ) -> MultilabelClsBatchPredEntity: pred_scores = [torch.tensor([top_label[2] for top_label in out.top_labels]) for out in outputs] if outputs and outputs[0].saliency_map.size != 0: @@ -912,7 +900,7 @@ def _customize_outputs( # Squeeze dim 2D => 1D, (1, internal_dim) => (internal_dim) predicted_f_vectors = [out.feature_vector[0] for out in outputs] - return MultilabelClsBatchPredEntityWithXAI( + return MultilabelClsBatchPredEntity( batch_size=len(outputs), images=inputs.images, imgs_info=inputs.imgs_info, @@ -932,7 +920,7 @@ def _customize_outputs( def _convert_pred_entity_to_compute_metric( self, - preds: MultilabelClsBatchPredEntity | MultilabelClsBatchPredEntityWithXAI, + preds: MultilabelClsBatchPredEntity, inputs: MultilabelClsBatchDataEntity, ) -> MetricInput: return { @@ -941,9 +929,7 @@ def _convert_pred_entity_to_compute_metric( } -class OVHlabelClassificationModel( - OVModel[HlabelClsBatchDataEntity, HlabelClsBatchPredEntity, HlabelClsBatchPredEntityWithXAI], -): +class OVHlabelClassificationModel(OVModel[HlabelClsBatchDataEntity, HlabelClsBatchPredEntity]): """Hierarchical classification model compatible for OpenVINO IR inference. It can consume OpenVINO IR model path or model name from Intel OMZ repository @@ -977,7 +963,7 @@ def _customize_outputs( self, outputs: list[ClassificationResult], inputs: HlabelClsBatchDataEntity, - ) -> HlabelClsBatchPredEntity | HlabelClsBatchPredEntityWithXAI: + ) -> HlabelClsBatchPredEntity: all_pred_labels = [] all_pred_scores = [] for output in outputs: @@ -1012,7 +998,7 @@ def _customize_outputs( # Squeeze dim 2D => 1D, (1, internal_dim) => (internal_dim) predicted_f_vectors = [out.feature_vector[0] for out in outputs] - return HlabelClsBatchPredEntityWithXAI( + return HlabelClsBatchPredEntity( batch_size=len(outputs), images=inputs.images, imgs_info=inputs.imgs_info, @@ -1032,7 +1018,7 @@ def _customize_outputs( def _convert_pred_entity_to_compute_metric( self, - preds: HlabelClsBatchPredEntity | HlabelClsBatchPredEntityWithXAI, + preds: HlabelClsBatchPredEntity, inputs: HlabelClsBatchDataEntity, ) -> MetricInput: cls_heads_info = self.model.hierarchical_info["cls_heads_info"] diff --git a/src/otx/core/model/detection.py b/src/otx/core/model/detection.py index f1f654c34c3..e0e2fbefd8c 100644 --- a/src/otx/core/model/detection.py +++ b/src/otx/core/model/detection.py @@ -17,7 +17,7 @@ from otx.core.config.data import TileConfig from otx.core.data.entity.base import OTXBatchLossEntity -from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity, DetBatchPredEntityWithXAI +from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity from otx.core.data.entity.tile import TileBatchDetDataEntity from otx.core.exporter.base import OTXModelExporter from otx.core.metrics import MetricInput @@ -41,9 +41,7 @@ from otx.core.metrics import MetricCallable -class OTXDetectionModel( - OTXModel[DetBatchDataEntity, DetBatchPredEntity, DetBatchPredEntityWithXAI, TileBatchDetDataEntity], -): +class OTXDetectionModel(OTXModel[DetBatchDataEntity, DetBatchPredEntity, TileBatchDetDataEntity]): """Base class for the detection models used in OTX.""" def __init__( @@ -63,7 +61,7 @@ def __init__( ) self.tile_config = TileConfig() - def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity | DetBatchPredEntityWithXAI: + def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity: """Unpack detection tiles. Args: @@ -72,7 +70,7 @@ def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity | Returns: DetBatchPredEntity: Merged detection prediction. """ - tile_preds: list[DetBatchPredEntity | DetBatchPredEntityWithXAI] = [] + tile_preds: list[DetBatchPredEntity] = [] tile_attrs: list[list[dict[str, int | str]]] = [] merger = DetectionTileMerge( inputs.imgs_info, @@ -124,7 +122,7 @@ def _export_parameters(self) -> dict[str, Any]: def _convert_pred_entity_to_compute_metric( self, - preds: DetBatchPredEntity | DetBatchPredEntityWithXAI, + preds: DetBatchPredEntity, inputs: DetBatchDataEntity, ) -> MetricInput: return { @@ -190,7 +188,7 @@ class ExplainableOTXDetModel(OTXDetectionModel): def forward_explain( self, inputs: DetBatchDataEntity, - ) -> DetBatchPredEntityWithXAI: + ) -> DetBatchPredEntity: """Model forward function.""" from otx.algo.hooks.recording_forward_hook import feature_vector_fn @@ -413,7 +411,7 @@ def _customize_outputs( self, outputs: dict[str, Any], inputs: DetBatchDataEntity, - ) -> DetBatchPredEntity | DetBatchPredEntityWithXAI | OTXBatchLossEntity: + ) -> DetBatchPredEntity | OTXBatchLossEntity: from mmdet.structures import DetDataSample if self.training: @@ -465,7 +463,7 @@ def _customize_outputs( saliency_maps = outputs["saliency_map"].detach().cpu().numpy() feature_vectors = outputs["feature_vector"].detach().cpu().numpy() - return DetBatchPredEntityWithXAI( + return DetBatchPredEntity( batch_size=len(predictions), images=inputs.images, imgs_info=inputs.imgs_info, @@ -493,7 +491,7 @@ def _exporter(self) -> OTXModelExporter: return MMdeployExporter(**self._export_parameters) -class OVDetectionModel(OVModel[DetBatchDataEntity, DetBatchPredEntity, DetBatchPredEntityWithXAI]): +class OVDetectionModel(OVModel[DetBatchDataEntity, DetBatchPredEntity]): """Object detection model compatible for OpenVINO IR inference. It can consume OpenVINO IR model path or model name from Intel OMZ repository @@ -567,7 +565,7 @@ def _customize_outputs( self, outputs: list[DetectionResult], inputs: DetBatchDataEntity, - ) -> DetBatchPredEntity | DetBatchPredEntityWithXAI | OTXBatchLossEntity: + ) -> DetBatchPredEntity | OTXBatchLossEntity: # add label index bboxes = [] scores = [] @@ -606,7 +604,7 @@ def _customize_outputs( # Squeeze dim 2D => 1D, (1, internal_dim) => (internal_dim) predicted_f_vectors = [out.feature_vector[0] for out in outputs] - return DetBatchPredEntityWithXAI( + return DetBatchPredEntity( batch_size=len(outputs), images=inputs.images, imgs_info=inputs.imgs_info, @@ -628,7 +626,7 @@ def _customize_outputs( def _convert_pred_entity_to_compute_metric( self, - preds: DetBatchPredEntity | DetBatchPredEntityWithXAI, + preds: DetBatchPredEntity, inputs: DetBatchDataEntity, ) -> MetricInput: return { diff --git a/src/otx/core/model/instance_segmentation.py b/src/otx/core/model/instance_segmentation.py index 927b6dfa02b..ea2bd1bccff 100644 --- a/src/otx/core/model/instance_segmentation.py +++ b/src/otx/core/model/instance_segmentation.py @@ -19,14 +19,8 @@ from otx.algo.hooks.recording_forward_hook import MaskRCNNRecordingForwardHook, feature_vector_fn from otx.core.config.data import TileConfig -from otx.core.data.entity.base import ( - OTXBatchLossEntity, -) -from otx.core.data.entity.instance_segmentation import ( - InstanceSegBatchDataEntity, - InstanceSegBatchPredEntity, - InstanceSegBatchPredEntityWithXAI, -) +from otx.core.data.entity.base import OTXBatchLossEntity +from otx.core.data.entity.instance_segmentation import InstanceSegBatchDataEntity, InstanceSegBatchPredEntity from otx.core.data.entity.tile import TileBatchInstSegDataEntity from otx.core.exporter.base import OTXModelExporter from otx.core.metrics import MetricInput @@ -55,7 +49,6 @@ class OTXInstanceSegModel( OTXModel[ InstanceSegBatchDataEntity, InstanceSegBatchPredEntity, - InstanceSegBatchPredEntityWithXAI, TileBatchInstSegDataEntity, ], ): @@ -87,7 +80,7 @@ def forward_tiles(self, inputs: TileBatchInstSegDataEntity) -> InstanceSegBatchP Returns: InstanceSegBatchPredEntity: Merged instance segmentation prediction. """ - tile_preds: list[InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI] = [] + tile_preds: list[InstanceSegBatchPredEntity] = [] tile_attrs: list[list[dict[str, int | str]]] = [] merger = InstanceSegTileMerge( inputs.imgs_info, @@ -186,7 +179,7 @@ def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwa def _convert_pred_entity_to_compute_metric( self, - preds: InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI, + preds: InstanceSegBatchPredEntity, inputs: InstanceSegBatchDataEntity, ) -> MetricInput: """Convert the prediction entity to the format that the metric can compute and cache the ground truth. @@ -246,7 +239,7 @@ class ExplainableOTXInstanceSegModel(OTXInstanceSegModel): def forward_explain( self, inputs: InstanceSegBatchDataEntity, - ) -> InstanceSegBatchPredEntityWithXAI: + ) -> InstanceSegBatchPredEntity: """Model forward function.""" self.model.feature_vector_fn = feature_vector_fn self.model.explain_fn = self.get_explain_fn() @@ -455,7 +448,7 @@ def _customize_outputs( self, outputs: dict[str, Any], inputs: InstanceSegBatchDataEntity, - ) -> InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI | OTXBatchLossEntity: + ) -> InstanceSegBatchPredEntity | OTXBatchLossEntity: from mmdet.structures import DetDataSample if self.training: @@ -511,7 +504,7 @@ def _customize_outputs( saliency_maps = outputs["saliency_map"].detach().cpu().numpy() feature_vectors = outputs["feature_vector"].detach().cpu().numpy() - return InstanceSegBatchPredEntityWithXAI( + return InstanceSegBatchPredEntity( batch_size=len(predictions), images=inputs.images, imgs_info=inputs.imgs_info, @@ -544,7 +537,7 @@ def _exporter(self) -> OTXModelExporter: class OVInstanceSegmentationModel( - OVModel[InstanceSegBatchDataEntity, InstanceSegBatchPredEntity, InstanceSegBatchPredEntityWithXAI], + OVModel[InstanceSegBatchDataEntity, InstanceSegBatchPredEntity], ): """Instance segmentation model compatible for OpenVINO IR inference. @@ -619,7 +612,7 @@ def _customize_outputs( self, outputs: list[InstanceSegmentationResult], inputs: InstanceSegBatchDataEntity, - ) -> InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI | OTXBatchLossEntity: + ) -> InstanceSegBatchPredEntity | OTXBatchLossEntity: # add label index bboxes = [] scores = [] @@ -654,7 +647,7 @@ def _customize_outputs( # Squeeze dim 2D => 1D, (1, internal_dim) => (internal_dim) predicted_f_vectors = [out.feature_vector[0] for out in outputs] - return InstanceSegBatchPredEntityWithXAI( + return InstanceSegBatchPredEntity( batch_size=len(outputs), images=inputs.images, imgs_info=inputs.imgs_info, @@ -680,7 +673,7 @@ def _customize_outputs( def _convert_pred_entity_to_compute_metric( self, - preds: InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI, + preds: InstanceSegBatchPredEntity, inputs: InstanceSegBatchDataEntity, ) -> MetricInput: """Convert the prediction entity to the format that the metric can compute and cache the ground truth. diff --git a/src/otx/core/model/segmentation.py b/src/otx/core/model/segmentation.py index a1e05db9c58..d498bb75b65 100644 --- a/src/otx/core/model/segmentation.py +++ b/src/otx/core/model/segmentation.py @@ -12,7 +12,7 @@ from torchvision import tv_tensors from otx.core.data.entity.base import OTXBatchLossEntity -from otx.core.data.entity.segmentation import SegBatchDataEntity, SegBatchPredEntity, SegBatchPredEntityWithXAI +from otx.core.data.entity.segmentation import SegBatchDataEntity, SegBatchPredEntity from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.exporter.base import OTXModelExporter from otx.core.exporter.native import OTXNativeModelExporter @@ -34,9 +34,7 @@ from otx.core.metrics import MetricCallable -class OTXSegmentationModel( - OTXModel[SegBatchDataEntity, SegBatchPredEntity, SegBatchPredEntityWithXAI, T_OTXTileBatchDataEntity], -): +class OTXSegmentationModel(OTXModel[SegBatchDataEntity, SegBatchPredEntity, T_OTXTileBatchDataEntity]): """Base class for the detection models used in OTX.""" def __init__( @@ -76,7 +74,7 @@ def _export_parameters(self) -> dict[str, Any]: def _convert_pred_entity_to_compute_metric( self, - preds: SegBatchPredEntity | SegBatchPredEntityWithXAI, + preds: SegBatchPredEntity, inputs: SegBatchDataEntity, ) -> MetricInput: return [ @@ -158,7 +156,7 @@ def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: SegBatchDataEntity, - ) -> SegBatchPredEntity | SegBatchPredEntityWithXAI | OTXBatchLossEntity: + ) -> SegBatchPredEntity | OTXBatchLossEntity: from mmseg.structures import SegDataSample if self.training: @@ -182,7 +180,7 @@ def _customize_outputs( hook_records = self.explain_hook.records explain_results = copy.deepcopy(hook_records[-len(outputs) :]) - return SegBatchPredEntityWithXAI( + return SegBatchPredEntity( batch_size=len(outputs), images=inputs.images, imgs_info=inputs.imgs_info, @@ -220,7 +218,7 @@ def _exporter(self) -> OTXModelExporter: return OTXNativeModelExporter(**self._export_parameters) -class OVSegmentationModel(OVModel[SegBatchDataEntity, SegBatchPredEntity, SegBatchPredEntityWithXAI]): +class OVSegmentationModel(OVModel[SegBatchDataEntity, SegBatchPredEntity]): """Semantic segmentation model compatible for OpenVINO IR inference. It can consume OpenVINO IR model path or model name from Intel OMZ repository @@ -252,11 +250,11 @@ def _customize_outputs( self, outputs: list[ImageResultWithSoftPrediction], inputs: SegBatchDataEntity, - ) -> SegBatchPredEntity | SegBatchPredEntityWithXAI | OTXBatchLossEntity: + ) -> SegBatchPredEntity | OTXBatchLossEntity: if outputs and outputs[0].saliency_map.size != 1: predicted_s_maps = [out.saliency_map for out in outputs] predicted_f_vectors = [out.feature_vector for out in outputs] - return SegBatchPredEntityWithXAI( + return SegBatchPredEntity( batch_size=len(outputs), images=inputs.images, imgs_info=inputs.imgs_info, @@ -276,7 +274,7 @@ def _customize_outputs( def _convert_pred_entity_to_compute_metric( self, - preds: SegBatchPredEntity | SegBatchPredEntityWithXAI, + preds: SegBatchPredEntity, inputs: SegBatchDataEntity, ) -> MetricInput: return [ diff --git a/src/otx/core/model/visual_prompting.py b/src/otx/core/model/visual_prompting.py index 49b1157482b..bcb1d9608bf 100644 --- a/src/otx/core/model/visual_prompting.py +++ b/src/otx/core/model/visual_prompting.py @@ -25,15 +25,13 @@ from torch import Tensor from torchvision import tv_tensors -from otx.core.data.entity.base import OTXBatchLossEntity, Points, T_OTXBatchPredEntityWithXAI +from otx.core.data.entity.base import OTXBatchLossEntity, Points from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.data.entity.visual_prompting import ( VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity, - VisualPromptingBatchPredEntityWithXAI, ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingBatchPredEntity, - ZeroShotVisualPromptingBatchPredEntityWithXAI, ) from otx.core.exporter.base import OTXModelExporter from otx.core.exporter.visual_prompting import OTXVisualPromptingModelExporter @@ -171,12 +169,7 @@ def _inference_step_for_zero_shot( class OTXVisualPromptingModel( - OTXModel[ - VisualPromptingBatchDataEntity, - VisualPromptingBatchPredEntity, - VisualPromptingBatchPredEntityWithXAI, - T_OTXTileBatchDataEntity, - ], + OTXModel[VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity, T_OTXTileBatchDataEntity], ): """Base class for the visual prompting models used in OTX.""" @@ -270,7 +263,7 @@ def test_step(self, inputs: VisualPromptingBatchDataEntity, batch_idx: int) -> N def _convert_pred_entity_to_compute_metric( self, - preds: VisualPromptingBatchPredEntity | VisualPromptingBatchPredEntityWithXAI, + preds: VisualPromptingBatchPredEntity, inputs: VisualPromptingBatchDataEntity, ) -> MetricInput: """Convert the prediction entity to the format required by the compute metric function.""" @@ -285,7 +278,6 @@ class OTXZeroShotVisualPromptingModel( OTXModel[ ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingBatchPredEntity, - ZeroShotVisualPromptingBatchPredEntityWithXAI, T_OTXTileBatchDataEntity, ], ): @@ -446,7 +438,7 @@ def test_step( def _convert_pred_entity_to_compute_metric( self, - preds: ZeroShotVisualPromptingBatchPredEntity | ZeroShotVisualPromptingBatchPredEntityWithXAI, + preds: ZeroShotVisualPromptingBatchPredEntity, inputs: ZeroShotVisualPromptingBatchDataEntity, ) -> MetricInput: """Convert the prediction entity to the format required by the compute metric function.""" @@ -461,7 +453,6 @@ class OVVisualPromptingModel( OVModel[ VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity, - VisualPromptingBatchPredEntityWithXAI, ], ): """Visual prompting model compatible for OpenVINO IR inference. @@ -597,7 +588,7 @@ def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: VisualPromptingBatchDataEntity, # type: ignore[override] - ) -> VisualPromptingBatchPredEntity | T_OTXBatchPredEntityWithXAI | OTXBatchLossEntity: + ) -> VisualPromptingBatchPredEntity | OTXBatchLossEntity: """Customize OTX output batch data entity if needed for model.""" masks: list[tv_tensors.Mask] = [] scores: list[torch.Tensor] = [] @@ -921,7 +912,7 @@ def infer( def forward( # type: ignore[override] self, inputs: ZeroShotVisualPromptingBatchDataEntity, # type: ignore[override] - ) -> ZeroShotVisualPromptingBatchPredEntity | T_OTXBatchPredEntityWithXAI | OTXBatchLossEntity: + ) -> ZeroShotVisualPromptingBatchPredEntity | OTXBatchLossEntity: """Model forward function.""" kwargs: dict[str, Any] = {} fn = self.learn if self.training else self.infer @@ -992,7 +983,7 @@ def _customize_outputs( # type: ignore[override] self, outputs: Any, # noqa: ANN401 inputs: ZeroShotVisualPromptingBatchDataEntity, # type: ignore[override] - ) -> ZeroShotVisualPromptingBatchPredEntity | T_OTXBatchPredEntityWithXAI | OTXBatchLossEntity: + ) -> ZeroShotVisualPromptingBatchPredEntity | OTXBatchLossEntity: """Customize OTX output batch data entity if needed for model.""" if self.training: return outputs @@ -1426,7 +1417,7 @@ def test_step( def _convert_pred_entity_to_compute_metric( self, - preds: ZeroShotVisualPromptingBatchPredEntity | ZeroShotVisualPromptingBatchPredEntityWithXAI, + preds: ZeroShotVisualPromptingBatchPredEntity, inputs: ZeroShotVisualPromptingBatchDataEntity, ) -> MetricInput: """Convert the prediction entity to the format required by the compute metric function.""" diff --git a/src/otx/core/utils/tile_merge.py b/src/otx/core/utils/tile_merge.py index 97f19660981..41fe707cbd3 100644 --- a/src/otx/core/utils/tile_merge.py +++ b/src/otx/core/utils/tile_merge.py @@ -14,12 +14,8 @@ from torchvision.ops import batched_nms from otx.core.data.entity.base import ImageInfo, T_OTXBatchPredEntity, T_OTXDataEntity -from otx.core.data.entity.detection import DetBatchPredEntity, DetBatchPredEntityWithXAI, DetPredEntity -from otx.core.data.entity.instance_segmentation import ( - InstanceSegBatchPredEntity, - InstanceSegBatchPredEntityWithXAI, - InstanceSegPredEntity, -) +from otx.core.data.entity.detection import DetBatchPredEntity, DetPredEntity +from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity, InstanceSegPredEntity class TileMerge(Generic[T_OTXDataEntity, T_OTXBatchPredEntity]): @@ -94,7 +90,7 @@ class DetectionTileMerge(TileMerge): def merge( self, - batch_tile_preds: list[DetBatchPredEntity | DetBatchPredEntityWithXAI], + batch_tile_preds: list[DetBatchPredEntity], batch_tile_attrs: list[list[dict]], ) -> list[DetPredEntity]: """Merge batch tile predictions to a list of full-size prediction data entities. @@ -187,7 +183,7 @@ class InstanceSegTileMerge(TileMerge): def merge( self, - batch_tile_preds: list[InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI], + batch_tile_preds: list[InstanceSegBatchPredEntity], batch_tile_attrs: list[list[dict]], ) -> list[InstanceSegPredEntity]: """Merge inst-seg tile predictions to one single prediction. diff --git a/tests/conftest.py b/tests/conftest.py index 3fcffb2a655..e5b22579144 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,19 +18,23 @@ def fxt_seg_data_entity() -> tuple[tuple, SegDataEntity, SegBatchDataEntity]: fake_image_info = ImageInfo(img_idx=0, img_shape=img_size, ori_shape=img_size) fake_masks = Mask(torch.randint(low=0, high=255, size=img_size, dtype=torch.uint8)) # define data entity - single_data_entity = SegDataEntity(fake_image, fake_image_info, fake_masks) + single_data_entity = SegDataEntity( + image=fake_image, + img_info=fake_image_info, + gt_seg_map=fake_masks, + ) batch_data_entity = SegBatchDataEntity( - 1, - [Image(data=torch.from_numpy(fake_image))], - [fake_image_info], - [fake_masks], + batch_size=1, + images=[Image(data=torch.from_numpy(fake_image))], + imgs_info=[fake_image_info], + masks=[fake_masks], ) batch_pred_data_entity = SegBatchPredEntity( - 1, - [Image(data=torch.from_numpy(fake_image))], - [fake_image_info], - [], - [fake_masks], + batch_size=1, + images=[Image(data=torch.from_numpy(fake_image))], + imgs_info=[fake_image_info], + masks=[fake_masks], + scores=[], ) return single_data_entity, batch_pred_data_entity, batch_data_entity diff --git a/tests/integration/api/test_xai.py b/tests/integration/api/test_xai.py index 5702ceabe1c..63e3bc4890c 100644 --- a/tests/integration/api/test_xai.py +++ b/tests/integration/api/test_xai.py @@ -6,7 +6,7 @@ import numpy as np import openvino.runtime as ov import pytest -from otx.core.data.entity.base import OTXBatchPredEntity, OTXBatchPredEntityWithXAI +from otx.core.data.entity.base import OTXBatchPredEntity from otx.engine import Engine RECIPE_LIST_ALL = pytest.RECIPE_LIST @@ -57,7 +57,8 @@ def test_forward_explain( assert isinstance(predict_result[0], OTXBatchPredEntity) predict_result_explain = engine.predict(explain=True) - assert isinstance(predict_result_explain[0], OTXBatchPredEntityWithXAI) + assert isinstance(predict_result_explain[0], OTXBatchPredEntity) + assert predict_result_explain[0].has_xai_outputs batch_size = len(predict_result[0].scores) for i in range(batch_size): @@ -106,7 +107,8 @@ def test_predict_with_explain( # Predict with explain torch & process maps predict_result_explain_torch = engine.predict(explain=True) - assert isinstance(predict_result_explain_torch[0], OTXBatchPredEntityWithXAI) + assert isinstance(predict_result_explain_torch[0], OTXBatchPredEntity) + assert predict_result_explain_torch[0].has_xai_outputs assert predict_result_explain_torch[0].saliency_maps is not None assert isinstance(predict_result_explain_torch[0].saliency_maps[0], dict) @@ -134,7 +136,8 @@ def test_predict_with_explain( # Predict OV model with xai & process maps predict_result_explain_ov = engine.predict(checkpoint=exported_model_path, explain=True) - assert isinstance(predict_result_explain_ov[0], OTXBatchPredEntityWithXAI) + assert isinstance(predict_result_explain_ov[0], OTXBatchPredEntity) + assert predict_result_explain_ov[0].has_xai_outputs assert predict_result_explain_ov[0].saliency_maps is not None assert isinstance(predict_result_explain_ov[0].saliency_maps[0], dict) assert predict_result_explain_ov[0].feature_vectors is not None diff --git a/tests/unit/algo/hooks/test_saliency_map_dumping.py b/tests/unit/algo/hooks/test_saliency_map_dumping.py index 28c9d3254cf..3f790981d02 100644 --- a/tests/unit/algo/hooks/test_saliency_map_dumping.py +++ b/tests/unit/algo/hooks/test_saliency_map_dumping.py @@ -8,7 +8,7 @@ from otx.algo.utils.xai_utils import dump_saliency_maps from otx.core.config.explain import ExplainConfig from otx.core.data.entity.base import ImageInfo -from otx.core.data.entity.classification import MulticlassClsBatchPredEntityWithXAI +from otx.core.data.entity.classification import MulticlassClsBatchPredEntity from otx.core.types.task import OTXTaskType from otx.engine.utils.auto_configurator import AutoConfigurator @@ -30,7 +30,7 @@ def test_sal_map_dump( datamodule = auto_configurator.get_datamodule() predict_result = [ - MulticlassClsBatchPredEntityWithXAI( + MulticlassClsBatchPredEntity( batch_size=BATCH_SIZE, images=None, imgs_info=IMGS_INFO, diff --git a/tests/unit/algo/hooks/test_saliency_map_processing.py b/tests/unit/algo/hooks/test_saliency_map_processing.py index 628649925b0..fdc8f2739b5 100644 --- a/tests/unit/algo/hooks/test_saliency_map_processing.py +++ b/tests/unit/algo/hooks/test_saliency_map_processing.py @@ -6,7 +6,7 @@ from otx.algo.utils.xai_utils import process_saliency_maps, process_saliency_maps_in_pred_entity from otx.core.config.explain import ExplainConfig from otx.core.data.entity.base import ImageInfo -from otx.core.data.entity.classification import MulticlassClsBatchPredEntityWithXAI, MultilabelClsBatchPredEntityWithXAI +from otx.core.data.entity.classification import MulticlassClsBatchPredEntity, MultilabelClsBatchPredEntity from otx.core.types.explain import TargetExplainGroup NUM_CLASSES = 5 @@ -100,8 +100,8 @@ def test_process_image(postprocess) -> None: assert all(s_map_dict["map_per_image"].shape == (RAW_SIZE, RAW_SIZE) for s_map_dict in processed_saliency_maps) -def _get_pred_result_multiclass(pred_labels) -> MulticlassClsBatchPredEntityWithXAI: - return MulticlassClsBatchPredEntityWithXAI( +def _get_pred_result_multiclass(pred_labels) -> MulticlassClsBatchPredEntity: + return MulticlassClsBatchPredEntity( batch_size=BATCH_SIZE, images=None, imgs_info=IMGS_INFO, @@ -112,8 +112,8 @@ def _get_pred_result_multiclass(pred_labels) -> MulticlassClsBatchPredEntityWith ) -def _get_pred_result_multilabel(pred_labels) -> MultilabelClsBatchPredEntityWithXAI: - return MultilabelClsBatchPredEntityWithXAI( +def _get_pred_result_multilabel(pred_labels) -> MultilabelClsBatchPredEntity: + return MultilabelClsBatchPredEntity( batch_size=BATCH_SIZE, images=None, imgs_info=IMGS_INFO,