From 408155a8fae39bb2c2a4d5dde83b414c9144194b Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Sun, 29 Dec 2024 21:41:12 -0800 Subject: [PATCH] Fix layer conductance pyre fixme issues Differential Revision: D67705320 --- captum/attr/_core/layer/layer_conductance.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/captum/attr/_core/layer/layer_conductance.py b/captum/attr/_core/layer/layer_conductance.py index 1f1a5f4676..0d9e4a28a6 100644 --- a/captum/attr/_core/layer/layer_conductance.py +++ b/captum/attr/_core/layer/layer_conductance.py @@ -2,7 +2,7 @@ # pyre-strict import typing -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Literal, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -44,8 +44,7 @@ class LayerConductance(LayerAttribution, GradientAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], layer: Module, device_ids: Union[None, List[int]] = None, ) -> None: @@ -73,8 +72,7 @@ def has_convergence_delta(self) -> bool: return True @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `75`. + @log_usage() def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -91,8 +89,7 @@ def attribute( ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `91`. + @log_usage() def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -108,8 +105,6 @@ def attribute( ) -> Union[Tensor, Tuple[Tensor, ...]]: ... @log_usage() - # pyre-fixme[43]: This definition does not have the same decorators as the - # preceding overload(s). def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -376,7 +371,7 @@ def _attribute( layer_evals, ) = compute_layer_gradients_and_eval( forward_fn=self.forward_func, - layer=self.layer, + layer=cast(Module, self.layer), inputs=scaled_features_tpl, additional_forward_args=input_additional_args, target_ind=expanded_target, @@ -389,8 +384,6 @@ def _attribute( # This approximates the total input gradient of each step multiplied # by the step size. grad_diffs = tuple( - # pyre-fixme[58]: `-` is not supported for operand types `Tuple[Tensor, - # ...]` and `Tuple[Tensor, ...]`. layer_eval[num_examples:] - layer_eval[:-num_examples] for layer_eval in layer_evals ) @@ -403,8 +396,7 @@ def _attribute( grad_diff * layer_gradient[:-num_examples], n_steps, num_examples, - # pyre-fixme[16]: `tuple` has no attribute `shape`. - layer_eval.shape[1:], + tuple(layer_eval.shape[1:]), ) for layer_gradient, layer_eval, grad_diff in zip( layer_gradients, layer_evals, grad_diffs