Skip to content

Commit

Permalink
Fix pyre errors in FeaturePermutation (pytorch#1393)
Browse files Browse the repository at this point in the history
Summary:

Initial work on fixing Pyre errors in Feature Permutation

Reviewed By: jjuncho

Differential Revision: D64677344
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Oct 22, 2024
1 parent 0dba8ad commit 463068a
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions captum/attr/_core/feature_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,8 @@ class FeaturePermutation(FeatureAblation):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
perm_func: Callable = _permute_feature,
forward_func: Callable[..., Union[int, float, Tensor, Future[Tensor]]],
perm_func: Callable[[Tensor, Tensor], Tensor] = _permute_feature,
) -> None:
r"""
Args:
Expand All @@ -101,8 +99,7 @@ def attribute( # type: ignore
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
Expand Down Expand Up @@ -283,8 +280,7 @@ def attribute_future(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
Expand Down

0 comments on commit 463068a

Please sign in to comment.