Skip to content

Commit

Permalink
Fix neuron deep lift pyre fixme issues (pytorch#1461)
Browse files Browse the repository at this point in the history
Summary:


Fixing unresolved pyre fixme issues in corresponding file

Reviewed By: craymichael

Differential Revision: D67704291
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 31, 2024
1 parent 8ca325f commit 2ac9060
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions captum/attr/_core/neuron/neuron_deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from typing import Callable, cast, Optional, Tuple, Union

from captum._utils.gradient import construct_neuron_grad_fn
from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric
from captum._utils.typing import (
BaselineType,
SliceIntType,
TensorOrTupleOfTensorsGeneric,
)
from captum.attr._core.deep_lift import DeepLift, DeepLiftShap
from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution
from captum.log import log_usage
Expand Down Expand Up @@ -79,8 +83,11 @@ def __init__(
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
baselines: BaselineType = None,
additional_forward_args: Optional[object] = None,
attribute_to_neuron_input: bool = False,
Expand Down Expand Up @@ -309,8 +316,11 @@ def __init__(
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
baselines: Union[
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
],
Expand Down

0 comments on commit 2ac9060

Please sign in to comment.