From 0b6b588e8737e4947d14e4ee608ca1555ba4e6e3 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 21 Oct 2024 09:22:10 -0700 Subject: [PATCH] Fix pyre errors in LRP (#1401) Summary: Initial work on fixing Pyre errors in LRP Differential Revision: D64677351 --- captum/attr/_core/lrp.py | 72 ++++++++++------------------------------ 1 file changed, 18 insertions(+), 54 deletions(-) diff --git a/captum/attr/_core/lrp.py b/captum/attr/_core/lrp.py index 06c2fd5ae..d08b7b4de 100644 --- a/captum/attr/_core/lrp.py +++ b/captum/attr/_core/lrp.py @@ -4,7 +4,7 @@ import typing from collections import defaultdict -from typing import Any, Callable, cast, List, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Literal, Tuple, Union import torch.nn as nn from captum._utils.common import ( @@ -18,7 +18,7 @@ apply_gradient_requirements, undo_gradient_requirements, ) -from captum._utils.typing import Literal, TargetType, TensorOrTupleOfTensorsGeneric +from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import GradientAttribution from captum.attr._utils.common import _sum_rows from captum.attr._utils.custom_modules import Addition_Module @@ -43,6 +43,12 @@ class LRP(GradientAttribution): Ancona et al. [https://openreview.net/forum?id=Sy21R9JAW]. """ + verbose: bool = False + _original_state_dict: Dict[str, Any] = {} + layers: List[Module] = [] + backward_handles: List[RemovableHandle] = [] + forward_handles: List[RemovableHandle] = [] + def __init__(self, model: Module) -> None: r""" Args: @@ -62,33 +68,22 @@ def multiplies_by_inputs(self) -> bool: return True @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `75`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], verbose: bool = False, ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `65`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. + additional_forward_args: object = None, return_convergence_delta: Literal[False] = False, verbose: bool = False, ) -> TensorOrTupleOfTensorsGeneric: ... @@ -100,7 +95,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - additional_forward_args: Any = None, + additional_forward_args: object = None, return_convergence_delta: bool = False, verbose: bool = False, ) -> Union[ @@ -199,35 +194,22 @@ def attribute( >>> attribution = lrp.attribute(input, target=5) """ - # pyre-fixme[16]: `LRP` has no attribute `verbose`. self.verbose = verbose - # pyre-fixme[16]: `LRP` has no attribute `_original_state_dict`. self._original_state_dict = self.model.state_dict() - # pyre-fixme[16]: `LRP` has no attribute `layers`. - self.layers: List[Module] = [] + self.layers = [] self._get_layers(self.model) self._check_and_attach_rules() - # pyre-fixme[16]: `LRP` has no attribute `backward_handles`. self.backward_handles: List[RemovableHandle] = [] - # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. self.forward_handles: List[RemovableHandle] = [] - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs = _format_tensor_into_tuples(inputs) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - gradient_mask = apply_gradient_requirements(inputs) + input_tuple = _format_tensor_into_tuples(inputs) + gradient_mask = apply_gradient_requirements(input_tuple) try: # 1. Forward pass: Change weights of layers according to selected rules. output = self._compute_output_and_change_weights( - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but - # got `TensorOrTupleOfTensorsGeneric`. - inputs, + input_tuple, target, additional_forward_args, ) @@ -235,7 +217,7 @@ def attribute( # propagation and execute back-propagation. self._register_forward_hooks() normalized_relevances = self.gradient_func( - self._forward_fn_wrapper, inputs, target, additional_forward_args + self._forward_fn_wrapper, input_tuple, target, additional_forward_args ) relevances = tuple( normalized_relevance @@ -245,9 +227,7 @@ def attribute( finally: self._restore_model() - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - undo_gradient_requirements(inputs, gradient_mask) + undo_gradient_requirements(input_tuple, gradient_mask) if return_convergence_delta: # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen... @@ -310,13 +290,11 @@ def compute_convergence_delta( def _get_layers(self, model: Module) -> None: for layer in model.children(): if len(list(layer.children())) == 0: - # pyre-fixme[16]: `LRP` has no attribute `layers`. self.layers.append(layer) else: self._get_layers(layer) def _check_and_attach_rules(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if hasattr(layer, "rule"): layer.activations = {} # type: ignore @@ -355,50 +333,41 @@ def _check_rules(self) -> None: ) def _register_forward_hooks(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if type(layer) in SUPPORTED_NON_LINEAR_LAYERS: backward_handles = _register_backward_hook( layer, PropagationRule.backward_hook_activation, self ) - # pyre-fixme[16]: `LRP` has no attribute `backward_handles`. self.backward_handles.extend(backward_handles) else: forward_handle = layer.register_forward_hook( layer.rule.forward_hook # type: ignore ) - # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. self.forward_handles.append(forward_handle) - # pyre-fixme[16]: `LRP` has no attribute `verbose`. if self.verbose: print(f"Applied {layer.rule} on layer {layer}") def _register_weight_hooks(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if layer.rule is not None: forward_handle = layer.register_forward_hook( layer.rule.forward_hook_weights # type: ignore ) - # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. self.forward_handles.append(forward_handle) def _register_pre_hooks(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if layer.rule is not None: forward_handle = layer.register_forward_pre_hook( layer.rule.forward_pre_hook_activations # type: ignore ) - # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. self.forward_handles.append(forward_handle) def _compute_output_and_change_weights( self, inputs: Tuple[Tensor, ...], target: TargetType, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any, + additional_forward_args: object, ) -> Tensor: try: self._register_weight_hooks() @@ -416,15 +385,12 @@ def _compute_output_and_change_weights( return cast(Tensor, output) def _remove_forward_hooks(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. for forward_handle in self.forward_handles: forward_handle.remove() def _remove_backward_hooks(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `backward_handles`. for backward_handle in self.backward_handles: backward_handle.remove() - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if hasattr(layer.rule, "_handle_input_hooks"): for handle in layer.rule._handle_input_hooks: # type: ignore @@ -433,13 +399,11 @@ def _remove_backward_hooks(self) -> None: layer.rule._handle_output_hook.remove() # type: ignore def _remove_rules(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if hasattr(layer, "rule"): del layer.rule def _clear_properties(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if hasattr(layer, "activation"): del layer.activation