diff --git a/captum/attr/_core/neuron/neuron_gradient.py b/captum/attr/_core/neuron/neuron_gradient.py index 0e74382d3..b806c1f4c 100644 --- a/captum/attr/_core/neuron/neuron_gradient.py +++ b/captum/attr/_core/neuron/neuron_gradient.py @@ -14,9 +14,10 @@ apply_gradient_requirements, undo_gradient_requirements, ) -from captum._utils.typing import TensorOrTupleOfTensorsGeneric +from captum._utils.typing import SliceIntType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution from captum.log import log_usage +from torch import Tensor from torch.nn import Module @@ -28,8 +29,7 @@ class NeuronGradient(NeuronAttribution, GradientAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Union[int, float, Tensor]], layer: Module, device_ids: Union[None, List[int]] = None, ) -> None: @@ -60,8 +60,11 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + neuron_selector: Union[ + int, + Tuple[Union[int, SliceIntType], ...], + Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor], + ], additional_forward_args: Optional[object] = None, attribute_to_neuron_input: bool = False, ) -> TensorOrTupleOfTensorsGeneric: @@ -162,18 +165,12 @@ def attribute( >>> # index (4,1,2). >>> attribution = neuron_ig.attribute(input, (4,1,2)) """ - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs = _format_tensor_into_tuples(inputs) + inputs_tuple = _format_tensor_into_tuples(inputs) additional_forward_args = _format_additional_forward_args( additional_forward_args ) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - gradient_mask = apply_gradient_requirements(inputs) + gradient_mask = apply_gradient_requirements(inputs_tuple) _, input_grads = _forward_layer_eval_with_neuron_grads( self.forward_func, @@ -185,9 +182,9 @@ def attribute( attribute_to_layer_input=attribute_to_neuron_input, ) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - undo_gradient_requirements(inputs, gradient_mask) - # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got - # `Tuple[Tensor, ...]`. + undo_gradient_requirements(inputs_tuple, gradient_mask) + + # pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <: + # [Tensor, typing.Tuple[Tensor, ...]]]` but got `Union[Tensor, + # typing.Tuple[Tensor, ...]]`. return _format_output(is_inputs_tuple, input_grads)