diff --git a/captum/attr/_core/occlusion.py b/captum/attr/_core/occlusion.py index 33c153110..62ac38e84 100644 --- a/captum/attr/_core/occlusion.py +++ b/captum/attr/_core/occlusion.py @@ -37,8 +37,7 @@ class Occlusion(FeatureAblation): /tensorflow/methods.py#L401 """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def __init__(self, forward_func: Callable) -> None: + def __init__(self, forward_func: Callable[..., Tensor]) -> None: r""" Args: @@ -58,8 +57,7 @@ def attribute( # type: ignore ] = None, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, perturbations_per_eval: int = 1, show_progress: bool = False, ) -> TensorOrTupleOfTensorsGeneric: @@ -377,9 +375,7 @@ def _occlusion_mask( padded_tensor = torch.nn.functional.pad( sliding_window_tsr, tuple(pad_values) # type: ignore ) - # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]` and - # `Size`. - return padded_tensor.reshape((1,) + padded_tensor.shape) + return padded_tensor.reshape((1,) + tuple(padded_tensor.shape)) def _get_feature_range_and_mask( self, input: Tensor, input_mask: Optional[Tensor], **kwargs: Any @@ -389,8 +385,7 @@ def _get_feature_range_and_mask( def _get_feature_counts( self, - # pyre-fixme[2]: Parameter must be annotated. - inputs, + inputs: TensorOrTupleOfTensorsGeneric, feature_mask: Tuple[Tensor, ...], **kwargs: Any, ) -> Tuple[int, ...]: