Skip to content

Commit

Permalink
Fix neuron gradient pyre fixme issues (pytorch#1464)
Browse files Browse the repository at this point in the history
Summary:

Fixing unresolved pyre fixme issues in corresponding file

Reviewed By: craymichael

Differential Revision: D67704365
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 30, 2024
1 parent a309fe2 commit 6ce1522
Showing 1 changed file with 15 additions and 18 deletions.
33 changes: 15 additions & 18 deletions captum/attr/_core/neuron/neuron_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
apply_gradient_requirements,
undo_gradient_requirements,
)
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
from captum._utils.typing import SliceIntType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution
from captum.log import log_usage
from torch import Tensor
from torch.nn import Module


Expand All @@ -28,8 +29,7 @@ class NeuronGradient(NeuronAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Union[int, float, Tensor]],
layer: Module,
device_ids: Union[None, List[int]] = None,
) -> None:
Expand Down Expand Up @@ -60,8 +60,11 @@ def __init__(
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
additional_forward_args: Optional[object] = None,
attribute_to_neuron_input: bool = False,
) -> TensorOrTupleOfTensorsGeneric:
Expand Down Expand Up @@ -162,18 +165,12 @@ def attribute(
>>> # index (4,1,2).
>>> attribution = neuron_ig.attribute(input, (4,1,2))
"""
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs = _format_tensor_into_tuples(inputs)
inputs_tuple = _format_tensor_into_tuples(inputs)
additional_forward_args = _format_additional_forward_args(
additional_forward_args
)
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
gradient_mask = apply_gradient_requirements(inputs)
gradient_mask = apply_gradient_requirements(inputs_tuple)

_, input_grads = _forward_layer_eval_with_neuron_grads(
self.forward_func,
Expand All @@ -185,9 +182,9 @@ def attribute(
attribute_to_layer_input=attribute_to_neuron_input,
)

# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
undo_gradient_requirements(inputs, gradient_mask)
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
# `Tuple[Tensor, ...]`.
undo_gradient_requirements(inputs_tuple, gradient_mask)

# pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <:
# [Tensor, typing.Tuple[Tensor, ...]]]` but got `Union[Tensor,
# typing.Tuple[Tensor, ...]]`.
return _format_output(is_inputs_tuple, input_grads)

0 comments on commit 6ce1522

Please sign in to comment.