Skip to content

Commit

Permalink
v2.1: XAI clean-up & XAI unit tests (#3303)
Browse files Browse the repository at this point in the history
* XAI clean-up & XAI unit tests

* Update

* Add tests

* Pre-commit
  • Loading branch information
GalyaZalesskaya authored Apr 12, 2024
1 parent babae3b commit a03c294
Show file tree
Hide file tree
Showing 16 changed files with 314 additions and 188 deletions.
6 changes: 3 additions & 3 deletions src/otx/algo/classification/deit_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from mmpretrain.models.utils import resize_pos_embed

from otx.algo.hooks.recording_forward_hook import ViTReciproCAMHook
from otx.algo.explain.explain_algo import ViTReciproCAM
from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable
Expand All @@ -33,7 +33,7 @@


class ForwardExplainMixInForDeit(ExplainableMixInMMPretrainModel):
"""Deit model which can attach a XAI hook."""
"""Deit model which can attach a XAI (Explainable AI) branch."""

@torch.no_grad()
def head_forward_fn(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -133,7 +133,7 @@ def _forward_explain_image_classifier(

def get_explain_fn(self) -> Callable:
"""Returns explain function."""
explainer = ViTReciproCAMHook(
explainer = ViTReciproCAM(
self.head_forward_fn,
num_classes=self.num_classes,
)
Expand Down
4 changes: 2 additions & 2 deletions src/otx/algo/classification/torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch import nn
from torchvision.models import get_model, get_model_weights

from otx.algo.hooks.recording_forward_hook import ReciproCAMHook
from otx.algo.explain.explain_algo import ReciproCAM
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.classification import MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity
from otx.core.exporter.base import OTXModelExporter
Expand Down Expand Up @@ -139,7 +139,7 @@ def __init__(
self.softmax = nn.Softmax(dim=-1)
self.loss = loss

self.explainer = ReciproCAMHook(
self.explainer = ReciproCAM(
self._head_forward_fn,
num_classes=num_classes,
optimize_gap=True,
Expand Down
4 changes: 4 additions & 0 deletions src/otx/algo/explain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Module for OTX XAI algorithms."""
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Hooks for recording/updating model internal activations."""
"""Algorithms for calculcalating XAI branch for Explainable AI."""

from __future__ import annotations

from typing import TYPE_CHECKING, Callable

import numpy as np
import torch

from otx.core.types.explain import FeatureMapType

if TYPE_CHECKING:
import numpy as np
from mmengine.structures.instance_data import InstanceData
from torch.utils.hooks import RemovableHandle

HeadForwardFn = Callable[[FeatureMapType], torch.Tensor]
ExplainerForwardFn = HeadForwardFn
Expand All @@ -34,7 +33,7 @@ def get_feature_vector(feature_map: FeatureMapType) -> torch.Tensor:
return torch.nn.functional.adaptive_avg_pool2d(feature_map, (1, 1)).flatten(start_dim=1)


class BaseRecordingForwardHook:
class BaseExplainAlgo:
"""While registered with the designated PyTorch module, this class caches feature vector during forward pass.
Args:
Expand All @@ -43,19 +42,8 @@ class BaseRecordingForwardHook:

def __init__(self, head_forward_fn: HeadForwardFn | None = None, normalize: bool = True) -> None:
self._head_forward_fn = head_forward_fn
self.handle: RemovableHandle | None = None
self._records: list[torch.Tensor] = []
self._norm_saliency_maps = normalize

@property
def records(self) -> list[torch.Tensor]:
"""Return records."""
return self._records

def reset(self) -> None:
"""Clear all history of records."""
self._records.clear()

def func(self, feature_map: torch.Tensor, fpn_idx: int = -1) -> torch.Tensor:
"""This method get the feature vector or saliency map from the output of the module.
Expand All @@ -69,25 +57,6 @@ def func(self, feature_map: torch.Tensor, fpn_idx: int = -1) -> torch.Tensor:
"""
raise NotImplementedError

def recording_forward(
self,
_: torch.nn.Module,
x: torch.Tensor,
output: torch.Tensor,
) -> None: # pylint: disable=unused-argument
"""Record the XAI result during executing model forward function."""
tensors = self.func(output)
if isinstance(tensors, torch.Tensor):
tensors_np = tensors.detach().cpu().numpy()
elif isinstance(tensors, np.ndarray):
tensors_np = tensors
else:
self._torch_to_numpy_from_list(tensors)
tensors_np = tensors

for tensor in tensors_np:
self._records.append(tensor)

def _predict_from_feature_map(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
if self._head_forward_fn:
Expand All @@ -96,14 +65,6 @@ def _predict_from_feature_map(self, x: torch.Tensor) -> torch.Tensor:
x = torch.tensor(x)
return x

def _torch_to_numpy_from_list(self, tensor_list: list[torch.Tensor | None]) -> None:
for i in range(len(tensor_list)):
tensor = tensor_list[i]
if isinstance(tensor, list):
self._torch_to_numpy_from_list(tensor)
elif isinstance(tensor, torch.Tensor):
tensor_list[i] = tensor.detach().cpu().numpy()

@staticmethod
def _normalize_map(saliency_map: torch.Tensor) -> torch.Tensor:
"""Normalize saliency maps."""
Expand All @@ -116,18 +77,8 @@ def _normalize_map(saliency_map: torch.Tensor) -> torch.Tensor:
return saliency_map.to(torch.uint8)


class ActivationMapHook(BaseRecordingForwardHook):
"""ActivationMapHook. Mean of the feature map along the channel dimension."""

@classmethod
def create_and_register_hook(
cls,
backbone: torch.nn.Module,
) -> BaseRecordingForwardHook:
"""Create this object and register it to the module forward hook."""
hook = cls()
hook.handle = backbone.register_forward_hook(hook.recording_forward)
return hook
class ActivationMap(BaseExplainAlgo):
"""ActivationMap. Mean of the feature map along the channel dimension."""

def func(self, feature_map: FeatureMapType, fpn_idx: int = -1) -> torch.Tensor:
"""Generate the saliency map by average feature maps then normalizing to (0, 255)."""
Expand All @@ -144,7 +95,7 @@ def func(self, feature_map: FeatureMapType, fpn_idx: int = -1) -> torch.Tensor:
return activation_map.reshape((batch_size, h, w))


class ReciproCAMHook(BaseRecordingForwardHook):
class ReciproCAM(BaseExplainAlgo):
"""Implementation of Recipro-CAM for class-wise saliency map.
Recipro-CAM: gradient-free reciprocal class activation map (https://arxiv.org/pdf/2209.14074.pdf)
Expand All @@ -161,23 +112,6 @@ def __init__(
self._num_classes = num_classes
self._optimize_gap = optimize_gap

@classmethod
def create_and_register_hook(
cls,
backbone: torch.nn.Module,
head_forward_fn: HeadForwardFn,
num_classes: int,
optimize_gap: bool,
) -> BaseRecordingForwardHook:
"""Create this object and register it to the module forward hook."""
hook = cls(
head_forward_fn,
num_classes=num_classes,
optimize_gap=optimize_gap,
)
hook.handle = backbone.register_forward_hook(hook.recording_forward)
return hook

def func(self, feature_map: FeatureMapType, fpn_idx: int = -1) -> torch.Tensor:
"""Generate the class-wise saliency maps using Recipro-CAM and then normalizing to (0, 255).
Expand Down Expand Up @@ -226,7 +160,7 @@ def _get_mosaic_feature_map(self, feature_map: torch.Tensor, c: int, h: int, w:
return mosaic_feature_map


class ViTReciproCAMHook(BaseRecordingForwardHook):
class ViTReciproCAM(BaseExplainAlgo):
"""Implementation of ViTRecipro-CAM for class-wise saliency map for transformer-based classifiers.
Args:
Expand All @@ -251,21 +185,6 @@ def __init__(
self._use_gaussian = use_gaussian
self._cls_token = cls_token

@classmethod
def create_and_register_hook(
cls,
target_layernorm: torch.nn.Module,
head_forward_fn: HeadForwardFn,
num_classes: int,
) -> BaseRecordingForwardHook:
"""Create this object and register it to the module forward hook."""
hook = cls(
head_forward_fn,
num_classes=num_classes,
)
hook.handle = target_layernorm.register_forward_hook(hook.recording_forward)
return hook

def func(self, feature_map: torch.Tensor, _: int = -1) -> torch.Tensor:
"""Generate the class-wise saliency maps using ViTRecipro-CAM and then normalizing to (0, 255).
Expand Down Expand Up @@ -328,8 +247,8 @@ def _get_mosaic_feature_map(self, feature_map: torch.Tensor) -> torch.Tensor:
return mosaic_feature_map


class DetClassProbabilityMapHook(BaseRecordingForwardHook):
"""Saliency map hook for object detection models."""
class DetClassProbabilityMap(BaseExplainAlgo):
"""Saliency map generation algo for object detection models."""

def __init__(
self,
Expand Down Expand Up @@ -392,8 +311,8 @@ def func(
return saliency_map.reshape((batch_size, self._num_classes, height, width))


class MaskRCNNRecordingForwardHook(BaseRecordingForwardHook):
"""Dummy saliency map hook for Mask R-CNN model."""
class MaskRCNNExplainAlgo(BaseExplainAlgo):
"""Dummy saliency map algo for Mask R-CNN model."""

def __init__(self, num_classes: int) -> None:
super().__init__()
Expand Down
4 changes: 0 additions & 4 deletions src/otx/algo/hooks/__init__.py

This file was deleted.

9 changes: 4 additions & 5 deletions src/otx/core/model/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from openvino.model_api.tilers import DetectionTiler
from torchvision import tv_tensors

from otx.algo.explain.explain_algo import get_feature_vector
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity
Expand Down Expand Up @@ -167,15 +168,13 @@ def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwa


class ExplainableOTXDetModel(OTXDetectionModel):
"""OTX detection model which can attach a XAI hook."""
"""OTX detection model which can attach a XAI (Explainable AI) branch."""

def forward_explain(
self,
inputs: DetBatchDataEntity,
) -> DetBatchPredEntity:
"""Model forward function."""
from otx.algo.hooks.recording_forward_hook import get_feature_vector

self.model.feature_vector_fn = get_feature_vector
self.model.explain_fn = self.get_explain_fn()

Expand Down Expand Up @@ -239,11 +238,11 @@ def _forward_explain_detection(
def get_explain_fn(self) -> Callable:
"""Returns explain function."""
from otx.algo.detection.heads.custom_ssd_head import SSDHead
from otx.algo.hooks.recording_forward_hook import DetClassProbabilityMapHook
from otx.algo.explain.explain_algo import DetClassProbabilityMap

# SSD-like heads also have background class
background_class = isinstance(self.model.bbox_head, SSDHead)
explainer = DetClassProbabilityMapHook(
explainer = DetClassProbabilityMap(
num_classes=self.num_classes + background_class,
num_anchors=self.get_num_anchors(),
)
Expand Down
9 changes: 3 additions & 6 deletions src/otx/core/model/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from openvino.model_api.tilers import InstanceSegmentationTiler
from torchvision import tv_tensors

from otx.algo.explain.explain_algo import MaskRCNNExplainAlgo, get_feature_vector
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.instance_segmentation import InstanceSegBatchDataEntity, InstanceSegBatchPredEntity
Expand Down Expand Up @@ -206,15 +207,13 @@ def _convert_pred_entity_to_compute_metric(


class ExplainableOTXInstanceSegModel(OTXInstanceSegModel):
"""OTX Instance Segmentation model which can attach a XAI hook."""
"""OTX Instance Segmentation model which can attach a XAI (Explainable AI) branch."""

def forward_explain(
self,
inputs: InstanceSegBatchDataEntity,
) -> InstanceSegBatchPredEntity:
"""Model forward function."""
from otx.algo.hooks.recording_forward_hook import get_feature_vector

self.model.feature_vector_fn = get_feature_vector
self.model.explain_fn = self.get_explain_fn()

Expand Down Expand Up @@ -272,9 +271,7 @@ def _forward_explain_inst_seg(

def get_explain_fn(self) -> Callable:
"""Returns explain function."""
from otx.algo.hooks.recording_forward_hook import MaskRCNNRecordingForwardHook

explainer = MaskRCNNRecordingForwardHook(num_classes=self.num_classes)
explainer = MaskRCNNExplainAlgo(num_classes=self.num_classes)
return explainer.func

def _reset_model_forward(self) -> None:
Expand Down
15 changes: 0 additions & 15 deletions src/otx/core/model/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from __future__ import annotations

import copy
import json
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -166,20 +165,6 @@ def _customize_outputs(
raise TypeError(output)
masks.append(output.pred_sem_seg.data)

if hasattr(self, "explain_hook"):
hook_records = self.explain_hook.records
explain_results = copy.deepcopy(hook_records[-len(outputs) :])

return SegBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=[],
masks=masks,
saliency_map=explain_results,
feature_vector=[],
)

return SegBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
Expand Down
6 changes: 2 additions & 4 deletions src/otx/core/model/utils/mmpretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mmpretrain.models.utils import ClsDataPreprocessor as _ClsDataPreprocessor
from mmpretrain.registry import MODELS

from otx.algo.hooks.recording_forward_hook import get_feature_vector
from otx.algo.explain.explain_algo import ReciproCAM, get_feature_vector
from otx.core.data.entity.base import T_OTXBatchDataEntity, T_OTXBatchPredEntity
from otx.core.utils.build import build_mm_model, get_classification_layers

Expand Down Expand Up @@ -134,9 +134,7 @@ def get_explain_fn(self) -> Callable:
Note:
Can be redefined at the model's level.
"""
from otx.algo.hooks.recording_forward_hook import ReciproCAMHook

explainer = ReciproCAMHook(
explainer = ReciproCAM(
self.head_forward_fn,
num_classes=self.num_classes,
optimize_gap=self.has_gap,
Expand Down
Loading

0 comments on commit a03c294

Please sign in to comment.