Skip to content

Commit

Permalink
Fix neuron_integrated_gradients pyre fixme issues (pytorch#1457)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#1457

Differential Revision: D67523072
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 30, 2024
1 parent 9a7ef2e commit 01120d3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
8 changes: 6 additions & 2 deletions captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from captum._utils.sample_gradient import SampleGradientWrapper
from captum._utils.typing import (
ModuleOrModuleList,
SliceIntType,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
Expand Down Expand Up @@ -775,8 +776,11 @@ def compute_layer_gradients_and_eval(

def construct_neuron_grad_fn(
layer: Module,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
device_ids: Union[None, List[int]] = None,
attribute_to_neuron_input: bool = False,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
Expand Down
7 changes: 7 additions & 0 deletions captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@
TensorLikeList5D,
]

try:
# Subscripted slice syntax is not supported in previous Python versions,
# falling back to slice type.
SliceIntType = slice[int, int, int]
except TypeError:
# pyre-fixme[24]: Generic type `slice` expects 3 type parameters.
SliceIntType = slice # type: ignore

# Necessary for Python >=3.7 and <3.9!
if TYPE_CHECKING:
Expand Down
12 changes: 7 additions & 5 deletions captum/attr/_core/neuron/neuron_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, List, Optional, Tuple, Union

from captum._utils.gradient import construct_neuron_grad_fn
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
from captum._utils.typing import SliceIntType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.integrated_gradients import IntegratedGradients
from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution
from captum.log import log_usage
Expand All @@ -27,8 +27,7 @@ class NeuronIntegratedGradients(NeuronAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: Module,
device_ids: Union[None, List[int]] = None,
multiply_by_inputs: bool = True,
Expand Down Expand Up @@ -76,8 +75,11 @@ def __init__(
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
baselines: Union[None, Tensor, Tuple[Tensor, ...]] = None,
additional_forward_args: Optional[object] = None,
n_steps: int = 50,
Expand Down

0 comments on commit 01120d3

Please sign in to comment.