From 968ab38c4471bfd132a30980ad9b11a5e85b3f01 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 30 Dec 2024 16:00:37 -0800 Subject: [PATCH] Fix internal influence pyre fixme issues (#1467) Summary: Fixing unresolved pyre fixme issues in corresponding file Reviewed By: craymichael Differential Revision: D67705214 --- captum/attr/_core/layer/internal_influence.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/captum/attr/_core/layer/internal_influence.py b/captum/attr/_core/layer/internal_influence.py index 47b69ffb2b..a0bbffee20 100644 --- a/captum/attr/_core/layer/internal_influence.py +++ b/captum/attr/_core/layer/internal_influence.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -41,8 +41,7 @@ class InternalInfluence(LayerAttribution, GradientAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], layer: Module, device_ids: Union[None, List[int]] = None, ) -> None: @@ -293,7 +292,7 @@ def _attribute( # Returns gradient of output with respect to hidden layer. layer_gradients, _ = compute_layer_gradients_and_eval( forward_fn=self.forward_func, - layer=self.layer, + layer=cast(Module, self.layer), inputs=scaled_features_tpl, target_ind=expanded_target, additional_forward_args=input_additional_args, @@ -304,9 +303,7 @@ def _attribute( # flattening grads so that we can multiply it with step-size # calling contiguous to avoid `memory whole` problems scaled_grads = tuple( - # pyre-fixme[16]: `tuple` has no attribute `contiguous`. layer_grad.contiguous().view(n_steps, -1) - # pyre-fixme[16]: `tuple` has no attribute `device`. * torch.tensor(step_sizes).view(n_steps, 1).to(layer_grad.device) for layer_grad in layer_gradients ) @@ -317,8 +314,7 @@ def _attribute( scaled_grad, n_steps, inputs[0].shape[0], - # pyre-fixme[16]: `tuple` has no attribute `shape`. - layer_grad.shape[1:], + tuple(layer_grad.shape[1:]), ) for scaled_grad, layer_grad in zip(scaled_grads, layer_gradients) )