From dd55c7a86cf48f958b637498881f15ddebda1bcd Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 30 Dec 2024 08:55:00 -0800 Subject: [PATCH] Fix layer LRP pyre fixme issues (#1474) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1474 Differential Revision: D67706680 --- captum/attr/_core/layer/layer_lrp.py | 81 ++++++++++++++-------------- captum/attr/_utils/lrp_rules.py | 8 ++- 2 files changed, 47 insertions(+), 42 deletions(-) diff --git a/captum/attr/_core/layer/layer_lrp.py b/captum/attr/_core/layer/layer_lrp.py index ba6a73d701..3621c118ac 100644 --- a/captum/attr/_core/layer/layer_lrp.py +++ b/captum/attr/_core/layer/layer_lrp.py @@ -2,7 +2,9 @@ # pyre-strict import typing -from typing import Any, cast, List, Literal, Optional, Tuple, Union +from typing import cast, Dict, List, Literal, Optional, Tuple, TypeVar, Union + +import torch from captum._utils.common import ( _format_tensor_into_tuples, @@ -21,8 +23,12 @@ ) from captum.attr._core.lrp import LRP from captum.attr._utils.attribution import LayerAttribution +from captum.attr._utils.lrp_rules import PropagationRule from torch import Tensor from torch.nn import Module +from torch.utils.hooks import RemovableHandle + +Generic = TypeVar("Generic") class LayerLRP(LRP, LayerAttribution): @@ -39,6 +45,13 @@ class LayerLRP(LRP, LayerAttribution): Ancona et al. [https://openreview.net/forum?id=Sy21R9JAW]. """ + device_ids: List[int] + verbose: bool + layers: List[Module] + attribute_to_layer_input: bool = False + backward_handles: List[RemovableHandle] + forward_handles: List[RemovableHandle] + def __init__(self, model: Module, layer: ModuleOrModuleList) -> None: """ Args: @@ -59,7 +72,6 @@ def __init__(self, model: Module, layer: ModuleOrModuleList) -> None: LayerAttribution.__init__(self, model, layer) LRP.__init__(self, model) if hasattr(self.model, "device_ids"): - # pyre-fixme[4]: Attribute must be annotated. self.device_ids = cast(List[int], self.model.device_ids) @typing.overload # type: ignore @@ -208,48 +220,34 @@ def attribute( >>> attribution = layer_lrp.attribute(input, target=5) """ - # pyre-fixme[16]: `LayerLRP` has no attribute `verbose`. self.verbose = verbose - # pyre-fixme[16]: `LayerLRP` has no attribute `_original_state_dict`. self._original_state_dict = self.model.state_dict() - # pyre-fixme[16]: `LayerLRP` has no attribute `layers`. self.layers = [] self._get_layers(self.model) self._check_and_attach_rules() - # pyre-fixme[16]: `LayerLRP` has no attribute `attribute_to_layer_input`. self.attribute_to_layer_input = attribute_to_layer_input - # pyre-fixme[16]: `LayerLRP` has no attribute `backward_handles`. self.backward_handles = [] - # pyre-fixme[16]: `LayerLRP` has no attribute `forward_handles`. self.forward_handles = [] - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs = _format_tensor_into_tuples(inputs) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - gradient_mask = apply_gradient_requirements(inputs) + inputs_tuple = _format_tensor_into_tuples(inputs) + gradient_mask = apply_gradient_requirements(inputs_tuple) try: # 1. Forward pass output = self._compute_output_and_change_weights( - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but - # got `TensorOrTupleOfTensorsGeneric`. - inputs, + inputs_tuple, target, additional_forward_args, ) self._register_forward_hooks() # 2. Forward pass + backward pass _ = compute_gradients( - self._forward_fn_wrapper, inputs, target, additional_forward_args + self._forward_fn_wrapper, inputs_tuple, target, additional_forward_args ) relevances = self._get_output_relevance(output) finally: self._restore_model() - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - undo_gradient_requirements(inputs, gradient_mask) + undo_gradient_requirements(inputs_tuple, gradient_mask) if return_convergence_delta: delta: Union[Tensor, List[Tensor]] @@ -257,7 +255,10 @@ def attribute( delta = [] for relevance_layer in relevances: delta.append( - self.compute_convergence_delta(relevance_layer, output) + self.compute_convergence_delta( + cast(Union[Tensor, Tuple[Tensor, ...]], relevance_layer), + output, + ) ) else: delta = self.compute_convergence_delta( @@ -267,33 +268,35 @@ def attribute( else: return relevances # type: ignore - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def _get_single_output_relevance(self, layer, output): - # pyre-fixme[16]: `LayerLRP` has no attribute `attribute_to_layer_input`. + def _get_single_output_relevance( + self, layer: Module, output: Tensor + ) -> Union[Tensor, Tuple[Tensor, ...]]: if self.attribute_to_layer_input: - normalized_relevances = layer.rule.relevance_input + normalized_relevances = cast( + Dict[torch.device, Tensor], + cast(PropagationRule, layer.rule).relevance_input, + ) else: - normalized_relevances = layer.rule.relevance_output + normalized_relevances = cast(PropagationRule, layer.rule).relevance_output key_list = _sort_key_list(list(normalized_relevances.keys()), self.device_ids) - normalized_relevances = _reduce_list( + normalized_relevances_reduced = _reduce_list( [normalized_relevances[device_id] for device_id in key_list] ) - if isinstance(normalized_relevances, tuple): + if isinstance(normalized_relevances_reduced, tuple): return tuple( normalized_relevance * output.reshape((-1,) + (1,) * (normalized_relevance.dim() - 1)) - for normalized_relevance in normalized_relevances + for normalized_relevance in normalized_relevances_reduced ) else: - return normalized_relevances * output.reshape( - (-1,) + (1,) * (normalized_relevances.dim() - 1) + return normalized_relevances_reduced * output.reshape( + (-1,) + (1,) * (normalized_relevances_reduced.dim() - 1) ) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def _get_output_relevance(self, output): + def _get_output_relevance( + self, output: Tensor + ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: if isinstance(self.layer, list): relevances = [] for layer in self.layer: @@ -303,11 +306,9 @@ def _get_output_relevance(self, output): return self._get_single_output_relevance(self.layer, output) @staticmethod - # pyre-fixme[3]: Return annotation cannot contain `Any`. def _convert_list_to_tuple( - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - relevances: Union[List[Any], Tuple[Any, ...]] - ) -> Tuple[Any, ...]: + relevances: Union[List[Generic], Tuple[Generic, ...]] + ) -> Tuple[Generic, ...]: if isinstance(relevances, list): return tuple(relevances) else: diff --git a/captum/attr/_utils/lrp_rules.py b/captum/attr/_utils/lrp_rules.py index 2dd8dc4fe8..91761c226c 100644 --- a/captum/attr/_utils/lrp_rules.py +++ b/captum/attr/_utils/lrp_rules.py @@ -3,10 +3,11 @@ # pyre-strict from abc import ABC, abstractmethod +from typing import cast, Dict, List, Union import torch - from captum._utils.common import _format_tensor_into_tuples +from torch import Tensor class PropagationRule(ABC): @@ -15,6 +16,9 @@ class PropagationRule(ABC): STABILITY_FACTOR is used to assure that no zero divison occurs. """ + relevance_input: Dict[torch.device, Union[torch.Tensor, List[torch.Tensor]]] = {} + relevance_output: Dict[torch.device, torch.Tensor] = {} + STABILITY_FACTOR = 1e-9 # pyre-fixme[3]: Return type must be annotated. @@ -67,7 +71,7 @@ def _backward_hook_input(grad): # pyre-fixme[16]: `PropagationRule` has no attribute `relevance_input`. self.relevance_input[device] = relevance.data else: - self.relevance_input[device].append(relevance.data) + cast(List[Tensor], self.relevance_input[device]).append(relevance.data) # replace_out is needed since two hooks are set on the same tensor # The output of this hook is needed in backward_hook_activation