diff --git a/captum/_utils/common.py b/captum/_utils/common.py index f1b5fd9a7..6336fc4a8 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -86,6 +86,10 @@ def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ... def _is_tuple(inputs: Tensor) -> Literal[False]: ... +@typing.overload +def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ... + + def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool: return isinstance(inputs, tuple) diff --git a/captum/_utils/typing.py b/captum/_utils/typing.py index 598c031b2..538135003 100644 --- a/captum/_utils/typing.py +++ b/captum/_utils/typing.py @@ -2,25 +2,11 @@ # pyre-strict -from typing import ( - List, - Optional, - overload, - Protocol, - Tuple, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union from torch import Tensor from torch.nn import Module -if TYPE_CHECKING: - from typing import Literal -else: - Literal = {True: bool, False: bool, (True, False): bool, "pt": str} - TensorOrTupleOfTensorsGeneric = TypeVar( "TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...] ) diff --git a/captum/attr/_utils/common.py b/captum/attr/_utils/common.py index 92c1ccafb..9cb38b10d 100644 --- a/captum/attr/_utils/common.py +++ b/captum/attr/_utils/common.py @@ -82,6 +82,12 @@ def _format_input_baseline( # type: ignore ) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: ... +@typing.overload +def _format_input_baseline( # type: ignore + inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType +) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: ... + + def _format_input_baseline( inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType ) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: @@ -236,6 +242,21 @@ def _compute_conv_delta_and_format_attrs( ) -> Union[Tensor, Tuple[Tensor, Tensor]]: ... +@typing.overload +def _compute_conv_delta_and_format_attrs( + attr_algo: "GradientAttribution", + return_convergence_delta: bool, + attributions: Tuple[Tensor, ...], + start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]], + end_point: Union[Tensor, Tuple[Tensor, ...]], + additional_forward_args: Any, + target: TargetType, + is_inputs_tuple: bool = False, +) -> Union[ + Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor] +]: ... + + # FIXME: GradientAttribution is provided as a string due to a circular import. # This should be fixed when common is refactored into separate files. def _compute_conv_delta_and_format_attrs(