Skip to content

Commit

Permalink
Forward_explain & Export with XAI: Detection & InstSeg (#3079)
Browse files Browse the repository at this point in the history
* Merge draft export XAI version from negvet

* Draft implementation for detection

* Code prettifier for detection

* Remove extra changes

* Remove extra changes

* Forward_explain & export with XAI for InstSeg

* Fix comments

* Minor

* Merge changes with tests

* Update CLI tests

* Update tests

* Comments

* Draft

* Draft

* Fix pre-commit

* Fix pre-commit

* Fix tests

* Fix tests

* Add feature vector for det and InstSeg(#1)

* add fv for det

* update det and add isegm

* minor

* minor

* Fix tests

* Minor

* Fix output name order (#2)

* Disable tests for speed

* Disable tests for speed

* Restore changes

* Disable tests that cause undetermined failures  for ATSS and Mask RCNN

* Disable rtmdet_inst_tiny XAI tests

* Minor

* Fixes from comments

---------

Co-authored-by: Evgeny Tsykunov <[email protected]>
  • Loading branch information
GalyaZalesskaya and negvet authored Mar 14, 2024
1 parent 7b7de3c commit d7509e7
Show file tree
Hide file tree
Showing 10 changed files with 389 additions and 175 deletions.
75 changes: 21 additions & 54 deletions src/otx/algo/hooks/recording_forward_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
import numpy as np
import torch

from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity, InstanceSegBatchPredEntityWithXAI

if TYPE_CHECKING:
from mmengine.structures.instance_data import InstanceData
from torch.utils.hooks import RemovableHandle


Expand Down Expand Up @@ -333,54 +332,33 @@ class DetClassProbabilityMapHook(BaseRecordingForwardHook):

def __init__(
self,
cls_head_forward_fn: Callable,
num_classes: int,
num_anchors: list[int],
normalize: bool = True,
use_cls_softmax: bool = True,
) -> None:
super().__init__(cls_head_forward_fn, normalize)
super().__init__(head_forward_fn=None, normalize=normalize)
# SSD-like heads also have background class
self._num_classes = num_classes
self._num_anchors = num_anchors
# Should be switched off for tiling
self.use_cls_softmax = use_cls_softmax

@classmethod
def create_and_register_hook(
cls,
backbone: torch.nn.Module,
cls_head_forward_fn: Callable,
num_classes: int,
num_anchors: list[int],
) -> BaseRecordingForwardHook:
"""Create this object and register it to the module forward hook."""
hook = cls(
cls_head_forward_fn,
num_classes=num_classes,
num_anchors=num_anchors,
)
hook.handle = backbone.register_forward_hook(hook.recording_forward)
return hook

def func(
self,
feature_map: torch.Tensor | Sequence[torch.Tensor],
cls_scores: torch.Tensor | Sequence[torch.Tensor],
_: int = -1,
) -> torch.Tensor:
"""Generate the saliency map from raw classification head output, then normalizing to (0, 255).
Args:
feature_map (torch.Tensor | Sequence[torch.Tensor]): Feature maps from backbone/FPN or
classification scores from cls_head.
cls_scores (torch.Tensor | Sequence[torch.Tensor]): Classification scores from cls_head.
Returns:
torch.Tensor: Class-wise Saliency Maps. One saliency map per each class - [batch, class_id, H, W]
"""
cls_scores = self._head_forward_fn(feature_map) if self._head_forward_fn else feature_map

middle_idx = len(cls_scores) // 2
# resize to the middle feature map
# Resize to the middle feature map
batch_size, _, height, width = cls_scores[middle_idx].size()
saliency_maps = torch.empty(batch_size, self._num_classes, height, width)
for batch_idx in range(batch_size):
Expand Down Expand Up @@ -420,55 +398,45 @@ def __init__(self, num_classes: int) -> None:
super().__init__()
self.num_classes = num_classes

@classmethod
def create_and_register_hook(cls, num_classes: int) -> BaseRecordingForwardHook:
"""Create this object and register it to the module forward hook."""
return cls(num_classes)

def func(
self,
preds: list[InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI],
predictions: list[InstanceData],
_: int = -1,
) -> list[np.array]:
"""Generate saliency maps from predicted masks by averaging and normalizing them per-class.
Args:
preds (List[InstanceSegBatchPredEntity]): Predictions of Instance Segmentation model.
predictions (list[InstanceData]): Predictions of Instance Segmentation model.
Returns:
list[np.array]: Class-wise Saliency Maps. One saliency map per each class - [batch, class_id, H, W]
torch.Tensor: Class-wise Saliency Maps. One saliency map per each class - [batch, class_id, H, W]
"""
# TODO(gzalessk): Add unit tests # noqa: TD003
batch_size = len(preds)
batch_saliency_maps = list(range(batch_size))

for batch, pred in enumerate(preds):
class_averaged_masks = self.average_and_normalize(pred, self.num_classes)
batch_saliency_maps[batch] = class_averaged_masks
return batch_saliency_maps
batch_saliency_maps = []
for prediction in predictions:
class_averaged_masks = self.average_and_normalize(prediction, self.num_classes)
batch_saliency_maps.append(class_averaged_masks)
return torch.stack(batch_saliency_maps)

@classmethod
def average_and_normalize(
cls,
pred: InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI,
pred: InstanceData,
num_classes: int,
) -> np.array:
"""Average and normalize masks in prediction per-class.
Args:
preds (InstanceSegBatchPredEntity): Predictions of Instance Segmentation model.
preds (InstanceData): Predictions of Instance Segmentation model.
num_classes (int): Num classes that model can predict.
Returns:
np.array: Class-wise Saliency Maps. One saliency map per each class - [batch, class_id, H, W]
np.array: Class-wise Saliency Maps. One saliency map per each class - [class_id, H, W]
"""
_, height, width = pred.masks[0].data.shape
masks, scores, labels = (
pred.masks[0].data,
pred.scores[0].data,
pred.labels[0].data,
)
saliency_maps = torch.zeros((num_classes, height, width), dtype=torch.float32)
masks, scores, labels = (pred.masks, pred.scores, pred.labels)
_, height, width = masks.shape

saliency_maps = torch.zeros((num_classes, height, width), dtype=torch.float32, device=labels.device)
class_objects = [0 for _ in range(num_classes)]

for confidence, class_ind, raw_mask in zip(scores, labels, masks):
Expand All @@ -482,6 +450,5 @@ def average_and_normalize(

saliency_maps = saliency_maps.reshape((num_classes, -1))
saliency_maps = cls._normalize_map(saliency_maps)
saliency_maps = saliency_maps.reshape(num_classes, height, width)

return saliency_maps.numpy()
return saliency_maps.reshape(num_classes, height, width)
8 changes: 7 additions & 1 deletion src/otx/core/exporter/mmdeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class MMdeployExporter(OTXModelExporter):
pad_value (int, optional): Padding value. Defaults to 0.
swap_rgb (bool, optional): Whether to convert the image from BGR to RGB Defaults to False.
max_num_detections (int, optional): Maximum number of detections per image. Defaults to 0.
output_names (list[str], optional): Additional names for the output nodes in addition to ones in "ir_config" .
"""

def __init__(
Expand All @@ -63,14 +64,19 @@ def __init__(
swap_rgb: bool = False,
metadata: dict[tuple[str, str], str] | None = None,
max_num_detections: int = 0,
output_names: list[str] | None = None,
) -> None:
super().__init__(input_size, mean, std, resize_mode, pad_value, swap_rgb, metadata)
super().__init__(input_size, mean, std, resize_mode, pad_value, swap_rgb, metadata, output_names)
self._model_builder = model_builder
model_cfg = convert_conf_to_mmconfig_dict(model_cfg, "list")
self._model_cfg = MMConfig({"model": model_cfg, "test_pipeline": list(map(to_tuple, test_pipeline))})
self._deploy_cfg = deploy_cfg if isinstance(deploy_cfg, MMConfig) else load_mmconfig_from_pkg(deploy_cfg)

patch_input_shape(self._deploy_cfg, input_size[3], input_size[2])
if output_names is not None:
self._deploy_cfg.ir_config.output_names.extend(output_names)
self.output_names = self._deploy_cfg.ir_config.output_names

if max_num_detections > 0:
self._set_max_num_detections(max_num_detections)

Expand Down
10 changes: 5 additions & 5 deletions src/otx/core/model/entity/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ def forward_explain(
outputs = (
self._forward_explain_image_classifier(self.model, **self._customize_inputs(inputs))
if self._customize_inputs != ExplainableOTXClsModel._customize_inputs
else self.model(inputs)
else self._forward_explain_image_classifier(self.model, inputs)
)

return (
self._customize_outputs(outputs, inputs)
if self._customize_outputs != ExplainableOTXClsModel._customize_outputs
else outputs
else outputs["predictions"]
)

@staticmethod
Expand All @@ -102,8 +102,8 @@ def _forward_explain_image_classifier(
inputs: torch.Tensor,
data_samples: list[DataSample] | None = None,
mode: str = "tensor",
) -> dict:
"""Forward func of the ImageClassifier instance, which located in is in OTXModel().model."""
) -> dict[str, torch.Tensor]:
"""Forward func of the ImageClassifier instance, which located in ExplainableOTXClsModel().model."""
x = self.backbone(inputs)
backbone_feat = x

Expand Down Expand Up @@ -246,7 +246,7 @@ def _customize_inputs(self, entity: MulticlassClsBatchDataEntity) -> dict[str, A

def _customize_outputs(
self,
outputs: Any, # noqa: ANN401
outputs: dict[str, Any],
inputs: MulticlassClsBatchDataEntity,
) -> MulticlassClsBatchPredEntity | MulticlassClsBatchPredEntityWithXAI | OTXBatchLossEntity:
from mmpretrain.structures import DataSample
Expand Down
Loading

0 comments on commit d7509e7

Please sign in to comment.