Skip to content

Commit

Permalink
Fix pyre errors in Integrated Gradients (pytorch#1398)
Browse files Browse the repository at this point in the history
Summary:

Initial work on fixing Pyre errors in Integrated Gradients

Reviewed By: csauper

Differential Revision: D64677345
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Oct 22, 2024
1 parent b216194 commit a33c834
Showing 1 changed file with 14 additions and 31 deletions.
45 changes: 14 additions & 31 deletions captum/attr/_core/integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict
import typing
from typing import Any, Callable, List, Tuple, Union
from typing import Any, Callable, List, Literal, Tuple, Union

import torch
from captum._utils.common import (
Expand All @@ -12,12 +12,7 @@
_format_output,
_is_tuple,
)
from captum._utils.typing import (
BaselineType,
Literal,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.approximation_methods import approximation_parameters
from captum.attr._utils.attribution import GradientAttribution
from captum.attr._utils.batching import _batch_attribution
Expand Down Expand Up @@ -49,8 +44,7 @@ class IntegratedGradients(GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
multiply_by_inputs: bool = True,
) -> None:
r"""
Expand Down Expand Up @@ -80,21 +74,16 @@ def __init__(
# and when return_convergence_delta is True, the return type is
# a tuple with both attributions and deltas.
@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `95`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
n_steps: int = 50,
method: str = "gausslegendre",
internal_batch_size: Union[None, int] = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...

Expand All @@ -111,9 +100,6 @@ def attribute(
n_steps: int = 50,
method: str = "gausslegendre",
internal_batch_size: Union[None, int] = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
) -> TensorOrTupleOfTensorsGeneric: ...

Expand Down Expand Up @@ -275,37 +261,35 @@ def attribute( # type: ignore
"""
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)

# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs, baselines = _format_input_baseline(inputs, baselines)
formatted_inputs, formatted_baselines = _format_input_baseline(
inputs, baselines
)

# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
_validate_input(inputs, baselines, n_steps, method)
_validate_input(formatted_inputs, formatted_baselines, n_steps, method)

if internal_batch_size is not None:
num_examples = inputs[0].shape[0]
num_examples = formatted_inputs[0].shape[0]
attributions = _batch_attribution(
self,
num_examples,
internal_batch_size,
n_steps,
inputs=inputs,
baselines=baselines,
inputs=formatted_inputs,
baselines=formatted_baselines,
target=target,
additional_forward_args=additional_forward_args,
method=method,
)
else:
attributions = self._attribute(
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
# got `TensorOrTupleOfTensorsGeneric`.
inputs=inputs,
baselines=baselines,
inputs=formatted_inputs,
baselines=formatted_baselines,
target=target,
additional_forward_args=additional_forward_args,
n_steps=n_steps,
Expand Down Expand Up @@ -344,8 +328,7 @@ def _attribute(
inputs: Tuple[Tensor, ...],
baselines: Tuple[Union[Tensor, int, float], ...],
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
n_steps: int = 50,
method: str = "gausslegendre",
step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None,
Expand Down

0 comments on commit a33c834

Please sign in to comment.