diff --git a/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py b/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py index 03f3e1418..4b3720c96 100644 --- a/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py +++ b/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py @@ -4,10 +4,11 @@ from typing import Callable, List, Optional, Tuple, Union from captum._utils.gradient import construct_neuron_grad_fn -from captum._utils.typing import TensorOrTupleOfTensorsGeneric +from captum._utils.typing import SliceIntType, TensorOrTupleOfTensorsGeneric from captum.attr._core.guided_backprop_deconvnet import Deconvolution, GuidedBackprop from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution from captum.log import log_usage +from torch import Tensor from torch.nn import Module @@ -60,8 +61,11 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + neuron_selector: Union[ + int, + Tuple[Union[int, SliceIntType], ...], + Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor], + ], additional_forward_args: Optional[object] = None, attribute_to_neuron_input: bool = False, ) -> TensorOrTupleOfTensorsGeneric: @@ -215,8 +219,11 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + neuron_selector: Union[ + int, + Tuple[Union[int, SliceIntType], ...], + Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor], + ], additional_forward_args: Optional[object] = None, attribute_to_neuron_input: bool = False, ) -> TensorOrTupleOfTensorsGeneric: