diff --git a/tests/attr/test_deeplift_classification.py b/tests/attr/test_deeplift_classification.py index 85a9db00d..fac39040d 100644 --- a/tests/attr/test_deeplift_classification.py +++ b/tests/attr/test_deeplift_classification.py @@ -2,7 +2,7 @@ # pyre-unsafe -from typing import Union +from typing import TypeVar, Union import torch from captum._utils.typing import TargetType @@ -21,6 +21,8 @@ from torch import Tensor from torch.nn import Module +DeepLiftAttrMethod = TypeVar("DeepLiftAttrMethod", DeepLift, DeepLiftShap) + class Test(BaseTest): def test_sigmoid_classification(self) -> None: @@ -155,7 +157,7 @@ def test_convnet_with_maxpool1d_large_baselines(self) -> None: def softmax_classification( self, model: Module, - attr_method: Union[DeepLift, DeepLiftShap], + attr_method: DeepLiftAttrMethod, input: Tensor, baselines: Union[float, int, Tensor], target: TargetType,