Skip to content

Commit

Permalink
Correct remaining typing.Literal imports (#1412)
Browse files Browse the repository at this point in the history
Summary:

Change remaining imports of `Literal` to be from `typing` library

Reviewed By: vivekmig

Differential Revision: D64807610
  • Loading branch information
craymichael authored and facebook-github-bot committed Oct 23, 2024
1 parent b80e488 commit ffee56d
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 143 deletions.
42 changes: 15 additions & 27 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,23 @@
from enum import Enum
from functools import reduce
from inspect import signature
from typing import Any, Callable, cast, Dict, List, overload, Sequence, Tuple, Union
from typing import (
Any,
Callable,
cast,
Dict,
List,
Literal,
overload,
Sequence,
Tuple,
Union,
)

import numpy as np
import torch
from captum._utils.typing import (
BaselineType,
Literal,
TargetType,
TensorOrTupleOfTensorsGeneric,
TupleOrTensorOrBoolGeneric,
Expand Down Expand Up @@ -71,23 +81,17 @@ def safe_div(


@typing.overload
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
# is incompatible with the return type of the implementation (`bool`).
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...


@typing.overload
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
# is incompatible with the return type of the implementation (`bool`).
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
def _is_tuple(inputs: Tensor) -> Literal[False]: ...


@typing.overload
def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ...
def _is_tuple(
inputs: TensorOrTupleOfTensorsGeneric,
) -> bool: ... # type: ignore


def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
Expand Down Expand Up @@ -480,22 +484,14 @@ def _expand_and_update_feature_mask(n_samples: int, kwargs: dict) -> None:


@typing.overload
# pyre-fixme[43]: The implementation of `_format_output` does not accept all
# possible arguments of overload defined on line `449`.
def _format_output(
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_inputs_tuple: Literal[True],
output: Tuple[Tensor, ...],
) -> Tuple[Tensor, ...]: ...


@typing.overload
# pyre-fixme[43]: The implementation of `_format_output` does not accept all
# possible arguments of overload defined on line `455`.
def _format_output(
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_inputs_tuple: Literal[False],
output: Tuple[Tensor, ...],
) -> Tensor: ...
Expand Down Expand Up @@ -526,22 +522,14 @@ def _format_output(


@typing.overload
# pyre-fixme[43]: The implementation of `_format_outputs` does not accept all
# possible arguments of overload defined on line `483`.
def _format_outputs(
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_multiple_inputs: Literal[False],
outputs: List[Tuple[Tensor, ...]],
) -> Union[Tensor, Tuple[Tensor, ...]]: ...


@typing.overload
# pyre-fixme[43]: The implementation of `_format_outputs` does not accept all
# possible arguments of overload defined on line `489`.
def _format_outputs(
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_multiple_inputs: Literal[True],
outputs: List[Tuple[Tensor, ...]],
) -> List[Union[Tensor, Tuple[Tensor, ...]]]: ...
Expand Down
20 changes: 12 additions & 8 deletions captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@
import typing
import warnings
from collections import defaultdict
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
from typing import (
Any,
Callable,
cast,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
)

import torch
from captum._utils.common import (
Expand All @@ -16,7 +27,6 @@
)
from captum._utils.sample_gradient import SampleGradientWrapper
from captum._utils.typing import (
Literal,
ModuleOrModuleList,
TargetType,
TensorOrTupleOfTensorsGeneric,
Expand Down Expand Up @@ -226,9 +236,6 @@ def _forward_layer_distributed_eval(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
attribute_to_layer_input: bool = False,
# pyre-fixme[9]: forward_hook_with_return has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
forward_hook_with_return: Literal[False] = False,
require_layer_grads: bool = False,
) -> Dict[Module, Dict[device, Tuple[Tensor, ...]]]: ...
Expand All @@ -246,8 +253,6 @@ def _forward_layer_distributed_eval(
additional_forward_args: Any = None,
attribute_to_layer_input: bool = False,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
forward_hook_with_return: Literal[True],
require_layer_grads: bool = False,
) -> Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor]: ...
Expand Down Expand Up @@ -675,7 +680,6 @@ def compute_layer_gradients_and_eval(
target_ind=target_ind,
additional_forward_args=additional_forward_args,
attribute_to_layer_input=attribute_to_layer_input,
# pyre-fixme[6]: For 7th argument expected `Literal[]` but got `bool`.
forward_hook_with_return=True,
require_layer_grads=True,
)
Expand Down
10 changes: 1 addition & 9 deletions captum/_utils/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import sys
import warnings
from time import time
from typing import Any, cast, Iterable, Optional, Sized, TextIO

from captum._utils.typing import Literal
from typing import Any, cast, Iterable, Literal, Optional, Sized, TextIO

try:
from tqdm.auto import tqdm
Expand Down Expand Up @@ -75,10 +73,7 @@ def __enter__(self) -> "NullProgress":
return self

# pyre-fixme[2]: Parameter must be annotated.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]:
# pyre-fixme[7]: Expected `Literal[]` but got `bool`.
return False

# pyre-fixme[3]: Return type must be annotated.
Expand Down Expand Up @@ -139,11 +134,8 @@ def __enter__(self) -> "SimpleProgress":
return self

# pyre-fixme[2]: Parameter must be annotated.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]:
self.close()
# pyre-fixme[7]: Expected `Literal[]` but got `bool`.
return False

# pyre-fixme[3]: Return type must be annotated.
Expand Down
9 changes: 2 additions & 7 deletions captum/attr/_core/layer/layer_conductance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict
import typing
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import torch
from captum._utils.common import (
Expand All @@ -12,7 +12,7 @@
_format_output,
)
from captum._utils.gradient import compute_layer_gradients_and_eval
from captum._utils.typing import BaselineType, Literal, TargetType
from captum._utils.typing import BaselineType, TargetType
from captum.attr._utils.approximation_methods import approximation_parameters
from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
from captum.attr._utils.batching import _batch_attribution
Expand Down Expand Up @@ -86,8 +86,6 @@ def attribute(
method: str = "gausslegendre",
internal_batch_size: Union[None, int] = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
attribute_to_layer_input: bool = False,
grad_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -105,9 +103,6 @@ def attribute(
n_steps: int = 50,
method: str = "gausslegendre",
internal_batch_size: Union[None, int] = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
attribute_to_layer_input: bool = False,
grad_kwargs: Optional[Dict[str, Any]] = None,
Expand Down
29 changes: 2 additions & 27 deletions captum/attr/_core/layer/layer_deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict
import typing
from typing import Any, Callable, cast, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, cast, Dict, Literal, Optional, Sequence, Tuple, Union

import torch
from captum._utils.common import (
Expand All @@ -13,12 +13,7 @@
ExpansionTypes,
)
from captum._utils.gradient import compute_layer_gradients_and_eval
from captum._utils.typing import (
BaselineType,
Literal,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.deep_lift import DeepLift, DeepLiftShap
from captum.attr._utils.attribution import LayerAttribution
from captum.attr._utils.common import (
Expand Down Expand Up @@ -101,8 +96,6 @@ def __init__(

# Ignoring mypy error for inconsistent signature with DeepLift
@typing.overload # type: ignore
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `117`.
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand All @@ -111,27 +104,20 @@ def attribute(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
attribute_to_layer_input: bool = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
grad_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `104`.
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
attribute_to_layer_input: bool = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
Expand Down Expand Up @@ -382,8 +368,6 @@ def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence:
inputs,
additional_forward_args,
target,
# pyre-fixme[31]: Expression `Literal[False])]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
cast(Union[Literal[True], Literal[False]], len(attributions) > 1),
)

Expand Down Expand Up @@ -464,8 +448,6 @@ def attribute(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
attribute_to_layer_input: bool = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
Expand All @@ -483,9 +465,6 @@ def attribute(
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
attribute_to_layer_input: bool = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
Expand Down Expand Up @@ -686,10 +665,6 @@ def attribute(
target=exp_target,
additional_forward_args=exp_addit_args,
return_convergence_delta=cast(
# pyre-fixme[31]: Expression `Literal[(True, False)]` is not a valid
# type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take
# parameters.
Literal[True, False],
return_convergence_delta,
),
Expand Down
Loading

0 comments on commit ffee56d

Please sign in to comment.