From 4aee7ce2408c1df9757260dc854fe0121e9dfe1b Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 30 Dec 2024 07:38:08 -0800 Subject: [PATCH] Fix neuron guided backprop pyre fixme issues (#1465) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1465 Differential Revision: D67705091 --- .../neuron/neuron_guided_backprop_deconvnet.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py b/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py index 03f3e14186..4b3720c96f 100644 --- a/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py +++ b/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py @@ -4,10 +4,11 @@ from typing import Callable, List, Optional, Tuple, Union from captum._utils.gradient import construct_neuron_grad_fn -from captum._utils.typing import TensorOrTupleOfTensorsGeneric +from captum._utils.typing import SliceIntType, TensorOrTupleOfTensorsGeneric from captum.attr._core.guided_backprop_deconvnet import Deconvolution, GuidedBackprop from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution from captum.log import log_usage +from torch import Tensor from torch.nn import Module @@ -60,8 +61,11 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + neuron_selector: Union[ + int, + Tuple[Union[int, SliceIntType], ...], + Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor], + ], additional_forward_args: Optional[object] = None, attribute_to_neuron_input: bool = False, ) -> TensorOrTupleOfTensorsGeneric: @@ -215,8 +219,11 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + neuron_selector: Union[ + int, + Tuple[Union[int, SliceIntType], ...], + Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor], + ], additional_forward_args: Optional[object] = None, attribute_to_neuron_input: bool = False, ) -> TensorOrTupleOfTensorsGeneric: