Skip to content

Commit

Permalink
Fix neuron gradient shap pyre fixme issues (pytorch#1463)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#1463

Differential Revision: D67705098
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 30, 2024
1 parent b08fba7 commit 44fc19c
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions captum/attr/_core/neuron/neuron_gradient_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from typing import Callable, List, Optional, Tuple, Union

from captum._utils.gradient import construct_neuron_grad_fn
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
from captum._utils.typing import SliceIntType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.gradient_shap import GradientShap
from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution
from captum.log import log_usage
from torch import Tensor
from torch.nn import Module


Expand Down Expand Up @@ -50,8 +51,7 @@ class NeuronGradientShap(NeuronAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Union[int, float, Tensor]],
layer: Module,
device_ids: Union[None, List[int]] = None,
multiply_by_inputs: bool = True,
Expand Down Expand Up @@ -97,8 +97,11 @@ def __init__(
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
baselines: Union[
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
],
Expand Down

0 comments on commit 44fc19c

Please sign in to comment.