Skip to content

Commit

Permalink
Cam addition (#7)
Browse files Browse the repository at this point in the history
* Initialised cam parts

* Added new class for CAM

* Fixed licence link

* Several fixes

 * Deleted unused libraries in LRP
 * Deleted unused libraries in Occlusion
 * Changes mock data paths
 * Initialised tests

* Updated requirements and setup

* Added coverage

* Added tests

* Added cam notebooks

* Updated README

* Updated badges in README
  • Loading branch information
stavrostheocharis authored Feb 12, 2024
1 parent ccafc68 commit 4a6864f
Show file tree
Hide file tree
Showing 21 changed files with 2,959 additions and 16 deletions.
14 changes: 14 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[run]
source = easy_explain
omit =
*/tests/*
*/venv/*

[report]
exclude_lines =
pragma: no cover
def __repr__
if self.debug:
raise AssertionError
raise NotImplementedError
if __name__ == .__main__.:
27 changes: 25 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
[![GitHub][github_badge]][github_link]
[![PyPI][pypi_badge]][pypi_link]
[![Download][download_badge]][download_link]
[![Download][total_download_badge]][download_link]
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Licence][licence_badge]][licence_link]

Expand Down Expand Up @@ -51,6 +52,7 @@ There are also other customade algorithms to support other models like the LRP i
Currently, `easy-explain` specializes in two cutting-edge XAI methodologies for images:

- Occlusion: For deep insight into classification model decisions.
- Cam: SmoothGradCAMpp & LayerCAM for explainability on image classification models.
- Layer-wise Relevance Propagation (LRP): Specifically tailored for YoloV8 models, unveiling the decision-making process in object detection tasks.

## Quick Start
Expand Down Expand Up @@ -83,6 +85,23 @@ explanation_lrp = lrp.explain(image, cls='your-class', contrastive=False).cpu()
lrp.plot_explanation(frame=image, explanation = explanation_lrp, contrastive=True, cmap='seismic', title='Explanation for your class"')
```

```python
from easy_explain import YOLOv8LRP

model = 'your-model'
image = 'your-image'

trans_params = {"ImageNet_transformation":
{"Resize": {"h": 224,"w": 224},
"Normalize": {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}}}

explainer = CAMExplain(model)

input_tensor = explainer.transform_image(img, trans_params["ImageNet_transformation"])

explainer.generate_explanation(img, input_tensor, multiple_layers=["a_layer", "another_layer", "another_layer"])
```

For more information about how to begin have a look at the [examples notebooks](https://github.com/stavrostheocharis/easy_explain/tree/main/examples).

## Examples
Expand All @@ -95,6 +114,8 @@ Explore how `easy-explain` can be applied in various scenarios:

![Use Case Example](easy_explain/images/siberian-positive.png "Use Case Example")

![Use Case Example](easy_explain/images/jiraffe-cam-method.png "Use Case Example")

![Use Case Example](easy_explain/images/class-traffic.png "Use Case Example")

## How to contribute?
Expand All @@ -118,10 +139,12 @@ Join us in making AI models more interpretable, transparent, and trustworthy wit

[pypi_link]: https://pypi.org/project/easy-explain/

[download_badge]: https://badgen.net/pypi/dm/easy-explain
[download_badge]: https://static.pepy.tech/personalized-badge/easy-explain?period=month&units=international_system&left_color=grey&right_color=green&left_text=Monthly%20Downloads

[total_download_badge]: https://static.pepy.tech/personalized-badge/easy-explain?period=total&units=international_system&left_color=grey&right_color=green&left_text=Total%20Downloads

[download_link]: https://pypi.org/project/easy-explain/#files

[licence_badge]: https://img.shields.io/github/license/stavrostheocharis/easy-explain
[licence_badge]: https://img.shields.io/github/license/stavrostheocharis/easy_explain

[licence_link]: LICENSE
4 changes: 2 additions & 2 deletions easy_explain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .methods import YOLOv8LRP, OcclusionExplain
from .methods import YOLOv8LRP, OcclusionExplain, CAMExplain

__all__ = ["YOLOv8LRP", "OcclusionExplain"]
__all__ = ["YOLOv8LRP", "OcclusionExplain", "CAMExplain"]
Binary file added easy_explain/images/jiraffe-cam-method.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion easy_explain/methods/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .lrp import YOLOv8LRP
from .occlusion import OcclusionExplain
from .cam import CAMExplain

__all__ = ["YOLOv8LRP", "OcclusionExplain"]
__all__ = ["YOLOv8LRP", "OcclusionExplain", "CAMExplain"]
3 changes: 3 additions & 0 deletions easy_explain/methods/cam/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .cam import CAMExplain

__all__ = ["CAMExplain"]
225 changes: 225 additions & 0 deletions easy_explain/methods/cam/cam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import torch
from torchcam.methods import SmoothGradCAMpp, LayerCAM
from torchcam.utils import overlay_mask
from torchvision import transforms
import matplotlib.pyplot as plt
from typing import List, Optional, Dict, Any
import logging
from easy_explain.methods.xai_base import ExplainabilityMethod


class CAMExplain(ExplainabilityMethod):
def __init__(self, model: torch.nn.Module):
self.model = model
logging.basicConfig(level=logging.INFO)

def transform_image(
self,
img: torch.Tensor,
trans_params: Dict[str, Dict[str, Any]],
) -> torch.Tensor:
"""
Transforms an image using specified resizing and normalization parameters.
Args:
img (Image.Image): The image to transform.
trans_params (Dict[str, Dict[str, Any]]): Parameters for resizing and normalization.
Returns:
torch.Tensor: The transformed image tensor.
"""
try:
resize_params = trans_params["Resize"]
normalize_params = trans_params["Normalize"]
input_tensor = transforms.functional.normalize(
transforms.functional.resize(
img, (resize_params["h"], resize_params["w"])
)
/ 255.0,
normalize_params["mean"],
normalize_params["std"],
)
return input_tensor

except Exception as e:
logging.error(f"Error transforming image: {e}")
raise

def get_multiple_layers_result(
self,
img: torch.Tensor,
input_tensor: torch.Tensor,
layers: List[str],
alpha: float,
):
"""
Visualizes CAMs for multiple layers and their fused result.
Args:
img (torch.Tensor): The original image tensor.
input_tensor (torch.Tensor): The tensor to input to the model.
layers (List[str]): List of layer names to visualize CAMs for.
alpha (float): Alpha value for blending CAMs on the original image.
"""
try:
# Retrieve the CAM from several layers at the same time
cam_extractor = LayerCAM(self.model, layers)
# Preprocess your data and feed it to the model
output = self.model(input_tensor.unsqueeze(0))
# Retrieve the CAM by passing the class index and the model output
cams = cam_extractor(output.squeeze(0).argmax().item(), output)
logging.info("Successfully retrieved CAMs for multiple layers")

cam_per_layer_list = []
# Get the cam per target layer provided
for cam in cams:
cam_per_layer_list.append(cam.shape)

logging.info(f"The cams per target layer are: {cam_per_layer_list}")

# Raw CAM
_, axes = plt.subplots(1, len(cam_extractor.target_names))
for id, name, cam in zip(
range(len(cam_extractor.target_names)), cam_extractor.target_names, cams
):
axes[id].imshow(cam.squeeze(0).numpy())
axes[id].axis("off")
axes[id].set_title(name)
plt.show()

fused_cam = cam_extractor.fuse_cams(cams)
# Plot the raw version
plt.imshow(fused_cam.squeeze(0).numpy())
plt.axis("off")
plt.title(" + ".join(cam_extractor.target_names))
plt.show()
# Plot the overlayed version
result = overlay_mask(
transforms.functional.to_pil_image(img),
transforms.functional.to_pil_image(fused_cam, mode="F"),
alpha=alpha,
)
plt.imshow(result)
plt.axis("off")
plt.title(" + ".join(cam_extractor.target_names))
plt.show()
cam_extractor.remove_hooks()

except Exception as e:
logging.error(f"Error retrieving CAMs for multiple layers: {e}")
raise

def get_localisation_mask(self, input_tensor: torch.Tensor, img: torch.Tensor):
"""
Generates and visualizes localization masks based on CAMs.
Args:
input_tensor (torch.Tensor): The tensor input to the model.
img (torch.Tensor): The original image tensor.
"""
try:
# Retrieve CAM for differnet layers at the same time
cam_extractor = LayerCAM(self.model)
output = self.model(input_tensor.unsqueeze(0))
cams = cam_extractor(output.squeeze(0).argmax().item(), output)

# Transformations
resized_cams = [
transforms.functional.resize(
transforms.functional.to_pil_image(cam.squeeze(0)), img.shape[-2:]
)
for cam in cams
]
segmaps = [
transforms.functional.to_pil_image(
(
transforms.functional.resize(cam, img.shape[-2:]).squeeze(0)
>= 0.5
).to(dtype=torch.float32)
)
for cam in cams
]

# Plots
for name, cam, seg in zip(
cam_extractor.target_names, resized_cams, segmaps
):
_, axes = plt.subplots(1, 2)
axes[0].imshow(cam)
axes[0].axis("off")
axes[0].set_title(name)
axes[1].imshow(seg)
axes[1].axis("off")
axes[1].set_title(name)
plt.show()
cam_extractor.remove_hooks()

except Exception as e:
logging.error(f"Error generating localization masks: {e}")
raise

def generate_explanation(
self,
img: torch.Tensor,
input_tensor: torch.Tensor,
target_layer: Optional[str] = None,
localisation_mask: bool = True,
multiple_layers: List[str] = [],
alpha=0.5,
):
"""
Extracts and visualizes CAMs for a target layer or multiple layers.
Args:
img (torch.Tensor): The original image tensor.
input_tensor (torch.Tensor): The tensor input to the model.
target_layer (Optional[str]): The target layer for CAM visualization.
localisation_mask (bool): Whether to generate localization masks.
multiple_layers (List[str]): Layers for multi-layer CAM visualization.
alpha (float): Alpha value for blending CAMs on the original image.
"""
try:
cam_extractor = SmoothGradCAMpp(self.model, target_layer=target_layer)
output = self.model(input_tensor.unsqueeze(0))
# Get the CAM giving the class index and output
cams = cam_extractor(output.squeeze(0).argmax().item(), output)

cam_per_layer_list = []
# Get the cam per target layer provided
for cam in cams:
cam_per_layer_list.append(cam.shape)

logging.info(f"The cams per target layer are: {cam_per_layer_list}")

# The raw CAM
for name, cam in zip(cam_extractor.target_names, cams):
plt.imshow(cam.squeeze(0).numpy())
plt.axis("off")
plt.title(name)
plt.show()

# Overlayed on the image
for name, cam in zip(cam_extractor.target_names, cams):
result = overlay_mask(
transforms.functional.to_pil_image(img),
transforms.functional.to_pil_image(cam.squeeze(0), mode="F"),
alpha=alpha,
)
plt.imshow(result)
plt.axis("off")
plt.title(name)
plt.show()

cam_extractor.remove_hooks()

if localisation_mask:
self.get_localisation_mask(input_tensor, img)

if len(multiple_layers) > 0:
self.get_multiple_layers_result(
img, input_tensor, multiple_layers, alpha
)

except Exception as e:
logging.error(f"Error extracting CAM: {e}")
raise
2 changes: 0 additions & 2 deletions easy_explain/methods/lrp/yolov8/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from scipy.ndimage import zoom
from sklearn.model_selection import train_test_split
import numpy as np


Expand Down Expand Up @@ -60,7 +59,6 @@ def scale_mask(mask, shape):


class LayerRelevance(torch.Tensor):

"""
LayerRelevance(relevance=None, contrastive=False, print_decimals=5)
Expand Down
5 changes: 2 additions & 3 deletions easy_explain/methods/occlusion/occlusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from captum.attr import visualization as viz
from captum.attr import Occlusion
import json
from typing import Union, List, Dict, Any
import itertools
from easy_explain.methods.occlusion.xai_base import ExplainabilityMethod
from typing import Union, List, Dict
from easy_explain.methods.xai_base import ExplainabilityMethod


class OcclusionExplain(ExplainabilityMethod):
Expand Down
File renamed without changes.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/data/nam-anh-QJbyG6O0ick-unsplash.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 4a6864f

Please sign in to comment.