diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index 0f5be9381..19287b6ec 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -74,10 +74,8 @@ class FeaturePermutation(FeatureAblation): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - perm_func: Callable = _permute_feature, + forward_func: Callable[..., Union[int, float, Tensor, Future[Tensor]]], + perm_func: Callable[[Tensor, Tensor], Tensor] = _permute_feature, ) -> None: r""" Args: @@ -101,8 +99,7 @@ def attribute( # type: ignore self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, perturbations_per_eval: int = 1, show_progress: bool = False, @@ -283,8 +280,7 @@ def attribute_future( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, perturbations_per_eval: int = 1, show_progress: bool = False,