Skip to content

Commit

Permalink
XAI clean-up & XAI unit tests (#3161)
Browse files Browse the repository at this point in the history
* Rename hooks& add unit tests

* Clean up

* Fix pre-commit
  • Loading branch information
GalyaZalesskaya authored Mar 20, 2024
1 parent 10f66e8 commit dc68dcb
Show file tree
Hide file tree
Showing 13 changed files with 311 additions and 156 deletions.
6 changes: 3 additions & 3 deletions src/otx/algo/classification/deit_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from mmpretrain.models.utils import resize_pos_embed

from otx.algo.hooks.recording_forward_hook import ViTReciproCAMHook
from otx.algo.explain.explain_algo import ViTReciproCAM
from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.classification import (
Expand All @@ -26,7 +26,7 @@


class ExplainableDeit(ExplainableOTXClsModel):
"""Deit model which can attach a XAI hook."""
"""Deit model which can attach a XAI (Explainable AI) branch."""

@torch.no_grad()
def head_forward_fn(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -126,7 +126,7 @@ def _forward_explain_image_classifier(

def get_explain_fn(self) -> Callable:
"""Returns explain function."""
explainer = ViTReciproCAMHook(
explainer = ViTReciproCAM(
self.head_forward_fn,
num_classes=self.num_classes,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Module for OTX custom hooks."""
"""Module for OTX XAI algorithms."""
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Hooks for recording/updating model internal activations."""
"""Algorithms for calculcalating XAI branch for Explainable AI."""

from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Sequence

import numpy as np
import torch

if TYPE_CHECKING:
import numpy as np
from mmengine.structures.instance_data import InstanceData
from torch.utils.hooks import RemovableHandle


def feature_vector_fn(feature_map: torch.Tensor | Sequence[torch.Tensor]) -> torch.Tensor:
Expand All @@ -31,7 +30,7 @@ def feature_vector_fn(feature_map: torch.Tensor | Sequence[torch.Tensor]) -> tor
return torch.nn.functional.adaptive_avg_pool2d(feature_map, (1, 1)).flatten(start_dim=1)


class BaseRecordingForwardHook:
class BaseExplainAlgo:
"""While registered with the designated PyTorch module, this class caches feature vector during forward pass.
Args:
Expand All @@ -40,19 +39,8 @@ class BaseRecordingForwardHook:

def __init__(self, head_forward_fn: Callable | None = None, normalize: bool = True) -> None:
self._head_forward_fn = head_forward_fn
self.handle: RemovableHandle | None = None
self._records: list[torch.Tensor] = []
self._norm_saliency_maps = normalize

@property
def records(self) -> list[torch.Tensor]:
"""Return records."""
return self._records

def reset(self) -> None:
"""Clear all history of records."""
self._records.clear()

def func(self, feature_map: torch.Tensor, fpn_idx: int = -1) -> torch.Tensor:
"""This method get the feature vector or saliency map from the output of the module.
Expand All @@ -66,25 +54,6 @@ def func(self, feature_map: torch.Tensor, fpn_idx: int = -1) -> torch.Tensor:
"""
raise NotImplementedError

def recording_forward(
self,
_: torch.nn.Module,
x: torch.Tensor,
output: torch.Tensor,
) -> None: # pylint: disable=unused-argument
"""Record the XAI result during executing model forward function."""
tensors = self.func(output)
if isinstance(tensors, torch.Tensor):
tensors_np = tensors.detach().cpu().numpy()
elif isinstance(tensors, np.ndarray):
tensors_np = tensors
else:
self._torch_to_numpy_from_list(tensors)
tensors_np = tensors

for tensor in tensors_np:
self._records.append(tensor)

def _predict_from_feature_map(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
if self._head_forward_fn:
Expand All @@ -93,14 +62,6 @@ def _predict_from_feature_map(self, x: torch.Tensor) -> torch.Tensor:
x = torch.tensor(x)
return x

def _torch_to_numpy_from_list(self, tensor_list: list[torch.Tensor | None]) -> None:
for i in range(len(tensor_list)):
tensor = tensor_list[i]
if isinstance(tensor, list):
self._torch_to_numpy_from_list(tensor)
elif isinstance(tensor, torch.Tensor):
tensor_list[i] = tensor.detach().cpu().numpy()

@staticmethod
def _normalize_map(saliency_maps: torch.Tensor) -> torch.Tensor:
"""Normalize saliency maps."""
Expand All @@ -115,18 +76,8 @@ def _normalize_map(saliency_maps: torch.Tensor) -> torch.Tensor:
return saliency_maps.to(torch.uint8)


class ActivationMapHook(BaseRecordingForwardHook):
"""ActivationMapHook. Mean of the feature map along the channel dimension."""

@classmethod
def create_and_register_hook(
cls,
backbone: torch.nn.Module,
) -> BaseRecordingForwardHook:
"""Create this object and register it to the module forward hook."""
hook = cls()
hook.handle = backbone.register_forward_hook(hook.recording_forward)
return hook
class ActivationMap(BaseExplainAlgo):
"""ActivationMap. Mean of the feature map along the channel dimension."""

def func(self, feature_map: torch.Tensor | Sequence[torch.Tensor], fpn_idx: int = -1) -> torch.Tensor:
"""Generate the saliency map by average feature maps then normalizing to (0, 255)."""
Expand All @@ -143,7 +94,7 @@ def func(self, feature_map: torch.Tensor | Sequence[torch.Tensor], fpn_idx: int
return activation_map.reshape((batch_size, h, w))


class ReciproCAMHook(BaseRecordingForwardHook):
class ReciproCAM(BaseExplainAlgo):
"""Implementation of Recipro-CAM for class-wise saliency map.
Recipro-CAM: gradient-free reciprocal class activation map (https://arxiv.org/pdf/2209.14074.pdf)
Expand All @@ -160,23 +111,6 @@ def __init__(
self._num_classes = num_classes
self._optimize_gap = optimize_gap

@classmethod
def create_and_register_hook(
cls,
backbone: torch.nn.Module,
head_forward_fn: Callable,
num_classes: int,
optimize_gap: bool,
) -> BaseRecordingForwardHook:
"""Create this object and register it to the module forward hook."""
hook = cls(
head_forward_fn,
num_classes=num_classes,
optimize_gap=optimize_gap,
)
hook.handle = backbone.register_forward_hook(hook.recording_forward)
return hook

def func(self, feature_map: torch.Tensor | Sequence[torch.Tensor], fpn_idx: int = -1) -> torch.Tensor:
"""Generate the class-wise saliency maps using Recipro-CAM and then normalizing to (0, 255).
Expand Down Expand Up @@ -225,7 +159,7 @@ def _get_mosaic_feature_map(self, feature_map: torch.Tensor, c: int, h: int, w:
return mosaic_feature_map


class ViTReciproCAMHook(BaseRecordingForwardHook):
class ViTReciproCAM(BaseExplainAlgo):
"""Implementation of ViTRecipro-CAM for class-wise saliency map for transformer-based classifiers.
Args:
Expand All @@ -250,21 +184,6 @@ def __init__(
self._use_gaussian = use_gaussian
self._cls_token = cls_token

@classmethod
def create_and_register_hook(
cls,
target_layernorm: torch.nn.Module,
head_forward_fn: Callable,
num_classes: int,
) -> BaseRecordingForwardHook:
"""Create this object and register it to the module forward hook."""
hook = cls(
head_forward_fn,
num_classes=num_classes,
)
hook.handle = target_layernorm.register_forward_hook(hook.recording_forward)
return hook

def func(self, feature_map: torch.Tensor, _: int = -1) -> torch.Tensor:
"""Generate the class-wise saliency maps using ViTRecipro-CAM and then normalizing to (0, 255).
Expand Down Expand Up @@ -327,8 +246,8 @@ def _get_mosaic_feature_map(self, feature_map: torch.Tensor) -> torch.Tensor:
return mosaic_feature_map


class DetClassProbabilityMapHook(BaseRecordingForwardHook):
"""Saliency map hook for object detection models."""
class DetClassProbabilityMap(BaseExplainAlgo):
"""Saliency map generation algo for object detection models."""

def __init__(
self,
Expand Down Expand Up @@ -391,8 +310,8 @@ def func(
return saliency_maps.reshape((batch_size, self._num_classes, height, width))


class MaskRCNNRecordingForwardHook(BaseRecordingForwardHook):
"""Dummy saliency map hook for Mask R-CNN model."""
class MaskRCNNExplainAlgo(BaseExplainAlgo):
"""Dummy saliency map algo for Mask R-CNN model."""

def __init__(self, num_classes: int) -> None:
super().__init__()
Expand Down
8 changes: 4 additions & 4 deletions src/otx/core/model/entity/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
import torch

from otx.algo.hooks.recording_forward_hook import feature_vector_fn
from otx.algo.explain.explain_algo import feature_vector_fn
from otx.core.data.dataset.classification import HLabelInfo
from otx.core.data.entity.base import (
OTXBatchLossEntity,
Expand Down Expand Up @@ -50,7 +50,7 @@
class ExplainableOTXClsModel(
OTXModel[T_OTXBatchDataEntity, T_OTXBatchPredEntity, T_OTXBatchPredEntityWithXAI, T_OTXTileBatchDataEntity],
):
"""OTX classification model which can attach a XAI hook."""
"""OTX classification model which can attach a XAI (Explainable AI) branch."""

@property
def has_gap(self) -> bool:
Expand Down Expand Up @@ -129,9 +129,9 @@ def _forward_explain_image_classifier(

def get_explain_fn(self) -> Callable:
"""Returns explain function."""
from otx.algo.hooks.recording_forward_hook import ReciproCAMHook
from otx.algo.explain.explain_algo import ReciproCAM

explainer = ReciproCAMHook(
explainer = ReciproCAM(
self.head_forward_fn,
num_classes=self.num_classes,
optimize_gap=self.has_gap,
Expand Down
8 changes: 4 additions & 4 deletions src/otx/core/model/entity/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ def _export_parameters(self) -> dict[str, Any]:


class ExplainableOTXDetModel(OTXDetectionModel):
"""OTX detection model which can attach a XAI hook."""
"""OTX detection model which can attach a XAI (Explainable AI) branch."""

def forward_explain(
self,
inputs: DetBatchDataEntity,
) -> DetBatchPredEntity | DetBatchPredEntityWithXAI | OTXBatchLossEntity:
"""Model forward function."""
from otx.algo.hooks.recording_forward_hook import feature_vector_fn
from otx.algo.explain.explain_algo import feature_vector_fn

self.model.feature_vector_fn = feature_vector_fn
self.model.explain_fn = self.get_explain_fn()
Expand Down Expand Up @@ -177,11 +177,11 @@ def _forward_explain_detection(
def get_explain_fn(self) -> Callable:
"""Returns explain function."""
from otx.algo.detection.heads.custom_ssd_head import CustomSSDHead
from otx.algo.hooks.recording_forward_hook import DetClassProbabilityMapHook
from otx.algo.explain.explain_algo import DetClassProbabilityMap

# SSD-like heads also have background class
background_class = isinstance(self.model.bbox_head, CustomSSDHead)
explainer = DetClassProbabilityMapHook(
explainer = DetClassProbabilityMap(
num_classes=self.num_classes + background_class,
num_anchors=self.get_num_anchors(),
)
Expand Down
6 changes: 3 additions & 3 deletions src/otx/core/model/entity/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from openvino.model_api.tilers import InstanceSegmentationTiler
from torchvision import tv_tensors

from otx.algo.hooks.recording_forward_hook import MaskRCNNRecordingForwardHook, feature_vector_fn
from otx.algo.explain.explain_algo import MaskRCNNExplainAlgo, feature_vector_fn
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import (
OTXBatchLossEntity,
Expand Down Expand Up @@ -132,7 +132,7 @@ def _export_parameters(self) -> dict[str, Any]:


class ExplainableOTXInstanceSegModel(OTXInstanceSegModel):
"""OTX Instance Segmentation model which can attach a XAI hook."""
"""OTX Instance Segmentation model which can attach a XAI (Explainable AI) branch."""

def forward_explain(
self,
Expand Down Expand Up @@ -196,7 +196,7 @@ def _forward_explain_inst_seg(

def get_explain_fn(self) -> Callable:
"""Returns explain function."""
explainer = MaskRCNNRecordingForwardHook(num_classes=self.num_classes)
explainer = MaskRCNNExplainAlgo(num_classes=self.num_classes)
return explainer.func

def _reset_model_forward(self) -> None:
Expand Down
Loading

0 comments on commit dc68dcb

Please sign in to comment.