diff --git a/captum/attr/_core/layer/layer_gradient_x_activation.py b/captum/attr/_core/layer/layer_gradient_x_activation.py index c828a262e..f56265c2e 100644 --- a/captum/attr/_core/layer/layer_gradient_x_activation.py +++ b/captum/attr/_core/layer/layer_gradient_x_activation.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union from captum._utils.common import ( _format_additional_forward_args, @@ -24,8 +24,7 @@ class LayerGradientXActivation(LayerAttribution, GradientAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], layer: ModuleOrModuleList, device_ids: Union[None, List[int]] = None, multiply_by_inputs: bool = True, @@ -186,11 +185,10 @@ def attribute( if isinstance(self.layer, Module): return _format_output( len(layer_evals) > 1, - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but - # got `List[typing.Tuple[Tensor, ...]]`. - # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but - # got `List[typing.Tuple[Tensor, ...]]`. - self.multiply_gradient_acts(layer_gradients, layer_evals), + self.multiply_gradient_acts( + cast(Tuple[Tensor, ...], layer_gradients), + cast(Tuple[Tensor, ...], layer_evals), + ), ) else: return [