From 1373e929cab90f69892aa522f1a3270f819053d8 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 30 Dec 2024 16:00:41 -0800 Subject: [PATCH] Fix neuron gradient shap pyre fixme issues (#1463) Summary: Fixing unresolved pyre fixme issues in corresponding file Reviewed By: craymichael Differential Revision: D67705098 --- captum/attr/_core/neuron/neuron_gradient_shap.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/captum/attr/_core/neuron/neuron_gradient_shap.py b/captum/attr/_core/neuron/neuron_gradient_shap.py index 18b6507231..b0b82084f5 100644 --- a/captum/attr/_core/neuron/neuron_gradient_shap.py +++ b/captum/attr/_core/neuron/neuron_gradient_shap.py @@ -4,10 +4,11 @@ from typing import Callable, List, Optional, Tuple, Union from captum._utils.gradient import construct_neuron_grad_fn -from captum._utils.typing import TensorOrTupleOfTensorsGeneric +from captum._utils.typing import SliceIntType, TensorOrTupleOfTensorsGeneric from captum.attr._core.gradient_shap import GradientShap from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution from captum.log import log_usage +from torch import Tensor from torch.nn import Module @@ -50,8 +51,7 @@ class NeuronGradientShap(NeuronAttribution, GradientAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Union[int, float, Tensor]], layer: Module, device_ids: Union[None, List[int]] = None, multiply_by_inputs: bool = True, @@ -97,8 +97,11 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + neuron_selector: Union[ + int, + Tuple[Union[int, SliceIntType], ...], + Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor], + ], baselines: Union[ TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] ],