Skip to content

Commit

Permalink
Add additional overload signatures for shared methods to resolve pyre…
Browse files Browse the repository at this point in the history
… errors (pytorch#1406)

Summary:

Add a few additional overload signatures to shared methods for resolving pyre errors

Also remove separate cases for typing Literal since the split was necessary due to previous support for Python < 3.8

Reviewed By: csauper

Differential Revision: D64677349
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Oct 22, 2024
1 parent 1ba8977 commit 3c066f1
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 15 deletions.
4 changes: 4 additions & 0 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
def _is_tuple(inputs: Tensor) -> Literal[False]: ...


@typing.overload
def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ...


def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
return isinstance(inputs, tuple)

Expand Down
16 changes: 1 addition & 15 deletions captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,11 @@

# pyre-strict

from typing import (
List,
Optional,
overload,
Protocol,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union

from torch import Tensor
from torch.nn import Module

if TYPE_CHECKING:
from typing import Literal
else:
Literal = {True: bool, False: bool, (True, False): bool, "pt": str}

TensorOrTupleOfTensorsGeneric = TypeVar(
"TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...]
)
Expand Down
21 changes: 21 additions & 0 deletions captum/attr/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ def _format_input_baseline( # type: ignore
) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: ...


@typing.overload
def _format_input_baseline( # type: ignore
inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType
) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: ...


def _format_input_baseline(
inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType
) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]:
Expand Down Expand Up @@ -236,6 +242,21 @@ def _compute_conv_delta_and_format_attrs(
) -> Union[Tensor, Tuple[Tensor, Tensor]]: ...


@typing.overload
def _compute_conv_delta_and_format_attrs(
attr_algo: "GradientAttribution",
return_convergence_delta: bool,
attributions: Tuple[Tensor, ...],
start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]],
end_point: Union[Tensor, Tuple[Tensor, ...]],
additional_forward_args: Any,
target: TargetType,
is_inputs_tuple: bool = False,
) -> Union[
Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
]: ...


# FIXME: GradientAttribution is provided as a string due to a circular import.
# This should be fixed when common is refactored into separate files.
def _compute_conv_delta_and_format_attrs(
Expand Down

0 comments on commit 3c066f1

Please sign in to comment.