Skip to content

Commit

Permalink
Fix layer conductance pyre fixme issues
Browse files Browse the repository at this point in the history
Differential Revision: D67705320
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 30, 2024
1 parent dffd461 commit 408155a
Showing 1 changed file with 6 additions and 14 deletions.
20 changes: 6 additions & 14 deletions captum/attr/_core/layer/layer_conductance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict
import typing
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Literal, Optional, Tuple, Union

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -44,8 +44,7 @@ class LayerConductance(LayerAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: Module,
device_ids: Union[None, List[int]] = None,
) -> None:
Expand Down Expand Up @@ -73,8 +72,7 @@ def has_convergence_delta(self) -> bool:
return True

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `75`.
@log_usage()
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand All @@ -91,8 +89,7 @@ def attribute(
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `91`.
@log_usage()
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand All @@ -108,8 +105,6 @@ def attribute(
) -> Union[Tensor, Tuple[Tensor, ...]]: ...

@log_usage()
# pyre-fixme[43]: This definition does not have the same decorators as the
# preceding overload(s).
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand Down Expand Up @@ -376,7 +371,7 @@ def _attribute(
layer_evals,
) = compute_layer_gradients_and_eval(
forward_fn=self.forward_func,
layer=self.layer,
layer=cast(Module, self.layer),
inputs=scaled_features_tpl,
additional_forward_args=input_additional_args,
target_ind=expanded_target,
Expand All @@ -389,8 +384,6 @@ def _attribute(
# This approximates the total input gradient of each step multiplied
# by the step size.
grad_diffs = tuple(
# pyre-fixme[58]: `-` is not supported for operand types `Tuple[Tensor,
# ...]` and `Tuple[Tensor, ...]`.
layer_eval[num_examples:] - layer_eval[:-num_examples]
for layer_eval in layer_evals
)
Expand All @@ -403,8 +396,7 @@ def _attribute(
grad_diff * layer_gradient[:-num_examples],
n_steps,
num_examples,
# pyre-fixme[16]: `tuple` has no attribute `shape`.
layer_eval.shape[1:],
tuple(layer_eval.shape[1:]),
)
for layer_gradient, layer_eval, grad_diff in zip(
layer_gradients, layer_evals, grad_diffs
Expand Down

0 comments on commit 408155a

Please sign in to comment.