Skip to content

Commit

Permalink
Refactor XAI data entities (#3230)
Browse files Browse the repository at this point in the history
* Refactor XAI data entities

* Fix tests

* Fix test errors

Signed-off-by: Kim, Vinnam <[email protected]>

---------

Signed-off-by: Kim, Vinnam <[email protected]>
  • Loading branch information
vinnamkim authored Mar 29, 2024
1 parent 375e89e commit f1f55e2
Show file tree
Hide file tree
Showing 29 changed files with 306 additions and 404 deletions.
10 changes: 3 additions & 7 deletions src/otx/algo/classification/torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
from torchvision.models import get_model, get_model_weights

from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.classification import (
MulticlassClsBatchDataEntity,
MulticlassClsBatchPredEntity,
MulticlassClsBatchPredEntityWithXAI,
)
from otx.core.data.entity.classification import MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity
from otx.core.metrics.accuracy import MultiClassClsMetricCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.classification import OTXMulticlassClsModel
Expand Down Expand Up @@ -225,7 +221,7 @@ def _customize_outputs(
self,
outputs: Any, # noqa: ANN401
inputs: MulticlassClsBatchDataEntity,
) -> MulticlassClsBatchPredEntity | MulticlassClsBatchPredEntityWithXAI | OTXBatchLossEntity:
) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity:
if self.training:
return OTXBatchLossEntity(loss=outputs)

