From d0da72766ef909a224f8eefa28646294c65951f7 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 30 Dec 2024 16:00:37 -0800 Subject: [PATCH] Fix layer integrated gradients pyre fixme issues (#1473) Summary: Fixing unresolved pyre fixme issues in corresponding file Differential Revision: D67706224 --- .../_core/layer/layer_integrated_gradients.py | 72 ++++++++++--------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/captum/attr/_core/layer/layer_integrated_gradients.py b/captum/attr/_core/layer/layer_integrated_gradients.py index 406acef96..6590fa75e 100644 --- a/captum/attr/_core/layer/layer_integrated_gradients.py +++ b/captum/attr/_core/layer/layer_integrated_gradients.py @@ -33,6 +33,7 @@ ) from captum.log import log_usage from torch import Tensor +from torch.nn import Module from torch.nn.parallel.scatter_gather import scatter @@ -58,8 +59,7 @@ class LayerIntegratedGradients(LayerAttribution, GradientAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], layer: ModuleOrModuleList, device_ids: Union[None, List[int]] = None, multiply_by_inputs: bool = True, @@ -128,8 +128,7 @@ def _make_gradient_func( ) -> Callable[..., Tuple[Tensor, ...]]: def _gradient_func( - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_fn: Callable, + forward_fn: Callable[..., Tensor], inputs: Union[Tensor, Tuple[Tensor, ...]], target_ind: TargetType = None, additional_forward_args: Optional[object] = None, @@ -146,28 +145,21 @@ def _gradient_func( target_gpus=self.device_ids, ) - scattered_inputs_dict = { + scattered_inputs_dict: Dict[ + torch.device, Union[Tensor, Tuple[Tensor, ...]] + ] = { scattered_input[0].device: scattered_input for scattered_input in scattered_inputs } with torch.autograd.set_grad_enabled(True): - # pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not - # annotated. - # pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not - # annotated. - # pyre-fixme[3]: Return type must be annotated. def layer_forward_hook( - # pyre-fixme[2]: Parameter must be annotated. - module, - # pyre-fixme[2]: Parameter must be annotated. - hook_inputs, - # pyre-fixme[2]: Parameter must be annotated. - hook_outputs=None, - # pyre-fixme[2]: Parameter must be annotated. - layer_idx=0, - ): + module: Module, + hook_inputs: Union[Tensor, Tuple[Tensor, ...]], + hook_outputs: Union[None, Tensor, Tuple[Tensor, ...]] = None, + layer_idx: int = 0, + ) -> Union[Tensor, Tuple[Tensor, ...]]: device = _extract_device(module, hook_inputs, hook_outputs) is_layer_tuple = ( isinstance(hook_outputs, tuple) @@ -177,11 +169,14 @@ def layer_forward_hook( ) if is_layer_tuple: - return scattered_inputs_dict[device][ - num_outputs_cumsum[layer_idx] : num_outputs_cumsum[ - layer_idx + 1 - ] - ] + return cast( + Union[Tensor, Tuple[Tensor, ...]], + scattered_inputs_dict[device][ + num_outputs_cumsum[layer_idx] : num_outputs_cumsum[ + layer_idx + 1 + ] + ], + ) return scattered_inputs_dict[device][num_outputs_cumsum[layer_idx]] @@ -502,11 +497,22 @@ def attribute( additional_forward_args ) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def flatten_tuple(tup): + def flatten_tuple(tup: List[Tuple[Tensor, ...]]) -> Tuple[Tensor, ...]: return tuple( - sum((list(x) if isinstance(x, (tuple, list)) else [x] for x in tup), []) + cast( + List[Tensor], + sum( + ( + ( + list(x) + if isinstance(x, (tuple, list)) + else cast(List[Tensor], [x]) + ) + for x in tup + ), + [], + ), + ) ) if self.device_ids is None: @@ -520,16 +526,18 @@ def flatten_tuple(tup): additional_forward_args=additional_forward_args, attribute_to_layer_input=attribute_to_layer_input, ) - + input_layer_list: List[Tuple[Tensor, ...]] # if we have one output if not isinstance(self.layer, list): - inputs_layer = (inputs_layer,) + input_layer_list = [cast(Tuple[Tensor, ...], inputs_layer)] + else: + input_layer_list = inputs_layer - num_outputs = [1 if isinstance(x, Tensor) else len(x) for x in inputs_layer] + num_outputs = [1 if isinstance(x, Tensor) else len(x) for x in input_layer_list] num_outputs_cumsum = torch.cumsum( torch.IntTensor([0] + num_outputs), dim=0 # type: ignore ) - inputs_layer = flatten_tuple(inputs_layer) + inputs_layer = flatten_tuple(input_layer_list) baselines_layer = _forward_layer_eval( self.forward_func,