From 14314721d7db03065ff48ef10fe078a5cce38047 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Tue, 22 Oct 2024 10:07:11 -0700 Subject: [PATCH] Add additional overload signatures for shared methods to resolve pyre errors (#1406) Summary: Add a few additional overload signatures to shared methods for resolving pyre errors Also remove separate cases for typing Literal since the split was necessary due to previous support for Python < 3.8 Reviewed By: csauper Differential Revision: D64677349 --- captum/_utils/common.py | 4 ++++ captum/_utils/typing.py | 16 +--------------- captum/attr/_utils/common.py | 21 +++++++++++++++++++++ 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/captum/_utils/common.py b/captum/_utils/common.py index f1b5fd9a7..6336fc4a8 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -86,6 +86,10 @@ def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ... def _is_tuple(inputs: Tensor) -> Literal[False]: ... +@typing.overload +def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ... + + def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool: return isinstance(inputs, tuple) diff --git a/captum/_utils/typing.py b/captum/_utils/typing.py index 598c031b2..538135003 100644 --- a/captum/_utils/typing.py +++ b/captum/_utils/typing.py @@ -2,25 +2,11 @@ # pyre-strict -from typing import ( - List, - Optional, - overload, - Protocol, - Tuple, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union from torch import Tensor from torch.nn import Module -if TYPE_CHECKING: - from typing import Literal -else: - Literal = {True: bool, False: bool, (True, False): bool, "pt": str} - TensorOrTupleOfTensorsGeneric = TypeVar( "TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...] ) diff --git a/captum/attr/_utils/common.py b/captum/attr/_utils/common.py index 92c1ccafb..9cb38b10d 100644 --- a/captum/attr/_utils/common.py +++ b/captum/attr/_utils/common.py @@ -82,6 +82,12 @@ def _format_input_baseline( # type: ignore ) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: ... +@typing.overload +def _format_input_baseline( # type: ignore + inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType +) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: ... + + def _format_input_baseline( inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType ) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: @@ -236,6 +242,21 @@ def _compute_conv_delta_and_format_attrs( ) -> Union[Tensor, Tuple[Tensor, Tensor]]: ... +@typing.overload +def _compute_conv_delta_and_format_attrs( + attr_algo: "GradientAttribution", + return_convergence_delta: bool, + attributions: Tuple[Tensor, ...], + start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]], + end_point: Union[Tensor, Tuple[Tensor, ...]], + additional_forward_args: Any, + target: TargetType, + is_inputs_tuple: bool = False, +) -> Union[ + Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor] +]: ... + + # FIXME: GradientAttribution is provided as a string due to a circular import. # This should be fixed when common is refactored into separate files. def _compute_conv_delta_and_format_attrs(