From 9a7ef2e92a97623566011748c1ec6a56d80708a8 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 23 Dec 2024 15:41:43 -0800 Subject: [PATCH] Fix deeplift mypy error (#1459) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1459 Currently, Captum OSS tests are failing due to mypy failures (likely from new version) in DeepLift test cases. Adds fix for type failure caused by different signature between DeepLift and DeepLiftShap. Reviewed By: cyrjano Differential Revision: D67538043 fbshipit-source-id: cc0236fa819c666c08e31d78a51f89e0807d9791 --- tests/attr/test_deeplift_classification.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/attr/test_deeplift_classification.py b/tests/attr/test_deeplift_classification.py index 85a9db00d..fac39040d 100644 --- a/tests/attr/test_deeplift_classification.py +++ b/tests/attr/test_deeplift_classification.py @@ -2,7 +2,7 @@ # pyre-unsafe -from typing import Union +from typing import TypeVar, Union import torch from captum._utils.typing import TargetType @@ -21,6 +21,8 @@ from torch import Tensor from torch.nn import Module +DeepLiftAttrMethod = TypeVar("DeepLiftAttrMethod", DeepLift, DeepLiftShap) + class Test(BaseTest): def test_sigmoid_classification(self) -> None: @@ -155,7 +157,7 @@ def test_convnet_with_maxpool1d_large_baselines(self) -> None: def softmax_classification( self, model: Module, - attr_method: Union[DeepLift, DeepLiftShap], + attr_method: DeepLiftAttrMethod, input: Tensor, baselines: Union[float, int, Tensor], target: TargetType,