Skip to content

Commit

Permalink
Fix pyre errors in LRP (pytorch#1401)
Browse files Browse the repository at this point in the history
Summary:

Initial work on fixing Pyre errors in LRP

Reviewed By: craymichael

Differential Revision: D64677351
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Oct 22, 2024
1 parent 5615e43 commit ba7cf86
Showing 1 changed file with 18 additions and 54 deletions.
72 changes: 18 additions & 54 deletions captum/attr/_core/lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import typing
from collections import defaultdict
from typing import Any, Callable, cast, List, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Literal, Tuple, Union

import torch.nn as nn
from captum._utils.common import (
Expand All @@ -18,7 +18,7 @@
apply_gradient_requirements,
undo_gradient_requirements,
)
from captum._utils.typing import Literal, TargetType, TensorOrTupleOfTensorsGeneric
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import GradientAttribution
from captum.attr._utils.common import _sum_rows
from captum.attr._utils.custom_modules import Addition_Module
Expand All @@ -43,6 +43,12 @@ class LRP(GradientAttribution):
Ancona et al. [https://openreview.net/forum?id=Sy21R9JAW].
"""

verbose: bool = False
_original_state_dict: Dict[str, Any] = {}
layers: List[Module] = []
backward_handles: List[RemovableHandle] = []
forward_handles: List[RemovableHandle] = []

def __init__(self, model: Module) -> None:
r"""
Args:
Expand All @@ -62,33 +68,22 @@ def multiplies_by_inputs(self) -> bool:
return True

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `75`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
verbose: bool = False,
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `65`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
additional_forward_args: object = None,
return_convergence_delta: Literal[False] = False,
verbose: bool = False,
) -> TensorOrTupleOfTensorsGeneric: ...
Expand All @@ -100,7 +95,7 @@ def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
additional_forward_args: Any = None,
additional_forward_args: object = None,
return_convergence_delta: bool = False,
verbose: bool = False,
) -> Union[
Expand Down Expand Up @@ -199,43 +194,30 @@ def attribute(
>>> attribution = lrp.attribute(input, target=5)
"""
# pyre-fixme[16]: `LRP` has no attribute `verbose`.
self.verbose = verbose
# pyre-fixme[16]: `LRP` has no attribute `_original_state_dict`.
self._original_state_dict = self.model.state_dict()
# pyre-fixme[16]: `LRP` has no attribute `layers`.
self.layers: List[Module] = []
self.layers = []
self._get_layers(self.model)
self._check_and_attach_rules()
# pyre-fixme[16]: `LRP` has no attribute `backward_handles`.
self.backward_handles: List[RemovableHandle] = []
# pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
self.forward_handles: List[RemovableHandle] = []

# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs = _format_tensor_into_tuples(inputs)
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
gradient_mask = apply_gradient_requirements(inputs)
input_tuple = _format_tensor_into_tuples(inputs)
gradient_mask = apply_gradient_requirements(input_tuple)

try:
# 1. Forward pass: Change weights of layers according to selected rules.
output = self._compute_output_and_change_weights(
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
# got `TensorOrTupleOfTensorsGeneric`.
inputs,
input_tuple,
target,
additional_forward_args,
)
# 2. Forward pass + backward pass: Register hooks to configure relevance
# propagation and execute back-propagation.
self._register_forward_hooks()
normalized_relevances = self.gradient_func(
self._forward_fn_wrapper, inputs, target, additional_forward_args
self._forward_fn_wrapper, input_tuple, target, additional_forward_args
)
relevances = tuple(
normalized_relevance
Expand All @@ -245,9 +227,7 @@ def attribute(
finally:
self._restore_model()

# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
undo_gradient_requirements(inputs, gradient_mask)
undo_gradient_requirements(input_tuple, gradient_mask)

if return_convergence_delta:
# pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen...
Expand Down Expand Up @@ -310,13 +290,11 @@ def compute_convergence_delta(
def _get_layers(self, model: Module) -> None:
for layer in model.children():
if len(list(layer.children())) == 0:
# pyre-fixme[16]: `LRP` has no attribute `layers`.
self.layers.append(layer)
else:
self._get_layers(layer)

def _check_and_attach_rules(self) -> None:
# pyre-fixme[16]: `LRP` has no attribute `layers`.
for layer in self.layers:
if hasattr(layer, "rule"):
layer.activations = {} # type: ignore
Expand Down Expand Up @@ -355,50 +333,41 @@ def _check_rules(self) -> None:
)

def _register_forward_hooks(self) -> None:
# pyre-fixme[16]: `LRP` has no attribute `layers`.
for layer in self.layers:
if type(layer) in SUPPORTED_NON_LINEAR_LAYERS:
backward_handles = _register_backward_hook(
layer, PropagationRule.backward_hook_activation, self
)
# pyre-fixme[16]: `LRP` has no attribute `backward_handles`.
self.backward_handles.extend(backward_handles)
else:
forward_handle = layer.register_forward_hook(
layer.rule.forward_hook # type: ignore
)
# pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
self.forward_handles.append(forward_handle)
# pyre-fixme[16]: `LRP` has no attribute `verbose`.
if self.verbose:
print(f"Applied {layer.rule} on layer {layer}")

def _register_weight_hooks(self) -> None:
# pyre-fixme[16]: `LRP` has no attribute `layers`.
for layer in self.layers:
if layer.rule is not None:
forward_handle = layer.register_forward_hook(
layer.rule.forward_hook_weights # type: ignore
)
# pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
self.forward_handles.append(forward_handle)

def _register_pre_hooks(self) -> None:
# pyre-fixme[16]: `LRP` has no attribute `layers`.
for layer in self.layers:
if layer.rule is not None:
forward_handle = layer.register_forward_pre_hook(
layer.rule.forward_pre_hook_activations # type: ignore
)
# pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
self.forward_handles.append(forward_handle)

def _compute_output_and_change_weights(
self,
inputs: Tuple[Tensor, ...],
target: TargetType,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any,
additional_forward_args: object,
) -> Tensor:
try:
self._register_weight_hooks()
Expand All @@ -416,15 +385,12 @@ def _compute_output_and_change_weights(
return cast(Tensor, output)

def _remove_forward_hooks(self) -> None:
# pyre-fixme[16]: `LRP` has no attribute `forward_handles`.
for forward_handle in self.forward_handles:
forward_handle.remove()

def _remove_backward_hooks(self) -> None:
# pyre-fixme[16]: `LRP` has no attribute `backward_handles`.
for backward_handle in self.backward_handles:
backward_handle.remove()
# pyre-fixme[16]: `LRP` has no attribute `layers`.
for layer in self.layers:
if hasattr(layer.rule, "_handle_input_hooks"):
for handle in layer.rule._handle_input_hooks: # type: ignore
Expand All @@ -433,13 +399,11 @@ def _remove_backward_hooks(self) -> None:
layer.rule._handle_output_hook.remove() # type: ignore

def _remove_rules(self) -> None:
# pyre-fixme[16]: `LRP` has no attribute `layers`.
for layer in self.layers:
if hasattr(layer, "rule"):
del layer.rule

def _clear_properties(self) -> None:
# pyre-fixme[16]: `LRP` has no attribute `layers`.
for layer in self.layers:
if hasattr(layer, "activation"):
del layer.activation
Expand Down

0 comments on commit ba7cf86

Please sign in to comment.