Skip to content

Commit

Permalink
Fix pyre errors in Saliency (pytorch#1404)
Browse files Browse the repository at this point in the history
Summary:

Initial work on fixing Pyre errors in Shapley Values

Reviewed By: craymichael

Differential Revision: D64677352
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Oct 22, 2024
1 parent 587bd8a commit ea2c054
Showing 1 changed file with 8 additions and 17 deletions.
25 changes: 8 additions & 17 deletions captum/attr/_core/saliency.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict

from typing import Any, Callable
from typing import Callable

import torch
from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple
Expand All @@ -13,6 +13,7 @@
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import GradientAttribution
from captum.log import log_usage
from torch import Tensor


class Saliency(GradientAttribution):
Expand All @@ -25,8 +26,7 @@ class Saliency(GradientAttribution):
https://arxiv.org/abs/1312.6034
"""

# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def __init__(self, forward_func: Callable) -> None:
def __init__(self, forward_func: Callable[..., Tensor]) -> None:
r"""
Args:
Expand All @@ -41,8 +41,7 @@ def attribute(
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
abs: bool = True,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
) -> TensorOrTupleOfTensorsGeneric:
r"""
Args:
Expand Down Expand Up @@ -124,29 +123,21 @@ def attribute(
"""
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)

# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs = _format_tensor_into_tuples(inputs)
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
gradient_mask = apply_gradient_requirements(inputs)
inputs_tuple = _format_tensor_into_tuples(inputs)
gradient_mask = apply_gradient_requirements(inputs_tuple)

# No need to format additional_forward_args here.
# They are being formated in the `_run_forward` function in `common.py`
gradients = self.gradient_func(
self.forward_func, inputs, target, additional_forward_args
self.forward_func, inputs_tuple, target, additional_forward_args
)
if abs:
attributions = tuple(torch.abs(gradient) for gradient in gradients)
else:
attributions = gradients
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
undo_gradient_requirements(inputs, gradient_mask)
undo_gradient_requirements(inputs_tuple, gradient_mask)
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
# `Tuple[Tensor, ...]`.
return _format_output(is_inputs_tuple, attributions)
Expand Down

0 comments on commit ea2c054

Please sign in to comment.