From ffee56d057a89d2885cb38aceb7477c7b019f017 Mon Sep 17 00:00:00 2001 From: Zach Carmichael Date: Wed, 23 Oct 2024 09:54:27 -0700 Subject: [PATCH] Correct remaining typing.Literal imports (#1412) Summary: Change remaining imports of `Literal` to be from `typing` library Reviewed By: vivekmig Differential Revision: D64807610 --- captum/_utils/common.py | 42 +++++++------------ captum/_utils/gradient.py | 20 +++++---- captum/_utils/progress.py | 10 +---- captum/attr/_core/layer/layer_conductance.py | 9 +--- captum/attr/_core/layer/layer_deep_lift.py | 29 +------------ .../attr/_core/layer/layer_gradient_shap.py | 20 +-------- .../_core/layer/layer_integrated_gradients.py | 12 +----- captum/attr/_core/layer/layer_lrp.py | 12 +----- captum/attr/_utils/common.py | 18 +------- tests/attr/helpers/attribution_delta_util.py | 2 +- tests/attr/layer/test_layer_lrp.py | 3 -- tests/attr/test_interpretable_input.py | 8 +--- 12 files changed, 42 insertions(+), 143 deletions(-) diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 6336fc4a8..0a9a42770 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -5,13 +5,23 @@ from enum import Enum from functools import reduce from inspect import signature -from typing import Any, Callable, cast, Dict, List, overload, Sequence, Tuple, Union +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Literal, + overload, + Sequence, + Tuple, + Union, +) import numpy as np import torch from captum._utils.typing import ( BaselineType, - Literal, TargetType, TensorOrTupleOfTensorsGeneric, TupleOrTensorOrBoolGeneric, @@ -71,23 +81,17 @@ def safe_div( @typing.overload -# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`) -# is incompatible with the return type of the implementation (`bool`). -# pyre-fixme[31]: Expression `Literal[True]` is not a valid type. -# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ... @typing.overload -# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`) -# is incompatible with the return type of the implementation (`bool`). -# pyre-fixme[31]: Expression `Literal[False]` is not a valid type. -# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. def _is_tuple(inputs: Tensor) -> Literal[False]: ... @typing.overload -def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ... +def _is_tuple( + inputs: TensorOrTupleOfTensorsGeneric, +) -> bool: ... # type: ignore def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool: @@ -480,22 +484,14 @@ def _expand_and_update_feature_mask(n_samples: int, kwargs: dict) -> None: @typing.overload -# pyre-fixme[43]: The implementation of `_format_output` does not accept all -# possible arguments of overload defined on line `449`. def _format_output( - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. is_inputs_tuple: Literal[True], output: Tuple[Tensor, ...], ) -> Tuple[Tensor, ...]: ... @typing.overload -# pyre-fixme[43]: The implementation of `_format_output` does not accept all -# possible arguments of overload defined on line `455`. def _format_output( - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. is_inputs_tuple: Literal[False], output: Tuple[Tensor, ...], ) -> Tensor: ... @@ -526,22 +522,14 @@ def _format_output( @typing.overload -# pyre-fixme[43]: The implementation of `_format_outputs` does not accept all -# possible arguments of overload defined on line `483`. def _format_outputs( - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. is_multiple_inputs: Literal[False], outputs: List[Tuple[Tensor, ...]], ) -> Union[Tensor, Tuple[Tensor, ...]]: ... @typing.overload -# pyre-fixme[43]: The implementation of `_format_outputs` does not accept all -# possible arguments of overload defined on line `489`. def _format_outputs( - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. is_multiple_inputs: Literal[True], outputs: List[Tuple[Tensor, ...]], ) -> List[Union[Tensor, Tuple[Tensor, ...]]]: ... diff --git a/captum/_utils/gradient.py b/captum/_utils/gradient.py index cc74ef92c..2dab8154d 100644 --- a/captum/_utils/gradient.py +++ b/captum/_utils/gradient.py @@ -5,7 +5,18 @@ import typing import warnings from collections import defaultdict -from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, +) import torch from captum._utils.common import ( @@ -16,7 +27,6 @@ ) from captum._utils.sample_gradient import SampleGradientWrapper from captum._utils.typing import ( - Literal, ModuleOrModuleList, TargetType, TensorOrTupleOfTensorsGeneric, @@ -226,9 +236,6 @@ def _forward_layer_distributed_eval( # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, attribute_to_layer_input: bool = False, - # pyre-fixme[9]: forward_hook_with_return has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. forward_hook_with_return: Literal[False] = False, require_layer_grads: bool = False, ) -> Dict[Module, Dict[device, Tuple[Tensor, ...]]]: ... @@ -246,8 +253,6 @@ def _forward_layer_distributed_eval( additional_forward_args: Any = None, attribute_to_layer_input: bool = False, *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. forward_hook_with_return: Literal[True], require_layer_grads: bool = False, ) -> Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor]: ... @@ -675,7 +680,6 @@ def compute_layer_gradients_and_eval( target_ind=target_ind, additional_forward_args=additional_forward_args, attribute_to_layer_input=attribute_to_layer_input, - # pyre-fixme[6]: For 7th argument expected `Literal[]` but got `bool`. forward_hook_with_return=True, require_layer_grads=True, ) diff --git a/captum/_utils/progress.py b/captum/_utils/progress.py index 47e391735..2e025006c 100644 --- a/captum/_utils/progress.py +++ b/captum/_utils/progress.py @@ -5,9 +5,7 @@ import sys import warnings from time import time -from typing import Any, cast, Iterable, Optional, Sized, TextIO - -from captum._utils.typing import Literal +from typing import Any, cast, Iterable, Literal, Optional, Sized, TextIO try: from tqdm.auto import tqdm @@ -75,10 +73,7 @@ def __enter__(self) -> "NullProgress": return self # pyre-fixme[2]: Parameter must be annotated. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]: - # pyre-fixme[7]: Expected `Literal[]` but got `bool`. return False # pyre-fixme[3]: Return type must be annotated. @@ -139,11 +134,8 @@ def __enter__(self) -> "SimpleProgress": return self # pyre-fixme[2]: Parameter must be annotated. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]: self.close() - # pyre-fixme[7]: Expected `Literal[]` but got `bool`. return False # pyre-fixme[3]: Return type must be annotated. diff --git a/captum/attr/_core/layer/layer_conductance.py b/captum/attr/_core/layer/layer_conductance.py index dc74a76c9..54ec6fdb2 100644 --- a/captum/attr/_core/layer/layer_conductance.py +++ b/captum/attr/_core/layer/layer_conductance.py @@ -2,7 +2,7 @@ # pyre-strict import typing -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -12,7 +12,7 @@ _format_output, ) from captum._utils.gradient import compute_layer_gradients_and_eval -from captum._utils.typing import BaselineType, Literal, TargetType +from captum._utils.typing import BaselineType, TargetType from captum.attr._utils.approximation_methods import approximation_parameters from captum.attr._utils.attribution import GradientAttribution, LayerAttribution from captum.attr._utils.batching import _batch_attribution @@ -86,8 +86,6 @@ def attribute( method: str = "gausslegendre", internal_batch_size: Union[None, int] = None, *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], attribute_to_layer_input: bool = False, grad_kwargs: Optional[Dict[str, Any]] = None, @@ -105,9 +103,6 @@ def attribute( n_steps: int = 50, method: str = "gausslegendre", internal_batch_size: Union[None, int] = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, attribute_to_layer_input: bool = False, grad_kwargs: Optional[Dict[str, Any]] = None, diff --git a/captum/attr/_core/layer/layer_deep_lift.py b/captum/attr/_core/layer/layer_deep_lift.py index 2c4c10bbf..85d81cd5e 100644 --- a/captum/attr/_core/layer/layer_deep_lift.py +++ b/captum/attr/_core/layer/layer_deep_lift.py @@ -2,7 +2,7 @@ # pyre-strict import typing -from typing import Any, Callable, cast, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, cast, Dict, Literal, Optional, Sequence, Tuple, Union import torch from captum._utils.common import ( @@ -13,12 +13,7 @@ ExpansionTypes, ) from captum._utils.gradient import compute_layer_gradients_and_eval -from captum._utils.typing import ( - BaselineType, - Literal, - TargetType, - TensorOrTupleOfTensorsGeneric, -) +from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._core.deep_lift import DeepLift, DeepLiftShap from captum.attr._utils.attribution import LayerAttribution from captum.attr._utils.common import ( @@ -101,8 +96,6 @@ def __init__( # Ignoring mypy error for inconsistent signature with DeepLift @typing.overload # type: ignore - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `117`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -111,8 +104,6 @@ def attribute( # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], attribute_to_layer_input: bool = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, @@ -120,8 +111,6 @@ def attribute( ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `104`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -129,9 +118,6 @@ def attribute( target: TargetType = None, # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, attribute_to_layer_input: bool = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, @@ -382,8 +368,6 @@ def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence: inputs, additional_forward_args, target, - # pyre-fixme[31]: Expression `Literal[False])]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. cast(Union[Literal[True], Literal[False]], len(attributions) > 1), ) @@ -464,8 +448,6 @@ def attribute( # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], attribute_to_layer_input: bool = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, @@ -483,9 +465,6 @@ def attribute( target: TargetType = None, # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, attribute_to_layer_input: bool = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, @@ -686,10 +665,6 @@ def attribute( target=exp_target, additional_forward_args=exp_addit_args, return_convergence_delta=cast( - # pyre-fixme[31]: Expression `Literal[(True, False)]` is not a valid - # type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take - # parameters. Literal[True, False], return_convergence_delta, ), diff --git a/captum/attr/_core/layer/layer_gradient_shap.py b/captum/attr/_core/layer/layer_gradient_shap.py index 8c94c13b5..dcfe109fa 100644 --- a/captum/attr/_core/layer/layer_gradient_shap.py +++ b/captum/attr/_core/layer/layer_gradient_shap.py @@ -3,12 +3,12 @@ # pyre-strict import typing -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Literal, Optional, Tuple, Union import numpy as np import torch from captum._utils.gradient import _forward_layer_eval, compute_layer_gradients_and_eval -from captum._utils.typing import Literal, TargetType, TensorOrTupleOfTensorsGeneric +from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._core.gradient_shap import _scale_input from captum.attr._core.noise_tunnel import NoiseTunnel from captum.attr._utils.attribution import GradientAttribution, LayerAttribution @@ -117,8 +117,6 @@ def attribute( # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], attribute_to_layer_input: bool = False, ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... @@ -135,9 +133,6 @@ def attribute( stdevs: Union[float, Tuple[float, ...]] = 0.0, target: TargetType = None, additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, attribute_to_layer_input: bool = False, ) -> Union[Tensor, Tuple[Tensor, ...]]: ... @@ -392,8 +387,6 @@ def __init__( self._multiply_by_inputs = multiply_by_inputs @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `385`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -402,16 +395,12 @@ def attribute( # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], attribute_to_layer_input: bool = False, grad_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `373`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -419,9 +408,6 @@ def attribute( target: TargetType = None, # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, attribute_to_layer_input: bool = False, grad_kwargs: Optional[Dict[str, Any]] = None, @@ -505,8 +491,6 @@ def attribute( # type: ignore inputs, additional_forward_args, target, - # pyre-fixme[31]: Expression `Literal[False])]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. cast(Union[Literal[True], Literal[False]], len(attributions) > 1), ) diff --git a/captum/attr/_core/layer/layer_integrated_gradients.py b/captum/attr/_core/layer/layer_integrated_gradients.py index 146c5c552..474b7d2fc 100644 --- a/captum/attr/_core/layer/layer_integrated_gradients.py +++ b/captum/attr/_core/layer/layer_integrated_gradients.py @@ -3,7 +3,7 @@ # pyre-strict import functools import warnings -from typing import Any, Callable, cast, List, overload, Tuple, Union +from typing import Any, Callable, cast, List, Literal, overload, Tuple, Union import torch from captum._utils.common import ( @@ -12,7 +12,7 @@ _format_outputs, ) from captum._utils.gradient import _forward_layer_eval, _run_forward -from captum._utils.typing import BaselineType, Literal, ModuleOrModuleList, TargetType +from captum._utils.typing import BaselineType, ModuleOrModuleList, TargetType from captum.attr._core.integrated_gradients import IntegratedGradients from captum.attr._utils.attribution import GradientAttribution, LayerAttribution from captum.attr._utils.common import ( @@ -110,8 +110,6 @@ def __init__( ) @overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `112`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -122,15 +120,11 @@ def attribute( n_steps: int, method: str, internal_batch_size: Union[None, int], - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False], attribute_to_layer_input: bool, ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: ... @overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `126`. def attribute( # type: ignore self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -141,8 +135,6 @@ def attribute( # type: ignore n_steps: int, method: str, internal_batch_size: Union[None, int], - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], attribute_to_layer_input: bool, ) -> Tuple[ diff --git a/captum/attr/_core/layer/layer_lrp.py b/captum/attr/_core/layer/layer_lrp.py index 705cb2a91..cd774d855 100644 --- a/captum/attr/_core/layer/layer_lrp.py +++ b/captum/attr/_core/layer/layer_lrp.py @@ -2,7 +2,7 @@ # pyre-strict import typing -from typing import Any, cast, List, Tuple, Union +from typing import Any, cast, List, Literal, Tuple, Union from captum._utils.common import ( _format_tensor_into_tuples, @@ -15,7 +15,6 @@ undo_gradient_requirements, ) from captum._utils.typing import ( - Literal, ModuleOrModuleList, TargetType, TensorOrTupleOfTensorsGeneric, @@ -64,8 +63,6 @@ def __init__(self, model: Module, layer: ModuleOrModuleList) -> None: self.device_ids = cast(List[int], self.model.device_ids) @typing.overload # type: ignore - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `77`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -73,8 +70,6 @@ def attribute( # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], attribute_to_layer_input: bool = False, verbose: bool = False, @@ -84,17 +79,12 @@ def attribute( ]: ... @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `66`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, attribute_to_layer_input: bool = False, verbose: bool = False, diff --git a/captum/attr/_utils/common.py b/captum/attr/_utils/common.py index 9cb38b10d..09889cd52 100644 --- a/captum/attr/_utils/common.py +++ b/captum/attr/_utils/common.py @@ -3,7 +3,7 @@ # pyre-strict import typing from inspect import signature -from typing import Any, Callable, List, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, List, Literal, Tuple, TYPE_CHECKING, Union import torch from captum._utils.common import ( @@ -12,12 +12,7 @@ _format_tensor_into_tuples, _validate_input as _validate_input_basic, ) -from captum._utils.typing import ( - BaselineType, - Literal, - TargetType, - TensorOrTupleOfTensorsGeneric, -) +from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.approximation_methods import SUPPORTED_METHODS from torch import Tensor @@ -206,8 +201,6 @@ def _format_and_verify_sliding_window_shapes( @typing.overload -# pyre-fixme[43]: The implementation of `_compute_conv_delta_and_format_attrs` does -# not accept all possible arguments of overload defined on line `212`. def _compute_conv_delta_and_format_attrs( attr_algo: "GradientAttribution", return_convergence_delta: bool, @@ -217,15 +210,11 @@ def _compute_conv_delta_and_format_attrs( # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any, target: TargetType, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. is_inputs_tuple: Literal[True], ) -> Union[Tuple[Tensor, ...], Tuple[Tuple[Tensor, ...], Tensor]]: ... @typing.overload -# pyre-fixme[43]: The implementation of `_compute_conv_delta_and_format_attrs` does -# not accept all possible arguments of overload defined on line `199`. def _compute_conv_delta_and_format_attrs( attr_algo: "GradientAttribution", return_convergence_delta: bool, @@ -235,9 +224,6 @@ def _compute_conv_delta_and_format_attrs( # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any, target: TargetType, - # pyre-fixme[9]: is_inputs_tuple has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. is_inputs_tuple: Literal[False] = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: ... diff --git a/tests/attr/helpers/attribution_delta_util.py b/tests/attr/helpers/attribution_delta_util.py index 3640cbdc9..4fcdc09ed 100644 --- a/tests/attr/helpers/attribution_delta_util.py +++ b/tests/attr/helpers/attribution_delta_util.py @@ -4,8 +4,8 @@ from typing import Tuple, Union import torch -from captum._utils.typing import Tensor from tests.helpers import BaseTest +from torch import Tensor def assert_attribution_delta( diff --git a/tests/attr/layer/test_layer_lrp.py b/tests/attr/layer/test_layer_lrp.py index acc3aa064..ccc56377e 100644 --- a/tests/attr/layer/test_layer_lrp.py +++ b/tests/attr/layer/test_layer_lrp.py @@ -65,7 +65,6 @@ def test_lrp_basic_attributions(self) -> None: relevance, delta = lrp.attribute( # type: ignore inputs, classIndex.item(), - # pyre-fixme[6]: For 3rd argument expected `Literal[]` but got `bool`. return_convergence_delta=True, ) assertTensorAlmostEqual( @@ -82,7 +81,6 @@ def test_lrp_simple_attributions(self) -> None: relevance_upper, delta = lrp_upper.attribute( inputs, attribute_to_layer_input=True, - # pyre-fixme[6]: For 3rd argument expected `Literal[]` but got `bool`. return_convergence_delta=True, ) lrp_lower = LayerLRP(model, model.linear) @@ -185,7 +183,6 @@ def test_lrp_simple_attributions_all_layers_delta(self) -> None: relevance, delta = lrp.attribute( inputs, attribute_to_layer_input=True, - # pyre-fixme[6]: For 3rd argument expected `Literal[]` but got `bool`. return_convergence_delta=True, ) self.assertEqual(len(relevance), len(delta)) diff --git a/tests/attr/test_interpretable_input.py b/tests/attr/test_interpretable_input.py index 0550b3562..085813b09 100644 --- a/tests/attr/test_interpretable_input.py +++ b/tests/attr/test_interpretable_input.py @@ -2,10 +2,9 @@ # pyre-unsafe -from typing import List, Optional, overload, Union +from typing import List, Literal, Optional, overload, Union import torch -from captum._utils.typing import Literal from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput from parameterized import parameterized from tests.helpers import BaseTest @@ -22,10 +21,7 @@ def __init__(self, vocab_list) -> None: @overload def encode(self, text: str, return_tensors: None = None) -> List[int]: ... @overload - # pyre-fixme[43]: Incompatible overload. The implementation of - # `DummyTokenizer.encode` does not accept all possible arguments of overload. - # pyre-ignore[11]: Annotation `pt` is not defined as a type - def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ... # type: ignore # noqa: E501 line too long + def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ... def encode( self, text: str, return_tensors: Optional[str] = "pt"