diff --git a/captum/attr/_core/layer/internal_influence.py b/captum/attr/_core/layer/internal_influence.py index 47b69ffb2..a0bbffee2 100644 --- a/captum/attr/_core/layer/internal_influence.py +++ b/captum/attr/_core/layer/internal_influence.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -41,8 +41,7 @@ class InternalInfluence(LayerAttribution, GradientAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], layer: Module, device_ids: Union[None, List[int]] = None, ) -> None: @@ -293,7 +292,7 @@ def _attribute( # Returns gradient of output with respect to hidden layer. layer_gradients, _ = compute_layer_gradients_and_eval( forward_fn=self.forward_func, - layer=self.layer, + layer=cast(Module, self.layer), inputs=scaled_features_tpl, target_ind=expanded_target, additional_forward_args=input_additional_args, @@ -304,9 +303,7 @@ def _attribute( # flattening grads so that we can multiply it with step-size # calling contiguous to avoid `memory whole` problems scaled_grads = tuple( - # pyre-fixme[16]: `tuple` has no attribute `contiguous`. layer_grad.contiguous().view(n_steps, -1) - # pyre-fixme[16]: `tuple` has no attribute `device`. * torch.tensor(step_sizes).view(n_steps, 1).to(layer_grad.device) for layer_grad in layer_gradients ) @@ -317,8 +314,7 @@ def _attribute( scaled_grad, n_steps, inputs[0].shape[0], - # pyre-fixme[16]: `tuple` has no attribute `shape`. - layer_grad.shape[1:], + tuple(layer_grad.shape[1:]), ) for scaled_grad, layer_grad in zip(scaled_grads, layer_gradients) )