Skip to content

Commit

Permalink
Fix pyre errors in InputXGradient (pytorch#1397)
Browse files Browse the repository at this point in the history
Summary:

Initial work on fixing Pyre errors in InputXGradient

Reviewed By: csauper

Differential Revision: D64677348
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Oct 22, 2024
1 parent 848e805 commit 5aeee56
Showing 1 changed file with 9 additions and 18 deletions.
27 changes: 9 additions & 18 deletions captum/attr/_core/input_x_gradient.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

# pyre-strict
from typing import Any, Callable
from typing import Callable

from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple
from captum._utils.gradient import (
Expand All @@ -11,6 +11,7 @@
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import GradientAttribution
from captum.log import log_usage
from torch import Tensor


class InputXGradient(GradientAttribution):
Expand All @@ -20,8 +21,7 @@ class InputXGradient(GradientAttribution):
https://arxiv.org/abs/1605.01713
"""

# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def __init__(self, forward_func: Callable) -> None:
def __init__(self, forward_func: Callable[..., Tensor]) -> None:
r"""
Args:
Expand All @@ -35,8 +35,7 @@ def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
) -> TensorOrTupleOfTensorsGeneric:
r"""
Args:
Expand Down Expand Up @@ -113,28 +112,20 @@ def attribute(
"""
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)

# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs = _format_tensor_into_tuples(inputs)
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
gradient_mask = apply_gradient_requirements(inputs)
inputs_tuple = _format_tensor_into_tuples(inputs)
gradient_mask = apply_gradient_requirements(inputs_tuple)

gradients = self.gradient_func(
self.forward_func, inputs, target, additional_forward_args
self.forward_func, inputs_tuple, target, additional_forward_args
)

attributions = tuple(
input * gradient for input, gradient in zip(inputs, gradients)
input * gradient for input, gradient in zip(inputs_tuple, gradients)
)

# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
undo_gradient_requirements(inputs, gradient_mask)
undo_gradient_requirements(inputs_tuple, gradient_mask)
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
# `Tuple[Tensor, ...]`.
return _format_output(is_inputs_tuple, attributions)
Expand Down

0 comments on commit 5aeee56

Please sign in to comment.