From f9006fd5a1ee03ad727bf65dc6ce76524a7a44d7 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 30 Dec 2024 06:11:39 -0800 Subject: [PATCH] Fix neuron deep lift pyre fixme issues (#1461) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1461 Differential Revision: D67704291 --- captum/attr/_core/neuron/neuron_deep_lift.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/captum/attr/_core/neuron/neuron_deep_lift.py b/captum/attr/_core/neuron/neuron_deep_lift.py index f4648b43b5..e7e3f2a77e 100644 --- a/captum/attr/_core/neuron/neuron_deep_lift.py +++ b/captum/attr/_core/neuron/neuron_deep_lift.py @@ -4,7 +4,11 @@ from typing import Callable, cast, Optional, Tuple, Union from captum._utils.gradient import construct_neuron_grad_fn -from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric +from captum._utils.typing import ( + BaselineType, + SliceIntType, + TensorOrTupleOfTensorsGeneric, +) from captum.attr._core.deep_lift import DeepLift, DeepLiftShap from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution from captum.log import log_usage @@ -79,8 +83,11 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + neuron_selector: Union[ + int, + Tuple[Union[int, SliceIntType], ...], + Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor], + ], baselines: BaselineType = None, additional_forward_args: Optional[object] = None, attribute_to_neuron_input: bool = False, @@ -309,8 +316,11 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + neuron_selector: Union[ + int, + Tuple[Union[int, SliceIntType], ...], + Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor], + ], baselines: Union[ TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] ],