Skip to content

Commit

Permalink
Fix neuron_integrated_gradients pyre fixme issues (#1457)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #1457

Differential Revision: D67523072
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 22, 2024
1 parent dd3fa2b commit 4c154f6
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions captum/attr/_core/neuron/neuron_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ class NeuronIntegratedGradients(NeuronAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: Module,
device_ids: Union[None, List[int]] = None,
multiply_by_inputs: bool = True,
Expand Down Expand Up @@ -76,8 +75,11 @@ def __init__(
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, slice[int, int, int]], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
baselines: Union[None, Tensor, Tuple[Tensor, ...]] = None,
additional_forward_args: Optional[object] = None,
n_steps: int = 50,
Expand Down

0 comments on commit 4c154f6

Please sign in to comment.