diff --git a/captum/attr/_core/input_x_gradient.py b/captum/attr/_core/input_x_gradient.py index 86115bb03..bfaa75def 100644 --- a/captum/attr/_core/input_x_gradient.py +++ b/captum/attr/_core/input_x_gradient.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable +from typing import Callable from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple from captum._utils.gradient import ( @@ -11,6 +11,7 @@ from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import GradientAttribution from captum.log import log_usage +from torch import Tensor class InputXGradient(GradientAttribution): @@ -20,8 +21,7 @@ class InputXGradient(GradientAttribution): https://arxiv.org/abs/1605.01713 """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def __init__(self, forward_func: Callable) -> None: + def __init__(self, forward_func: Callable[..., Tensor]) -> None: r""" Args: @@ -35,8 +35,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, ) -> TensorOrTupleOfTensorsGeneric: r""" Args: @@ -113,28 +112,20 @@ def attribute( """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs = _format_tensor_into_tuples(inputs) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - gradient_mask = apply_gradient_requirements(inputs) + inputs_tuple = _format_tensor_into_tuples(inputs) + gradient_mask = apply_gradient_requirements(inputs_tuple) gradients = self.gradient_func( - self.forward_func, inputs, target, additional_forward_args + self.forward_func, inputs_tuple, target, additional_forward_args ) attributions = tuple( - input * gradient for input, gradient in zip(inputs, gradients) + input * gradient for input, gradient in zip(inputs_tuple, gradients) ) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - undo_gradient_requirements(inputs, gradient_mask) + undo_gradient_requirements(inputs_tuple, gradient_mask) # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got # `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, attributions)