Skip to content

Commit

Permalink
Fix layer LRP pyre fixme issues (#1474)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #1474

Differential Revision: D67706680
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 30, 2024
1 parent c922793 commit dd55c7a
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 42 deletions.
81 changes: 41 additions & 40 deletions captum/attr/_core/layer/layer_lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

# pyre-strict
import typing
from typing import Any, cast, List, Literal, Optional, Tuple, Union
from typing import cast, Dict, List, Literal, Optional, Tuple, TypeVar, Union

import torch

from captum._utils.common import (
_format_tensor_into_tuples,
Expand All @@ -21,8 +23,12 @@
)
from captum.attr._core.lrp import LRP
from captum.attr._utils.attribution import LayerAttribution
from captum.attr._utils.lrp_rules import PropagationRule
from torch import Tensor
from torch.nn import Module
from torch.utils.hooks import RemovableHandle

Generic = TypeVar("Generic")


class LayerLRP(LRP, LayerAttribution):
Expand All @@ -39,6 +45,13 @@ class LayerLRP(LRP, LayerAttribution):
Ancona et al. [https://openreview.net/forum?id=Sy21R9JAW].
"""

device_ids: List[int]
verbose: bool
layers: List[Module]
attribute_to_layer_input: bool = False
backward_handles: List[RemovableHandle]
forward_handles: List[RemovableHandle]

def __init__(self, model: Module, layer: ModuleOrModuleList) -> None:
"""
Args:
Expand All @@ -59,7 +72,6 @@ def __init__(self, model: Module, layer: ModuleOrModuleList) -> None:
LayerAttribution.__init__(self, model, layer)
LRP.__init__(self, model)
if hasattr(self.model, "device_ids"):
# pyre-fixme[4]: Attribute must be annotated.
self.device_ids = cast(List[int], self.model.device_ids)

@typing.overload # type: ignore
Expand Down Expand Up @@ -208,56 +220,45 @@ def attribute(
>>> attribution = layer_lrp.attribute(input, target=5)
"""
# pyre-fixme[16]: `LayerLRP` has no attribute `verbose`.
self.verbose = verbose
# pyre-fixme[16]: `LayerLRP` has no attribute `_original_state_dict`.
self._original_state_dict = self.model.state_dict()
# pyre-fixme[16]: `LayerLRP` has no attribute `layers`.
self.layers = []
self._get_layers(self.model)
self._check_and_attach_rules()
# pyre-fixme[16]: `LayerLRP` has no attribute `attribute_to_layer_input`.
self.attribute_to_layer_input = attribute_to_layer_input
# pyre-fixme[16]: `LayerLRP` has no attribute `backward_handles`.
self.backward_handles = []
# pyre-fixme[16]: `LayerLRP` has no attribute `forward_handles`.
self.forward_handles = []

# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs = _format_tensor_into_tuples(inputs)
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
gradient_mask = apply_gradient_requirements(inputs)
inputs_tuple = _format_tensor_into_tuples(inputs)
gradient_mask = apply_gradient_requirements(inputs_tuple)

try:
# 1. Forward pass
output = self._compute_output_and_change_weights(
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
# got `TensorOrTupleOfTensorsGeneric`.
inputs,
inputs_tuple,
target,
additional_forward_args,
)
self._register_forward_hooks()
# 2. Forward pass + backward pass
_ = compute_gradients(
self._forward_fn_wrapper, inputs, target, additional_forward_args
self._forward_fn_wrapper, inputs_tuple, target, additional_forward_args
)
relevances = self._get_output_relevance(output)
finally:
self._restore_model()
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
undo_gradient_requirements(inputs, gradient_mask)
undo_gradient_requirements(inputs_tuple, gradient_mask)

if return_convergence_delta:
delta: Union[Tensor, List[Tensor]]
if isinstance(self.layer, list):
delta = []
for relevance_layer in relevances:
delta.append(
self.compute_convergence_delta(relevance_layer, output)
self.compute_convergence_delta(
cast(Union[Tensor, Tuple[Tensor, ...]], relevance_layer),
output,
)
)
else:
delta = self.compute_convergence_delta(
Expand All @@ -267,33 +268,35 @@ def attribute(
else:
return relevances # type: ignore

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _get_single_output_relevance(self, layer, output):
# pyre-fixme[16]: `LayerLRP` has no attribute `attribute_to_layer_input`.
def _get_single_output_relevance(
self, layer: Module, output: Tensor
) -> Union[Tensor, Tuple[Tensor, ...]]:
if self.attribute_to_layer_input:
normalized_relevances = layer.rule.relevance_input
normalized_relevances = cast(
Dict[torch.device, Tensor],
cast(PropagationRule, layer.rule).relevance_input,
)
else:
normalized_relevances = layer.rule.relevance_output
normalized_relevances = cast(PropagationRule, layer.rule).relevance_output
key_list = _sort_key_list(list(normalized_relevances.keys()), self.device_ids)
normalized_relevances = _reduce_list(
normalized_relevances_reduced = _reduce_list(
[normalized_relevances[device_id] for device_id in key_list]
)

if isinstance(normalized_relevances, tuple):
if isinstance(normalized_relevances_reduced, tuple):
return tuple(
normalized_relevance
* output.reshape((-1,) + (1,) * (normalized_relevance.dim() - 1))
for normalized_relevance in normalized_relevances
for normalized_relevance in normalized_relevances_reduced
)
else:
return normalized_relevances * output.reshape(
(-1,) + (1,) * (normalized_relevances.dim() - 1)
return normalized_relevances_reduced * output.reshape(
(-1,) + (1,) * (normalized_relevances_reduced.dim() - 1)
)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _get_output_relevance(self, output):
def _get_output_relevance(
self, output: Tensor
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
if isinstance(self.layer, list):
relevances = []
for layer in self.layer:
Expand All @@ -303,11 +306,9 @@ def _get_output_relevance(self, output):
return self._get_single_output_relevance(self.layer, output)

@staticmethod
# pyre-fixme[3]: Return annotation cannot contain `Any`.
def _convert_list_to_tuple(
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
relevances: Union[List[Any], Tuple[Any, ...]]
) -> Tuple[Any, ...]:
relevances: Union[List[Generic], Tuple[Generic, ...]]
) -> Tuple[Generic, ...]:
if isinstance(relevances, list):
return tuple(relevances)
else:
Expand Down
8 changes: 6 additions & 2 deletions captum/attr/_utils/lrp_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
# pyre-strict

from abc import ABC, abstractmethod
from typing import cast, Dict, List, Union

import torch

from captum._utils.common import _format_tensor_into_tuples
from torch import Tensor


class PropagationRule(ABC):
Expand All @@ -15,6 +16,9 @@ class PropagationRule(ABC):
STABILITY_FACTOR is used to assure that no zero divison occurs.
"""

relevance_input: Dict[torch.device, Union[torch.Tensor, List[torch.Tensor]]] = {}
relevance_output: Dict[torch.device, torch.Tensor] = {}

STABILITY_FACTOR = 1e-9

# pyre-fixme[3]: Return type must be annotated.
Expand Down Expand Up @@ -67,7 +71,7 @@ def _backward_hook_input(grad):
# pyre-fixme[16]: `PropagationRule` has no attribute `relevance_input`.
self.relevance_input[device] = relevance.data
else:
self.relevance_input[device].append(relevance.data)
cast(List[Tensor], self.relevance_input[device]).append(relevance.data)

# replace_out is needed since two hooks are set on the same tensor
# The output of this hook is needed in backward_hook_activation
Expand Down

0 comments on commit dd55c7a

Please sign in to comment.