Expand All @@ -241,7 +237,7 @@ def _customize_outputs(

saliency_maps = outputs["saliency_map"].detach().cpu().numpy()

return MulticlassClsBatchPredEntityWithXAI(
return MulticlassClsBatchPredEntity(
batch_size=len(preds),
images=inputs.images,
imgs_info=inputs.imgs_info,
Expand Down
20 changes: 10 additions & 10 deletions src/otx/algo/utils/xai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from datumaro import Image

from otx.core.config.explain import ExplainConfig
from otx.core.data.entity.base import OTXBatchPredEntityWithXAI
from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntityWithXAI
from otx.core.data.entity.base import OTXBatchPredEntity
from otx.core.types.explain import TargetExplainGroup

if TYPE_CHECKING:
Expand All @@ -23,22 +22,23 @@


def process_saliency_maps_in_pred_entity(
predict_result: list[OTXBatchPredEntityWithXAI | InstanceSegBatchPredEntityWithXAI | Any],
predict_result: list[OTXBatchPredEntity],
explain_config: ExplainConfig,
) -> list[Any] | list[OTXBatchPredEntityWithXAI | InstanceSegBatchPredEntityWithXAI]:
) -> list[OTXBatchPredEntity]:
"""Process saliency maps in PredEntity."""
for predict_result_per_batch in predict_result:

def _process(predict_result_per_batch: OTXBatchPredEntity) -> OTXBatchPredEntity:
saliency_maps = predict_result_per_batch.saliency_maps
imgs_info = predict_result_per_batch.imgs_info
ori_img_shapes = [img_info.ori_shape for img_info in imgs_info]
pred_labels = predict_result_per_batch.labels # type: ignore[union-attr]
if pred_labels:
if pred_labels := getattr(predict_result_per_batch, "labels", None):
pred_labels = [pred.tolist() for pred in pred_labels]

processed_saliency_maps = process_saliency_maps(saliency_maps, explain_config, pred_labels, ori_img_shapes)

predict_result_per_batch.saliency_maps = processed_saliency_maps
return predict_result
return predict_result_per_batch.wrap(saliency_maps=processed_saliency_maps)

return [_process(predict_result_per_batch) for predict_result_per_batch in predict_result]


def process_saliency_maps(
Expand Down Expand Up @@ -116,7 +116,7 @@ def postprocess(saliency_map: np.ndarray, output_size: tuple[int, int] | None) -


def dump_saliency_maps(
predict_result: list[OTXBatchPredEntityWithXAI | InstanceSegBatchPredEntityWithXAI | Any],
predict_result: list[OTXBatchPredEntity],
explain_config: ExplainConfig,
datamodule: EVAL_DATALOADERS | OTXDataModule,
output_dir: Path,
Expand Down
13 changes: 7 additions & 6 deletions src/otx/algo/visual_prompting/zero_shot_segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,12 +863,13 @@ def preprocess(self, x: Image) -> Image:

def transforms(self, entity: ZeroShotVisualPromptingBatchDataEntity) -> ZeroShotVisualPromptingBatchDataEntity:
"""Transforms for ZeroShotVisualPromptingBatchDataEntity."""
entity.images = [self.preprocess(self.apply_image(image)) for image in entity.images]
entity.prompts = [
self.apply_prompts(prompt, info.ori_shape, self.model.image_size)
for prompt, info in zip(entity.prompts, entity.imgs_info)
]
return entity
return entity.wrap(
images=[self.preprocess(self.apply_image(image)) for image in entity.images],
prompts=[
self.apply_prompts(prompt, info.ori_shape, self.model.image_size)
for prompt, info in zip(entity.prompts, entity.imgs_info)
],
)

def initialize_reference_info(self) -> None:
"""Initialize reference information."""
Expand Down
7 changes: 5 additions & 2 deletions src/otx/core/data/dataset/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,12 @@ def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None:
)
transformed_entity = self._apply_transforms(entity)

if transformed_entity is None:
msg = "This is not allowed."
raise RuntimeError(msg)

# insert masks to transformed_entity
transformed_entity.masks = masks # type: ignore[union-attr]
return transformed_entity
return transformed_entity.wrap(masks=masks)

@property
def collate_fn(self) -> Callable:
Expand Down
20 changes: 3 additions & 17 deletions src/otx/core/data/entity/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
from otx.core.data.entity.base import (
OTXBatchDataEntity,
OTXBatchPredEntity,
OTXBatchPredEntityWithXAI,
OTXDataEntity,
OTXPredEntity,
OTXPredEntityWithXAI,
)
from otx.core.data.entity.utils import register_pytree_node
from otx.core.types.task import OTXTaskType
Expand Down Expand Up @@ -51,15 +49,10 @@ def task(self) -> OTXTaskType:


@dataclass
class ActionClsPredEntity(ActionClsDataEntity, OTXPredEntity):
class ActionClsPredEntity(OTXPredEntity, ActionClsDataEntity):
"""Data entity to represent the action classification model's output prediction."""


@dataclass
class ActionClsPredEntityWithXAI(ActionClsDataEntity, OTXPredEntityWithXAI):
"""Data entity to represent the detection model output prediction with explanations."""


@dataclass
class ActionClsBatchDataEntity(OTXBatchDataEntity[ActionClsDataEntity]):
"""Batch data entity for action classification.
Expand Down Expand Up @@ -92,16 +85,9 @@ def collate_fn(

def pin_memory(self) -> ActionClsBatchDataEntity:
"""Pin memory for member tensor variables."""
super().pin_memory()
self.labels = [label.pin_memory() for label in self.labels]
return self
return super().pin_memory().wrap(labels=[label.pin_memory() for label in self.labels])


@dataclass
class ActionClsBatchPredEntity(ActionClsBatchDataEntity, OTXBatchPredEntity):
class ActionClsBatchPredEntity(OTXBatchPredEntity, ActionClsBatchDataEntity):
"""Data entity to represent model output predictions for action classification task."""


@dataclass
class ActionClsBatchPredEntityWithXAI(ActionClsBatchDataEntity, OTXBatchPredEntityWithXAI):
"""Data entity to represent model output predictions for multi-class classification task with explanations."""
24 changes: 11 additions & 13 deletions src/otx/core/data/entity/action_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from otx.core.data.entity.base import (
OTXBatchDataEntity,
OTXBatchPredEntity,
OTXBatchPredEntityWithXAI,
OTXDataEntity,
OTXPredEntity,
)
Expand Down Expand Up @@ -48,7 +47,7 @@ def task(self) -> OTXTaskType:


@dataclass
class ActionDetPredEntity(ActionDetDataEntity, OTXPredEntity):
class ActionDetPredEntity(OTXPredEntity, ActionDetDataEntity):
"""Data entity to represent the action classification model's output prediction."""


Expand Down Expand Up @@ -89,18 +88,17 @@ def collate_fn(

def pin_memory(self) -> ActionDetBatchDataEntity:
"""Pin memory for member tensor variables."""
super().pin_memory()
self.bboxes = [tv_tensors.wrap(bbox.pin_memory(), like=bbox) for bbox in self.bboxes]
self.labels = [label.pin_memory() for label in self.labels]
self.proposals = [tv_tensors.wrap(proposal.pin_memory(), like=proposal) for proposal in self.proposals]
return self
return (
super()
.pin_memory()
.wrap(
bboxes=[tv_tensors.wrap(bbox.pin_memory(), like=bbox) for bbox in self.bboxes],
labels=[label.pin_memory() for label in self.labels],
proposals=[tv_tensors.wrap(proposal.pin_memory(), like=proposal) for proposal in self.proposals],
)
)


@dataclass
class ActionDetBatchPredEntity(ActionDetBatchDataEntity, OTXBatchPredEntity):
class ActionDetBatchPredEntity(OTXBatchPredEntity, ActionDetBatchDataEntity):
"""Data entity to represent model output predictions for action classification task."""


@dataclass
class ActionDetBatchPredEntityWithXAI(ActionDetBatchDataEntity, OTXBatchPredEntityWithXAI):
"""Data entity to represent model output predictions for multi-class classification task with explanations."""
10 changes: 4 additions & 6 deletions src/otx/core/data/entity/anomaly/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,16 @@ def collate_fn(

def pin_memory(self) -> AnomalyClassificationDataBatch:
"""Pin memory for member tensor variables."""
super().pin_memory()
self.labels = [label.pin_memory() for label in self.labels]
return self
return super().pin_memory().wrap(labels=[label.pin_memory() for label in self.labels])


@dataclass
class AnomalyClassificationPrediction(AnomalyClassificationDataItem, OTXPredEntity):
class AnomalyClassificationPrediction(OTXPredEntity, AnomalyClassificationDataItem):
"""Anomaly classification Prediction item."""


@dataclass
class AnomalyClassificationBatchPrediction(AnomalyClassificationDataBatch, OTXBatchPredEntity):
@dataclass(kw_only=True)
class AnomalyClassificationBatchPrediction(OTXBatchPredEntity, AnomalyClassificationDataBatch):
"""Anomaly classification batch prediction."""

anomaly_maps: torch.Tensor
20 changes: 12 additions & 8 deletions src/otx/core/data/entity/anomaly/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,24 @@ def collate_fn(

def pin_memory(self) -> AnomalyDetectionDataBatch:
"""Pin memory for member tensor variables."""
super().pin_memory()
self.labels = [label.pin_memory() for label in self.labels]
self.masks = self.masks.pin_memory()
self.boxes = [box.pin_memory() for box in self.boxes]
return self
return (
super()
.pin_memory()
.wrap(
labels=[label.pin_memory() for label in self.labels],
masks=self.masks.pin_memory(),
boxes=[box.pin_memory() for box in self.boxes],
)
)


@dataclass
class AnomalyDetectionPrediction(AnomalyDetectionDataItem, OTXPredEntity):
class AnomalyDetectionPrediction(OTXPredEntity, AnomalyDetectionDataItem):
"""Anomaly Detection Prediction item."""


@dataclass
class AnomalyDetectionBatchPrediction(AnomalyDetectionDataBatch, OTXBatchPredEntity):
@dataclass(kw_only=True)
class AnomalyDetectionBatchPrediction(OTXBatchPredEntity, AnomalyDetectionDataBatch):
"""Anomaly classification batch prediction."""

anomaly_maps: torch.Tensor
Expand Down
18 changes: 11 additions & 7 deletions src/otx/core/data/entity/anomaly/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,23 @@ def collate_fn(

def pin_memory(self) -> AnomalySegmentationDataBatch:
"""Pin memory for member tensor variables."""
super().pin_memory()
self.labels = [label.pin_memory() for label in self.labels]
self.masks = self.masks.pin_memory()
return self
return (
super()
.pin_memory()
.wrap(
labels=[label.pin_memory() for label in self.labels],
masks=self.masks.pin_memory(),
)
)


@dataclass
class AnomalySegmentationPrediction(AnomalySegmentationDataItem, OTXPredEntity):
class AnomalySegmentationPrediction(OTXPredEntity, AnomalySegmentationDataItem):
"""Anomaly Segmentation Prediction item."""


@dataclass
class AnomalySegmentationBatchPrediction(AnomalySegmentationDataBatch, OTXBatchPredEntity):
@dataclass(kw_only=True)
class AnomalySegmentationBatchPrediction(OTXBatchPredEntity, AnomalySegmentationDataBatch):
"""Anomaly classification batch prediction."""

anomaly_maps: torch.Tensor
Loading

0 comments on commit f1f55e2

Please sign in to comment.