diff --git a/captum/_utils/gradient.py b/captum/_utils/gradient.py index 69502b7443..1e2b827ab4 100644 --- a/captum/_utils/gradient.py +++ b/captum/_utils/gradient.py @@ -28,6 +28,7 @@ from captum._utils.sample_gradient import SampleGradientWrapper from captum._utils.typing import ( ModuleOrModuleList, + SliceIntType, TargetType, TensorOrTupleOfTensorsGeneric, ) @@ -775,8 +776,11 @@ def compute_layer_gradients_and_eval( def construct_neuron_grad_fn( layer: Module, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + neuron_selector: Union[ + int, + Tuple[Union[int, SliceIntType], ...], + Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor], + ], device_ids: Union[None, List[int]] = None, attribute_to_neuron_input: bool = False, # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. diff --git a/captum/_utils/typing.py b/captum/_utils/typing.py index 10a2385611..80cf22d451 100644 --- a/captum/_utils/typing.py +++ b/captum/_utils/typing.py @@ -41,6 +41,13 @@ TensorLikeList5D, ] +try: + # Subscripted slice syntax is not supported in previous Python versions, + # falling back to slice type. + SliceIntType = slice[int, int, int] +except TypeError: + # pyre-fixme[24]: Generic type `slice` expects 3 type parameters. + SliceIntType = slice # type: ignore # Necessary for Python >=3.7 and <3.9! if TYPE_CHECKING: diff --git a/captum/attr/_core/neuron/neuron_integrated_gradients.py b/captum/attr/_core/neuron/neuron_integrated_gradients.py index 8e56221d77..0e4504bee9 100644 --- a/captum/attr/_core/neuron/neuron_integrated_gradients.py +++ b/captum/attr/_core/neuron/neuron_integrated_gradients.py @@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Tuple, Union from captum._utils.gradient import construct_neuron_grad_fn -from captum._utils.typing import TensorOrTupleOfTensorsGeneric +from captum._utils.typing import SliceIntType, TensorOrTupleOfTensorsGeneric from captum.attr._core.integrated_gradients import IntegratedGradients from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution from captum.log import log_usage @@ -27,8 +27,7 @@ class NeuronIntegratedGradients(NeuronAttribution, GradientAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], layer: Module, device_ids: Union[None, List[int]] = None, multiply_by_inputs: bool = True, @@ -76,8 +75,11 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + neuron_selector: Union[ + int, + Tuple[Union[int, SliceIntType], ...], + Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor], + ], baselines: Union[None, Tensor, Tuple[Tensor, ...]] = None, additional_forward_args: Optional[object] = None, n_steps: int = 50,