Skip to content

Commit

Permalink
Fix deeplift mypy error (#1459)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1459

Currently, Captum OSS tests are failing due to mypy failures (likely from new version) in DeepLift test cases. Adds fix for type failure caused by different signature between DeepLift and DeepLiftShap.

Reviewed By: cyrjano

Differential Revision: D67538043

fbshipit-source-id: cc0236fa819c666c08e31d78a51f89e0807d9791
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 23, 2024
1 parent ad6795b commit 9a7ef2e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tests/attr/test_deeplift_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-unsafe

from typing import Union
from typing import TypeVar, Union

import torch
from captum._utils.typing import TargetType
Expand All @@ -21,6 +21,8 @@
from torch import Tensor
from torch.nn import Module

DeepLiftAttrMethod = TypeVar("DeepLiftAttrMethod", DeepLift, DeepLiftShap)


class Test(BaseTest):
def test_sigmoid_classification(self) -> None:
Expand Down Expand Up @@ -155,7 +157,7 @@ def test_convnet_with_maxpool1d_large_baselines(self) -> None:
def softmax_classification(
self,
model: Module,
attr_method: Union[DeepLift, DeepLiftShap],
attr_method: DeepLiftAttrMethod,
input: Tensor,
baselines: Union[float, int, Tensor],
target: TargetType,
Expand Down

0 comments on commit 9a7ef2e

Please sign in to comment.