Skip to content

Commit

Permalink
Fix neuron conductance pyre fixme issues (pytorch#1458)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#1458

Differential Revision: D67523217
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 30, 2024
1 parent 2320db3 commit 3526ec8
Showing 1 changed file with 30 additions and 25 deletions.
55 changes: 30 additions & 25 deletions captum/attr/_core/neuron/neuron_conductance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
_verify_select_neuron,
)
from captum._utils.gradient import compute_layer_gradients_and_eval
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum._utils.typing import (
BaselineType,
SliceIntType,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
from captum.attr._utils.approximation_methods import approximation_parameters
from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution
from captum.attr._utils.batching import _batch_attribution
Expand All @@ -39,8 +44,7 @@ class NeuronConductance(NeuronAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: Module,
device_ids: Union[None, List[int]] = None,
multiply_by_inputs: bool = True,
Expand Down Expand Up @@ -94,8 +98,11 @@ def __init__(
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[int, ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: Optional[object] = None,
Expand Down Expand Up @@ -285,28 +292,24 @@ def attribute(
" results.",
stacklevel=1,
)
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)

# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs, baselines = _format_input_baseline(inputs, baselines)
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
_validate_input(inputs, baselines, n_steps, method)
formatted_inputs, formatted_baselines = _format_input_baseline(
inputs, baselines
)
_validate_input(formatted_inputs, formatted_baselines, n_steps, method)

num_examples = inputs[0].shape[0]
num_examples = formatted_inputs[0].shape[0]

if internal_batch_size is not None:
num_examples = inputs[0].shape[0]
num_examples = formatted_inputs[0].shape[0]
attrs = _batch_attribution(
self,
num_examples,
internal_batch_size,
n_steps,
inputs=inputs,
baselines=baselines,
inputs=formatted_inputs,
baselines=formatted_baselines,
neuron_selector=neuron_selector,
target=target,
additional_forward_args=additional_forward_args,
Expand All @@ -315,11 +318,9 @@ def attribute(
)
else:
attrs = self._attribute(
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
# got `TensorOrTupleOfTensorsGeneric`.
inputs=inputs,
inputs=formatted_inputs,
neuron_selector=neuron_selector,
baselines=baselines,
baselines=formatted_baselines,
target=target,
additional_forward_args=additional_forward_args,
n_steps=n_steps,
Expand All @@ -334,8 +335,11 @@ def attribute(
def _attribute(
self,
inputs: Tuple[Tensor, ...],
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[int, ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
baselines: Tuple[Union[Tensor, int, float], ...],
target: TargetType = None,
additional_forward_args: Optional[object] = None,
Expand Down Expand Up @@ -409,8 +413,9 @@ def _attribute(

# Aggregates across all steps for each tensor in the input tuple
total_grads = tuple(
# pyre-fixme[6]: For 4th argument expected `Tuple[int, ...]` but got `Size`.
_reshape_and_sum(scaled_grad, n_steps, num_examples, input_grad.shape[1:])
_reshape_and_sum(
scaled_grad, n_steps, num_examples, tuple(input_grad.shape[1:])
)
for (scaled_grad, input_grad) in zip(scaled_grads, input_grads)
)

Expand Down

0 comments on commit 3526ec8

Please sign in to comment.